In this experiment I will be running multi class segmentation model in simulated VR environment. 

For experiment Cholec8K dataset is used



In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra



import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os



# Data collection

In [None]:
paths = []
# Loop to collect paths to all images
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        p = os.path.join(dirname, filename)
        paths.append(p)
files = pd.DataFrame(paths) 



In [None]:
# split paths to get masks and images
files['mask_path'] = files[0].apply(lambda x: x if 'endo_mask' in x else None)
files['img_path'] = files[0].apply(lambda x: x if 'endo.png' in x else None)

# parse image ids
files['video'] = files[0].apply(lambda x: x.split('/')[5])
files['frame'] = files[0].apply(lambda x: x.split('/')[6].split('_')[1])
files

In [None]:
# get rid of redundant rows
data = files.groupby(['video', 'frame'], as_index=False).agg({'mask_path': 'sum', 'img_path':'sum'})

In [None]:
data

In [None]:
import os, cv2
import numpy as np
import pandas as pd
import random, tqdm
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import albumentations as album

In [None]:
# I am using segmentation-models-pytorch for experiments
!pip install -q -U segmentation-models-pytorch albumentations
import segmentation_models_pytorch as smp

In [None]:
# Train test split

# My train-test split is based on videos, so I may not worry about overfitting

train_videos = ['video18',
                 'video09',
                 'video35',
                 'video20',
                 'video01',
                 'video17',
                 'video52',
                 'video43',
                 'video55',
                 'video28',
                'video48',
                 'video27',
               ]
valid_videos = [
                 'video25',
                 'video12',
                   'video37',
                 'video24',
                 'video26']

# For some reason the class ids are not ordered. I have to fix it

ids = [0, 5, 11, 12, 13, 21, 22, 23, 24, 25, 31, 32, 33, 35, 36, 50]
replace = {k:i for i, k in enumerate(ids)}

In [None]:
# A function to map old ids to new ones
def mp(entry):
    return replace[entry] if entry in replace else entry

In [None]:
# just vectorizing this function for robustness
mp = np.vectorize(mp)


In [None]:
from tensorflow.keras.utils import to_categorical


class EndoscopyDataset(torch.utils.data.Dataset):
    def __init__(
            self, 
            df,
            n_classes=16,
            augmentation=None, 
            preprocessing=None,
            distortion=False,
    ):
        self.image_paths = df['img_path'].tolist()
        self.mask_paths = df['mask_path'].tolist()
        self.n_classes = n_classes
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.distortion = distortion
    
    def __getitem__(self, i):
        
        image = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(cv2.imread(self.mask_paths[i]), cv2.COLOR_BGR2RGB)
        # the mask has 3 channels, but the value is the same in all channels
        # so using only one channel is fine
        mask = mp(mask[:, :, 0])
        # using to_categorical to transfer a single mask to multiple masks
        mask = to_categorical(mask, num_classes=self.n_classes,dtype ="float32" if self.distortion else 'int32')
        if self.distortion:
            image = album.augmentations.functional.optical_distortion(image, k=4, dx=0, dy=0, interpolation=1, border_mode=0, value=None)
            mask = album.augmentations.functional.optical_distortion(mask, k=4, dx=0, dy=0, interpolation=1, border_mode=0, value=None)
            mask = mask.astype('int32')
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
      
        return image, mask
        
    def __len__(self):
        return len(self.image_paths)

In [None]:
def get_training_augmentation():
    train_transform = [
        album.HorizontalFlip(p=0.5),
        album.Resize(256, 256) # it is important to resize images in this dataset
    ]
    return album.Compose(train_transform)


def get_validation_augmentation():
    test_transform = [
        album.Resize(256, 256) # no need to add additional augmentations
    ]
    return album.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn=None):
    _transform = []
    if preprocessing_fn:
        _transform.append(album.Lambda(image=preprocessing_fn))
    _transform.append(album.Lambda(image=to_tensor, mask=to_tensor))
        
    return album.Compose(_transform)

In [None]:
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = 16
ACTIVATION = 'softmax2d' # I am using a multi class segmentation, so
                         # softmax is preferrable

model = smp.DeepLabV3Plus(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=CLASSES, 
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
data['vid'] = data['video'].apply(lambda x: x.split('_')[0]) 

In [None]:
train_dataset = EndoscopyDataset(
    data.loc[data['vid'].isin(train_videos)], 
    augmentation=get_training_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
)

valid_dataset = EndoscopyDataset(
    data.loc[data['vid'].isin(valid_videos)], 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
)


In [None]:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False, num_workers=4)

In [None]:
TRAINING = True

EPOCHS = 10

# Set device: `cuda` or `cpu`
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# define loss function
loss = smp.utils.losses.DiceLoss()

# define metrics
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
    smp.utils.metrics.Fscore(threshold=0.5),
    smp.utils.metrics.Accuracy(threshold=0.5),
    smp.utils.metrics.Recall(threshold=0.5),
    smp.utils.metrics.Precision(threshold=0.5),
]


# define optimizer
optimizer = torch.optim.AdamW([ 
    dict(params=model.parameters(), lr=0.0001),
])



In [None]:
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)
valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
if TRAINING:

    best_iou_score = 0.0
    train_logs_list, valid_logs_list = [], []

    for i in range(0, EPOCHS):

        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(valid_loader)
        train_logs_list.append(train_logs)
        valid_logs_list.append(valid_logs)

        # Save model if a better val IoU score is obtained
        if best_iou_score < valid_logs['iou_score']:
            best_iou_score = valid_logs['iou_score']
            torch.save(model, './best_model.pth')
            print('Model saved!')

# Model with distortion


In [None]:
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = 16
ACTIVATION = 'softmax2d' 

model = smp.DeepLabV3Plus(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=CLASSES, 
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
train_dataset = EndoscopyDataset(
    data.loc[data['vid'].isin(train_videos)], 
    augmentation=get_training_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    distortion=True, # the only difference from the previous part is this param
)

valid_dataset = EndoscopyDataset(
    data.loc[data['vid'].isin(valid_videos)], 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    distortion=True, # the only difference from the previous part is this param
)



In [None]:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False, num_workers=4)

In [None]:
TRAINING = True
EPOCHS = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss = smp.utils.losses.DiceLoss() # uses multiclass DiceLoss by default
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
    smp.utils.metrics.Fscore(threshold=0.5),
    smp.utils.metrics.Accuracy(threshold=0.5),
    smp.utils.metrics.Recall(threshold=0.5),
    smp.utils.metrics.Precision(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0005),
])




In [None]:
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)
valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
if TRAINING:

    best_iou_score = 0.0
    train_logs_list, valid_logs_list = [], []

    for i in range(0, EPOCHS):

        # Perform training & validation
        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(valid_loader)
        train_logs_list.append(train_logs)
        valid_logs_list.append(valid_logs)

        # Save model if a better val IoU score is obtained
        if best_iou_score < valid_logs['iou_score']:
            best_iou_score = valid_logs['iou_score']
            torch.save(model, './best_model.pth')
            print('Model saved!')