In [None]:
## Imports
import random
random.seed(10)
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import matplotlib as plt
from torchmetrics import JaccardIndex
from matplotlib.colors import ListedColormap

In [None]:
## DataLoader
class CarSegData(Dataset):
    def __init__(self, data_root, transform=None):
        self.data_root = data_root
        self.transform = transform
        self.class_labels = {
            10: 1,
            20: 2,
            30: 3,
            40: 4,
            50: 5,
            60: 6,
            70: 7,
            80: 8,
            90: 0
        }
        self.classes = {
            1: "hood",
            2: "front door",
            3: "rear door",
            4: "frame",
            5: "rear quarter panel",
            6: "trunk lid",
            7: "fender",
            8: "bumper",
            9: "rest of car"
        }

        # List all the array files in the 'arrays' directory
        self.array_files = np.load(data_root)

    def __len__(self):
        return np.shape(self.array_files)[0]

    def __getitem__(self, idx):
        array_data = self.array_files[idx,:,:,:]
        image_data = array_data[:,:,:3]
        target_data = array_data[:,:,3]

        # Convert target data to class labels
        target_data = self.map_to_class_labels(target_data)
        target_data = self.map_to_classes(target_data)

        # Convert to PIL image
        image = Image.fromarray(image_data.astype('uint8'))

        if self.transform:
            image = self.transform(image)

        return image, target_data

    def map_to_classes(self, target_data):
        class_labels = np.zeros_like(target_data)
        for class_value, class_name in self.classes.items():
            class_labels[target_data == class_value] = class_value
        return torch.from_numpy(class_labels)
    
    def map_to_class_labels(self, target_data):
        class_labels = np.zeros_like(target_data)
        for old_label, new_label in self.class_labels.items():
            class_labels[target_data == old_label] = new_label
        return torch.from_numpy(class_labels)

In [None]:
# Unet 
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d()
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet_new(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet_new, self).__init__()

        # Encoder (contracting path)
        self.enc1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Middle layer
        self.middle = DoubleConv(512, 1024)
#         self.dropout = nn.Dropout2d(p=0.5)

        # Decoder (expansive path)
        self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(1024, 512)
        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(512, 256)
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(256, 128)
        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(128, 64)
        
        # Output layer
        self.out = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        pool1 = self.pool1(enc1)
        enc2 = self.enc2(pool1)
        pool2 = self.pool2(enc2)
        enc3 = self.enc3(pool2)
        pool3 = self.pool3(enc3)
        enc4 = self.enc4(pool3)
        pool4 = self.pool4(enc4)

        # Middle layer
        middle = self.middle(pool4)
#         middle = self.dropout(middle)

        # Decoder with skip connections
        up1 = self.up1(middle)
        concat1 = torch.cat((up1, enc4), dim=1)
        dec1 = self.dec1(concat1)
        up2 = self.up2(dec1)
        concat2 = torch.cat((up2, enc3), dim=1)
        dec2 = self.dec2(concat2)
        up3 = self.up3(dec2)
        concat3 = torch.cat((up3, enc2), dim=1)
        dec3 = self.dec3(concat3)
        up4 = self.up4(dec3)
        concat4 = torch.cat((up4, enc1), dim=1)
        dec4 = self.dec4(concat4)
        
        # Output layer
        out = self.out(dec4)
        return out

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
in_channels = 3  # Input channels (RGB)
out_channels = 9  # Number of classes
model_only_real = UNet_new(in_channels, out_channels)
model_10p_fake = UNet_new(in_channels, out_channels)
model_only_real.to(device)
model_10p_fake.to(device)

# model file root
previous_model_file_root = "data_handin/" ######### Change this ##########

# Name of the models
previous_model1 = 'best_model_weights_only_real_data_batch_norm_dropout_no_rest_of_car'
previous_model2 = 'best_model_weights_10p_real_val'

# Load model weights
checkpoint1 = torch.load(f"{previous_model_file_root}/{previous_model1}")
checkpoint2 = torch.load(f"{previous_model_file_root}/{previous_model2}")

model_only_real.load_state_dict(checkpoint2['model_state_dict'])
model_10p_fake.load_state_dict(checkpoint1['model_state_dict'])

# Tranformation to use on the test data
transform = transforms.Compose([
    transforms.ToTensor()# Converts the image to a tensor
])

# Load the test data
path_test =  'data_handin\Prossed_data_test_ny.npy' ######### Change this ##########
testdata = CarSegData(data_root=path_test,transform=transform)
test_loader = DataLoader(testdata, batch_size=1, shuffle=True, num_workers=2)

In [None]:
# Define the colormap
class_colors = ["black","orange", "darkgreen", "yellow", "cyan", "purple", "lightgreen", "blue", "magenta"]
class_values = [0, 1, 2, 3, 4, 5, 6, 7, 8]
cmap = ListedColormap(class_colors, name='custom_colormap', N=len(class_colors))

def check_accuracy(loader, model,plot=False): 
    model.eval()
    IoU = []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).float()
            softmax = nn.Softmax(dim=1)
            z = model(x)
            preds = torch.argmax(softmax(z),axis=1)
            jaccard = JaccardIndex(task='multiclass', num_classes=9,average='macro').to(device)
            curr = jaccard(preds,y).item()
            IoU.append(curr)
            if plot:
                fig, (ax1, ax2, ax3) = plt.subplots(1, 3,figsize=(15, 15))
                ax1.imshow(x[0,:,:,:].cpu().detach().numpy().transpose(1,2,0))
                ax1.set_title('Original image')
                ax1.axis('off')
                ax2.imshow(y[0,:,:].cpu().detach().numpy(),cmap=cmap)
                ax2.set_title('Targets')
                ax2.axis('off')
                ax3.imshow(preds[0,:,:].cpu().detach().numpy(),cmap = cmap)
                ax3.set_title(f'Predictions, IoU: {curr:.3}')
                ax3.axis('off')
                plt.show()
    print(f"Average IoU score: {np.round(np.sum(IoU)/len(IoU),3)}")

In [None]:
# Check accuracy on the test data with Unet model trained on only real data
check_accuracy(test_loader, model_only_real,plot=False) # set plot=True to see the predictions

In [None]:
# Check accuracy on the test data with Unet model trained on 10% CAD data
check_accuracy(test_loader, model_10p_fake,plot=False) # set plot=True to see the predictions