In [None]:
import mlflow

# Set our tracking server uri for logging
mlflow.set_tracking_uri(uri="http://127.0.0.1:8005")

# Create a new MLflow Experiment
mlflow.set_experiment("Car Segmentation")

# Auto Log on MLFlow
# mlflow.pytorch.autolog() # Not working in Vanilla Pytorch code

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
import os
from tqdm import tqdm
import torch.optim as optim
import multiprocessing
import numpy as np
import kornia
from kornia.augmentation import *
import cv2
import copy
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from typing import Any, Dict, Optional
from monai.losses.dice import *

import warnings
warnings.filterwarnings("ignore")

In [None]:
# Hyperparameters
dataset_path = "../Data/car_dataset"

LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
NUM_EPOCHS = 150
NUM_WORKERS = multiprocessing.cpu_count()
IMAGE_SIZE = 512
PIN_MEMORY = True
LOAD_MODEL = True

LOSS_NAME = 'diceceloss'
LOSS_FUNCTION = DiceCELoss(include_background = True, softmax= True, lambda_dice=0.1, lambda_ce=0.9)

AUGMENTATIONS = [
            {
                "name":"RandomAffine",
                "degrees":360,
                "align_corners":True,
                "p":0.6
            },
            {
                "name":"RandomHorizontalFlip",
                "p":0.6
            },
            {
                "name":"RandomVerticalFlip",
                "p":0.6
            },
            {
                "name":"RandomRotation",
                "degrees":360,
                "p":0.6
            }
            # {
            #     "name":"CustomPadding",
            #     "padding":100,
            #     "p":0.6
            # }
            ]

In [None]:
# Pre-processing
class PreProcess(torch.nn.Module):
    '''
    Class to convert numpy array into torch tensor
    '''
    
    def __init__(self):
        super().__init__()
    
    @torch.no_grad()  #disable gradients for efficiency
    def forward(self, x: np.array) -> torch.tensor:
        temp: np.ndarray = np.asarray(x) # HxWxC
        out: torch.tensor = kornia.image_to_tensor(temp, keepdim=True)  # CxHxW
        
        return out.float()


In [None]:
# Dataset Class
class SegmentationDataset(Dataset):
    
    def __init__(self, dirPath= r'../data', imageDir='images', masksDir='masks', img_size=512):
        self.imgDirPath = os.path.join(dirPath, imageDir)
        self.maskDirPath = os.path.join(dirPath, masksDir)
        self.img_size = img_size
        self.nameImgFile = sorted(os.listdir(self.imgDirPath))
        self.nameMaskFile = sorted(os.listdir(self.maskDirPath))
        self.preprocess = PreProcess()
    
    def __len__(self):
        return len(self.nameImgFile)
    
    def __getitem__(self, index):
        imgPath = os.path.join(self.imgDirPath, self.nameImgFile[index])
        maskPath = os.path.join(self.maskDirPath, self.nameMaskFile[index])
        
        img = cv2.imread(imgPath, cv2.IMREAD_COLOR)
        resized_img = cv2.resize(img, (self.img_size, self.img_size))
        
        # Min-max scaling
        imin, imax = resized_img.min(), resized_img.max()
        resized_img = (resized_img-imin)/(imax-imin)
        
        img = self.preprocess(resized_img) 
        
        mask = cv2.imread(maskPath, cv2.IMREAD_UNCHANGED)
        resized_mask = cv2.resize(mask, (self.img_size, self.img_size))
        
        mask = self.preprocess(resized_mask)
        
        # Create a new tensor of shape (5, 256, 256) filled with zeros
        output_mask = torch.zeros((5, self.img_size, self.img_size), dtype=torch.float)

        # Populate the output mask tensor using one-hot encoding
        
        '''
            0 - background
            1 - car
            2 - wheel
            3 - light
            4 - windows
        '''
        
        for i in range(5):
            output_mask[i] = (mask == i).float()
        
        return img, output_mask


In [None]:
class CustomPadding(AugmentationBase2D):
    """
    Custom augmentation to add padding on all sides of an image.
    """
    def __init__(self, padding: int, p: float = 1.0):
        super(CustomPadding, self).__init__(p=p)
        self.padding = padding
        
    def apply_transform(self, img: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, Any], transform: Optional[torch.Tensor] = None
    ) -> torch.Tensor:

        # Calculate the size of the padded image
        b, c, h, w = img.size()
        padded_h, padded_w = h + 2*self.padding, w + 2*self.padding
        
        # Create a tensor filled with zeros as the new padded image
        padded_img = torch.zeros(b, c, padded_h, padded_w)

        # Insert the original image in the center of the padded image
        padded_img[:, :, self.padding:h+self.padding, self.padding:w+self.padding] = img
        
        resize_padded_img = torch.nn.functional.interpolate(padded_img, size=(512, 512), mode='bilinear', align_corners=False)
        
        return resize_padded_img.to("cuda")
    
    def apply_non_transform(self, img: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, Any], transform: Optional[torch.Tensor] = None
    ) -> torch.Tensor:

        return img.to("cuda")
        
    
    def apply_transform_mask(self, mask: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, Any], transform: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        
        # Calculate the size of the padded image
        b, c, h, w = mask.size()
        padded_h, padded_w = h + 2*self.padding, w + 2*self.padding
        
        # Create a tensor filled with zeros as the new padded image
        padded_mask = torch.zeros(b, c, padded_h, padded_w)
        
        # Insert the original image in the center of the padded image
        padded_mask[:, :, self.padding:h+self.padding, self.padding:w+self.padding] = mask
        
        resize_padded_mask = torch.nn.functional.interpolate(padded_mask, size=(512, 512), mode='bilinear', align_corners=False)
        
        return resize_padded_mask.to("cuda")
    
    def apply_non_transform_mask(self, mask: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, Any], transform: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        
        return mask.to("cuda")

In [None]:
# Augmentation
class DataAugmentation(torch.nn.Module):
    '''
    Augmentation from Kornai
    - Works with Image and Mask tensor input.
    - Returns "Identity" if no augmentations are passed.
    '''
    
    def __init__(self, augmentations):
        super().__init__()
        
        self.augmentations = torch.nn.Identity()
        
        if len(augmentations) > 0:
            self.augmentations = self._createAugmentationObject(augmentations)
    
    def _createAugmentationObject(self,augs):
        aug_object_list = []
        print(augs)
        for aug in augs:
            aug_name = aug['name']
            aug.pop('name', None)
            aug_object_list.append(
                globals()[aug_name](**aug)
                )
            aug['name'] = aug_name
        aug_container = kornia.augmentation.container.AugmentationSequential(*aug_object_list, data_keys=['input', 'mask'])
        return aug_container
    
    @torch.no_grad()  # disable gradients for effiency
    def forward(self, img, mask):
        img, mask = self.augmentations(img, mask)
        return img, mask



#### Plot

In [None]:
def plot_image_mask_using_dataset_class(dataset, number_of_images = 10):
    ag = DataAugmentation(augmentations = copy.deepcopy(AUGMENTATIONS))
    
    for idx in range(number_of_images):
            img, mask = dataset[idx]
            
            img, mask = ag(img,mask)
            
            img = img.squeeze().cpu()
            mask = mask.cpu()

            with torch.no_grad():
                fig, axes = plt.subplots(1, 3,figsize=(15,15)) 
                axes[0].imshow(img[0], cmap = 'gray')
                axes[0].axis("off")
                axes[0].set_title("Original scan", fontsize = 12)
                axes[1].imshow(mask[0][0], cmap="copper")
                axes[1].axis("off")
                axes[1].set_title("Ground Truth", fontsize = 12)
                axes[2].imshow(img[0], cmap = 'gray')
                axes[2].imshow(mask[0][0], alpha = 0.5, cmap = 'copper')
                axes[2].axis("off")
                axes[2].set_title("Overlapped View", fontsize = 12)

                plt.show()

# plot_image_mask(val_ds, 5)

In [None]:
# Function to plot images and masks
def plot_images_and_masks_dataloader(images, masks, num_images):
    fig, axs = plt.subplots(num_images, 2, figsize=(10, 5 * num_images))

    for i in range(num_images):
        # Plot the image
        axs[i, 0].imshow(images[i].permute(1, 2, 0).cpu().numpy())  # Convert CHW to HWC format and to numpy
        axs[i, 0].set_title(f'Image {i + 1}')
        axs[i, 0].axis('off')

        # Plot the mask
        axs[i, 1].imshow(masks[i].squeeze().cpu().numpy(), cmap='gray')  # Convert C1H to H format and to numpy
        axs[i, 1].set_title(f'Mask {i + 1}')
        axs[i, 1].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
# Iterate through the DataLoader to get images and masks to plot, KEEP NUM_WORKER as 0
'''
for i, (img, mask) in enumerate(train_dataloader):
    try:
        print(f"Batch {i + 1}:")
        print("Image batch shape:", img.shape)
        print("Mask batch shape:", mask.shape)
        
        # for i in img:
        #     print(i.shape)
        #     i = i.permute(1,2,0)
        #     print("lol = ",i.shape)
        
        # plot_images_and_masks(img, mask, min(2, img.size(0)))
    except Exception as e:
        print(f"Error processing batch {i + 1}: {e}")
        continue
'''

### Model

In [None]:
# Model
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias = False),
            nn.BatchNorm2d(out_channels),  # Bias is False as it will be cancelled out but BatchNorm
            nn.ReLU(inplace = True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias = False),
            nn.BatchNorm2d(out_channels),  # Bias is False as it will be cancelled out but BatchNorm
            nn.ReLU(inplace = True)
        )
    
    def forward(self, x):
        return self.conv(x)


class UNET(nn.Module):
    def __init__(self, in_channels = 3, out_channels = 1, features = [64, 128, 256, 512]):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        
        # Down part of Unet
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature
        
        # Up part of Unet
        for feature in reversed(features):
            self.ups.append(
                        nn.ConvTranspose2d(feature*2, feature, kernel_size = 2, stride = 2),
                        )
            self.ups.append(DoubleConv(feature*2, feature))
        
        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.finalconv = nn.Conv2d(features[0], out_channels, kernel_size = 1)
    
    def forward(self, x):
        skip_connections = []
        
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]
            
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size = skip_connection.shape[2:])
            
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)
        
        return self.finalconv(x)

### Dice Score Function

In [None]:
def dice_coefficient(pred, gt_mask, threshold=0.5):
    pred = (pred > threshold).float()
    gt_mask = gt_mask.float()
    
    intersection = (pred * gt_mask).sum()
    dice = (2. * intersection) / (pred.sum() + gt_mask.sum() + 1e-8)
    return dice.item()

### Train Val Function

In [None]:
def train_fn(loader, augmentations_Obj, model, optimizer, loss_fn, scheduler, scaler):
    model.train()
    loop = tqdm(loader, leave=False)
    losses = []
    dice_scores = []

    for image, mask in loop:
        # Apply Augmentations
        image, mask = augmentations_Obj(image, mask)
        
        image = image.to(device=DEVICE)
        mask = mask.float().to(device=DEVICE)
        
        # Forward pass
        with torch.cuda.amp.autocast():
            predictions = model(image)
            loss = loss_fn(predictions, mask)
        
        # Backward pass
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        # Calculate Dice score
        dice_score = dice_coefficient(predictions, mask)
        dice_scores.append(dice_score)
        
        losses.append(loss.item())
        loop.set_postfix(loss=sum(losses) / len(losses), diceScore=sum(dice_scores) / len(dice_scores))
        
    avg_loss = sum(losses) / len(losses)
    avg_dice = sum(dice_scores) / len(dice_scores)
    
    scheduler.step(avg_loss)
    
    mlflow.log_metric("train_loss", avg_loss)
    mlflow.log_metric("train_dice_score", avg_dice)
    
    print(f"Average training loss: {avg_loss:.5f}")
    print(f"Average training Dice coefficient Score: {avg_dice:.5f}")

In [None]:
def val_fn(loader, model, loss_fn):
    model.eval()
    val_losses = []
    dice_scores = []
    
    loop = tqdm(loader, leave=False)
    
    with torch.no_grad():
        for image, mask in loop:
            image = image.to(device=DEVICE)
            mask = mask.float().to(device=DEVICE)
            
            predictions = model(image)
            loss = loss_fn(predictions, mask)
            
            # Calculate Dice score
            dice_score = dice_coefficient(predictions, mask)
            dice_scores.append(dice_score)
            
            val_losses.append(loss.item())
            loop.set_postfix(loss=sum(val_losses) / len(val_losses), diceScore=sum(dice_scores) / len(dice_scores))
    
    avg_val_loss = sum(val_losses) / len(val_losses)
    avg_dice = sum(dice_scores) / len(dice_scores)
    
    mlflow.log_metric("val_loss", avg_val_loss)
    mlflow.log_metric("val_dice_score", avg_dice)
    
    print(f"Average validation loss: {avg_val_loss:.5f}")
    print(f"Average validation Dice coefficient Score: {avg_dice:.5f}")
    

#### Checkpoints

In [None]:
def save_checkpoint(state, filename=f"../Model Checkpoints/car_segmentation_{LOSS_NAME}_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

In [None]:
def main():
    # Define datasets
    whole_dataset = SegmentationDataset(dirPath=dataset_path, imageDir='images/', masksDir='masks/', img_size=IMAGE_SIZE)
    augmentations_Obj = DataAugmentation(augmentations=copy.deepcopy(AUGMENTATIONS))
    
    train_size = int(0.8 * len(whole_dataset))
    val_size = len(whole_dataset) - train_size
    train_ds, val_ds = torch.utils.data.random_split(whole_dataset, [train_size, val_size])
    
    print(f"Length of Train dataset = {len(train_ds)}")
    print(f"Length of Val dataset = {len(val_ds)}")

    # Define DataLoaders
    train_dataloader = DataLoader(train_ds, BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=False)
    val_dataloader = DataLoader(val_ds, BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=False)

    print("Created Train and Val dataloaders")
    
    model = UNET(in_channels=3, out_channels=5).to(DEVICE)
    loss_fn = LOSS_FUNCTION
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,factor=0.1,patience=1,verbose=True)
    scaler = torch.cuda.amp.GradScaler()

    if LOAD_MODEL and os.path.exists(f"../Model Checkpoints/car_segmentation_{LOSS_NAME}_checkpoint.pth.tar"):
        load_checkpoint(torch.load(f"../Model Checkpoints/car_segmentation_{LOSS_NAME}_checkpoint.pth.tar"), model)
    
    with mlflow.start_run() as run:
        
        mlflow.log_params({"learning_rate": LEARNING_RATE ,"batch_size": BATCH_SIZE, "number_of_epochs": NUM_EPOCHS, "image_size": IMAGE_SIZE, "loss function": LOSS_FUNCTION})
        
        for epoch in range(NUM_EPOCHS):
            print(f"\nEPOCH [{epoch+1}/{NUM_EPOCHS}]")
            train_fn(train_dataloader, augmentations_Obj, model, optimizer, loss_fn, scheduler, scaler)
            val_fn(val_dataloader, model, loss_fn)

            # save model
            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer":optimizer.state_dict(),
            }

            save_checkpoint(checkpoint)
    


In [None]:

if __name__ == "__main__":
    main()