## Preparing data

In [1]:
import torch

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from torch.utils.data import DataLoader

from time import time

import os
import matplotlib.pyplot as plt
import numpy as np
from time import time

from skimage.io import imread
from skimage.transform import resize

from IPython.display import clear_output
from matplotlib import rcParams

from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


### 1. Loading dataset

In [3]:
images = []
lesions = []
root = 'PH2Dataset'

for root, dirs, files in os.walk(os.path.join(root, 'PH2 Dataset images')):
    if root.endswith('_Dermoscopic_Image'):
        images.append(imread(os.path.join(root, files[0])))
    if root.endswith('_lesion'):
        lesions.append(imread(os.path.join(root, files[0])))

### 2. Resizing images

In [4]:
size = (256, 256)
X = [resize(x, size, mode='constant', anti_aliasing=True,) for x in images]
Y = [resize(y, size, mode='constant', anti_aliasing=False) > 0.5 for y in lesions]

In [5]:
X = np.array(X, np.float32)
Y = np.array(Y, np.float32)
print(f'Loaded {len(X)} images')

Loaded 0 images


#### Showing some data...

In [None]:
plt.figure(figsize=(18, 6))
for i in range(6):
    plt.subplot(2, 6, i+1)
    plt.axis("off")
    plt.imshow(X[i])

    plt.subplot(2, 6, i+7)
    plt.axis("off")
    plt.imshow(Y[i])
plt.show();

### 3. Splitting data

In [None]:
ix = np.random.choice(len(X), len(X), False)
tr, val, ts = np.split(ix, [100, 150])

#### Creating DataLoader

In [None]:
batch_size = 25
data_tr = DataLoader(list(zip(np.rollaxis(X[tr], 3, 1), Y[tr, np.newaxis])), 
                     batch_size=batch_size, shuffle=True)
data_val = DataLoader(list(zip(np.rollaxis(X[val], 3, 1), Y[val, np.newaxis])),
                      batch_size=batch_size, shuffle=True)
data_ts = DataLoader(list(zip(np.rollaxis(X[ts], 3, 1), Y[ts, np.newaxis])),
                     batch_size=batch_size, shuffle=True)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Creating models

In [None]:
rcParams['figure.figsize'] = (15,4)

### 1. SegNet

In [None]:
class SegNet(nn.Module):
    def __init__(self):
        super(SegNet, self).__init__()
        
        self.enc_conv0 = nn.Sequential(*[
                                         nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(64),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(64),
                                         nn.ReLU()
                                        ])
        self.pool0 =  nn.MaxPool2d(kernel_size=(2, 2), stride=2, return_indices=True)
        
        self.enc_conv1 = nn.Sequential(*[
                                         nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(128),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(128),
                                         nn.ReLU()
                                        ])
        self.pool1 =  nn.MaxPool2d(kernel_size=(2, 2), stride=2, return_indices=True)
        
        self.enc_conv2 = nn.Sequential(*[
                                         nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(256),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(256),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(256),
                                         nn.ReLU()
                                        ])
        self.pool2 =  nn.MaxPool2d(kernel_size=(2, 2), stride=2, return_indices=True)
        
        self.enc_conv3 = nn.Sequential(*[
                                         nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(512),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(512),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(512),
                                         nn.ReLU()
                                         ])
        self.pool3 =  nn.MaxPool2d(kernel_size=(2,2), stride=2, return_indices=True)
        
        self.bottleneck_conv_enc = nn.Sequential(*[
                                         nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(512),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(512),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(512),
                                         nn.ReLU()
                                         ])
        self.bottleneck_pool = nn.MaxPool2d(kernel_size=(2,2), stride=2, return_indices=True)
        self.bottleneck_upsample = nn.MaxUnpool2d(kernel_size=2, stride=2)
        self.bottleneck_conv_dec = nn.Sequential(*[
                                         nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(512),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(512),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(512),     
                                         nn.ReLU()
                                         ])
        
        self.upsample0 =  nn.MaxUnpool2d(kernel_size=2, stride=2)
        self.dec_conv0 = nn.Sequential(*[
                                         nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(512),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(512),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(256),
                                         nn.ReLU()
                                         ]) 
        
        self.upsample1 = nn.MaxUnpool2d(kernel_size=2, stride=2)
        self.dec_conv1 = nn.Sequential(*[
                                         nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(256),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(256),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(128),
                                         nn.ReLU()
                                         ])
        
        self.upsample2 = nn.MaxUnpool2d(kernel_size=2, stride=2)  # 64 -> 128
        self.dec_conv2 = nn.Sequential(*[
                                         nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(128),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(64),
                                         nn.ReLU()
                                         ]) 
        
        self.upsample3 = nn.MaxUnpool2d(kernel_size=2, stride=2) # 128 -> 256
        self.dec_conv3 = nn.Sequential(*[
                                         nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(64),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, padding=1),
                            
                                         ])
        
    def forward(self, x):

    size_0 = x.size()
    e0, id0 = self.pool0(self.enc_conv0(x))
        
    size_1 = e0.size()
    e1, id1 = self.pool1(self.enc_conv1(e0))

    size_2 = e1.size()
    e2, id2 = self.pool2(self.enc_conv2(e1))

    size_3 = e2.size()
    e3, id3 = self.pool3(self.enc_conv3(e2))

    size_4 = e3.size()
    bottle_enc, bottle_id = self.bottleneck_pool(self.bottleneck_conv_enc(e3))

    size_d = bottle_enc.size()

    bottle_dec = self.bottleneck_conv_dec(self.bottleneck_upsample(bottle_enc, bottle_id, output_size=size_4))

    d0 = self.dec_conv0(self.upsample0(bottle_dec, id3, output_size=size_3))
    
    d1 = self.dec_conv1(self.upsample1(d0, id2, output_size=size_2))

    d2 = self.dec_conv2(self.upsample2(d1, id1, output_size=size_1))
    d3 = self.dec_conv3(self.upsample3(d2, id0, output_size=size_0)) 
    return d3     

#### Metrics (IoU)

In [None]:
def iou_pytorch(outputs: torch.Tensor, labels: torch.Tensor):
    outputs = outputs.squeeze(1).byte()
    labels = labels.squeeze(1).byte()
    SMOOTH = 1e-8
    intersection = (outputs & labels).float().sum((1, 2))  
    union = (outputs | labels).float().sum((1, 2))         
    
    iou = (intersection + SMOOTH) / (union + SMOOTH)
    
    thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10
    
    return thresholded 

#### Training model

In [None]:
def train(model, opt, loss_fn, epochs, data_tr, data_val):
     since = time.time()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_iou = 0.0
    
    epoch_train_loss = []
    epoch_val_loss = []
    epoch_train_iou = []
    epoch_val_iou = []
    
    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch + 1, epochs))
        print('-' * 10)
    
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_iou = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)


                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    Y_pred = model(inputs)
                    loss = loss_fn(labels, Y_pred)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.detach().item() * inputs.size(0)
                Y_pred = torch.sigmoid(Y_pred)
                Y_pred = torch.where(Y_pred > 0.5, 1, 0)
                running_iou += iou_pytorch(Y_pred, labels).mean().item()
                
            
            epoch_loss = running_loss / len(dataloaders[phase])
            epoch_iou = running_iou / len(dataloaders[phase])
            
            if phase == 'train':
                epoch_train_loss.append(epoch_loss)
                epoch_train_iou.append(epoch_iou)
            else:
                epoch_val_loss.append(epoch_loss)
                epoch_val_iou.append(epoch_iou)
            
            
            if phase == 'val':
                Y_pred = Y_pred.detach().cpu()
                clear_output(wait=True)
                for k in range(5):
                    plt.subplot(2, 5, k+1)
                    plt.imshow(np.rollaxis(inputs[k].cpu().numpy(), 0, 3), cmap='gray')
                    plt.title('Real', fontsize=12)
                    
                    plt.axis('off')

                    plt.subplot(2, 5, k+6)
                    plt.imshow(Y_pred[k, 0], cmap='gray')
                    plt.title('Output', fontsize=12)
                    plt.axis('off')
                plt.suptitle('val Loss: {:.4f} val IOU: {:.4f}'.format(epoch_loss, epoch_iou), fontsize=15)
                plt.show()
                
            inputs, labels = inputs.cpu(), labels.cpu()
            Y_pred, loss = Y_pred.cpu(), loss.cpu()
            del inputs, labels, Y_pred, loss,
            torch.cuda.empty_cache()
            gc.collect()
            
            
            if phase == 'val' and epoch_iou > best_iou:
                best_iou = epoch_iou
                best_model_wts = copy.deepcopy(model.state_dict())
            
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
    print('Best val IOU: {:4f}'.format(best_iou))
    
    
    model.load_state_dict(best_model_wts)
    statistics = epoch_train_loss, epoch_train_iou, epoch_val_loss, epoch_val_iou
    return model, statistics

In [None]:
def predict(model, data):
    model.eval()  # testing mode
    Y_pred = [X_batch for X_batch, _ in data]
    return np.array(Y_pred)

In [None]:
def score_model(model, metric, data):
    model.eval()  # testing mode
    scores = 0
    with torch.no_grad():
        for X_batch, Y_label in data:
            Y_pred = model(X_batch.to(device))
            Y_pred = torch.sigmoid(Y_pred)
            Y_pred = torch.where(Y_pred > 0.5, 1, 0)

            scores += metric(Y_pred, Y_label.to(device)).mean().item()

    return scores/ len(data)

### Getting results with different loss-functions

#### BCE

In [None]:
def bce_loss(y_real, y_pred):
    y_pred, y_real = torch.unsqueeze(y_pred, 1), torch.unsqueeze(y_real, 1)
    
    epsilon = 10 ** -45
    result = - torch.mean(
        y_real * torch.log(torch.sigmoid(y_pred) + epsilon) + (1 - y_real) * torch.log(1 - torch.sigmoid(y_pred) + epsilon)
    )

    return result
    
model = SegNet().to(device)

In [None]:
model = SegNet().to(device)
max_epochs = 50
optim = torch.optim.Adam(model_bce.parameters(), lr=1e-4) 
train(model, optim, bce_loss, max_epochs, data_tr, data_val)

In [None]:
score_model(model, iou_pytorch, data_val)

#### Dice

In [None]:
def dice_loss(y_real, y_pred):
    smooth=1e-7
    inputs = torch.flatten(torch.sigmoid(inputs))
    targets = torch.flatten(targets)
    
    intersection = torch.sum(inputs * targets)
    cardinality = torch.sum(inputs + targets)
    
    dice_score = 2. * intersection / (cardinality + smooth)
    return torch.mean(1. - dice_score)

In [None]:
model_dice = SegNet().to(device)
max_epochs = 50
optimaizer = optim.Adam(model_dice.parameters(), lr=1e-3)
train(model_dice, optimaizer, dice_loss, max_epochs, data_tr, data_val)

In [None]:
score_model(model_dice, iou_pytorch, data_val)

#### Focal

In [None]:
def focal_loss(y_real, y_pred, eps = 1e-8, gamma = 2):
    y_pred, y_real = torch.unsqueeze(y_pred, 1), torch.unsqueeze(y_real, 1) 
    
    result = -torch.mean(
        ((1 - torch.sigmoid(y_pred)) ** gamma) * y_real * torch.log(torch.sigmoid(y_pred) + eps) + (1 - y_real) * torch.log(
        1 - torch.sigmoid(y_pred) + eps))
    return result

In [None]:
model_focal = SegNet().to(device)

max_epochs = 40
optimaizer = optim.Adam(model_focal.parameters(), lr=1e-4)
train(model_focal, optimaizer, focal_loss, max_epochs, data_tr, data_val)

In [None]:
score_model(model_focal, iou_pytorch, data_val)

### 2. UNet

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()

         self.enc_conv0 = nn.Sequential(*[
                                         nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(64),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(64),
                                         nn.ReLU()
                                         ])
        self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.enc_conv1 = nn.Sequential(*[
                                         nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(128),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(128),
                                         nn.ReLU()
                                         ])
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.enc_conv2 = nn.Sequential(*[
                                         nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(256),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(256),
                                         nn.ReLU()
                                         ])
        self.pool2 =  nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.enc_conv3 = nn.Sequential(*[
                                         nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(512),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(512),
                                         nn.ReLU()
                                         ])
        self.pool3 =  nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck_conv = nn.Sequential(*[
                                         nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(1024),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(1024),
                                         nn.ReLU()
                                         ])

        self.upsample0 = nn.Upsample(32)
        self.dec_conv0 = nn.Sequential(*[
                                         nn.Conv2d(in_channels=1024 + 512, out_channels=512, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(512),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(512),
                                         nn.ReLU()
                                         ])
        
        self.upsample1 = nn.Upsample(64)
        self.dec_conv1 = nn.Sequential(*[
                                         nn.Conv2d(in_channels=512 + 256, out_channels=256, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(256),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(256),
                                         nn.ReLU()
                                         ])
        self.upsample2 = nn.Upsample(128)
        self.dec_conv2 = nn.Sequential(*[
                                         nn.Conv2d(in_channels=256 + 128, out_channels=128, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(128),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(128),
                                         nn.ReLU()
                                         ])
        self.upsample3 = nn.Upsample(256
        self.dec_conv3 = nn.Sequential(*[
                                         nn.Conv2d(in_channels=128 + 64, out_channels=64, kernel_size=3, padding=1),
                                         nn.BatchNorm2d(64),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, padding=1)
                                         ])
                                     
    def forward(self, x):
        # encoder
        e0 = self.enc_conv0(x)
        e1 = self.enc_conv1(self.pool0(e0))
        e2 = self.enc_conv2(self.pool1(e1))
        e3 = self.enc_conv3(self.pool2(e2))

        # bottleneck
        b = self.bottleneck_conv(self.pool3(e3))

        # decoder
        d0 = self.upsample0(b)
        d0 = torch.cat([d0, e3], dim=1)
        d0 = self.dec_conv0(d0)
        
        d1 = self.upsample1(d0)
        d1 = torch.cat([d1, e2], dim=1)
        d1 = self.dec_conv1(d1)

        d2 = self.upsample2(d1)
        d2 = torch.cat([d2, e1], dim=1)
        d2 = self.dec_conv2(d2)

        d3 = self.upsample3(d2)
        d3 = torch.cat([d3, e0], dim=1)
        d3 = self.dec_conv3(d3)# no activation
        
        return d3

In [None]:
unet_model = UNet().to(device)

In [None]:
train(unet_model, optim.Adam(unet_model.parameters()), bce_loss, 20, data_tr, data_val)

In [None]:
score_model(unet_model, iou_pytorch, data_val)