#### Import

In [None]:
import os
import numpy as np
import glob
import PIL.Image as Image
import cv2
# pip install torchsummary

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import models
from torchsummary import summary
import torch.optim as optim
from time import time

import matplotlib.pyplot as plt
from IPython.display import clear_output
from sklearn.model_selection import train_test_split


In [None]:
class config:
    def __init__(self):
        self.img_size = 128 # org size mean 575 x 766
        self.batch_size = 6 #6
        self.seed = 42
        self.workers = 3 #3
        self.lr = 0.001
config = config()

In [None]:
class PhC(torch.utils.data.Dataset):
    def __init__(self,transform):
        'Initialization'
        self.transform = transform
        self.data_path = "/dtu/datasets1/02516/PH2_Dataset_images"
        self.image_paths = sorted(glob.glob(f"{self.data_path}/*/*_Dermoscopic_Image/*.bmp"))
        self.label_paths = sorted(glob.glob(f"{self.data_path}/*/*_lesion/*_lesion.bmp"))

    def __len__(self):
        'Returns the total number of samples'
        return len(self.image_paths)

    def __getitem__(self, idx):
        'Generates one sample of data'
        image_path = self.image_paths[idx]
        label_path = self.label_paths[idx]
        image = Image.open(image_path)
        label = Image.open(label_path)
        Y = self.transform(label)
        X = self.transform(image)

        return X, Y
    

size = config.img_size
train_transform = transforms.Compose([transforms.Resize((size, size)), 
                                    transforms.ToTensor()])
test_transform = transforms.Compose([transforms.Resize((size, size)), 
                                    transforms.ToTensor()])

trainset = PhC(transform=train_transform)
trainset, val_test_set = torch.utils.data.random_split(trainset, [120, 80],generator=torch.Generator().manual_seed(config.seed))
valset, testset = torch.utils.data.random_split(val_test_set, [40, 40],generator=torch.Generator().manual_seed(config.seed))

train_loader = DataLoader(trainset, batch_size=config.batch_size, shuffle=True, num_workers=config.workers,generator=torch.Generator().manual_seed(config.seed))
val_loader = DataLoader(valset, batch_size=config.batch_size, shuffle=False, num_workers=config.workers,generator=torch.Generator().manual_seed(config.seed))
test_loader = DataLoader(testset, batch_size=config.batch_size, shuffle=False, num_workers=config.workers,generator=torch.Generator().manual_seed(config.seed))

#### EncDec

In [None]:
class EncDec(nn.Module):
    """
    Encoder-Decoder network for image segmentation.

    Args:
        None

    Attributes:
        enc_conv0 (nn.Conv2d): Convolutional layer for the first encoder block.
        pool0 (nn.MaxPool2d): Max pooling layer for downsampling.
        enc_conv1 (nn.Conv2d): Convolutional layer for the second encoder block.
        pool1 (nn.MaxPool2d): Max pooling layer for downsampling.
        enc_conv2 (nn.Conv2d): Convolutional layer for the third encoder block.
        pool2 (nn.MaxPool2d): Max pooling layer for downsampling.
        enc_conv3 (nn.Conv2d): Convolutional layer for the fourth encoder block.
        pool3 (nn.MaxPool2d): Max pooling layer for downsampling.
        bottleneck_conv (nn.Conv2d): Convolutional layer for the bottleneck block.
        upsample0 (nn.Upsample): Upsampling layer for the first decoder block.
        dec_conv0 (nn.Conv2d): Convolutional layer for the first decoder block.
        upsample1 (nn.Upsample): Upsampling layer for the second decoder block.
        dec_conv1 (nn.Conv2d): Convolutional layer for the second decoder block.
        upsample2 (nn.Upsample): Upsampling layer for the third decoder block.
        dec_conv2 (nn.Conv2d): Convolutional layer for the third decoder block.
        upsample3 (nn.Upsample): Upsampling layer for the fourth decoder block.
        dec_conv3 (nn.Conv2d): Convolutional layer for the final decoder block.

    Methods:
        forward(x): Performs forward pass through the network.

    """

    def __init__(self):
        super().__init__()

        inp_size = config.img_size//2
        # encoder (downsampling)
        self.enc_conv0 = nn.Conv2d(3, inp_size, 3, padding=1)
        self.pool0 = nn.MaxPool2d(2, 2)  # 128 -> 64
        self.enc_conv1 = nn.Conv2d(inp_size, inp_size, 3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)  # 64 -> 32
        self.enc_conv2 = nn.Conv2d(inp_size, inp_size, 3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)  # 32 -> 16
        self.enc_conv3 = nn.Conv2d(inp_size, inp_size, 3, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)  # 16 -> 8

        # bottleneck
        self.bottleneck_conv = nn.Conv2d(inp_size, inp_size, 3, padding=1)

        # decoder (upsampling)
        self.upsample0 = nn.Upsample(scale_factor=2)  # 8 -> 16
        self.dec_conv0 = nn.Conv2d(inp_size, inp_size, 3, padding=1)
        self.upsample1 = nn.Upsample(scale_factor=2)  # 16 -> 32
        self.dec_conv1 = nn.Conv2d(inp_size, inp_size, 3, padding=1)
        self.upsample2 = nn.Upsample(scale_factor=2)  # 32 -> 64
        self.dec_conv2 = nn.Conv2d(inp_size, inp_size, 3, padding=1)
        self.upsample3 = nn.Upsample(scale_factor=2)  # 64 -> 128
        self.dec_conv3 = nn.Conv2d(inp_size, 1, 3, padding=1)

    def forward(self, x):
        # encoder
        e0 = self.pool0(F.relu(self.enc_conv0(x)))
        e1 = self.pool1(F.relu(self.enc_conv1(e0)))
        e2 = self.pool2(F.relu(self.enc_conv2(e1)))
        e3 = self.pool3(F.relu(self.enc_conv3(e2)))

        # bottleneck
        b = F.relu(self.bottleneck_conv(e3))

        # decoder
        d0 = F.relu(self.dec_conv0(self.upsample0(b)))
        d1 = F.relu(self.dec_conv1(self.upsample1(d0)))
        d2 = F.relu(self.dec_conv2(self.upsample2(d1)))
        d3 = self.dec_conv3(self.upsample3(d2))  # no activation
        return d3

#### GeneralUNet

In [None]:
class UNET(nn.Module):
    def __init__(self, in_ch:int=3, out_ch:int=1,upsample_type='transpose'):
        super().__init__()
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.factor = 2 if upsample_type == 'bilinear' else 1

        self.d1 = create_block(in_channels=in_ch,out_channels=64,padding=1)
        self.d2 = create_block(in_channels=64,out_channels=128,padding=1)
        self.d3 = create_block(in_channels=128,out_channels=256,padding=1)
        self.d4 = create_block(in_channels=256,out_channels=512,padding=1)
        self.bottleneck = create_block(in_channels=512,out_channels=1024//self.factor,padding=1)


        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.u1_conv = create_block(in_channels=1024,out_channels=512//self.factor,padding=1)
        self.u2_conv = create_block(in_channels=512,out_channels=256//self.factor,padding=1)
        self.u3_conv = create_block(in_channels=256,out_channels=128//self.factor,padding=1)
        self.u4_conv = create_block(in_channels=128,out_channels=64,padding=1)

        if upsample_type == 'transpose':
            self.u1 = nn.ConvTranspose2d(1024, 512//self.factor, kernel_size=2, stride=2)
            self.u2 = nn.ConvTranspose2d(512, 256//self.factor, kernel_size=2, stride=2)
            self.u3 = nn.ConvTranspose2d(256, 128//self.factor, kernel_size=2, stride=2)
            self.u4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        elif upsample_type == 'bilinear':
            self.u1 = nn.Upsample(scale_factor=2,mode='bilinear')
            self.u2 = nn.Upsample(scale_factor=2,mode='bilinear')
            self.u3 = nn.Upsample(scale_factor=2,mode='bilinear')
            self.u4 = nn.Upsample(scale_factor=2,mode='bilinear')

        self.out = nn.Conv2d(64, out_ch, kernel_size=1)

        self.dropout = nn.Dropout(p=0.2)


    def forward(self, x):
        down1 = self.d1(x) 
        x = self.maxpool(down1)  

        down2 = self.d2(x)  
        x = self.maxpool(down2) 
        x=self.dropout(x)

        down3 = self.d3(x)  
        x = self.maxpool(down3)
        x=self.dropout(x)  

        down4 = self.d4(x)  
        x = self.maxpool(down4)
        x=self.dropout(x)  

        bottleneck = self.bottleneck(x) 

        # Decoder
        up1 = self.u1(bottleneck) 
        x = torch.cat([up1, down4], dim=1)  
        x = self.u1_conv(x)
        x=self.dropout(x)  

        up2 = self.u2(x) 
        x = torch.cat([up2, down3], dim=1)  
        x = self.u2_conv(x)
        x=self.dropout(x)  

        up3 = self.u3(x)  
        x = torch.cat([up3, down2], dim=1)  
        x = self.u3_conv(x)
        x=self.dropout(x) 

        up4 = self.u4(x) 
        x = torch.cat([up4, down1], dim=1)  
        x = self.u4_conv(x) 

        # #output layer
        output = self.out(x)
        return output

#### Metrics class

In [None]:
class metrics:
    def __init__(self,eps:float=1e-8):
        self.eps = eps

    def get_confusion(self,y_hat,mask):
        # assuming y_hat is logits, then convert to confidences using sigmoid
        if y_hat.min().item() < 0.0 or (y_hat.max().item() > 1.0):
            y_hat = torch.sigmoid(y_hat)
        y_hat = (y_hat > 0.50).float()

        self.TP = (y_hat.flatten() * mask.flatten()).sum()
        self.FN = mask[y_hat == 0].sum()
        self.FP = y_hat[mask == 0].sum()
        self.TN = y_hat.numel() - self.TP - self.FN - self.FP

    def get_metrics(self,y_hat,mask):
        self.get_confusion(y_hat,mask)
        dice = ((2 * self.TP) / (2 * self.TP + self.FN + self.FP + self.eps)).item()
        iou = ((self.TP) / (self.TP + self.FN + self.FP )).item()
        acc = (self.TP+self.TN)/(self.TP+self.TN+self.FP+self.FN)
        sensitivity = self.TP/(self.TP+self.FN)
        specificity = self.TN/(self.TN+self.FP)
        self.metric_dict = {'dice':dice,'iou':iou,'acc':acc,'sensitivity':sensitivity,'specificity':specificity}
        return self.metric_dict 
    
    def print_my_metrics(self,y_hat,mask,type_):
        metric_dict = self.get_metrics(y_hat,mask)
        for key in metric_dict:
            print(f"{type_} {key}: {metric_dict[key]}")
    
evaluator = metrics() 

### loss_func

In [None]:
class loss_func:
    def __init__(self,type_:str='BCE'):
        self.type_ = type_
        self.gamma = 2
        self.pos_weights = torch.tensor(2) # approx

    def BCE(self):
        return torch.nn.BCEWithLogitsLoss()

    def FocalLoss(self,y_hat,mask):
        y_hat = torch.sigmoid(y_hat)
        return - torch.mean((1-y_hat)**self.gamma  * mask * torch.log(y_hat) + (1-mask)*torch.log(1-y_hat))

    def WeightedBCE(self):
        return torch.nn.BCEWithLogitsLoss(pos_weight=self.pos_weights)
    
    def get_loss(self,y_hat,mask):
        if self.type_ == 'BCE':
            return self.BCE()(y_hat,mask)
        elif self.type_ == 'FocalLoss':
            return self.FocalLoss(y_hat,mask)
        elif self.type_ == 'WeightedBCE':
            return self.WeightedBCE()(y_hat,mask)
    
    def loss_name(self):
        return self.type_

#### trainer

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
def train(model, opt, criterion, epochs, train_loader, val_loader, test_loader,save_text=None):
    X_val, Y_val = next(iter(val_loader))
    X_test,Y_test = next(iter(test_loader))

    best_val_dice = 0.0
    for epoch in range(epochs):
        tic = time()
        print('* Epoch %d/%d' % (epoch+1, epochs))

        avg_loss = 0
        model.train()  # train mode

        trains_dice = 0
        train_acc = 0
        trains_iou = 0
        train_sensitivity = 0
        train_specificity = 0
        for X_batch, Y_batch in train_loader:
            X_batch = X_batch.to(device)
            Y_batch = Y_batch.to(device)

            # set parameter gradients to zero
            opt.zero_grad()

            # forward
            Y_pred = model(X_batch)
            loss = criterion.get_loss(Y_pred,Y_batch)  # forward-pass
            loss.backward()  # backward-pass
            opt.step()  # update weights

            train_metrics = evaluator.get_metrics(F.sigmoid(Y_pred),Y_batch)

            # calculate metrics to show the user
            avg_loss += loss / len(train_loader)
            train_dice += train_metrics['dice']/ len(train_loader)
            train_acc += train_metrics['acc']/ len(train_loader)
            trains_iou += train_metrics['iou']/ len(train_loader)
            train_sensitivity += train_metrics['sensitivity']/ len(train_loader)
            train_specificity += train_metrics['specificity']/ len(train_loader)

        train_final_metrics = {'loss':avg_loss,'acc':train_acc,'dice':train_dice,'iou':trains_iou,'sensitivity':train_sensitivity,'specificity':train_specificity}
        toc = time()
        print(' - loss: %f' % avg_loss)


        # show intermediate results
        model.eval()  # testing mode

        val_dice = 0
        val_acc = 0
        val_iou = 0
        val_sensitivity = 0
        val_specificity = 0
        val_avg_loss=0

        with torch.no_grad():
            for X_batch, Y_batch in val_loader:
                X_batch = X_batch.to(device)
                Y_batch = Y_batch.to(device)
                Y_pred = model(X_batch)
                loss = criterion.get_loss(Y_pred,Y_batch)
                val_avg_loss += loss / len(val_loader)
                val_metrics = evaluator.get_metrics(F.sigmoid(Y_pred),Y_batch)
                val_dice += val_metrics['dice']/ len(val_loader)
                val_acc += val_metrics['acc']/ len(val_loader)
                val_iou += val_metrics['iou']/ len(val_loader)
                val_sensitivity += val_metrics['sensitivity']/ len(val_loader)
                val_specificity += val_metrics['specificity']/ len(val_loader)

        val_final_metrics = {'loss':val_avg_loss,'acc':val_acc,'dice':val_dice,'iou':val_iou,'sensitivity':val_sensitivity,'specificity':val_specificity}

        if val_dice >best_val_dice:
            best_val_dice = val_dice
            best_model_path = f"best_model.pth"
            torch.save(model.state_dict(),best_model_path)
            best_val_final_metrics = val_final_metrics
            best_train_final_metrics = train_final_metrics

        Y_hat = F.sigmoid(model(X_val.to(device))).detach().cpu()
        clear_output(wait=True)
        fig, ax = plt.subplots(nrows=4, ncols=6, figsize=(8.5, 5),dpi=400)
        for k in range(6):
            im = np.rollaxis(X_val[k].numpy(), 0, 3)
            cmap = 'jet'
            ax[0,k].imshow(im)
            ax[1,k].imshow(Y_val[k, 0], cmap=cmap)
            ax[2,k].imshow(Y_hat[k, 0], cmap=cmap,vmin=0, vmax=1)
            thres = 0.5
            maskim = im.copy()
            maskim[Y_hat[k][0]<=thres] = [0,0,0]
            implt = ax[3,k].imshow(maskim, cmap=cmap,vmin=0, vmax=1)
        titles = ['Real','Mask','Output','Threshold']
        for i in range(4): 
            ax[i,0].text(-0.1, 0.5, titles[i],rotation=90,ha='center',va='center',transform=ax[i,0].transAxes)
        for a in ax.flat: 
            a.set_axis_off()
        #metricstxt = 'dice: {:.2f} iou: {:.2f} acc: {:.2f} sens: {:.2f} spec: {:.2f}'.format(val_dice,val_iou,val_acc,val_sensitivity,val_specificity)
        #ax[0,0].text(0, 1.25, metricstxt,ha='left',va='center',transform=ax[0,0].transAxes)
        plt.suptitle(f"Validation")
        metricstxt = f'Epoch {epoch+1} / {epochs}    loss: {val_avg_loss:.2f}'
        cbar = fig.colorbar(implt, ax=ax.ravel().tolist(), shrink=0.97)
        cbar.set_ticks([0,thres,1])
        cbar.set_ticklabels(['0','0.5','1'])
        plt.show()
        if epoch+1==epochs:
            plt.savefig(save_text+'.pdf')
    
    # Now test on best val model
    model = torch.load(best_model_path)
    model.eval()
    val_dice = 0
    val_acc = 0
    val_iou = 0
    val_sensitivity = 0
    val_specificity = 0
    val_avg_loss=0
    with torch.no_grad():
        for X_batch, Y_batch in test_loader:
            X_batch = X_batch.to(device)
            Y_batch = Y_batch.to(device)
            Y_pred = model(X_batch)
            loss = criterion.get_loss(Y_pred,Y_batch)
            test_avg_loss += loss / len(val_loader)
            test_metrics = evaluator.get_metrics(F.sigmoid(Y_pred),Y_batch)
            test_dice += test_metrics['dice']/ len(test_loader)
            test_acc += test_metrics['acc']/ len(test_loader)
            test_iou += test_metrics['iou']/ len(test_loader)
            test_sensitivity += test_metrics['sensitivity']/ len(test_loader)
            test_specificity += test_metrics['specificity']/ len(test_loader)

    test_final_metrics = {'loss':test_avg_loss,'acc':test_acc,'dice':test_dice,'iou':test_iou,'sensitivity':test_sensitivity,'specificity':test_specificity}

    Y_hat = F.sigmoid(model(X_test.to(device))).detach().cpu()    
    fig, ax = plt.subplots(nrows=4, ncols=6, figsize=(8.5, 5),dpi=400)
    for k in range(6):
        im = np.rollaxis(X_val[k].numpy(), 0, 3)
        cmap = 'jet'
        ax[0,k].imshow(im)
        ax[1,k].imshow(Y_test[k, 0], cmap=cmap)
        ax[2,k].imshow(Y_hat[k, 0], cmap=cmap,vmin=0, vmax=1)
        thres = 0.5
        maskim = im.copy()
        maskim[Y_hat[k][0]<=thres] = [0,0,0]
        implt = ax[3,k].imshow(maskim, cmap=cmap,vmin=0, vmax=1)
    titles = ['Real','Mask','Output','Threshold']
    for i in range(4): 
        ax[i,0].text(-0.1, 0.5, titles[i],rotation=90,ha='center',va='center',transform=ax[i,0].transAxes)
    for a in ax.flat: 
        a.set_axis_off()
    #metricstxt = 'dice: {:.2f} iou: {:.2f} acc: {:.2f} sens: {:.2f} spec: {:.2f}'.format(val_dice,val_iou,val_acc,val_sensitivity,val_specificity)
    #ax[0,0].text(0, 1.25, metricstxt,ha='left',va='center',transform=ax[0,0].transAxes)
    plt.suptitle(f"Test on best validation")
    metricstxt = f'Total epochs {epochs}    loss: {test_avg_loss:.2f}'
    cbar = fig.colorbar(implt, ax=ax.ravel().tolist(), shrink=0.97)
    cbar.set_ticks([0,thres,1])
    cbar.set_ticklabels(['0','0.5','1'])
    plt.show()
    plt.savefig(save_text+'_test'+'.pdf')

    return best_train_final_metrics,best_val_final_metrics,test_final_metrics

In [None]:
loss_type = 'BCE'
criterion = criterion = loss_func(type_=loss_type) #torch.nn.BCEWithLogitsLoss()
model = EncDec().to(device)
optimizer = optim.Adam(model.parameters(),lr=0.001)
train_metrics,val_metrics,test_metrics = train(model=model,opt=optimizer,criterion=criterion,epochs=2,train_loader=train_loader,val_loader=val_loader,test_loader=test_loader,save_text=f"EncDec_{loss_type}")

In [None]:
train_metrics,val_metrics,test_metrics

In [None]:
import pandas as pd


In [None]:
criterion = criterion = loss_func(type_='BCE') #torch.nn.BCEWithLogitsLoss()
model = UNET().to(device)
optimizer = optim.Adam(model.parameters(),lr=0.001)
train(model=model,opt=optimizer,criterion=criterion,epochs=50,train_loader=train_loader,val_loader=val_loader)