# Import requirement

In [1]:
import torch.nn as nn

# Unet

In [2]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        # Stage one
        self.conv_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        
        # Stage two
        self.conv_3 = nn.Conv2d(64, 128 , kernel_size=3, padding=1)
        self.conv_4 = nn.Conv2d(128, 128 , kernel_size=3, padding=1)

        # Stage three
        self.conv_5 = nn.Conv2d(128, 256 , kernel_size=3, padding=1)
        self.conv_6 = nn.Conv2d(256, 256 , kernel_size=3, padding=1)

        # Stage four
        self.conv_7 = nn.Conv2d(256, 512 , kernel_size=3, padding=1)
        self.conv_8 = nn.Conv2d(512, 512 , kernel_size=3, padding=1)
       
        # Stage five
        self.conv_9 = nn.Conv2d(512, 1024 , kernel_size=3, padding=1)
        self.conv_10 = nn.Conv2d(1024, 1024 , kernel_size=3, padding=1)

        self.drop_out = nn.Dropout2d(0.5)
        self.maxpool = nn.MaxPool2d(2)  
        self.relu = nn.ReLU()

    def forward(self, x):
        # Stage one
        x = self.conv_1(x)
        x = self.relu(x)
        x = self.conv_2(x)
        y_1 = self.relu(x)
        x = self.maxpool(y_1)
        
        # Stage two
        x = self.conv_3(x)
        x = self.relu(x)
        x = self.conv_4(x)
        y_2 = self.relu(x)
        x = self.maxpool(y_2)
        
        # Stage three
        x = self.conv_5(x)
        x = self.relu(x)
        x = self.conv_6(x)
        y_3 = self.relu(x)
        x = self.maxpool(y_3)
        
        # Stage four
        x = self.conv_7(x)
        x = self.relu(x)
        x = self.conv_8(x)
        y_4 = self.relu(x)
        x = self.drop_out(y_4)
        x = self.maxpool(x)
        
        # Stage five
        x = self.conv_9(x)
        x = self.relu(x)
        x = self.conv_10(x)
        x = self.relu(x)
        x = self.drop_out(x)

        return y_4, y_3, y_2, y_1, x


In [3]:
import torchvision

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        # Stage four
        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.conv_up4_2 = nn.Conv2d(1024, 512, kernel_size=3, padding=1)
        self.conv_up4_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        
        # Stage three
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv_up3_2 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.conv_up3_1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        
        # Stage two
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv_up2_2 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.conv_up2_1 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        
        # Stage one
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv_up1_2 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.conv_up1_1 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv_up_1x1 = nn.Conv2d(64, 1, kernel_size=1)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

            


    def forward(self, y_4, y_3, y_2, y_1, x):
        # Stage four
        x = self.upconv4(x)
        y_4 = torchvision.transforms.CenterCrop((x.size(-1),x.size(-1)))(y_4)
        x = torch.cat([x, y_4], dim=1)  
        x = self.conv_up4_2(x)
        x = self.relu(x)
        x = self.conv_up4_1(x)
        x = self.relu(x)
        
        # Stage three
        x = self.upconv3(x)
        y_3 = torchvision.transforms.CenterCrop((x.size(-1),x.size(-1)))(y_3)
        x = torch.cat([x, y_3], dim=1)  
        x = self.conv_up3_2(x)
        x = self.relu(x)
        x = self.conv_up3_1(x)
        x = self.relu(x)
        
        # Stage two
        x = self.upconv2(x)
        y_2 = torchvision.transforms.CenterCrop((x.size(-1),x.size(-1)))(y_2)
        x = torch.cat([x, y_2], dim=1)  
        x = self.conv_up2_2(x)
        x = self.relu(x)
        x = self.conv_up2_1(x)
        x = self.relu(x)

        # Stage one
        x = self.upconv1(x)
        y_1 = torchvision.transforms.CenterCrop((x.size(-1),x.size(-1)))(y_1)
        # print(x.shape)
        x = torch.cat([x, y_1], dim=1)  
        x = self.conv_up1_2(x)
        x = self.relu(x)
        x = self.conv_up1_1(x)
        x = self.relu(x)
        x = self.conv_up_1x1(x) 
        
        return self.sigmoid(x)

In [4]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()


    def forward(self, x):
        # Encoding
        y_4, y_3, y_2, y_1, encoded = self.encoder(x)
        # Decoding
        decoded = self.decoder(y_4, y_3, y_2, y_1, encoded)
        
        return decoded


In [5]:
import torch
unet = UNet().to("cuda")
inputs = torch.randint(0, 255, (10, 3, 512, 512), dtype=torch.float32).to("cuda")
unet(inputs).shape

torch.Size([10, 1, 512, 512])

# Metrics

In [None]:
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score


y_true = y_true.flatten()
y_pred = y_pred.flatten()

def calculate_iou(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    return intersection / union

iou = calculate_iou(y_true, y_pred)

# Dice coefficient
def calculate_dice_coefficient(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    dice_coeff = (2.0 * intersection) / (np.sum(y_true) + np.sum(y_pred))
    return dice_coeff

dice_coefficient = calculate_dice_coefficient(y_true, y_pred)
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)
roc_auc = roc_auc_score(y_true, y_pred)

print(f"IoU: {iou}")
print(f"Dice coefficient: {dice_coefficient}")
print(f"Accuracy: {accuracy}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1-score: {f1}")
print(f"ROC AUC: {roc_auc}")


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc

fpr, tpr, thresholds = roc_curve(y_true, y_pred)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='gray', linestyle='--', lw=2, label='Random Guessing')

plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate (FPR)')
plt.ylabel('True Positive Rate (TPR)')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc='lower right')
plt.show()
