In [1]:
!pip install imutils
!pip install segmentation_models_pytorch
!pip install captum
!pip install albumentations
!pip install gdown 
import gdown 
url = 'https://drive.google.com/uc?id=1LfVKeX5eY2pPrbQxxIjcLljls5uoexwB' 
output = 'data.zip'
gdown.download(url, output)
url = 'https://drive.google.com/uc?id=1Y93BKeKrvlwrieQc9aPkrvzwaUjWwE2Z' 
output = 'best_model.pth'
gdown.download(url, output)
!unzip data.zip

In [2]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import os
from collections import defaultdict, OrderedDict
import shutil
import time
import copy
import math
import random
from imutils import paths

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from numpy import unravel_index

from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler

from torchvision import transforms
from torchvision import datasets

from PIL import *
import albumentations as A
import skimage

import segmentation_models_pytorch as smp

print(torch.cuda.is_available())

In [3]:
os.mkdir('./HyperKvasir/train')
os.mkdir('./HyperKvasir/train/images')
os.mkdir('./HyperKvasir/train/masks')
os.mkdir('./HyperKvasir/val')
os.mkdir('./HyperKvasir/val/images')
os.mkdir('./HyperKvasir/val/masks')
os.mkdir('./HyperKvasir/test')
os.mkdir('./HyperKvasir/test/images')
os.mkdir('./HyperKvasir/test/masks')
# a = np.random.permutation(range(1, 1001))
a = np.array(range(1, 1001))
train, val, test = a[:800], a[800:900], a[900:]
for i in train:
    file_name = '{:04d}.jpg'.format(i)
    shutil.copy('./HyperKvasir/images/' + file_name, './HyperKvasir/train/images/' + file_name)
    shutil.copy('./HyperKvasir/masks/' + file_name, './HyperKvasir/train/masks/' + file_name)
for i in val:
    file_name = '{:04d}.jpg'.format(i)
    shutil.copy('./HyperKvasir/images/' + file_name, './HyperKvasir/val/images/' + file_name)
    shutil.copy('./HyperKvasir/masks/' + file_name, './HyperKvasir/val/masks/' + file_name)
for i in test:
    file_name = '{:04d}.jpg'.format(i)
    shutil.copy('./HyperKvasir/images/' + file_name, './HyperKvasir/test/images/' + file_name)
    shutil.copy('./HyperKvasir/masks/' + file_name, './HyperKvasir/test/masks/' + file_name)

In [4]:
def visualize(**images):
    n_images = len(images)
    f, axarr = plt.subplots(1, n_images, figsize=(4 * n_images,4))
    for idx, (name, image) in enumerate(images.items()):
        if image.shape[0] == 3 or image.shape[0] == 2:
            axarr[idx].imshow(np.squeeze(image.permute(1, 2, 0)))
        else: 
            axarr[idx].imshow(np.squeeze(image))
        axarr[idx].set_title(name.replace('_',' ').title(), fontsize=20)
    plt.show()
    
class EndoscopyDataset(Dataset):
    def __init__(self, images, masks, augmentations=None):   
        self.input_images = images
        self.target_masks = masks
        self.augmentations = augmentations

    def __len__(self):
        return len(self.input_images)
    
    def __getitem__(self, idx): 
        img = Image.open(os.path.join(self.input_images[idx])).convert('RGB')
        mask = Image.open(os.path.join(self.target_masks[idx])).convert('RGB')
        img = transforms.Compose([transforms.Resize((400, 400), interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()])(img)
        mask = transforms.Compose([transforms.Resize((400, 400), interpolation=transforms.InterpolationMode.NEAREST), transforms.Grayscale(), transforms.ToTensor()])(mask)
        img = img.permute((1, 2, 0))
        mask = mask.permute((1, 2, 0))
        img = img.cpu().detach().numpy()
        mask = mask.cpu().detach().numpy()
        
        if self.augmentations:
            augmented = self.augmentations(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']
        
        img = torch.tensor(img, dtype=torch.float)
        img = img.permute((2, 0, 1))
        mask = torch.tensor(mask, dtype=torch.float)
        mask = mask.permute((2, 0, 1))
        
        return [img, mask]
    
train_batch_size = 8
val_batch_size = 4
test_batch_size = 4
num_workers = 2

# main_dir = '/media/external_3TB/3TB/rasekh/fazel/KVASIR/'
# main_dir = '/content/drive/My Drive/KVASIR/'
# main_dir = 'KVASIR/'
main_dir = './HyperKvasir/'

train_images = sorted(list(paths.list_files(main_dir + 'train/images/', contains="jpg")))
val_images = sorted(list(paths.list_files(main_dir + 'val/images/', contains="jpg")))
test_images = sorted(list(paths.list_files(main_dir + 'test/images/', contains="jpg")))

train_masks = sorted(list(paths.list_files(main_dir + 'train/masks/', contains="jpg")))
val_masks = sorted(list(paths.list_files(main_dir + 'val/masks/', contains="jpg")))
test_masks = sorted(list(paths.list_files(main_dir + 'test/masks/', contains="jpg")))

augmentations = A.Compose({
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=(-90, 90)),
        A.VerticalFlip(p=0.5),
        A.Transpose(p=0.5),
        A.GaussianBlur(p=0.5),
        A.augmentations.geometric.transforms.Affine(scale=(0.9, 1.1), translate_percent=0.1)
})

dataset = {
    'train': EndoscopyDataset(train_images, train_masks, augmentations), 
    'val': EndoscopyDataset(val_images, val_masks, None), 
    'test': EndoscopyDataset(test_images, test_masks, None)
}

dataloader = {
    'train': DataLoader(dataset['train'], batch_size=train_batch_size, shuffle=True, num_workers=num_workers),
    'val': DataLoader(dataset['val'], batch_size=val_batch_size, shuffle=True, num_workers=num_workers),
    'test': DataLoader(dataset['test'], batch_size=test_batch_size, shuffle=False, num_workers=num_workers)
}

image, mask = dataset['train'][random.randint(0, len(dataset['train'])-1)]
print(image.shape, image.min(), image.max())
print(mask.shape, mask.min(), mask.max())
visualize(
    original_image = image,
    grund_truth_mask = mask,
    polyp = skimage.segmentation.mark_boundaries(image.permute(1, 2, 0).detach().cpu().numpy(), mask.detach().cpu().numpy()[0].astype(np.int64), color=(0, 0, 1), mode='outer')
)

In [62]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
best_model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', in_channels=3, out_channels=1, init_features=32, pretrained=True)
best_model = torch.load('./best_model.pth')

In [90]:
%matplotlib inline

best_model.eval()

IOUs = []
F1s = []
predictions = []

with torch.no_grad():
    for i, (inputs, labels) in enumerate(dataloader['test']):
        inputs = inputs.to(device)
        labels = labels.to(device)

        pred_mask = best_model(inputs)

        for i in range(len(inputs)):
            test_image = inputs[i]
            test_mask = labels[i]
            predMask = pred_mask[i]
            
            iou = smp.utils.functional.iou(predMask, test_mask, threshold=0.3)
            IOUs.append(iou.cpu().detach())

            f1 = smp.utils.functional.f_score(predMask, test_mask, threshold=0.3)
            F1s.append(f1.cpu().detach())
            
            predMask = torch.where(predMask >= 0.3, 1, 0)
            predictions.append(predMask)

#             visualize(
#                 original_image = test_image.cpu(),
#                 ground_truth_mask = test_mask.cpu(),
#                 predicted_mask = predMask.cpu(),
#                 polyp = skimage.segmentation.mark_boundaries(test_image.permute(1, 2, 0).detach().cpu().numpy(), predMask.detach().cpu().numpy()[0].astype(np.int64), color=(0, 0, 1), mode='outer')
#             )

In [91]:
print('Test IOU: ' + str(np.mean(IOUs)))
print('Test F1: ' + str(np.mean(F1s)))

In [114]:
def segmentation_wrapper(inp):
    # inp : tensor (3, 400, 400)
    pred_mask = best_model(img.reshape((1, 3, 400, 400)).to(device))[0]
    pred_mask = torch.where(pred_mask >= 0.3, 1, 0)
    temp = np.stack(((1 - pred_mask)[0].detach().cpu().numpy(), pred_mask[0].detach().cpu().numpy())).reshape((1, 2, 400, 400))
    model_out = torch.tensor(temp).type(torch.FloatTensor).to(device)
    model_out.requires_grad_(True)
    selected_inds = torch.zeros_like(model_out[0:1]).scatter_(1, pred_mask.long().reshape((1, 1, 400, 400)), 1)
    return (model_out * selected_inds).sum(dim=(2,3))
#     border = skimage.segmentation.find_boundaries(pred_mask).astype(np.uint8)
#     indices = np.where(border == 1)
#     concated = np.concatenate((indices[0][...,np.newaxis],indices[1][...,np.newaxis]),axis=1)
#     borders = pred_mask[concated[:,0], concated[:,1]]
#     return torch.Tensor([[np.where(borders == 0)[0].shape[0], np.where(borders == 1)[0].shape[0]]])

In [96]:
index = np.argmax(IOUs)
img, mask = dataset['test'][index]

In [44]:
from captum.attr import visualization as viz
from captum.attr import LayerGradCam, FeatureAblation, LayerActivation, LayerAttribution

In [118]:
lgc = LayerGradCam(segmentation_wrapper, best_model.encoder1)
gc_attr = lgc.attribute(img, target=1)

In [None]:
la = LayerActivation(segmentation_wrapper, best_model.encoder1)
activation = la.attribute(img)
print("Input Shape:", img.shape)
print("Layer Activation Shape:", activation.shape)
print("Layer GradCAM Shape:", gc_attr.shape)

In [None]:
viz.visualize_image_attr(gc_attr[0].cpu().permute(1,2,0).detach().numpy(),sign="all")

In [None]:
upsampled_gc_attr = LayerAttribution.interpolate(gc_attr,normalized_inp.shape[2:])
print("Upsampled Shape:",upsampled_gc_attr.shape)

In [None]:
viz.visualize_image_attr_multiple(upsampled_gc_attr[0].cpu().permute(1,2,0).detach().numpy(),original_image=preproc_img.permute(1,2,0).numpy(),signs=["all", "positive", "negative"],methods=["original_image", "blended_heat_map","blended_heat_map"])

In [None]:
img_without_train = (1 - (out_max == 19).float())[0].cpu() * preproc_img
plt.imshow(img_without_train.permute(1,2,0))

In [None]:
fa = FeatureAblation(agg_segmentation_wrapper)
fa_attr = fa.attribute(normalized_inp, feature_mask=out_max, perturbations_per_eval=2, target=6)

In [None]:
viz.visualize_image_attr(fa_attr[0].cpu().detach().permute(1,2,0).numpy(),sign="all")

In [None]:
fa_attr_without_max = (1 - (out_max == 6).float())[0] * fa_attr

In [None]:
viz.visualize_image_attr(fa_attr_without_max[0].cpu().detach().permute(1,2,0).numpy(),sign="all")