# Import Libraries and Download Data

In [None]:
!pip install imutils
!pip install segmentation_models_pytorch
!pip install captum
!pip install albumentations
!pip install gdown 
import gdown 
url = 'https://drive.google.com/uc?id=1LfVKeX5eY2pPrbQxxIjcLljls5uoexwB' 
output = 'data.zip'
gdown.download(url, output)
!unzip data.zip

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import os
from collections import defaultdict, OrderedDict
import shutil
import time
import copy
import math
import random
from imutils import paths

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from numpy import unravel_index

from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler

from torchvision import transforms
from torchvision import datasets

from PIL import *
import albumentations as A
import skimage

import segmentation_models_pytorch as smp

print(torch.cuda.is_available())

# Load Data

In [None]:
def visualize(**images):
    n_images = len(images)
    f, axarr = plt.subplots(1, n_images, figsize=(4 * n_images,4))
    for idx, (name, image) in enumerate(images.items()):
        if image.shape[0] == 3 or image.shape[0] == 2:
            axarr[idx].imshow(np.squeeze(image.permute(1, 2, 0)))
        else: 
            axarr[idx].imshow(np.squeeze(image))
        axarr[idx].set_title(name.replace('_',' ').title(), fontsize=20)
    plt.show()
    
class EndoscopyDataset(Dataset):
    def __init__(self, images, masks, augmentations=None):   
        self.input_images = images
        self.target_masks = masks
        self.augmentations = augmentations

    def __len__(self):
        return len(self.input_images)
    
    def __getitem__(self, idx): 
        img = Image.open(os.path.join(self.input_images[idx])).convert('RGB')
        mask = Image.open(os.path.join(self.target_masks[idx])).convert('RGB')
        img = transforms.Compose([transforms.Resize((400, 400), interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()])(img)
        mask = transforms.Compose([transforms.Resize((400, 400), interpolation=transforms.InterpolationMode.NEAREST), transforms.Grayscale(), transforms.ToTensor()])(mask)
        img = img.permute((1, 2, 0))
        mask = mask.permute((1, 2, 0))
        img = img.cpu().detach().numpy()
        mask = mask.cpu().detach().numpy()
        
        if self.augmentations:
            augmented = self.augmentations(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']
        
        img = torch.tensor(img, dtype=torch.float)
        img = img.permute((2, 0, 1))
        mask = torch.tensor(mask, dtype=torch.float)
        mask = mask.permute((2, 0, 1))
        
        return [img, mask]
    
train_batch_size = 8
val_batch_size = 4
test_batch_size = 4
num_workers = 2

main_dir = './'

train_images = sorted(list(paths.list_files(main_dir + 'train/images/', contains="jpg")))
val_images = sorted(list(paths.list_files(main_dir + 'val/images/', contains="jpg")))
test_images = sorted(list(paths.list_files(main_dir + 'test/images/', contains="jpg")))

train_masks = sorted(list(paths.list_files(main_dir + 'train/masks/', contains="jpg")))
val_masks = sorted(list(paths.list_files(main_dir + 'val/masks/', contains="jpg")))
test_masks = sorted(list(paths.list_files(main_dir + 'test/masks/', contains="jpg")))

augmentations = A.Compose({
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=(-90, 90)),
        A.VerticalFlip(p=0.5),
        A.Transpose(p=0.5),
        A.GaussianBlur(p=0.5),
        A.augmentations.geometric.transforms.Affine(scale=(0.9, 1.1), translate_percent=0.1)
})

dataset = {
    'train': EndoscopyDataset(train_images, train_masks, augmentations), 
    'val': EndoscopyDataset(val_images, val_masks, None), 
    'test': EndoscopyDataset(test_images, test_masks, None)
}

dataloader = {
    'train': DataLoader(dataset['train'], batch_size=train_batch_size, shuffle=True, num_workers=num_workers),
    'val': DataLoader(dataset['val'], batch_size=val_batch_size, shuffle=True, num_workers=num_workers),
    'test': DataLoader(dataset['test'], batch_size=test_batch_size, shuffle=False, num_workers=num_workers)
}

image, mask = dataset['train'][random.randint(0, len(dataset['train'])-1)]
print(image.shape, image.min(), image.max())
print(mask.shape, mask.min(), mask.max())
visualize(
    original_image = image,
    grund_truth_mask = mask,
    polyp = skimage.segmentation.mark_boundaries(image.permute(1, 2, 0).detach().cpu().numpy(), mask.detach().cpu().numpy()[0].astype(np.int64), color=(0, 0, 1), mode='outer')
)

# Train Model

In [None]:
class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(features * 16, features * 8, kernel_size=2, stride=2)
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, kernel_size=2, stride=2)
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2)
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2)
        self.decoder1 = UNet._block(features * 2, features, name="dec1")
        self.conv = nn.Conv2d(in_channels=features, out_channels=out_channels, kernel_size=1)

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

class Loss(smp.utils.base.Loss):
    def __init__(self, eps=1.0, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
        self.activation = smp.base.modules.Activation(activation)
        self._name = 'loss'

    def forward(self, y_pr, y_gt):
        y_pr = self.activation(y_pr)
        return (1 - smp.utils.functional.iou(y_pr, y_gt, eps=self.eps, threshold=0.3)) + (1 - smp.utils.functional.f_score(y_pr, y_gt, beta=1, eps=self.eps, threshold=0.3)) + nn.functional.binary_cross_entropy(y_pr, y_gt)

In [None]:
training = True
epochs = 500
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', in_channels=3, out_channels=1, init_features=32, pretrained=True)

loss = Loss()

metrics = [
    smp.utils.metrics.IoU(threshold=0.3),
    smp.utils.metrics.Fscore(threshold=0.3),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0005),
])

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=1, eta_min=0.00001,
)

train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=device,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=device,
    verbose=True,
)

if training:

    best_model = {'loss': 0.0, 'iou_score': 0.0, 'fscore': 0.0}
    train_logs_list, valid_logs_list = [], []

    for i in range(0, epochs):
        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(dataloader['train'])
        valid_logs = valid_epoch.run(dataloader['val'])
        train_logs_list.append(train_logs)
        valid_logs_list.append(valid_logs)

        if best_model['iou_score'] < valid_logs['iou_score']:
            torch.save(model, main_dir + 'best_model.pth')
            best_model['loss'] = valid_logs['loss']
            best_model['iou_score'] = valid_logs['iou_score']
            best_model['fscore'] = valid_logs['fscore']
            print('Model saved!')

In [None]:
print(best_model)

In [None]:
train_logs_df = pd.DataFrame(train_logs_list)
valid_logs_df = pd.DataFrame(valid_logs_list)

In [None]:
plt.figure(figsize=(10,10))
plt.plot(train_logs_df.index.tolist(), train_logs_df.loss.tolist(), lw=1, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.loss.tolist(), lw=1, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('Loss', fontsize=20)
plt.title('Loss Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig('loss_plot.png')
plt.show()

In [None]:
plt.figure(figsize=(10,10))
plt.plot(train_logs_df.index.tolist(), train_logs_df.iou_score.tolist(), lw=1, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.iou_score.tolist(), lw=1, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('IoU Score', fontsize=20)
plt.title('IoU Score Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig('iou_score_plot.png')
plt.show()

In [None]:
plt.figure(figsize=(10,10))
plt.plot(train_logs_df.index.tolist(), train_logs_df.fscore.tolist(), lw=1, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.fscore.tolist(), lw=1, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('F1 Score', fontsize=20)
plt.title('F1 Score Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig('fscore_plot.png')
plt.show()

# Test Model

In [None]:
best_model = torch.load('./HyperKvasir/best_model.pth')

In [None]:
%matplotlib inline

best_model.eval()

IOUs = []
F1s = []

with torch.no_grad():
    for i, (inputs, labels) in enumerate(dataloader['test']):
        inputs = inputs.to(device)
        labels = labels.to(device)

        pred_mask = best_model(inputs)

        for i in range(len(inputs)):
            test_image = inputs[i]
            test_mask = labels[i]
            predMask = pred_mask[i]
            
            iou = smp.utils.functional.iou(predMask, test_mask, threshold=0.3)
            IOUs.append(iou.cpu().detach())

            f1 = smp.utils.functional.f_score(predMask, test_mask, threshold=0.3)
            F1s.append(f1.cpu().detach())
            
            predMask = torch.where(predMask >= 0.3, 1, 0)

            visualize(
                original_image = test_image.cpu(),
                ground_truth_mask = test_mask.cpu(),
                predicted_mask = predMask.cpu(),
                polyp = skimage.segmentation.mark_boundaries(image.permute(1, 2, 0).detach().cpu().numpy(), mask.detach().cpu().numpy()[0].astype(np.int64), color=(0, 0, 1), mode='outer')
            )

In [None]:
print('Test IOU: ' + str(np.mean(IOUs)))
print('Test F1: ' + str(np.mean(F1s)))