# MRI Hippocampus Segmentation

In this project, we propose a better framework for identifying Alzheimer's disease (AD) using U-Net variants trained on brain magnetic resonance imaging (MRI) data. We explore various state-of-the-art U-Net topologies, such as U-Net++, with the intention of enhancing the accuracy and robustness of AD diagnosis using advanced segmentation techniques.

### 1. Libraries

All the necessary libraries that need to be installed to run the code:

- OpenCV (cv2)
- Pandas
- NumPy
- Matplotlib
- Torchvision
- PyTorch
- TQDM
- PIL (Python Imaging Library)

In [None]:
import cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from pathlib import Path
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor, Compose, Normalize
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from torch.nn import Linear, ReLU, CrossEntropyLoss, Sequential, Conv2d, MaxPool2d, Module, Softmax, BatchNorm2d, Dropout
from PIL import Image

**TODO:** Replace "mps" with "cuda" or the appropriate CUDA version installed on your machine as needed.

In [None]:
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

Expect output: 

tensor([1.], device='mps:0')

In [None]:
print(f"PyTorch version: {torch.__version__}")

# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")

# Set the device      
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

Expect output:

PyTorch version: 2.1.0.dev20230413

Is MPS (Metal Performance Shader) built? True

Is MPS available? True

Using device: mps

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### 2. Prepare Data¶

In [None]:
def get_sorted_file_paths(label_path, image_path, num):
    '''
    This function takes in the paths to the directories containing the label and image files, 
    as well as a number indicating whether the dataset has 35 or 100 files. 
    It then uses these paths to extract the sorted paths for the left and right hippocampus images, 
    as well as the total image paths, and returns them as lists. 
    The function also includes error handling for cases where there may be issues with the file paths.
    '''
    Label_Path = list(label_path.glob(r"**/*.jpg"))
    Image_Path = list(image_path.glob(r"**/*.jpg"))

    Label_Series = pd.Series(Label_Path, name="LABEL", dtype='object').astype(str)
    Image_Series = pd.Series(Image_Path, name="IMAGE", dtype='object').astype(str)

    if num == 35:
        Split_Params_For_LABEL = "35label/"
        Split_Params_For_IMG = "35/"
    elif num == 100:
        Split_Params_For_LABEL = "100label/"
        Split_Params_For_IMG = "100/"
    else:
        raise ValueError("Invalid value for 'num'.")

    Common_Params = "/"
    List_Split_Params = "_"

    L_IMG = []
    R_IMG = []
    Total_IMG = []

    for label_x in Label_Series:
        try:
            L_Main_Path, L_Target_Path = label_x.split(Split_Params_For_LABEL)
            Label_Path_Before, Label_Path_Middle, Label_Path_After = L_Target_Path.split(Common_Params)
            Label_Path_Split = Label_Path_After.split(List_Split_Params)
            
            if "ADNI_013_S_0325_85153_ACPC" in L_Target_Path:
                continue

            if Label_Path_Split[-2] == 'L':
                L_IMG.append(label_x)
            elif Label_Path_Split[-2] == 'R':
                R_IMG.append(label_x)
            else:
                print("SOMETHING IS WRONG!")
        except:
            print("LABEL: ", label_x)

    for image_x in Image_Series:
        L_Main_Path, L_Target_Path = image_x.split(Split_Params_For_IMG)
        if "ADNI_013_S_0325_85153_ACPC" in L_Target_Path:
                continue
        
        Total_IMG.append(image_x)

    Sort_L = sorted(L_IMG)
    Sort_R = sorted(R_IMG)
    Sort_IMG = sorted(Total_IMG)

    return Sort_L, Sort_R, Sort_IMG


**TODO:** To use this code, please change the *folder* variable to the path where you have saved the Alzheimer's dataset.

In [None]:
# folder = "YOUR FOLDER"

In [None]:
# Using Google Drive

# from google.colab import drive
# drive.mount('/content/drive')

# folder = "/content/drive/MyDrive"

In [None]:
Label100_Path = Path(folder + "/hippocampus/label/100label/")
Image100_Path = Path(folder + "/hippocampus/original/100")

Sort_L_100, Sort_R_100, Sort_IMG_100 = get_sorted_file_paths(Label100_Path, Image100_Path, 100)

Label35_Path = Path(folder + "/hippocampus/label/35label")
Image35_Path = Path(folder + "/hippocampus/original/35")

Sort_L_35, Sort_R_35, Sort_IMG_35 = get_sorted_file_paths(Label35_Path, Image35_Path, 35)

In [None]:
print("100label - L:", len(Sort_L_100))
print("100label - R:", len(Sort_R_100))
print("100label - IMG:", len(Sort_IMG_100))
print("---------------------------------")
print("35label - L:", len(Sort_L_35))
print("35label - R:", len(Sort_R_35))
print("35label - IMG:", len(Sort_IMG_35))

In [None]:
def extract_and_prepare_data(sorted_images, sorted_left, sorted_right, threshold=250, min_pixel_count=50):
    '''
    This function extracts and prepares the data for training the model.
    It loads the images and hippocampus masks, applies a threshold, and checks if there are enough pixels in the mask.
    If the condition is met, the function adds the image and the hippocampus mask to the data to be returned.
    '''
    X_Image = []
    X_Hippocampus = []

    for img, left, right in zip(sorted_images, sorted_left, sorted_right):
        left_gray = cv2.imread(left, cv2.IMREAD_GRAYSCALE)
        right_gray = cv2.imread(right, cv2.IMREAD_GRAYSCALE)

        left_pixel_count = np.sum(left_gray > threshold)
        right_pixel_count = np.sum(right_gray > threshold)

        if left_pixel_count > min_pixel_count or right_pixel_count > min_pixel_count:
            image = cv2.cvtColor(cv2.imread(img), cv2.COLOR_BGR2RGB)
            left_colored = cv2.cvtColor(cv2.imread(left, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
            right_colored = cv2.cvtColor(cv2.imread(right, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
        
            hippocampus_concat = cv2.addWeighted(left_colored, 1, right_colored, 1, 0.2)
            
            X_Image.append(image)
            X_Hippocampus.append(hippocampus_concat)

    return np.array(X_Image), np.array(X_Hippocampus)

In [None]:
X_Image_100, X_Hippocampus_100 = extract_and_prepare_data(Sort_IMG_100[:10000], Sort_L_100[:10000], Sort_R_100[:10000])

In [None]:
print(X_Image_100.shape)
print(X_Hippocampus_100.shape)

In [None]:
figure,axis = plt.subplots(1,2,figsize=(10,10))

index = 1

axis[0].imshow(X_Image_100[index])
axis[0].set_xlabel(X_Image_100[index].shape)
axis[0].set_title("IMAGE")

axis[1].imshow(X_Hippocampus_100[index])
axis[1].set_xlabel(X_Hippocampus_100[index].shape)
axis[1].set_title("HIPPOCAMPUS")

In [None]:
def pad_images(images, target_shape=256):
    '''
    This code defines a function to pad images to a desired target shape. 
    The function takes in an array of images and a target shape (default value is 256) 
    and returns an array of padded images. For each image, the function calculates the difference between 
    the image's shape and the target shape, and then pads the image accordingly using np.pad. 
    Finally, the function returns the padded images as a NumPy array.
    '''
    padded_images = []
    for image in images:
        pad_height = target_shape - image.shape[0]
        pad_width = target_shape - image.shape[1]

        pad_top = pad_height // 2
        pad_bottom = pad_height - pad_top
        pad_left = pad_width // 2
        pad_right = pad_width - pad_left

        padded_image = np.pad(image, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode='constant')
        padded_images.append(padded_image)

    return np.array(padded_images)

In [None]:
X_Image_100_Hippocampus = pad_images(X_Image_100)
X_Hippocampus_100_Label = pad_images(X_Hippocampus_100)

print("Padded 100 - ARRAY IMAGE SHAPE: ", X_Image_100_Hippocampus.shape)
print("Padded 100 - ARRAY HIPPOCAMPUS SHAPE: ", X_Hippocampus_100_Label.shape)

In [None]:
figure,axis = plt.subplots(1, 2,figsize=(10,10))

index = 5

axis[0].imshow(X_Image_100_Hippocampus[index])
axis[0].set_xlabel(X_Image_100_Hippocampus[index].shape)
axis[0].set_title("IMAGE")

axis[1].imshow(X_Hippocampus_100_Label[index], cmap='gray')
axis[1].set_xlabel(X_Hippocampus_100_Label[index].shape)
axis[1].set_title("HIPPOCAMPUS")


In [None]:
def convert_to_binary_masks(images, threshold=127):
    '''
    This function converts a set of images to binary masks using thresholding.
    It takes as input an array of images and an optional threshold value (default is 127).
    It returns an array of binary masks with the same shape as the input images.
    '''
    binary_masks = []
    for image in images:
        # Convert the image to grayscale if it's not already
        if len(image.shape) == 3 and image.shape[2] == 3:
            gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        else:
            gray_image = image

        # Apply thresholding to create a binary mask
        _, binary_mask = cv2.threshold(gray_image,255,255,cv2.THRESH_TOZERO_INV)
        binary_masks.append(binary_mask)

    return np.array(binary_masks)

In [None]:
binary_masks = convert_to_binary_masks(X_Hippocampus_100_Label)

In [None]:
figure,axis = plt.subplots(1, 2,figsize=(10,10))

index = 5

axis[0].imshow(X_Image_100_Hippocampus[index])
axis[0].set_xlabel(X_Image_100_Hippocampus[index].shape)
axis[0].set_title("IMAGE")

axis[1].imshow(binary_masks[index], cmap='gray')
axis[1].set_xlabel(binary_masks[index].shape)
axis[1].set_title("HIPPOCAMPUS")

In [None]:
X_Image_35, X_Hippocampus_35 = extract_and_prepare_data(Sort_IMG_35[:100], Sort_L_35[:100], Sort_R_35[:100])
X_Image_35_Hippocampus = pad_images(X_Image_35)
X_Hippocampus_35_Label = pad_images(X_Hippocampus_35)
X_Hippocampus_35_Binary = convert_to_binary_masks(X_Hippocampus_35_Label)

In [None]:
figure,axis = plt.subplots(1, 2,figsize=(10,10))

index = 5

axis[0].imshow(X_Image_35_Hippocampus[index])
axis[0].set_xlabel(X_Image_35_Hippocampus[index].shape)
axis[0].set_title("IMAGE")

axis[1].imshow(X_Hippocampus_35_Binary[index], cmap='gray')
axis[1].set_xlabel(X_Hippocampus_35_Binary[index].shape)
axis[1].set_title("HIPPOCAMPUS")

### 3. Set DataLoader

In [None]:
class CustomDataset(Dataset):
    def __init__(self, images, masks, transform=None):
        self.images = images
        self.masks = masks
        self.transform = ToTensor()
     
    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image = self.images[index]
        mask = self.masks[index]

        image_tensor = self.transform(image)
        mask_tensor = self.transform(mask)

        return image_tensor, mask_tensor

In [None]:
def split_dataset(dataset, batch_size, split_ratio=0.8):
    '''
    This function takes in a dataset, batch size, and split ratio, and returns two dataloaders - 
    one for the training set and one for the validation set. 
    '''
    trainset_len = int(len(dataset) * split_ratio)
    valset_len = int(len(dataset) - trainset_len)
    trainset, valset = random_split(dataset, [trainset_len, valset_len])
    train_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(dataset=valset, batch_size=batch_size)
    return train_loader, val_loader

def small_dataset(batch_size):
    '''
    This function randomly selects a subset of images and masks from the full dataset, 
    creates a CustomDataset object using these images and masks, and returns the dataloaders obtained 
    by passing this dataset to the split_dataset function.
    '''
    indices = torch.randperm(len(X_Image_100_Hippocampus))[:train_number]
    images = [X_Image_100_Hippocampus[i] for i in indices]
    masks = [binary_masks[i] for i in indices]

    dataset = CustomDataset(images, masks)
    train_loader, val_loader = split_dataset(dataset, batch_size=batch_size)
    return train_loader, val_loader

def full_dataset(batch_size):
    '''
    This function creates a CustomDataset object using all the images and masks from the full dataset, 
    and returns the dataloaders obtained by passing this dataset to the split_dataset function.
    '''
    dataset = CustomDataset(X_Image_100_Hippocampus, binary_masks)
    train_loader, val_loader = split_dataset(dataset, batch_size=batch_size)
    return train_loader, val_loader

**TODO:** Before running this code, you need to define the value of *train_number* and *batch_size*, as they are used in the *small_dataset* function. Once you have set these values, you can call the *small_dataset* function with the desired *batch_size*, and it will return *train_loader* and *val_loader*, which are used for training and validation.
Normalizing the images use 'Min-max normalization' or 'Standardization'.

In [None]:
# train_number = 300 # choose your own value
# batch_size = 10
# train_loader, val_loader = small_dataset(batch_size)
# print(len(train_loader))

In [None]:
# batch_size = 64
# train_loader, val_loader = full_dataset(batch_size)

In [None]:
xb, yb = next(iter(train_loader))
xb.shape, yb.shape

Expext output:

(torch.Size([10, 3, 256, 256]), torch.Size([10, 1, 256, 256]))

Creating two index lists for train and validation data splitting. The total number of images used for training is specified by the variable train_number. The code first creates a list of indices with the length of train_number. It then splits the list into two parts using a 0.2 validation split ratio.

In [None]:
indices = list(range(train_number))
split = int(np.floor(0.2 * train_number))
train_idx, valid_idx = indices[split:], indices[:split]

In [None]:
# Test Loader
test_dataset = CustomDataset(X_Image_35_Hippocampus, X_Hippocampus_35_Binary)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

### 4. Models

U-Net/U-Net++ models from pytorch

In [None]:
import segmentation_models_pytorch as smp

# Create a U-Net model
unet_model = smp.Unet(
    encoder_name="resnet34", # choose encoder, e.g. resnet34, resnet50, etc.
    encoder_weights="imagenet", # use pre-trained weights for encoder initialization
    in_channels=3, # number of input channels
    classes=1, # number of output channels
)

# Create a U-Net++ model
unet_plusplus_model = smp.UnetPlusPlus(
    encoder_name="resnet34", # choose encoder, e.g. resnet34, resnet50, etc.
    encoder_weights="imagenet", # use pre-trained weights for encoder initialization
    in_channels=3, # number of input channels
    classes=1, # number of output channels
)

In [None]:
import torch.optim as optim
from segmentation_models_pytorch.utils import metrics

model = unet_plusplus_model.to(device)  
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Instantiate the Dice coefficient metric
dice_metric = metrics.Fscore()

In [None]:
# Training loop
num_epochs = 20
for epoch in range(num_epochs):
    print("epoch:", epoch)
    model.train()
    epoch_loss = 0
    epoch_dice_score = 0
    num_batches = 0

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        # Update epoch loss and dice score
        epoch_loss += loss.item()
        epoch_dice_score += dice_metric(outputs, targets).item()
        num_batches += 1

    # Calculate average loss and dice score for the epoch
    epoch_loss /= num_batches
    epoch_dice_score /= num_batches

    print(f"Epoch {epoch + 1}/{num_epochs}: Loss = {epoch_loss:.4f}, Dice Score = {epoch_dice_score:.4f}")

    # Save model checkpoint (optional)
    # torch.save(model.state_dict(), f"unet_checkpoint_epoch_{epoch + 1}.pth")

    # Validation loop (optional)
    # model.eval()
    # with torch.no_grad():
    #     ... # Perform validation using the same approach as the training loop
    #     # Print validation statistics


U-Net/U-Net++ model from scratch

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class UpConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpConv, self).__init__()

        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        

    def forward(self, x):
        return self.up(x)


class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()

        '''
        TODO: choose your own 'base_channels' value
        '''
        base_channels = 32
#         base_channels = 64
        filters = [base_channels * i for i in [1, 2, 4, 8, 16]]

        # Encoder part
        self.encoders = nn.ModuleList([
            ConvBlock(in_channels, filters[0]),
            ConvBlock(filters[0], filters[1]),
            ConvBlock(filters[1], filters[2]),
            ConvBlock(filters[2], filters[3]),
            ConvBlock(filters[3], filters[4]),
        ])

        self.pools = nn.ModuleList([nn.MaxPool2d(kernel_size=2, stride=2) for _ in range(4)])

        # Decoder part
        self.upconvs = nn.ModuleList([
            UpConv(filters[4], filters[3]),
            UpConv(filters[3], filters[2]),
            UpConv(filters[2], filters[1]),
            UpConv(filters[1], filters[0]),
        ])

        self.decoders = nn.ModuleList([
            ConvBlock(filters[4], filters[3]),
            ConvBlock(filters[3], filters[2]),
            ConvBlock(filters[2], filters[1]),
            ConvBlock(filters[1], filters[0]),
        ])

        self.output_conv = nn.Conv2d(filters[0], out_channels, kernel_size=1, stride=1, padding=0)
        
        # Use dropout
#         self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        # Encoder part
        encoder_outputs = []
        for i in range(5):
            x = self.encoders[i](x)
            encoder_outputs.append(x)
            if i < 4:
                x = self.pools[i](x)

        # Decoder part
        for i in range(4):
            x = self.upconvs[i](x)
            x = torch.cat((encoder_outputs[-(i + 2)], x), dim=1)
            x = self.decoders[i](x)

        out = self.output_conv(x)
        
        # Use a Sigmoid activation function for final output layer 
        out = torch.sigmoid(out)
        
        return out

In [None]:
'''
Architecture

UNet(
  (encoders): ModuleList(
    (0): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (1): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (2): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (3): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (4): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (pools): ModuleList(
    (0-3): 4 x MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (upconvs): ModuleList(
    (0): UpConv(
      (up): Sequential(
        (0): Upsample(scale_factor=2.0, mode='nearest')
        (1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
      )
    )
    (1): UpConv(
      (up): Sequential(
        (0): Upsample(scale_factor=2.0, mode='nearest')
        (1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
      )
    )
    (2): UpConv(
      (up): Sequential(
        (0): Upsample(scale_factor=2.0, mode='nearest')
        (1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
      )
    )
    (3): UpConv(
      (up): Sequential(
        (0): Upsample(scale_factor=2.0, mode='nearest')
        (1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
      )
    )
  )
  (decoders): ModuleList(
    (0): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (1): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (2): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (3): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (output_conv): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
)
'''

In [None]:
class NestedConvBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, dropout_rate=0.5):
        super(NestedConvBlock, self).__init__()
        
        self.activation = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(dropout_rate)
        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=True)
        self.bn1 = nn.BatchNorm2d(mid_channels)
        self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=True)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.match_channels = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)

    def forward(self, x):
        identity = x
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.conv2(x)
        x = self.bn2(x)
        
        identity = self.match_channels(identity)
        
        x = x + identity
        output = self.activation(x)

        return output


class UNetPlusPlus(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, deep_supervision=True, dropout_rate=0.5):
        super(UNetPlusPlus, self).__init__()

        '''
        TODO: choose your own 'base_channels' value
        '''
        base_channels = 32
#         base_channels = 64
        filters = [base_channels * i for i in [1, 2, 4, 8, 16]]

        self.deep_supervision = deep_supervision
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        #  Define the convolutional layers for each level of the UNet++
        self.conv0_0 = NestedConvBlock(in_channels, filters[0], filters[0], dropout_rate)
        self.conv1_0 = NestedConvBlock(filters[0], filters[1], filters[1], dropout_rate)
        self.conv2_0 = NestedConvBlock(filters[1], filters[2], filters[2], dropout_rate)
        self.conv3_0 = NestedConvBlock(filters[2], filters[3], filters[3], dropout_rate)
        self.conv4_0 = NestedConvBlock(filters[3], filters[4], filters[4], dropout_rate)

        self.conv0_1 = NestedConvBlock(filters[0] + filters[1], filters[0], filters[0], dropout_rate)
        self.conv1_1 = NestedConvBlock(filters[1] + filters[2], filters[1], filters[1], dropout_rate)
        self.conv2_1 = NestedConvBlock(filters[2] + filters[3], filters[2], filters[2], dropout_rate)
        self.conv3_1 = NestedConvBlock(filters[3] + filters[4], filters[3], filters[3], dropout_rate)

        self.conv0_2 = NestedConvBlock(filters[0]*2 + filters[1], filters[0], filters[0], dropout_rate)
        self.conv1_2 = NestedConvBlock(filters[1]*2 + filters[2], filters[1], filters[1], dropout_rate)
        self.conv2_2 = NestedConvBlock(filters[2]*2 + filters[3], filters[2], filters[2], dropout_rate)

        self.conv0_3 = NestedConvBlock(filters[0]*3 + filters[1], filters[0], filters[0], dropout_rate)
        self.conv1_3 = NestedConvBlock(filters[1]*3 + filters[2], filters[1], filters[1], dropout_rate)

        self.conv0_4 = NestedConvBlock(filters[0]*4 + filters[1], filters[0], filters[0], dropout_rate)

        self.final = nn.Conv2d(filters[0], out_channels, kernel_size=1) 
        
        if self.deep_supervision:
            self.final1 = nn.Conv2d(filters[0], out_channels, kernel_size=1)
            self.final2 = nn.Conv2d(filters[0], out_channels, kernel_size=1)
            self.final3 = nn.Conv2d(filters[0], out_channels, kernel_size=1)
            self.final4 = nn.Conv2d(filters[0], out_channels, kernel_size=1)
        else:
            self.final = nn.Conv2d(filters[0], out_channels, kernel_size=1)

    def forward(self, x):
        
        x0_0 = self.conv0_0(x)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.Up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.Up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.Up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.Up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.Up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.Up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.Up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.Up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.Up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.Up(x1_3)], 1))
        
        if self.deep_supervision:
            output1 = self.final1(x0_1)
            output1 = torch.sigmoid(output1)
            output2 = self.final2(x0_2)
            output2 = torch.sigmoid(output2)
            output3 = self.final3(x0_3)
            output3 = torch.sigmoid(output3)
            output4 = self.final4(x0_4)
            output4 = torch.sigmoid(output4)
            return [output1, output2, output3, output4]

        else:
            output = self.final(x0_4)
            output = torch.sigmoid(out)
            return output

In [None]:
'''
Architecture

UNetPlusPlus(
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Up): Upsample(scale_factor=2.0, mode='bilinear')
  (conv0_0): NestedConvBlock(
    (activation): ReLU(inplace=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (match_channels): Conv2d(3, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (conv1_0): NestedConvBlock(
    (activation): ReLU(inplace=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (match_channels): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (conv2_0): NestedConvBlock(
    (activation): ReLU(inplace=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (match_channels): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (conv3_0): NestedConvBlock(
    (activation): ReLU(inplace=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (match_channels): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (conv4_0): NestedConvBlock(
    (activation): ReLU(inplace=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (match_channels): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (conv0_1): NestedConvBlock(
    (activation): ReLU(inplace=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (match_channels): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (conv1_1): NestedConvBlock(
    (activation): ReLU(inplace=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (match_channels): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (conv2_1): NestedConvBlock(
    (activation): ReLU(inplace=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (match_channels): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (conv3_1): NestedConvBlock(
    (activation): ReLU(inplace=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (match_channels): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (conv0_2): NestedConvBlock(
    (activation): ReLU(inplace=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (match_channels): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (conv1_2): NestedConvBlock(
    (activation): ReLU(inplace=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (match_channels): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (conv2_2): NestedConvBlock(
    (activation): ReLU(inplace=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv2d(512, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (match_channels): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (conv0_3): NestedConvBlock(
    (activation): ReLU(inplace=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv2d(160, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (match_channels): Conv2d(160, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (conv1_3): NestedConvBlock(
    (activation): ReLU(inplace=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv2d(320, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (match_channels): Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (conv0_4): NestedConvBlock(
    (activation): ReLU(inplace=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv2d(192, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (match_channels): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (final): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
  (final1): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
  (final2): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
  (final3): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
  (final4): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
)
'''

**TODO**: Choose model

U-Net

In [None]:
# model_test = UNet().to(device)
# model_test

U-Net++

In [None]:
# model_test = UNetPlusPlus(in_channels=3, out_channels=1, deep_supervision=True, dropout_rate=0.5).to(device)
# model_test

### 5. Loss functions

In [None]:
def dice_loss(prediction, target):
    '''
    The Dice loss is a measure of the overlap between the prediction and target.
    It ranges between 0 and 1, with 1 indicating a perfect match (lower is better).
    '''
    smooth = 1.0
    i_flat = prediction.view(-1)
    t_flat = target.view(-1)
    intersection = (i_flat * t_flat).sum()
    return 1 - ((2. * intersection + smooth) / (i_flat.sum() + t_flat.sum() + smooth))

def calc_loss(prediction, target, bce_weight=0.5):
    '''
    This function calculates a combined loss using binary cross entropy (BCE) and Dice loss.
    The bce_weight determines the contribution of BCE to the final loss value.
    '''
    bce = F.binary_cross_entropy_with_logits(prediction, target)
    prediction = F.sigmoid(prediction)
    dice = dice_loss(prediction, target)
    loss = bce * bce_weight + dice * (1 - bce_weight)
    return loss

def combined_loss(y_pred, y_true, alpha=0.5, beta=0.5):
    '''
    This function calculates a combined loss using binary cross entropy (BCE) and the complement of the Dice coefficient.
    The alpha and beta parameters determine the contributions of BCE loss and the complement of the 
    Dice coefficient to the final loss value.
    '''
    bce_loss = F.binary_cross_entropy_with_logits(y_pred, y_true)
    y_pred = torch.sigmoid(y_pred)
    dice_coeff = 1 - dice_loss(y_pred, y_true)
    return alpha * bce_loss + beta * dice_coeff

criterion = nn.BCELoss()

### 6. Metrics

In [None]:
def iou_score(output, target):
    '''
    Calculates the Intersection over Union (IoU) score between the predicted output and the target. 
    The IoU is a metric used to evaluate the accuracy of image segmentation models. 
    It is calculated as the ratio of the intersection of the predicted output and the target to 
    the union of the predicted output and the target.
    '''
    smooth = 1e-5
    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()
    output_ = output > 0.5
    target_ = target > 0.5
    intersection = (output_ & target_).sum()
    union = (output_ | target_).sum()
    return (intersection + smooth) / (union + smooth)


def dice_coeff(im1, im2, empty_score=1.0):
    '''
    Calculates the Dice coefficient between two images im1 and im2. 
    The Dice coefficient is a similarity metric used to compare the pixel-wise agreement between two binary images.
    It is calculated as twice the intersection between the two images divided by the sum of pixels in both images.
    '''
    im1 = (im1 > 0.5).astype(bool)
    im2 = (im2 > 0.5).astype(bool)
    if im1.shape != im2.shape:
        raise ValueError("Error!")
    intersection = np.logical_and(im1, im2).sum()
    im_sum = im1.sum() + im2.sum()
    if im_sum == 0:
        return empty_score
    return 2.0 * intersection / im_sum


def numeric_scores(output, target):
    '''
    Calculates the false positive (FP), false negative (FN), true positive (TP), and true negative (TN) scores 
    between the predicted output and the ground truth. These scores are used to evaluate the performance of binary 
    classifiers.
    '''
    output = (output > 0.5).astype(bool)
    target = (target > 0.5).astype(bool)
    FP = np.sum((output == 1) & (target == 0))
    FN = np.sum((output == 0) & (target == 1))
    TP = np.sum((output == 1) & (target == 1))
    TN = np.sum((output == 0) & (target == 0))

    return FP, FN, TP, TN


def accuracy_score(output, target):
    '''
    Calculates the accuracy score between the predicted output and the ground truth. 
    The accuracy score is a metric used to evaluate the overall performance of binary classifiers. 
    It is calculated as the ratio of the number of correct predictions to the total number of predictions made.
    '''
    FP, FN, TP, TN = numeric_scores(output, target)
    N = FP + FN + TP + TN
    accuracy = (TP + TN) / N
    return accuracy * 100.0

### 7. Hyper parameters

**TODO**: Try different optimizers and schedulers

In [None]:
# optimizer = torch.optim.Adam(model_test.parameters(), lr=1e-4)
# optimizer = torch.optim.SGD(model_test.parameters(), lr=0.0001, momentum=0.99)
# optimizer = torch.optim.Adam(model_test.parameters(), lr=3e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

In [None]:
# scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
# scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-5)
# scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, int(1e10), eta_min=1e-5)
# scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-4, max_lr=1e-2, step_size_up=2000, mode='triangular2')

In [None]:
def initialize_weights(model):
    for module in model.modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.BatchNorm2d):
            nn.init.constant_(module.weight, 1)
            nn.init.constant_(module.bias, 0)

'''
TODO: choose 'initialize_weights' or not
'''
# initialize_weights(model_test)

### 8. Training

In [None]:
train_losses = []
valid_losses = []

In [None]:
epoch = 100
loss_train = []
loss_val = []
deep_supervision = isinstance(model_test, UNetPlusPlus)

for i in range(epoch):
    print(f"Epoch {i+1}/{epoch}")

    train_loss = 0.0
    valid_loss = 0.0

    model_test.train()

    for batch_idx, (train, ground_truth) in enumerate(train_loader):
        train, ground_truth = train.to(device), ground_truth.to(device)

        optimizer.zero_grad()
        
        train_predect = model_test(train)
        loss_train_batch = 0
        
        if deep_supervision:
            for output in train_predect:
                '''
                TODO: choose loss function
                '''
#                 loss_train_batch = dice_loss(output, ground_truth)
#                 loss_train_batch = calc_loss(output, ground_truth)    
#                 loss_train_batch = combined_loss(output, ground_truth)  
#                 loss_train_batch = criterion(output, ground_truth)
            loss_train_batch /= len(train_predect)
        else:
#             loss_train_batch = dice_loss(train_predect, ground_truth)
#             loss_train_batch = calc_loss(train_predect, ground_truth)    
#             loss_train_batch = combined_loss(train_predect, ground_truth)  
#             loss_train_batch = criterion(train_predect, ground_truth)
    
        train_loss += loss_train_batch.item() * train.size(0)
        loss_train_batch.backward()
        optimizer.step()
        
        print(f"Batch {batch_idx + 1} loss: {loss_train_batch.item():.4f}")
        
    scheduler.step()
    
    
    model_test.eval()
    with torch.no_grad():

        for val, ground_truth in val_loader:
            val, ground_truth = val.to(device), ground_truth.to(device)

            val_predect = model_test(val)
            loss_val_batch = 0

            if deep_supervision:
                for output in val_predect:
                    '''
                    TODO: choose loss function
                    '''
#                     loss_val_batch = dice_loss(output, ground_truth)
#                     loss_val_batch = calc_loss(output, ground_truth)     
#                     loss_val_batch = combined_loss(output, ground_truth) 
#                     loss_val_batch = criterion(output, ground_truth)
                loss_val_batch /= len(val_predect)
            else:
#                 loss_val_batch = dice_loss(val_predect, ground_truth)
#                 loss_val_batch = calc_loss(val_predect, ground_truth)    
#                 loss_val_batch = combined_loss(val_predect, ground_truth) 
#                 loss_val_batch = criterion(val_predect, ground_truth)

            valid_loss += loss_val_batch.item() * val.size(0)
        

    train_loss = train_loss / len(train_idx)
    valid_loss = valid_loss / len(valid_idx)

    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    
    print('Epoch: {}/{} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(i + 1, epoch, train_loss, valid_loss))

In [None]:
plt.figure(figsize=(10, 5))
'''
TODO: choose image title
'''
# plt.title('U-Net')
# plt.title('U-Net++')
plt.plot(train_losses, label='Training Loss')
plt.plot(valid_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

Print the output tensor w/ threshold

In [None]:
model_test.eval()

image, ground_truth = next(iter(train_loader))
image = image.to(device)

with torch.no_grad():
    outputs = model_test(image)

outputs = outputs.cpu()

print(outputs[0])

In [None]:
threshold = 0.5
binary_mask = (outputs > threshold).float()
print(binary_mask[0])

### 9. Evaluate

**TODO**: choose evaluate model

In [None]:
# Evaluate U-Net
def evaluate_model(model, test_loader):
    model.eval()

    iou_scores = []
    dice_scores = []
    accuracy_scores = []

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            outputs = torch.sigmoid(outputs)

            for i in range(outputs.shape[0]):
                output = outputs[i].cpu().numpy()
                target = targets[i].cpu().numpy()

                iou_scores.append(iou_score(output, target))
                dice_scores.append(dice_coeff(output, target))
                accuracy_scores.append(accuracy_score(output, target))

    iou_avg = np.mean(iou_scores)
    dice_avg = np.mean(dice_scores)
    accuracy_avg = np.mean(accuracy_scores)

    return iou_avg, dice_avg, accuracy_avg

In [None]:
# Evaluate U-Net++
def evaluate_model(model, test_loader):
    model.eval()

    iou_scores = []
    dice_scores = []
    accuracy_scores = []

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            output = outputs[-1]
            output = torch.sigmoid(output)

            for i in range(output.shape[0]):
                output_np = output[i].cpu().numpy()
                target_np = targets[i].cpu().numpy()

                iou_scores.append(iou_score(output_np, target_np))
                dice_scores.append(dice_coeff(output_np, target_np))
                accuracy_scores.append(accuracy_score(output_np, target_np))

    iou_avg = np.mean(iou_scores)
    dice_avg = np.mean(dice_scores)
    accuracy_avg = np.mean(accuracy_scores)

    return iou_avg, dice_avg, accuracy_avg

In [None]:
iou_avg, dice_avg, accuracy_avg = evaluate_model(model_test, test_loader)

print(f"IoU Score: {iou_avg:.4f}")
print(f"Dice Coefficient: {dice_avg:.4f}")
print(f"Accuracy: {accuracy_avg:.2f}%")

### 10. Help Functions

**TODO**: choose preoricess image with deep supervision or not

In [None]:
def preprocess_image(image_np, input_shape):
    image = Image.fromarray(image_np)
    transform = transforms.Compose([
        transforms.Resize(input_shape[1:]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0)

def postprocess_output(output_tensor):
    output_np = output_tensor.squeeze().cpu().detach().numpy()
    output_np = (output_np > 0.5).astype(np.uint8) * 255
    return output_np

In [None]:
idx = 5

input_shape = (3, 256, 256)
preprocessed_image = preprocess_image(X_Image_100_Hippocampus[idx], input_shape).to(device)

with torch.no_grad():
    model.eval()
    output = model(preprocessed_image)

predicted_segmentation = postprocess_output(output)

figure, axis = plt.subplots(1, 3, figsize=(10, 10))

axis[0].imshow(X_Image_100_Hippocampus[idx], cmap='gray')
axis[0].set_title('Original Test Image')

axis[1].imshow(predicted_segmentation, cmap='gray')
axis[1].set_title('Predicted Segmentation')

axis[2].imshow(X_Hippocampus_100_Label[idx], cmap='gray')
axis[2].set_title('Ground Truth Mask')

for ax in axis:
    ax.set_xticks([])
    ax.set_yticks([])

plt.show()

In [None]:
threshold = 0.5
binary_mask = (output > threshold).float()
print(binary_mask)

Deep supervision

In [None]:
def preprocess_image(image_np, input_shape):
    image = Image.fromarray(image_np)
    transform = transforms.Compose([
        transforms.Resize(input_shape[1:]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0)


def postprocess_output(output_list):
    output_tensor = output_list[-1]
    output_np = output_tensor.squeeze().cpu().detach().numpy()
    output_np = (output_np > 0.5).astype(np.uint8) * 255
    return output_np


In [None]:
idx = 5

input_shape = (3, 256, 256)
preprocessed_image = preprocess_image(X_Image_100_Hippocampus[idx], input_shape).to(device)

with torch.no_grad():
    model_test.eval()
    output = model_test(preprocessed_image)

predicted_segmentation = postprocess_output(output)

figure, axis = plt.subplots(1, 3, figsize=(10, 10))

axis[0].imshow(X_Image_100_Hippocampus[idx], cmap='gray')
axis[0].set_title('Original Test Image')

axis[1].imshow(predicted_segmentation, cmap='gray')
axis[1].set_title('Predicted Segmentation')

axis[2].imshow(X_Hippocampus_100_Label[idx], cmap='gray')
axis[2].set_title('Ground Truth Mask')

for ax in axis:
    ax.set_xticks([])
    ax.set_yticks([])

plt.show()

In [None]:
threshold = 0.5
binary_mask = (output[-1] > threshold).float()
print(binary_mask)

Addtional: U-Net with ResNeXt50 backbone.Unet with ResNeXt50 backbone.

In [None]:
# Model
from torchvision.models import resnext50_32x4d

class ConvRelu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel, padding):
        super().__init__()

        self.convrelu = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.convrelu(x)
        return x

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.conv1 = ConvRelu(in_channels, in_channels // 4, 1, 0)
        
        self.deconv = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, kernel_size=4, stride=2, padding=1, output_padding=0)
        
        self.conv2 = ConvRelu(in_channels // 4, out_channels, 1, 0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.deconv(x)
        x = self.conv2(x)

        return x
    
class ResNeXtUNet(nn.Module):

    def __init__(self, n_classes):
        super().__init__()
        
        self.base_model = resnext50_32x4d(pretrained=True)
        self.base_layers = list(self.base_model.children())
        filters = [4*64, 4*128, 4*256, 4*512]
        
        # Down
        self.encoder0 = nn.Sequential(*self.base_layers[:3])
        self.encoder1 = nn.Sequential(*self.base_layers[4])
        self.encoder2 = nn.Sequential(*self.base_layers[5])
        self.encoder3 = nn.Sequential(*self.base_layers[6])
        self.encoder4 = nn.Sequential(*self.base_layers[7])

        # Up
        self.decoder4 = DecoderBlock(filters[3], filters[2])
        self.decoder3 = DecoderBlock(filters[2], filters[1])
        self.decoder2 = DecoderBlock(filters[1], filters[0])
        self.decoder1 = DecoderBlock(filters[0], filters[0])

        # Final Classifier
        self.last_conv0 = ConvRelu(256, 128, 3, 1)
        self.last_conv1 = nn.Conv2d(128, n_classes, 3, padding=1)
                       
        
    def forward(self, x):
        # Down
        x = self.encoder0(x)
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)

        # Up + sc
        d4 = self.decoder4(e4) + e3
        d3 = self.decoder3(d4) + e2
        d2 = self.decoder2(d3) + e1
        d1 = self.decoder1(d2)
        #print(d1.shape)

        # final classifier
        out = self.last_conv0(d1)
        out = self.last_conv1(out)
        out = torch.sigmoid(out)
        
        return out

In [None]:
def compute_iou(model, loader, threshold=0.3):
    valloss = 0
    with torch.no_grad():
        for i_step, (data, target) in enumerate(loader):   
            data = data.to(device)
            target = target.to(device)
            
            outputs = model(data)

            out_cut = np.copy(outputs.data.cpu().numpy())
            out_cut[np.nonzero(out_cut < threshold)] = 0.0
            out_cut[np.nonzero(out_cut >= threshold)] = 1.0

            picloss = dice_coef_metric(out_cut, target.data.cpu().numpy())
            valloss += picloss

    return valloss / i_step

In [None]:
rx50 = ResNeXtUNet(n_classes=1).to(device)
output = rx50(torch.randn(1,3,256,256).to(device))
print(output.shape)

In [None]:
def dice_coef_loss(inputs, target):
    smooth = 1.0
    intersection = 2.0 * ((target * inputs).sum()) + smooth
    union = target.sum() + inputs.sum() + smooth
    return 1 - (intersection / union)


def bce_dice_loss(inputs, target):
    dicescore = dice_coef_loss(inputs, target)
    bcescore = nn.BCELoss()
    bceloss = bcescore(inputs, target)
    return bceloss + dicescore

In [None]:
def dice_coef_metric(inputs, target):
    intersection = 2.0 * (target * inputs).sum()
    union = target.sum() + inputs.sum()
    if target.sum() == 0 and inputs.sum() == 0:
        return 1.0

    return intersection / union

In [None]:
def train_model(model_name, model, train_loader, val_loader, train_loss, optimizer, lr_scheduler, num_epochs):  
    loss_history = []
    train_history = []
    val_history = []

    for epoch in range(num_epochs):
        model.train()
        
        losses = []
        train_iou = []
                
        if lr_scheduler:
            warmup_factor = 1.0 / 100
            warmup_iters = min(100, len(train_loader) - 1)
            lr_scheduler = warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
        
        
        for i_step, (data, target) in enumerate(train_loader):
            data = data.to(device)
            target = target.to(device)
                      
            outputs = model(data)
            
            out_cut = np.copy(outputs.data.cpu().numpy())
            out_cut[np.nonzero(out_cut < 0.5)] = 0.0
            out_cut[np.nonzero(out_cut >= 0.5)] = 1.0
            
            train_dice = dice_coef_metric(out_cut, target.data.cpu().numpy())
            
            loss = train_loss(outputs, target)
            
            losses.append(loss.item())
            train_iou.append(train_dice)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            if lr_scheduler:
                lr_scheduler.step()
                
        val_mean_iou = compute_iou(model, val_loader)
        
        loss_history.append(np.array(losses).mean())
        train_history.append(np.array(train_iou).mean())
        val_history.append(val_mean_iou)
        
        print("Epoch [%d]" % (epoch))
        print("Mean loss on train:", np.array(losses).mean(), 
              "\nMean DICE on train:", np.array(train_iou).mean(),
              "\nMean DICE on validation:", val_mean_iou)
        
    return loss_history, train_history, val_history

In [None]:
rx50_optimizer = torch.optim.Adam(rx50.parameters(), lr=5e-4)

def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):
    def f(x):
        if x >= warmup_iters:
            return 1
        alpha = float(x) / warmup_iters
        return warmup_factor * (1 - alpha) + alpha

    return torch.optim.lr_scheduler.LambdaLR(optimizer, f)

In [None]:
rx50_lh, rx50_th, rx50_vh = train_model("ResNeXt50", rx50, train_loader, val_loader, bce_dice_loss, rx50_optimizer, False, 20)

In [None]:
test_iou = compute_iou(rx50, test_loader)
print(f"""ResNext50\nMean IoU of the test images - {np.around(test_iou, 2)*100}%""")

Data source: Kaggle, MRI Hippocampus Segmentation