In [None]:
!pip install git+https://github.com/qubvel/segmentation_models.pytorch

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/qubvel/segmentation_models.pytorch
  Cloning https://github.com/qubvel/segmentation_models.pytorch to /tmp/pip-req-build-ebl2ej38
  Running command git clone -q https://github.com/qubvel/segmentation_models.pytorch /tmp/pip-req-build-ebl2ej38
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone


In [None]:
!pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@release/1.5.x --upgrade

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
import os
import torch
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import segmentation_models_pytorch as smp

from pprint import pprint
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import random
from PIL import Image
import numpy as np

## Dataset

In [None]:
#!unzip /content/drive/MyDrive/AI_Summer_Project/data_semantics.zip -d /content/drive/MyDrive/AI_Summer_Project/

In [None]:
DEFAULT_VALID_LABEL = 120

In [None]:
# Sweep parameters
hyperparameter_defaults = {
    'data_path' : '/content/drive/MyDrive/AI_Summer_Project',
    'batch_size' : 2
}

In [None]:
class KITTI(Dataset):
    '''
    Dataset Class for KITTI Semantic Segmentation Benchmark dataset
    Dataset link - http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015
    There are 34 classes in the given labels. However, not all of them are useful for training
    (like railings on highways, road dividers, etc.).
    So, these useless classes (the pixel values of these classes) are stored in the `void_labels`.
    The useful classes are stored in the `valid_labels`.
    The `encode_segmap` function sets all pixels with any of the `void_labels` to `ignore_index`
    (250 by default). It also sets all of the valid pixels to the appropriate value between 0 and
    `len(valid_labels)` (since that is the number of valid classes), so it can be used properly by
    the loss function when comparing with the output.
    The `get_filenames` function retrieves the filenames of all images in the given `path` and
    saves the absolute path in a list.
    In the `get_item` function, images and masks are resized to the given `img_size`, masks are
    encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
    (mask does not usually require transforms, but they can be implemented in a similar way).
    '''
    IMAGE_PATH = os.path.join('training', 'image_2')
    MASK_PATH = os.path.join('training', 'semantic_rgb')

    def __init__(
        self,
        data_path,
        split,
        img_size=(1248, 384),
        valid_label=DEFAULT_VALID_LABEL,
        transform=None,
    ):
        
        self.img_size = img_size
        self.valid_label = valid_label
        self.ignore_index = 0
        self.valid_index = 1
        self.transform = transform

        self.split = split
        self.data_path = data_path
        self.img_path = os.path.join(self.data_path, 'training/image_2')
        self.mask_path = os.path.join(self.data_path, 'training/semantic_rgb')
        self.img_list = self.get_filenames(self.img_path)
        self.mask_list = self.get_filenames(self.mask_path)

        # Split between train and valid set
        random_inst = random.Random(12345)  # for repeatability
        n_items = len(self.img_list)
        idxs = random_inst.sample(range(n_items), n_items // 5)
        if self.split == 'train': idxs = [idx for idx in range(n_items) if idx not in idxs]
        self.img_list = [self.img_list[i] for i in idxs]
        self.mask_list = [self.mask_list[i] for i in idxs]        

    def __len__(self):
        return(len(self.img_list))

    def __getitem__(self, idx):
        img = Image.open(self.img_list[idx])
        img = img.resize(self.img_size)
        img = np.array(img)

        mask = Image.open(self.mask_list[idx]).convert('L')
        mask = mask.resize(self.img_size)
        mask = np.array(mask)
        mask = self.encode_segmap(mask)
        mask = np.resize(mask,(1,self.img_size[1],self.img_size[0]))

        if self.transform:
            img = self.transform(img)

        return img, mask

    def encode_segmap(self, mask):
        '''
        Sets void classes to zero so they won't be considered for training
        '''
        mask[mask!=self.valid_label]=self.ignore_index
        mask[mask!=self.ignore_index]=self.valid_index
        return mask

    def get_filenames(self, path):
        '''
        Returns a list of absolute paths to images inside given `path`
        '''
        files_list = list()
        dir_files_list = sorted(os.listdir(path))
        for filename in dir_files_list:
            files_list.append(os.path.join(path, filename))
        return files_list


In [None]:
class KittiDataModule(pl.LightningDataModule):
    '''
    Kitti Data Module
    It is specific to KITTI dataset i.e. dataloaders are for KITTI
    and Normalize transform uses the mean and standard deviation of this dataset.
    '''

    def __init__(self, hparams):
        super().__init__()
        self.data_path = hparams['data_path']
        self.batch_size = hparams['batch_size']
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
                                 std=[0.32064945, 0.32098866, 0.32325324])
        ])
        self.setup()
    
    def setup(self, stage=None):
        self.trainset = KITTI(self.data_path, split='train', transform=self.transform)
        self.validset = KITTI(self.data_path, split='valid', transform=self.transform)
    
    def train_dataloader(self):
        return DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.validset, batch_size=self.batch_size, shuffle=False)

In [None]:
kittiData = KittiDataModule(hyperparameter_defaults)

In [None]:


train_dataloader = kittiData.train_dataloader()
valid_dataloader = kittiData.val_dataloader()

train_dataset = kittiData.trainset
valid_dataset = kittiData.validset

print(len(train_dataset))
print(len(valid_dataset))

In [None]:
# lets look at some samples

sample = train_dataset[1]
plt.subplot(1,2,1)
plt.imshow(sample[0].permute(1, 2, 0)) # for visualization we have to transpose back to HWC
plt.subplot(1,2,2)
plt.imshow(sample[1].squeeze())  # for visualization we have to remove 3rd dimension of mask
plt.show()

print(sample[0].shape)
print(sample[0].permute(1, 2, 0).shape)
print(sample[1].shape)
print(sample[1].squeeze().shape)

sample = valid_dataset[0]
plt.subplot(1,2,1)
plt.imshow(sample[0].permute(1, 2, 0)) # for visualization we have to transpose back to HWC
plt.subplot(1,2,2)
plt.imshow(sample[1].squeeze())  # for visualization we have to remove 3rd dimension of mask
plt.show()

In [None]:
# # sample = train_dataset[0]
# # plt.subplot(1,2,1)
# # plt.imshow(sample[0].permute(1, 2, 0)) # for visualization we have to transpose back to HWC
# # plt.subplot(1,2,2)
# # plt.imshow(sample[1].squeeze())  # for visualization we have to remove 3rd dimension of mask
# # plt.show()

# k = 10

# mask = train_dataset[k][0]
# plt.imshow(mask.permute(1, 2, 0))  # for visualization we have to remove 3rd dimension of mask
# plt.show()

# mask = train_dataset[k][1]
# plt.imshow(mask.squeeze())  # for visualization we have to remove 3rd dimension of mask
# plt.show()


# x = mask.shape[1]
# y = mask.shape[2]

# #lst = np.unique(mask)
# lst = np.array([120])

# for k in range(len(train_dataset)):
#   mask = train_dataset[k][0]
#   plt.imshow(mask.permute(1, 2, 0))  # for visualization we have to remove 3rd dimension of mask
#   plt.show()

#   mask = train_dataset[k][1]
#   plt.imshow(mask.squeeze())  # for visualization we have to remove 3rd dimension of mask
#   plt.show()

#   for i in lst:
#     print(i)
#     for xi in range(x):
#       for yi in range(y):
#         if mask[0][xi][yi]!=i:
#           mask[0][xi][yi]=0
#     plt.imshow(mask.squeeze())  # for visualization we have to remove 3rd dimension of mask
#     plt.show()

In [None]:
import albumentations as albu

In [None]:
def get_training_augmentation():
    train_transform = [

        albu.HorizontalFlip(p=0.5),

        albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),

        albu.PadIfNeeded(min_height=320, min_width=320, always_apply=True, border_mode=0),
        albu.RandomCrop(height=320, width=320, always_apply=True),

        albu.IAAAdditiveGaussianNoise(p=0.2),
        albu.IAAPerspective(p=0.5),

        albu.OneOf(
            [
                albu.CLAHE(p=1),
                albu.RandomBrightness(p=1),
                albu.RandomGamma(p=1),
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.IAASharpen(p=1),
                albu.Blur(blur_limit=3, p=1),
                albu.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.RandomContrast(p=1),
                albu.HueSaturationValue(p=1),
            ],
            p=0.9,
        ),
    ]
    return albu.Compose(train_transform)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albu.PadIfNeeded(384, 480)
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

## Model

In [None]:
class PetModel(pl.LightningModule):

    def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):
        super().__init__()
        self.model = smp.create_model(
            arch, encoder_name=encoder_name, in_channels=in_channels, classes=out_classes, **kwargs
        )

        # preprocessing parameteres for image
        params = smp.encoders.get_preprocessing_params(encoder_name)
        self.register_buffer("std", torch.tensor(params["std"]).view(1, 3, 1, 1))
        self.register_buffer("mean", torch.tensor(params["mean"]).view(1, 3, 1, 1))

        # for image segmentation dice loss could be the best first choice
        self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)

    def forward(self, image):
        # normalize image here
        image = (image - self.mean) / self.std
        mask = self.model(image)
        return mask

    def shared_step(self, batch, stage):
        
        image = batch[0]

        # Shape of the image should be (batch_size, num_channels, height, width)
        # if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width]
        assert image.ndim == 4

        # Check that image dimensions are divisible by 32, 
        # encoder and decoder connected by `skip connections` and usually encoder have 5 stages of 
        # downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have 
        # following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80
        # and we will get an error trying to concat these features
        h, w = image.shape[2:]
        assert h % 32 == 0 and w % 32 == 0

        mask = batch[1]

        # Shape of the mask should be [batch_size, num_classes, height, width]
        # for binary segmentation num_classes = 1
        assert mask.ndim == 4

        # Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation
        assert mask.max() <= 1.0 and mask.min() >= 0

        logits_mask = self.forward(image)
        
        # Predicted mask contains logits, and loss_fn param `from_logits` is set to True
        loss = self.loss_fn(logits_mask, mask)

        # Lets compute metrics for some threshold
        # first convert mask values to probabilities, then 
        # apply thresholding
        prob_mask = logits_mask.sigmoid()
        pred_mask = (prob_mask > 0.5).float()

        # We will compute IoU metric by two ways
        #   1. dataset-wise
        #   2. image-wise
        # but for now we just compute true positive, false positive, false negative and
        # true negative 'pixels' for each image and class
        # these values will be aggregated in the end of an epoch
        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="binary")

        return {
            "loss": loss,
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "tn": tn,
        }

    def shared_epoch_end(self, outputs, stage):
        # aggregate step metics
        tp = torch.cat([x["tp"] for x in outputs])
        fp = torch.cat([x["fp"] for x in outputs])
        fn = torch.cat([x["fn"] for x in outputs])
        tn = torch.cat([x["tn"] for x in outputs])

        # per image IoU means that we first calculate IoU score for each image 
        # and then compute mean over these scores
        per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
        
        # dataset IoU means that we aggregate intersection and union over whole dataset
        # and then compute IoU score. The difference between dataset_iou and per_image_iou scores
        # in this particular case will not be much, however for dataset 
        # with "empty" images (images without target class) a large gap could be observed. 
        # Empty images influence a lot on per_image_iou and much less on dataset_iou.
        dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")

        metrics = {
            f"{stage}_per_image_iou": per_image_iou,
            f"{stage}_dataset_iou": dataset_iou,
        }
        
        self.log_dict(metrics, prog_bar=True)

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")            

    def training_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "train")

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "valid")

    def validation_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "valid")

    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, "test")  

    def test_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "test")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001)

In [None]:
model = PetModel("FPN", "resnet34", in_channels=3, out_classes=1)

## Training

In [None]:
from pytorch_lightning.loggers import TensorBoardLogger

logger = TensorBoardLogger(save_dir='/content/drive/MyDrive/AI_Summer_Project', name="DeepLabV3Plus")

In [None]:
trainer = pl.Trainer(
    gpus=1, 
    max_epochs=5,
    logger=logger
)

trainer.fit(
    model, 
    train_dataloaders=train_dataloader, 
    val_dataloaders=valid_dataloader,
)

## Validation and test metrics

In [None]:
# run validation dataset
valid_metrics = trainer.validate(model, dataloaders=valid_dataloader, verbose=False)
pprint(valid_metrics)

# Result visualization

In [None]:
for batch in valid_dataloader:
  #batch = next(iter(valid_dataloader))
  with torch.no_grad():
      model.eval()
      logits = model(batch[0])
  pr_masks = logits.sigmoid()

  for image, gt_mask, pr_mask in zip(batch[0], batch[1], pr_masks):
      plt.figure(figsize=(10, 5))

      plt.subplot(1, 3, 1)
      plt.imshow(image.numpy().transpose(1, 2, 0))  # convert CHW -> HWC
      plt.title("Image")
      plt.axis("off")

      plt.subplot(1, 3, 2)
      plt.imshow(gt_mask.numpy().squeeze()) # just squeeze classes dim, because we have only one class
      plt.title("Ground truth")
      plt.axis("off")

      plt.subplot(1, 3, 3)
      plt.imshow(pr_mask.numpy().squeeze()) # just squeeze classes dim, because we have only one class
      plt.title("Prediction")
      plt.axis("off")

      plt.show()