In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torchvision import transforms

#from torchmetrics.functional import dice
from glob import glob
# import cv2
#from torchmetrics.classification import Dice
#from torchmetrics.classification import BinaryF1Score
import torchvision.transforms as T
from torch.optim import Adam
from tqdm import tqdm
import os
from PIL import Image
#import wandb used to parameter tuning
#from torchvision.transforms import v2
import time
from torchmetrics import JaccardIndex
import random
from torch.utils.data import Subset
import copy

In [None]:
seed_value = 42

torch.manual_seed(seed_value)
np.random.seed(seed_value)
random.seed(seed_value)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed_value)

In [None]:

class CarSegData(Dataset):
    def __init__(self, data_root, transform=None):
        self.data_root = data_root
        self.transform = transform
        self.class_labels = {
            10: 1,
            20: 2,
            30: 3,
            40: 4,
            50: 5,
            60: 6,
            70: 7,
            80: 8,
            90: 0
        }
        self.classes = {
            1: "hood",
            2: "front door",
            3: "rear door",
            4: "frame",
            5: "rear quarter panel",
            6: "trunk lid",
            7: "fender",
            8: "bumper",
            9: "rest of car"
        }

        # List all the array files in the 'arrays' directory
        self.array_files = np.load(data_root)

    def __len__(self):
        return np.shape(self.array_files)[0]

    def __getitem__(self, idx):
        array_data = self.array_files[idx,:,:,:]
        image_data = array_data[:,:,:3]
        target_data = array_data[:,:,3]

        # Convert target data to class labels
        target_data = self.map_to_class_labels(target_data)
        target_data = self.map_to_classes(target_data)

        # Convert to PIL image
        image = Image.fromarray(image_data.astype('uint8'))

        if self.transform:
            image = self.transform(image)

        return image, target_data

    def map_to_classes(self, target_data):
        class_labels = np.zeros_like(target_data)
        for class_value, class_name in self.classes.items():
            class_labels[target_data == class_value] = class_value
        return torch.from_numpy(class_labels)
    
    def map_to_class_labels(self, target_data):
        class_labels = np.zeros_like(target_data)
        for old_label, new_label in self.class_labels.items():
            class_labels[target_data == old_label] = new_label
        return torch.from_numpy(class_labels)

In [None]:
class EncoderBlock(nn.Module):
    
    def __init__(self,in_channels, out_channels,device, kernel_size = 4, stride = 2, padding = 1, norm = True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, device = device)
        self.act = nn.LeakyReLU(0.2, inplace=False) #The choice 0.2 is from the paper
        
        self.use_norm= norm
        if norm:
            self.bn = nn.BatchNorm2d(out_channels, device = device)
        else:
            self.bn = None
    def forward(self,x):
        x = self.conv(x)
        
        if self.use_norm:
            x = self.bn(x)
        x = self.act(x)
        return x

    
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, device,
                kernel_size = 4, stride = 2, padding = 1,dropout = False):
        super().__init__()
        self.conv = nn.ConvTranspose2d(in_channels,out_channels,kernel_size, stride, padding,device = device)
    
        self.act = nn.ReLU(inplace = False)
        self.bn = nn.BatchNorm2d(out_channels, device = device)
    

        if dropout is not None:
            self.dropout = nn.Dropout2d(p = 0.5, inplace = False) # p = 0.5 is from the paper
        else:
            self.dropout = None
        
    def forward(self,x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        
        if self.dropout is not None:
            x = self.dropout(x)
        
        return x    

In [None]:
#Encoder
#C64-C128-C256-C512-C512-C512-C512-C512
# 1    2    3    4    5    6   7    8
#CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128
class UNet_Generator(nn.Module):
    def __init__(self,device, input_channels= 3, out_channels = 9):
        super().__init__()
        #Encoder
        self.EB1 = EncoderBlock(input_channels,64,  norm=False, device = device)
        self.EB2 = EncoderBlock(64,128, device = device)
        self.EB3 = EncoderBlock(128,256, device = device)
        self.EB4 = EncoderBlock(256,512, device = device)
        self.EB5 = EncoderBlock(512,512, device = device)
        self.EB6 = EncoderBlock(512,512, device = device)
        self.EB7 = EncoderBlock(512,512, device = device)
        self.EB8 = EncoderBlock(512,512, norm = False, device = device)
        
        #Decoder
        self.DB8 = DecoderBlock(512,512,dropout=True, device = device)
        self.DB7 = DecoderBlock(2*512,512,dropout=True, device = device)
        self.DB6 = DecoderBlock(2*512,512,dropout=True, device = device)
        self.DB5 = DecoderBlock(2*512,512,device = device)
        self.DB4 = DecoderBlock(2*512,256,device = device)
        self.DB3 = DecoderBlock(2*256,128,device = device)
        self.DB2 = DecoderBlock(2*128,64,device = device)
        self.DB1 = nn.ConvTranspose2d(2*64, out_channels, kernel_size=4, stride=2, padding=1, device= device)
    
    def forward(self,x):
        #Encoder
        e1 = self.EB1(x)
        e2 = self.EB2(e1)
        e3 = self.EB3(e2)
        e4 = self.EB4(e3)
        e5 = self.EB5(e4)
        e6 = self.EB6(e5)
        e7 = self.EB7(e6)
        e8 = self.EB8(e7)
        
        #Decoder
        s8 = self.DB8(e8)
        s7 = self.DB7(torch.cat([s8,e7], dim = 1)) #Add skip connections
        s6 = self.DB6(torch.cat([s7,e6], dim = 1)) #Add skip connections
        s5 = self.DB5(torch.cat([s6,e5], dim = 1)) #Add skip connections
        s4 = self.DB4(torch.cat([s5,e4], dim = 1)) #Add skip connections
        s3 = self.DB3(torch.cat([s4,e3], dim = 1)) #Add skip connections
        s2 = self.DB2(torch.cat([s3,e2], dim = 1))  #Add skip connections
        s1 = self.DB1(torch.cat([s2,e1], dim = 1)) #Add skip connections
                
        return s1

In [None]:
#C64-C128-C256-C512
class DiscriminatorBlock(nn.Module):
    
    def __init__(self,in_channels, out_channels,device, kernel_size = 4, stride = 2, padding = 1, norm = True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, device = device)
        self.act = nn.LeakyReLU(0.2, inplace=True) #The choice 0.2 is from the paper
        
        self.use_norm= norm
        if norm:
            self.bn = nn.BatchNorm2d(out_channels, device = device)
        else:
            self.bn = None
    def forward(self,x):
        x = self.conv(x)
        if self.use_norm:
            x = self.bn(x)
        x = self.act(x)
        return x
        

class PatchGan_Discriminator(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.l1 = DiscriminatorBlock(3 + 1,64,norm=False, device= device)
        self.l2 = DiscriminatorBlock(64,128, device= device)
        self.l3 = DiscriminatorBlock(128,256, device= device)
        self.l4 = DiscriminatorBlock(256,512, device= device)
        self.l5 = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1, device= device)
    
    def forward(self,mask,image):
        #The discrimator is condition on the true image
        if mask.shape[1] > 1:
            mask = masker(mask)
        x = torch.cat([mask,image], dim = 1)
        x = self.l1(x)
        x = self.l2(x)
        x = self.l3(x)
        x = self.l4(x)
        x = self.l5(x)
        # Last output will be a value between 0 and 1
        x = torch.sigmoid(x) 
        return x
    

In [None]:
def masker(fake_mask):
    fake_mask = fake_mask.float()
    fake_mask = torch.softmax(fake_mask, dim = 1)
    N,_,H,W = fake_mask.shape
    return torch.argmax(fake_mask, dim=1).reshape(N,1,H,W)

def focal_loss(y_true, y_pred, alpha=[1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0], gamma=2.0):
    ce_loss = F.cross_entropy(y_pred, y_true, reduction='none')
    pt = torch.exp(-ce_loss)
    class_weights = torch.tensor(alpha).to(y_true.device)
    class_weights = torch.softmax(class_weights, dim = 0)
    alpha_t = class_weights[y_true.data.view(-1).long()].view_as(y_true)
    focal_loss = (alpha_t * (1 - pt) ** gamma * ce_loss).mean()
    return focal_loss

def generator_loss(d_out,fake_mask, true_mask, lambda_ = 100, loss = "CE", with_logits = False):
    assert loss in ["L1","CE", "FL","WCE"]
    #We "trick" the discrimintator
    fake_target = torch.ones_like(d_out)
    if with_logits:
        loss_fn = F.binary_cross_entropy_with_logits
    else:
        loss_fn = F.binary_cross_entropy
    adv_loss = loss_fn(d_out,fake_target)
    if loss == "L1":
        # Change 10 dim to 1 (choose labels)
        fake_mask = masker(fake_mask)
        recon_loss = F.l1_loss(fake_mask.float(),true_mask.float())
    
    if loss == "CE":
        N,C,H,W = fake_mask.shape
        y_fake  = fake_mask.reshape(N,C,H*W)
        y       = true_mask.reshape(N,H*W).long()
        recon_loss = F.cross_entropy(y_fake,y)
        
    if loss == "WCE":
        weights = [1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0]
        weights = torch.tensor(weights).to(fake_mask.device)
        weights = torch.softmax(weights, dim = 0)
        N,C,H,W = fake_mask.shape
        y_fake  = fake_mask.reshape(N,C,H*W)
        y       = true_mask.reshape(N,H*W).long()
        recon_loss = F.cross_entropy(y_fake,y, weight=weights)
        
    if loss == "FL":
        #alpha, gamma = FL_params
        N,C,H,W = fake_mask.shape
        fake_mask = torch.softmax(fake_mask, dim=1) # Make values prob
        y_fake  = fake_mask.reshape(N,C,H*W)
        y       = true_mask.reshape(N,H*W).long()

        recon_loss = focal_loss(y, y_fake)
        
    return adv_loss + lambda_*recon_loss


def discriminator_loss(real_pred,fake_pred, slow_down = 2):
    
    real_labels = torch.ones_like(real_pred)
    fake_labels = torch.zeros_like(fake_pred)
    
    
    fake_loss = F.binary_cross_entropy(fake_pred,fake_labels)
    real_loss = F.binary_cross_entropy(real_pred,real_labels)
    
    
    #Divide loss with 2 to slow learning
    return (real_loss + fake_loss)/slow_down
    
    
def combined_loss(gen_loss, disc_loss, gen_weight=0.5, disc_weight=0.5):
    return gen_weight * (gen_loss / (gen_loss + disc_loss)) + disc_weight * (disc_loss / (gen_loss + disc_loss))


In [None]:
PATH = "/kaggle/input/data-deloitte/Prossed_data_train_ny.npy"


generator_seed = torch.Generator().manual_seed(42)


transform = transforms.Compose([
    transforms.ToTensor()])
dataset = CarSegData(data_root = PATH, transform = transform )
# Get the total length of the dataset
dataset_length = len(dataset)

# Get the indices of the last 138 elements
last_indices = range(dataset_length - 138, dataset_length)

# Create a new dataset with only the last 138 elements
dataset2 = Subset(dataset, last_indices)


train_set, val_set = torch.utils.data.random_split(dataset2, [130, 8], generator = generator_seed)

#train_set = CarSegData(data_root='/kaggle/input/10percent/train10.npy', transform=transform)
#val_set = CarSegData(data_root='/kaggle/input/10percent/val_real.npy', transform=transform)
#print(len(train_set))

## Training

In [None]:
num_epochs = 1000

best_val = 1000
best_mIoU = 0
loss = "FL"


wd = 0

BATCH_SIZE = 1
LAMBDA= 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 9


generator = UNet_Generator(device, out_channels = num_classes).to(device)
discriminator = PatchGan_Discriminator(device).to(device)

ge_loss = 0
de_loss = 0

train_loss_g = []
train_loss_d = []
IoU_loss = []

generator.train()
discriminator.train()

batch_size = BATCH_SIZE #paper 1


num_workers = 2

# Define Genertor and Discrimantor optimizers
# Define separate learning rates
gen_learning_rate  = 0.001    # adjust this value as needed
disc_learning_rate = 0.00002  # adjust this value as needed

# Define separate optimizers
g_optim = torch.optim.Adam(generator.parameters(), lr=gen_learning_rate, betas=(0.5,0.999), weight_decay=wd) 
d_optim = torch.optim.Adam(discriminator.parameters(), lr=disc_learning_rate,betas=(0.5,0.999),  weight_decay=wd)

mean_losses = np.zeros(shape = (3,num_epochs))
mean_val_losses = np.zeros(shape = (2,num_epochs)) #Only generator loss and IoU


metric = JaccardIndex(task="multiclass", num_classes=num_classes)
metric = metric.to(device)

train_loader = DataLoader(train_set,batch_size = batch_size, shuffle = True, worker_init_fn=42)
val_loader = DataLoader(val_set,batch_size = batch_size, shuffle = True, worker_init_fn=42)

In [None]:
for epoch in range(num_epochs):
    epoch_loss_g = []
    epoch_loss_d = []
    epoch_loss_IoU = []

    start_time = time.time()  # Start time
    for batch in train_loader:

        img, mask = batch
        img = img.to(device)
        mask = mask.to(device)
        mask = mask.unsqueeze(1)


        #Train Discriminator

        fake_mask = generator(img)

        fake_pred = discriminator(fake_mask.detach(),img)
        real_pred = discriminator(mask,img)

        d_loss = discriminator_loss(real_pred,fake_pred)
        #Update discriminator
        d_optim.zero_grad()
        d_loss.backward()
        d_optim.step()

        #Train Generator
        fake_pred = discriminator(fake_mask,img)
        g_loss = generator_loss(fake_pred,fake_mask,mask,lambda_ = 100, loss = loss)

        #Update Generator
        g_optim.zero_grad()
        g_loss.backward()
        g_optim.step()


        ge_loss += g_loss.item()
        de_loss += d_loss.item()

        IoU = metric(masker(fake_mask),mask)

        epoch_loss_g.append(g_loss.item())
        epoch_loss_d.append(d_loss.item())
        epoch_loss_IoU.append(IoU.item())

        mean_losses[0][epoch] = np.mean(epoch_loss_g)
        mean_losses[1][epoch] = np.mean(epoch_loss_d)
        mean_losses[2][epoch] = np.mean(epoch_loss_IoU)


        end_time = time.time()  # End time
        epoch_time = end_time - start_time  # Time taken for the epoch

    #Validation for each epoch
    mIoU_val = []
    val_loss = []
    for img, mask in val_loader:
        img = img.to(device)
        mask = mask.to(device)
        mask = mask.unsqueeze(1)
        softmax = nn.Softmax(dim=1)
        
        #Get generator validation loss

        fake_mask = generator(img)
        fake_pred = discriminator(fake_mask.detach(),img)
        g_loss_val = generator_loss(fake_pred,fake_mask,mask,lambda_ = 100, loss = loss)
        
        fake_mask = torch.argmax(softmax(fake_mask),axis=1)
        fake_mask = fake_mask.unsqueeze(1)
        IoU = metric(fake_mask,mask)
    
        mIoU_val.append(IoU.item())
        val_loss.append(g_loss_val.item())
    
    mIoU_val = np.mean(mIoU_val)
    val_loss = np.mean(val_loss)
    
    #Save model with lowest val loss
    if val_loss < best_val:
        model_val = {'epoch': epoch+1,
                     'generator' : generator,
                     'loss': val_loss,
                     'IoU' : mIoU_val,
                     'model_state_dict': generator.state_dict(),
                     'optimizer_state_dict': g_optim.state_dict(),
                     'loss_type': loss} 
        best_val = val_loss
    #Save model with lowest val IoU
    if mIoU_val > best_mIoU:
        model_IoU = {'epoch': epoch+1,
                     'generator' : generator,
                     'loss': val_loss,
                     'IoU' : mIoU_val,
                     'model_state_dict': generator.state_dict(),
                     'optimizer_state_dict': g_optim.state_dict(),
                     'loss_type': loss} 
        best_mIoU = mIoU_val
        
        
    mean_val_losses[0][epoch] = val_loss
    mean_val_losses[1][epoch] = mIoU_val
        
        
        
    if epoch in list(range(0,num_epochs,5)):
        print("Epoch: {:3d} | Time: {:5.2f}s | ge_loss: {:5.3f} | de_loss: {:5.3f} | IoU: {:5.3f} | Val ge_loss : {:5.3f} | Val IoU: {:5.3f}".format(epoch, epoch_time,  mean_losses[0][epoch] , mean_losses[1][epoch], mean_losses[2][epoch],
                                                                                                                                                     mean_val_losses[0][epoch], mean_val_losses[1][epoch]))


#Save best models
torch.save(model_val, "generator_val_"+ loss + ".pt")
torch.save(model_IoU, "generator_IoU_"+ loss + ".pt")

#Save final models
torch.save(generator, "generator_"+ loss + ".pt")
torch.save(discriminator, "discriminator_"+ loss + ".pt")
torch.save(IoU_loss, "IoU_FL.pt")
torch.save(mean_losses, "mean_losses_"+ loss + ".pt")
torch.save(mean_val_losses, "mean_val_losses_"+ loss + ".pt")

In [None]:
def check_accuracy(loader, generator):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    #generator.eval()
    mIoU = []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).float()
            softmax = nn.Softmax(dim=1)
            z = generator(x)
            preds = torch.argmax(softmax(z),axis=1)
            IoU = metric(preds,y)
            mIoU.append(IoU.item())
            fig, (ax1, ax2, ax3) = plt.subplots(1, 3,figsize=(15, 15))
            ax1.imshow(x[0,:,:,:].cpu().detach().numpy().transpose(1,2,0))
            ax1.set_title('Original image')
            ax1.axis('off')
            ax2.imshow(y[0,:,:].cpu().detach().numpy(),cmap=cmap)
            ax2.set_title('Targets')
            ax2.axis('off')
            ax3.imshow(preds[0,:,:].cpu().detach().numpy(),cmap=cmap)
            ax3.set_title(f'Predictions, IoU: {IoU:.3}')
            ax3.axis('off')
            plt.show()

In [None]:
# Look at the training/validation masks
check_accuracy(train_loader, generator)

In [None]:
check_accuracy(val_loader, generator)