In [None]:
## Imports
import random
random.seed(10)
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import matplotlib.pyplot as plt
from torchmetrics import JaccardIndex
from matplotlib.colors import ListedColormap

In [None]:
## DataLoader
class CarSegData(Dataset):
    def __init__(self, data_root, transform=None):
        self.data_root = data_root
        self.transform = transform
        self.class_labels = {
            10: 1,
            20: 2,
            30: 3,
            40: 4,
            50: 5,
            60: 6,
            70: 7,
            80: 8,
            90: 0
        }
        self.classes = {
            1: "hood",
            2: "front door",
            3: "rear door",
            4: "frame",
            5: "rear quarter panel",
            6: "trunk lid",
            7: "fender",
            8: "bumper",
            9: "rest of car"
        }

        # List all the array files in the 'arrays' directory
        self.array_files = np.load(data_root)

    def __len__(self):
        return np.shape(self.array_files)[0]

    def __getitem__(self, idx):
        array_data = self.array_files[idx,:,:,:]
        image_data = array_data[:,:,:3]
        target_data = array_data[:,:,3]

        # Convert target data to class labels
        target_data = self.map_to_class_labels(target_data)
        target_data = self.map_to_classes(target_data)

        # Convert to PIL image
        image = Image.fromarray(image_data.astype('uint8'))

        if self.transform:
            image = self.transform(image)

        return image, target_data

    def map_to_classes(self, target_data):
        class_labels = np.zeros_like(target_data)
        for class_value, class_name in self.classes.items():
            class_labels[target_data == class_value] = class_value
        return torch.from_numpy(class_labels)
    
    def map_to_class_labels(self, target_data):
        class_labels = np.zeros_like(target_data)
        for old_label, new_label in self.class_labels.items():
            class_labels[target_data == old_label] = new_label
        return torch.from_numpy(class_labels)

## Baseline

In [None]:
# Unet 
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d()
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet_new(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet_new, self).__init__()

        # Encoder (contracting path)
        self.enc1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Middle layer
        self.middle = DoubleConv(512, 1024)
#         self.dropout = nn.Dropout2d(p=0.5)

        # Decoder (expansive path)
        self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(1024, 512)
        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(512, 256)
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(256, 128)
        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(128, 64)
        
        # Output layer
        self.out = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        pool1 = self.pool1(enc1)
        enc2 = self.enc2(pool1)
        pool2 = self.pool2(enc2)
        enc3 = self.enc3(pool2)
        pool3 = self.pool3(enc3)
        enc4 = self.enc4(pool3)
        pool4 = self.pool4(enc4)

        # Middle layer
        middle = self.middle(pool4)
#         middle = self.dropout(middle)

        # Decoder with skip connections
        up1 = self.up1(middle)
        concat1 = torch.cat((up1, enc4), dim=1)
        dec1 = self.dec1(concat1)
        up2 = self.up2(dec1)
        concat2 = torch.cat((up2, enc3), dim=1)
        dec2 = self.dec2(concat2)
        up3 = self.up3(dec2)
        concat3 = torch.cat((up3, enc2), dim=1)
        dec3 = self.dec3(concat3)
        up4 = self.up4(dec3)
        concat4 = torch.cat((up4, enc1), dim=1)
        dec4 = self.dec4(concat4)
        
        # Output layer
        out = self.out(dec4)
        return out

## Pix2Pix

In [None]:
class EncoderBlock(nn.Module):
    
    def __init__(self,in_channels, out_channels,device, kernel_size = 4, stride = 2, padding = 1, norm = True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, device = device)
        self.act = nn.LeakyReLU(0.2, inplace=False) #The choice 0.2 is from the paper
        
        self.use_norm= norm
        if norm:
            self.bn = nn.BatchNorm2d(out_channels, device = device)
        else:
            self.bn = None
    def forward(self,x):
        x = self.conv(x)
        
        if self.use_norm:
            x = self.bn(x)
        x = self.act(x)
        return x

    
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, device,
                kernel_size = 4, stride = 2, padding = 1,dropout = False):
        super().__init__()
        self.conv = nn.ConvTranspose2d(in_channels,out_channels,kernel_size, stride, padding,device = device)
    
        self.act = nn.ReLU(inplace = False)
        self.bn = nn.BatchNorm2d(out_channels, device = device)
    

        if dropout is not None:
            self.dropout = nn.Dropout2d(p = 0.5, inplace = False) # p = 0.5 is from the paper
        else:
            self.dropout = None
        
    def forward(self,x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        
        if self.dropout is not None:
            x = self.dropout(x)
        
        return x    

In [None]:
#Encoder
#C64-C128-C256-C512-C512-C512-C512-C512
# 1    2    3    4    5    6   7    8
#CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128
class UNet_Generator(nn.Module):
    def __init__(self,device, input_channels= 3, out_channels = 9):
        super().__init__()
        #Encoder
        self.EB1 = EncoderBlock(input_channels,64,  norm=False, device = device)
        self.EB2 = EncoderBlock(64,128, device = device)
        self.EB3 = EncoderBlock(128,256, device = device)
        self.EB4 = EncoderBlock(256,512, device = device)
        self.EB5 = EncoderBlock(512,512, device = device)
        self.EB6 = EncoderBlock(512,512, device = device)
        self.EB7 = EncoderBlock(512,512, device = device)
        self.EB8 = EncoderBlock(512,512, norm = False, device = device)
        
        #Decoder
        self.DB8 = DecoderBlock(512,512,dropout=True, device = device)
        self.DB7 = DecoderBlock(2*512,512,dropout=True, device = device)
        self.DB6 = DecoderBlock(2*512,512,dropout=True, device = device)
        self.DB5 = DecoderBlock(2*512,512,device = device)
        self.DB4 = DecoderBlock(2*512,256,device = device)
        self.DB3 = DecoderBlock(2*256,128,device = device)
        self.DB2 = DecoderBlock(2*128,64,device = device)
        self.DB1 = nn.ConvTranspose2d(2*64, out_channels, kernel_size=4, stride=2, padding=1, device= device)
    
    def forward(self,x):
        #Encoder
        e1 = self.EB1(x)
        e2 = self.EB2(e1)
        e3 = self.EB3(e2)
        e4 = self.EB4(e3)
        e5 = self.EB5(e4)
        e6 = self.EB6(e5)
        e7 = self.EB7(e6)
        e8 = self.EB8(e7)
        
        #Decoder
        s8 = self.DB8(e8)
        s7 = self.DB7(torch.cat([s8,e7], dim = 1)) #Add skip connections
        s6 = self.DB6(torch.cat([s7,e6], dim = 1)) #Add skip connections
        s5 = self.DB5(torch.cat([s6,e5], dim = 1)) #Add skip connections
        s4 = self.DB4(torch.cat([s5,e4], dim = 1)) #Add skip connections
        s3 = self.DB3(torch.cat([s4,e3], dim = 1)) #Add skip connections
        s2 = self.DB2(torch.cat([s3,e2], dim = 1))  #Add skip connections
        s1 = self.DB1(torch.cat([s2,e1], dim = 1)) #Add skip connections
                
        return s1

In [None]:
#C64-C128-C256-C512
class DiscriminatorBlock(nn.Module):
    
    def __init__(self,in_channels, out_channels,device, kernel_size = 4, stride = 2, padding = 1, norm = True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, device = device)
        self.act = nn.LeakyReLU(0.2, inplace=True) #The choice 0.2 is from the paper
        
        self.use_norm= norm
        if norm:
            self.bn = nn.BatchNorm2d(out_channels, device = device)
        else:
            self.bn = None
    def forward(self,x):
        x = self.conv(x)
        if self.use_norm:
            x = self.bn(x)
        x = self.act(x)
        return x
        

class PatchGan_Discriminator(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.l1 = DiscriminatorBlock(3 + 1,64,norm=False, device= device)
        self.l2 = DiscriminatorBlock(64,128, device= device)
        self.l3 = DiscriminatorBlock(128,256, device= device)
        self.l4 = DiscriminatorBlock(256,512, device= device)
        self.l5 = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1, device= device)
    
    def forward(self,mask,image):
        #The discrimator is condition on the true image
        if mask.shape[1] > 1:
            mask = masker(mask)
        x = torch.cat([mask,image], dim = 1)
        x = self.l1(x)
        x = self.l2(x)
        x = self.l3(x)
        x = self.l4(x)
        x = self.l5(x)
        # Last output will be a value between 0 and 1
        x = torch.sigmoid(x) 
        return x
    

## Plot results

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
in_channels = 3  # Input channels (RGB)
out_channels = 9  # Number of classes
# Baseline
model_only_real = UNet_new(in_channels, out_channels)
model_only_real.to(device)
#model_10p_fake = UNet_new(in_channels, out_channels)
#model_10p_fake.to(device)

PATH_baseline = "/kaggle/input/model-eval/best_model_weights_only_real_data_batch_norm_dropout_no_rest_of_car"
checkpoint1 = torch.load(PATH_baseline)
model_only_real.load_state_dict(checkpoint1['model_state_dict'])

#Pix2pix
PATH_p2p_FL = "/kaggle/input/model-eval/generator_FL_2000.pt" #Path to pretrained models
g_FL = torch.load(PATH_p2p_FL)

PATH_p2p_CE = "/kaggle/input/model-eval/generator_CE.pt" #Path to pretrained models
g_CE= torch.load(PATH_p2p_CE)

# model file root
#previous_model_file_root = "data_handin/" ######### Change this ##########

# Name of the models
#previous_model1 = 'best_model_weights_only_real_data_batch_norm_dropout_no_rest_of_car'
#previous_model2 = 'best_model_weights_10p_real_val'

# Load model weights
#checkpoint1 = torch.load(f"{previous_model_file_root}/{previous_model1}")
#checkpoint2 = torch.load(f"{previous_model_file_root}/{previous_model2}")

#model_only_real.load_state_dict(checkpoint2['model_state_dict'])
#model_10p_fake.load_state_dict(checkpoint1['model_state_dict'])

# Tranformation to use on the test data
transform = transforms.Compose([
    transforms.ToTensor()# Converts the image to a tensor
])

# Load the test data
""
path_test =  '/kaggle/input/data-deloitte/Prossed_data_test_ny.npy' ######### Change this ##########
testdata = CarSegData(data_root=path_test,transform=transform)
test_loader = DataLoader(testdata, batch_size=1, shuffle=True, num_workers=2)

In [None]:
# Define the colormap
class_colors = ["black","orange", "darkgreen", "yellow", "cyan", "purple", "lightgreen", "blue", "magenta"]
class_values = [0, 1, 2, 3, 4, 5, 6, 7, 8]
cmap = ListedColormap(class_colors, name='custom_colormap', N=len(class_colors))

In [None]:
def check_accuracy(loader, baseline, P2P_CE,P2P_FL,plot=False): 
    baseline.eval()
    P2P_CE.eval()
    P2P_FL.eval()
    IoU_b = []
    IoU_CE = []
    IoU_FL = []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).float()
            softmax = nn.Softmax(dim=1)
            z_b = baseline(x)
            z_CE = P2P_CE(x)
            z_FL = P2P_FL(x)
            preds_b = torch.argmax(softmax(z_b),axis=1)
            preds_CE = torch.argmax(softmax(z_CE),axis=1)
            preds_FL = torch.argmax(softmax(z_FL),axis=1)
            jaccard = JaccardIndex(task='multiclass', num_classes=9,average='macro').to(device)
            curr_b  = jaccard(preds_b,y).item()
            curr_CE = jaccard(preds_CE,y).item()
            curr_FL = jaccard(preds_FL,y).item()
            IoU_b.append(curr_b)
            IoU_CE.append(curr_CE)
            IoU_FL.append(curr_FL)
            if plot:
                fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1, 5,figsize=(15, 15))
                ax1.imshow(x[0,:,:,:].cpu().detach().numpy().transpose(1,2,0))
                ax1.set_title('Original image')
                ax1.axis('off')
                ax2.imshow(y[0,:,:].cpu().detach().numpy(),cmap=cmap)
                ax2.set_title('Targets')
                ax2.axis('off')
                ax3.imshow(preds_b[0,:,:].cpu().detach().numpy(),cmap = cmap)
                ax3.set_title(f'Baseline, IoU: {curr_b:.3}')
                ax3.axis('off')
                ax4.imshow(preds_CE[0,:,:].cpu().detach().numpy(),cmap = cmap)
                ax4.set_title(f'Pix2Pix CE, IoU: {curr_CE:.3}')
                ax4.axis('off')
                ax5.imshow(preds_FL[0,:,:].cpu().detach().numpy(),cmap = cmap)
                ax5.set_title(f'Pix2Pix FL, IoU: {curr_FL:.3}')
                ax5.axis('off')
                plt.show()

In [None]:
# Check accuracy on the test data with Unet model trained on only real data
check_accuracy(test_loader,baseline=model_only_real,P2P_CE=g_CE,P2P_FL=g_FL,plot=True) # set plot=True to see the predictions