In [None]:
import os
import torch
import numpy as np
import kornia
from kornia.augmentation import *
import cv2
import multiprocessing
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

from tqdm import tqdm
import copy
from typing import Any, Dict, List, Optional, Union


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()

In [None]:
dataset_path = "../Data/car_dataset"

In [None]:
# for dir_name, _ , filenames in os.walk(dataset_path):
#     for filename in filenames:
#         print(os.path.join(dir_name, filename))

In [None]:
# len(sorted(os.listdir(dataset_path + "/images")))

#### Pre-Processing Images

In [None]:
class PreProcess(torch.nn.Module):
    '''
    Class to convert numpy array into torch tensor
    '''
    
    def __init__(self):
        super().__init__()
    
    @torch.no_grad()  #disable gradients for efficiency
    def forward(self, x: np.array) -> torch.tensor:
        temp: np.ndarray = np.asarray(x) # HxWxC
        out: torch.tensor = kornia.image_to_tensor(temp, keepdim=True)  # CxHxW
        
        return out.float()

#### Dataset Class

In [None]:
# Dataset Class
class SegmentationDataset(Dataset):
    
    def __init__(self, dirPath= r'../data', imageDir='images', masksDir='masks', img_size=512):
        self.imgDirPath = os.path.join(dirPath, imageDir)
        self.maskDirPath = os.path.join(dirPath, masksDir)
        self.img_size = img_size
        self.nameImgFile = sorted(os.listdir(self.imgDirPath))
        self.nameMaskFile = sorted(os.listdir(self.maskDirPath))
        self.preprocess = PreProcess()
    
    def __len__(self):
        return len(self.nameImgFile)
    
    def __getitem__(self, index):
        imgPath = os.path.join(self.imgDirPath, self.nameImgFile[index])
        maskPath = os.path.join(self.maskDirPath, self.nameMaskFile[index])
        
        img = cv2.imread(imgPath, cv2.IMREAD_COLOR)
        resized_img = cv2.resize(img, (self.img_size, self.img_size))
        
        # Min-max scaling
        imin, imax = resized_img.min(), resized_img.max()
        resized_img = (resized_img-imin)/(imax-imin)
        
        img = self.preprocess(resized_img) 
        
        mask = cv2.imread(maskPath, cv2.IMREAD_UNCHANGED)
        resized_mask = cv2.resize(mask, (self.img_size, self.img_size))
        
        mask = self.preprocess(resized_mask)
        
        # Create a new tensor of shape (5, 256, 256) filled with zeros
        output_mask = torch.zeros((5, self.img_size, self.img_size), dtype=torch.float)

        # Populate the output mask tensor using one-hot encoding
        for i in range(5):
            output_mask[i] = (mask == i).float()
        
        return img, output_mask


#### Data Augmentation

In [None]:
class CustomPadding(AugmentationBase2D):
    """
    Custom augmentation to add padding on all sides of an image.
    """
    def __init__(self, padding: int, p: float = 1.0):
        super(CustomPadding, self).__init__(p=p)
        self.padding = padding
        
    def apply_transform(self, img: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, Any], transform: Optional[torch.Tensor] = None
    ) -> torch.Tensor:

        # Calculate the size of the padded image
        b, c, h, w = img.size()
        padded_h, padded_w = h + 2*self.padding, w + 2*self.padding
        
        # Create a tensor filled with zeros as the new padded image
        padded_img = torch.zeros(b, c, padded_h, padded_w)

        # Insert the original image in the center of the padded image
        padded_img[:, :, self.padding:h+self.padding, self.padding:w+self.padding] = img
        
        resize_padded_img = torch.nn.functional.interpolate(padded_img, size=(512, 512), mode='bilinear', align_corners=False)
        
        return resize_padded_img.to("cuda")
    
    def apply_non_transform(self, img: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, Any], transform: Optional[torch.Tensor] = None
    ) -> torch.Tensor:

        return img.to("cuda")
        
    
    def apply_transform_mask(self, mask: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, Any], transform: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        
        # Calculate the size of the padded image
        b, c, h, w = mask.size()
        padded_h, padded_w = h + 2*self.padding, w + 2*self.padding
        
        # Create a tensor filled with zeros as the new padded image
        padded_mask = torch.zeros(b, c, padded_h, padded_w)
        
        # Insert the original image in the center of the padded image
        padded_mask[:, :, self.padding:h+self.padding, self.padding:w+self.padding] = mask
        
        resize_padded_mask = torch.nn.functional.interpolate(padded_mask, size=(512, 512), mode='bilinear', align_corners=False)
        
        return resize_padded_mask.to("cuda")
    
    def apply_non_transform_mask(self, mask: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, Any], transform: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        
        return mask.to("cuda")

In [None]:
class DataAugmentation(torch.nn.Module):
    '''
    Augmentation from Kornai
    - Works with Image and Mask tensor input.
    - Returns "Identity" if no augmentations are passed.
    '''
    
    def __init__(self, augmentations):
        super().__init__()
        
        self.augmentations = torch.nn.Identity()
        
        if len(augmentations) > 0:
            self.augmentations = self._createAugmentationObject(augmentations)
    
    def _createAugmentationObject(self,augs):
        aug_object_list = []
        print(augs)
        for aug in augs:
            aug_name = aug['name']
            aug.pop('name', None)
            aug_object_list.append(
                globals()[aug_name](**aug)
                )
            aug['name'] = aug_name
        aug_container = kornia.augmentation.container.AugmentationSequential(*aug_object_list, data_keys=['input', 'mask'])
        return aug_container
    
    @torch.no_grad()  # disable gradients for effiency
    def forward(self, img, mask):
        img, mask = self.augmentations(img, mask)
        return img, mask

In [None]:
NUM_CORES = multiprocessing.cpu_count()
IMG_SIZE = 256
AUGMENTATIONS = [
            {
                "name":"RandomAffine",
                "degrees":360,
                "p":0.6
            }
            # {
            #     "name":"RandomCrop", # Not working OG size should be greater than the crop size
            #     "size":(256,256),
            #     "p":1.0
            # },
            # {
            #     "name":"CustomPadding",
            #     "padding":100,
            #     "p":1.0
            # }
            
            ]

#### Plot Dataset

In [None]:
ds = SegmentationDataset(dirPath=dataset_path, imageDir='images', masksDir='masks', img_size=IMG_SIZE)
ag = DataAugmentation(augmentations = copy.deepcopy(AUGMENTATIONS))

for idx in range(2):
        
        img, mask = ds[idx]
        
        img, mask = ag(img,mask.unsqueeze(0))
        
        img = img.squeeze().cpu()
        mask = mask.cpu()
        
        print(torch.unique(mask))

        with torch.no_grad():
                fig, axes = plt.subplots(1, 7,figsize=(15,15)) 
                axes[0].imshow(img[0], cmap = 'gray')
                axes[0].axis("off")
                axes[0].set_title("Original scan", fontsize = 12)
                axes[1].imshow(mask[0][0], cmap="copper")
                axes[1].axis("off")
                axes[1].set_title("Background", fontsize = 12)
                
                axes[2].imshow(mask[0][1], cmap="copper")
                axes[2].axis("off")
                axes[2].set_title("Ground Truth", fontsize = 12)
                
                axes[3].imshow(mask[0][2], cmap="copper")
                axes[3].axis("off")
                axes[3].set_title("Ground Truth", fontsize = 12)
                
                axes[4].imshow(mask[0][3], cmap="copper")
                axes[4].axis("off")
                axes[4].set_title("Ground Truth", fontsize = 12)
                
                axes[5].imshow(mask[0][4], cmap="copper")
                axes[5].axis("off")
                axes[5].set_title("Ground Truth", fontsize = 12)
                
                axes[6].imshow(img[0], cmap = 'gray')
                axes[6].imshow(mask[0][1], alpha = 0.5, cmap = 'copper')
                axes[6].imshow(mask[0][2], alpha = 0.5, cmap = 'copper')
                axes[6].imshow(mask[0][3], alpha = 0.5, cmap = 'copper')
                axes[6].imshow(mask[0][4], alpha = 0.5, cmap = 'copper')
                axes[6].axis("off")
                axes[6].set_title("Overlapped View", fontsize = 12)

                plt.show()
