#1 - Implementing the model





In [1]:
import os
import glob
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
from matplotlib.colors import ListedColormap
from torch.utils.tensorboard import SummaryWriter
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_colab = True

Loading images

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:

from matplotlib.colors import ListedColormap
print("Checking existence ",os.chdir(r'/content/drive/MyDrive'))

array_path = r'/content/drive/MyDrive/carseg_data/complete_arrays'
val_array_path = r'/content/drive/MyDrive/carseg_data/val_arrays'
#CMAP for our segmentation according to the colours provided
car_cmap = ListedColormap([
(255,255,255),
(250, 149, 10), (19, 98, 19), (249, 249, 10),
(10, 248, 250), (149, 7, 149), (5, 249, 9),
(20, 19, 249), (249, 9, 250),
(0,0,0)
])

##2 - Defining the Dataset and making DataLoaders

In [5]:
class NumpyArrayDataset(Dataset):
    def __init__(self, array_dir):
        self.array_dir = array_dir
        self.array_files = os.listdir(array_dir)

    def __len__(self):
        return len(self.array_files)

    def __getitem__(self, idx):
        array_filename = self.array_files[idx]
        array_path = os.path.join(self.array_dir, array_filename)

        # Load the numpy array
        numpy_array = np.load(array_path)

        # Split the array into RGB and segmentation channels
        rgb_channels = numpy_array[:, :, :3]  # Extract RGB channels
        segm_channel = numpy_array[:, :, 3]   # Extract segmentation channel

        # Convert to PyTorch tensors using transpose and assigning required data types
        rgb_channels_tensor = torch.from_numpy(rgb_channels.transpose(2, 0, 1)).float()
        segm_channel_tensor = torch.from_numpy(segm_channel).long()

        # Return a dictionary containing the RGB and segmentation channels
        return {'car': rgb_channels_tensor, 'segm_mask': segm_channel_tensor}



In [6]:
#Define batch size
batch_size = 3

#Create an instance of the NumpyArrayDataset
numpy_dataset = NumpyArrayDataset(array_path)
val_dataset = NumpyArrayDataset(val_array_path)
#Split into training and test datasets
train_dataset_np, val_dataset_np = train_test_split(numpy_dataset, train_size=0.75, random_state=42)

# Create data loaders for the datasets
train_dl_np = DataLoader(train_dataset_np, batch_size=batch_size, shuffle=True, num_workers=2)
validation_dl_np = DataLoader(val_dataset_np, batch_size=batch_size, shuffle=True, num_workers=2)
val_dl_np = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:

# Display only the first 3 images and segmentation masks
num_images_to_display = 5
fig, axs = plt.subplots(2, num_images_to_display, figsize=(16, 8))

for i in range(num_images_to_display):
    # Get data from the dataset
    data = numpy_dataset[i+ 700*i] # Shuffle manually so we can see different images, as the dataset itself does not shuffle them, the dataloaders do

    # Extract image and segmentation mask data from the dataset
    image_data = data['car'].numpy().transpose(1, 2, 0)
    segmentation_mask = data['segm_mask'].numpy()

    # Display original image
    axs[0, i].set_title(f'Image {i+1}')
    axs[0, i].imshow(image_data.astype(np.uint8))

    # Display segmentation mask using the provided colormap
    axs[1, i].set_title(f'Segmentation Mask {i+1}')
    axs[1, i].imshow(segmentation_mask, cmap=car_cmap)

plt.tight_layout()
plt.show()

# 3 - U-Net implenetation

In [32]:
class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False, activation=None ):
        super(ConvLayer, self).__init__()
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
            nn.BatchNorm2d(out_channels),nn.LeakyReLU(0.2, inplace=True)
        ]
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.conv(x)


class TransposeConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, flag, kernel_size=4, stride=2, padding=1, bias=False):
        super(TransposeConvLayer, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
            nn.BatchNorm2d(out_channels)
        ]
        if flag == 1:
            layers.append(nn.ReLU(inplace=True))
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.conv(x)


class UnetBlock(nn.Module):
    def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False, innermost=False, outermost=False):
        super(UnetBlock, self).__init__()
        self.outermost = outermost
        if input_c is None:
            input_c = nf
        downconv = ConvLayer(input_c, ni)
        downnorm = nn.BatchNorm2d(ni)
        uprelu = nn.ReLU(inplace=True)
        upnorm = nn.BatchNorm2d(nf)

        if outermost:
            upconv = TransposeConvLayer(ni * 2, nf, flag = 1)
            down = [downconv]
            up = [upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = TransposeConvLayer(ni, nf, bias=False, flag = 0)
            down = [downconv]
            up = [upconv, upnorm]
            model = down + up
        else:
            upconv = TransposeConvLayer(ni * 2, nf, bias=False, flag = 0)
            down = [downconv, downnorm]
            up = [upconv, upnorm]
            if dropout:
                up += [nn.Dropout(0.2)] #tweak this
            model = down + [submodule] + up
        self.model = nn.Sequential(*model)
        self.residual = nn.Sequential(
            nn.Conv2d(input_c, nf, kernel_size=1, stride=1),
            nn.BatchNorm2d(nf)
        )

    def forward(self, x):
        if self.outermost:
            output = self.model(x)
            return output
        else:
            x_clone = x.clone()
            residual_output = self.residual(x_clone)
            model_output = self.model(x)
            concatenated_output = torch.cat([residual_output, model_output], 1)
            return concatenated_output
class Unet(nn.Module):
    def __init__(self, input_c=3, output_c=10, n_down=8, num_filters=64):
        super(Unet, self).__init__()
        unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)
        for _ in range(3):
            unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
        out_filters = num_filters * 8
        for _ in range(3):
            unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
            out_filters //= 2
        self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)

    def forward(self, x):
        out = self.model(x)
        return  out


## 4 - Weight initialization

In [33]:
def init_weights(net, init='kaiming', gain=1.414, print_message=True):

    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and 'Conv' in classname:
            if init == 'norm':
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')

        if hasattr(m, 'bias') and m.bias is not None:
            nn.init.constant_(m.bias.data, 0.0)

        if 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.data, 1., gain)
            nn.init.constant_(m.bias.data, 0.)

    net.apply(init_func)
    if print_message:
        print(f"model initialized with {init} initialization")
    return net

def init_model(model, device, init_type='kaiming', gain=1.414):
    model = model.to(device)
    model = init_weights(model, init=init_type, gain=gain)
    return model





## 5 - Losses for intermediate testing

In [34]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, input, target):
        ce_loss = F.cross_entropy(input, target, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss

        if self.alpha is not None:
            focal_loss = self.alpha * focal_loss

        return torch.mean(focal_loss)

In [35]:
class DiceLoss(nn.Module):
    def __init__(self,num_classes):
        super(DiceLoss, self).__init__()
        self.num_classes = num_classes

    def forward(self, pred, label):
      pred = F.softmax(pred, dim=1)
      pred = torch.argmax(pred, dim=1).squeeze(1)
      iou_list = list()
      present_dice_list = list()
      pred = pred.view(-1)
      label = label.view(-1)
      for sem_class in range(self.num_classes):
          pred_inds = (pred == sem_class)
          target_inds = (label == sem_class)
          if target_inds.long().sum().item() == 0:
              iou_now = float('nan')
          else:
              intersection_now = 2 * (pred_inds[target_inds]).long().sum().item()
              union_now = pred_inds.long().sum().item() + target_inds.long().sum().item()
              iou_now = float(intersection_now) / float(union_now)
              present_dice_list.append(iou_now)
          iou_list.append(iou_now)
      return np.mean(present_dice_list)



In [36]:
class IOU_Loss(nn.Module):
  def __init__(self, num_classes):
    super(IOU_Loss, self).__init__()
    self.num_classes = num_classes

  def forward(self, pred, label):
      pred = F.softmax(pred, dim=1)
      pred = torch.argmax(pred, dim=1).squeeze(1)
      iou_list = list()
      present_iou_list = list()
      pred = pred.view(-1)
      label = label.view(-1)
      for sem_class in range(self.num_classes):
          pred_inds = (pred == sem_class)
          target_inds = (label == sem_class)
          if target_inds.long().sum().item() == 0:
              iou_now = float('nan')
          else:
              intersection_now = (pred_inds[target_inds]).long().sum().item()
              union_now = pred_inds.long().sum().item() + target_inds.long().sum().item() - intersection_now
              iou_now = float(intersection_now) / float(union_now)
              present_iou_list.append(iou_now)
          iou_list.append(iou_now)
      return np.mean(present_iou_list)


## 6 - Main model

In [37]:
class MainModel(nn.Module):
    def __init__(self, net_G=None, lr_G=0.00001,
                 beta1=0.5, beta2=0.999, lambda_L1=1):
        super().__init__()
        self.lr = lr_G
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1

        if net_G is None:
            self.net_G = init_model(Unet(input_c=3, output_c=10, n_down=8, num_filters=64), self.device)
        else:
            self.net_G = net_G.to(self.device)

        self.opt_G = torch.optim.SGD(self.net_G.parameters(), lr=self.lr, momentum=0.9)
        print(self.opt_G.param_groups[-1]['lr'])
        #self.opt_G = optim.Adam(self.net_G.parameters(), lr=self.lr, betas=(beta1, beta2))

        #self.LossFunction = nn.CrossEntropyLoss()  # Change loss function for segmentation
        self.LossFunction2 = FocalLoss(gamma = 2, alpha = 0.25)
        self.LossFunction3 = IOU_Loss(num_classes = 10)
    def set_requires_grad(self, requires_grad=True):
        for p in self.parameters():
            p.requires_grad = True

    def setup_input(self, data):
        self.car = data['car'].to(self.device)
        self.segm_mask = data['segm_mask'].to(self.device)

    def forward(self):

        self.predicted_segm_mask = self.net_G(self.car)

    def backward_G(self):

        self.segm_mask = self.segm_mask / 10
        self.focal_loss = self.LossFunction2(self.predicted_segm_mask.type(torch.FloatTensor),self.segm_mask.type(torch.LongTensor))
        self.iou_loss = self.LossFunction3(self.predicted_segm_mask.type(torch.FloatTensor),self.segm_mask.type(torch.LongTensor))
        self.focal_loss.backward()

    def optimize(self):
        self.forward()
        self.opt_G.zero_grad()
        self.net_G.train()
        self.backward_G()
        self.opt_G.step()



The MeasureClass class is a utility class for computing and tracking the average of a value over multiple iterations. It keeps track of the count, sum, and average of the values. The reset method resets the meter, while the update method updates the meter with a new value and count.



## 7 - MeasureClass

In [38]:
class MeasureClass:
    def __init__(self):
        self.reset()

    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count

def create_loss_meters():
    focal_loss = MeasureClass()
    seg_loss = MeasureClass()
    iou_loss = MeasureClass()
    return {
            'focal_loss': focal_loss,
            'iou_loss': iou_loss}

def update_losses(model, loss_meter_dict, count, writer, step):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)
        writer.add_scalar(loss_name, loss_meter.avg, step)

def log_results(loss_meter_dict, step, writer, num_epoch):
    for loss_name, loss_meter in loss_meter_dict.items():
        l_avg = loss_meter.avg
        #print(f"{loss_name}: {l_avg:.5f}")
        writer.add_scalar(loss_name, loss_meter.avg, step)

def plot_results(loss_list, val_loss_list, num_epochs):
    plt.figure(figsize=(8, 6))
    plt.plot(range(1, num_epochs + 1), loss_list, label='Focal Loss')
    plt.plot(range(1, num_epochs + 1), val_loss_list[0], label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Total Loss Over Epochs')
    plt.legend()
    plt.show()

## 8 - Training the model

In [39]:
def validate_model(model, val_loader,writer,step):
    model.eval()  # Set model to evaluation mode
    val_loss_meter = create_loss_meters()  # Create loss meter for validation
    l1=[]
    l3=[]
    for data in val_loader:
        model.setup_input(data)
        model.forward()
        update_losses(model, val_loss_meter, count=data['car'].size(0),writer = writer,step = step)
    for loss_name, loss_meter in val_loss_meter.items():
                l_avg = loss_meter.avg
                if loss_name == "focal_loss":
                  l1.append(l_avg)
                if loss_name == "iou_loss":
                  l3.append(l_avg)
    vl1 = sum(l1)/len(l1)
    vl3 = sum(l3)/len(l3)
    return (vl1,vl3)

In [None]:
loss_list = [] #For plotting the loss
val_loss_list = [[],[]]
def train_model(model, train_loader, epochs):
    writer = SummaryWriter()
    step = 0
    total_losses = []
    best_loss = 100
    counter = 0
    lr_ctr = 0
    ctr_mult = 0
    for epoch in range(1, epochs + 1):
        model.train()
        print('Epoch ', epoch)
        loss_meter_dict = create_loss_meters()
        i = 0
        avg_loss_list = []
        for data in tqdm(train_loader):
            model.set_requires_grad()
            model.setup_input(data)
            model.optimize()
            update_losses(model, loss_meter_dict, count=data['car'].size(0), writer=writer, step=step)
            i += 1
            step += 1
            for loss_name, loss_meter in loss_meter_dict.items():
                l_avg = loss_meter.avg
                if loss_name == "focal_loss":
                  avg_loss_list.append(l_avg)
            log_results(loss_meter_dict, step, writer, epoch)

        loss_epoch = sum(avg_loss_list)/len(avg_loss_list)

        loss_list.append(loss_epoch)

        (a,c) = validate_model(model, validation_dl_np, writer,step)  # Define function to calculate validation loss
        print(f'Validation Losses (focal, iou): {a,c}')
        val_loss_list[0].append(a)
        val_loss_list[1].append(c)

        print('Epoch ', epoch)
        lr_ctr+=1
        patience = 3
        print(loss_epoch, best_loss)
        if (epoch%10 == 0) and epoch > 50:
          torch.save(model.state_dict(), os.path.join(r'/content/drive/MyDrive/', 'epoch-{}.pt'.format(epoch)))
        if loss_epoch < best_loss - 0.02:
            best_loss = loss_epoch
            print(best_loss, loss_epoch)
            counter = 0  # Reset counter
        else:
            counter += 1
            ctr_mult += 1

        if counter >= patience:
          if model.lr*5 < 0.1:
            model.lr = model.lr*5
            model.opt_G = torch.optim.SGD(model.parameters(), lr=model.lr, momentum=0.9)
            print(f'mult learning rate ', model.lr)
          counter = 0

        if lr_ctr == 5:
          lr_ctr = 0
          model.lr = model.lr/2
          model.opt_G = torch.optim.SGD(model.parameters(), lr=model.lr, momentum=0.9)
          print('div lr ', model.lr)
        plot_results(loss_list, val_loss_list, epoch)
    writer.close()


model = MainModel()
num_epochs = 30
train_model(model, train_dl_np, num_epochs)

plot_results(loss_list, val_loss_list, num_epochs)

In [41]:
torch.save(model.state_dict(), os.path.join(r'/content/drive/MyDrive/', 'model_trained.pt'))

## 9 - Visualisation of the results

In [None]:

def visualize_segmentation(valid_dl, num_samples=5):
    model1 = MainModel()
    path = "/content/drive/MyDrive/epoch-60.pt"
    model1.load_state_dict(torch.load(path))
    model1.eval()
    vis_iterator = iter(valid_dl)
    with torch.no_grad():
      for j in range(15):
        data = next(vis_iterator)
        for i in range(2):
            car_scans = data['car'][i].cpu().numpy()
            model1.setup_input(data)
            # Forward pass
            model1.forward()
            # Get predicted segmentation
            segms = data['segm_mask'][i].cpu().numpy()
            predicted_masks = model1.predicted_segm_mask[i].cpu().numpy()
            print('IOU metric: ', IOU_Loss(num_classes = 10)(model1.predicted_segm_mask.type(torch.FloatTensor),model1.segm_mask.type(torch.LongTensor)))
            print('Dice metric: ', DiceLoss(num_classes = 10)(model1.predicted_segm_mask.type(torch.FloatTensor),model1.segm_mask.type(torch.LongTensor)))
            plt.figure(figsize=(12, 4))
            # Display the original car image
            plt.subplot(1, 3, 1)
            plt.title('Car image')
            plt.imshow(car_scans.transpose(1,2,0).astype(np.uint8))
            plt.axis('off')

            # Display the actual segmentation
            plt.subplot(1, 3, 2)
            plt.title('Actual segmentation')
            plt.imshow(segms,cmap=car_cmap)
            plt.colorbar()
            plt.axis('off')

            # Display the predicted segmentation mask
            plt.subplot(1, 3, 3)
            plt.title('Predicted segmentation')
            predicted_masks = np.argmax(predicted_masks,axis = 0)*10
            plt.imshow(predicted_masks,cmap=car_cmap)
            plt.axis('on')
            plt.colorbar()
            plt.tight_layout()
            plt.show()

visualize_segmentation(train_dl_np, num_samples=5)

