## Import libraries

In [1]:
import os, sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [2]:
from PIL import Image
from tqdm import tqdm
import numpy as np
from typing import Literal

import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
from torch import optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, Subset

from src.utils import get_lr

## Define dataset, model and function

### Dataset

**Mean** and **std** for image normalization. These values are suggested based on Imagenet after training on million of images.

> All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]

Access [here](https://pytorch.org/vision/0.8/models.html) for more detail.

In [3]:
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

In [4]:
class BCSSDataset(Dataset):
    SIZE=(224, 224)
    _img_transformer = transforms.Compose([
            transforms.Resize(SIZE),
            transforms.PILToTensor(),
            transforms.ConvertImageDtype(torch.float),
            transforms.Normalize(mean=MEAN, std=STD),
        ])
    _mask_transformer = transforms.Compose([
            transforms.Resize(SIZE),
            transforms.PILToTensor(),
        ])
    
    def __init__(self, image_path: str, mask_path: str):
        image_path = os.path.abspath(image_path)
        mask_path = os.path.abspath(mask_path)
        
        self.images = [os.path.join(image_path, filename) for filename in os.listdir(image_path)]
        self.masks = [os.path.join(mask_path, filename) for filename in os.listdir(mask_path)]

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx: int):
        image = Image.open(self.images[idx])
        image = self._img_transformer(image)

        mask = Image.open(self.masks[idx])
        mask = self._mask_transformer(mask)
        mask = torch.squeeze(mask, 0).long()

        return image, mask

### Unet

In [5]:
class DoubleConv(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 mid_channels: int = None,
                 kernel_size: int = 3,
                 stride: int = 1,
                 padding: int = 0):
        
        super().__init__()

        mid_channels = mid_channels or out_channels
        self.conv_ops = nn.Sequential(
            # first 
            nn.Conv2d(in_channels=in_channels,
                      out_channels=mid_channels,
                      kernel_size=kernel_size,
                      padding=padding,
                      stride=stride),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=mid_channels),

            # second
            nn.Conv2d(in_channels=mid_channels,
                      out_channels=out_channels,
                      kernel_size=kernel_size,
                      padding=padding,
                      stride=stride),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=out_channels)
        )

    def forward(self, X: Tensor):
        res = self.conv_ops(X)
        return res

In [6]:
class DownSample(nn.Module):
    def __init__(self,
                 kernel_size: int = 2,
                 stride: int = 1,
                 padding: int = 0):
        super().__init__()
        
        self.pool = nn.MaxPool2d(kernel_size, stride, padding)
        
    def forward(self, X: Tensor):
        return self.pool(X)

In [7]:
class UpSample(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int = 2,
                 stride: int = 1,
                 padding: int = 0):
        
        super().__init__()

        self.up_conv = nn.ConvTranspose2d(in_channels=in_channels,
                                          out_channels=out_channels,
                                          kernel_size=kernel_size,
                                          stride=stride,
                                          padding=padding)
        
    def forward(self, X: Tensor):
        return self.up_conv(X)

In [8]:
class CropAndConcat(nn.Module):
    def forward(self, X: Tensor, contracting_X: Tensor):
        contracting_X = transforms.functional.center_crop(
            img=contracting_X,
            output_size=(X.shape[2], X.shape[3])
        )
        X = torch.cat((X, contracting_X), dim=1)

        return X

In [9]:
class Unet(nn.Module):
    # TODO: Customize the conv blocks for easy-scalable
    def __init__(
        self,
        in_channels: int,
        output_classes: int,
        down_conv_kwargs: dict = None,
        down_sample_kwargs: dict = None,
        up_conv_kwargs: dict = None,
        up_sample_kwargs: dict = None,
    ):
        super().__init__()

        self.down_conv = nn.ModuleList(
            [
                DoubleConv(in_channels=i, out_channels=o, **(down_conv_kwargs or {}))
                for i, o in ((in_channels, 64), (64, 128), (128, 256), (256, 512))
            ]
        )

        self.down_sample = nn.ModuleList(
            [DownSample(**(down_sample_kwargs or {})) for _ in range(4)]
        )

        self.up_conv = nn.ModuleList(
            [
                DoubleConv(in_channels=i, out_channels=o, **(up_conv_kwargs or {}))
                for i, o in ((1024, 512), (512, 256), (256, 128), (128, 64))
            ]
        )

        self.up_sample = nn.ModuleList(
            [
                UpSample(in_channels=i, out_channels=o, **(up_sample_kwargs or {}))
                for i, o in ((1024, 512), (512, 256), (256, 128), (128, 64))
            ]
        )

        self.crop_concat = nn.ModuleList([CropAndConcat() for _ in range(4)])

        self.bottlekneck = DoubleConv(
            in_channels=512, out_channels=1024, **(up_conv_kwargs or {})
        )

        self.output = nn.Conv2d(
            in_channels=64, out_channels=output_classes, kernel_size=1
        )

    def forward(self, X: Tensor):
        pass_through = []
        for i in range(len(self.down_conv)):
            X = self.down_conv[i](X)
            pass_through = [X] + pass_through
            X = self.down_sample[i](X)

        X = self.bottlekneck(X)

        for i in range(len(self.up_conv)):
            X = self.up_sample[i](X)
            X = self.crop_concat[i](X, pass_through[i])
            X = self.up_conv[i](X)

        X = self.output(X)

        return X

### Evaluation functions

* Pixel Accuracy

In [10]:
def pixel_accuracy(logits: Tensor, masks: Tensor):
    """
    Calculate the pixel accuracy of the predicted masks.

    Args:
        logits (Tensor): A tensor of shape (N, C, H, W) containing the logits for each class.
        masks (Tensor): A tensor of shape (N, H, W) containing the ground truth masks.

    Returns:
        float: The pixel accuracy of the predicted masks.
    """
    with torch.no_grad():
        prob = F.softmax(logits, dim=1)
        predicted_mask = torch.argmax(prob, dim=1)

        correct_pred = (predicted_mask == masks)
        accuracy = torch.sum(correct_pred).item() / correct_pred.numel()

    return accuracy

* Mean Intersection over Union (Jaccard index)

In [11]:
def mean_iou(logits: Tensor, masks: Tensor, num_classes: int):
    """
    Calculate the mean Intersection over Union (IoU) of the predictions.

    Args:
        logits (Tensor): A tensor of shape (N, C, H, W) containing the logits for each class.
        masks (Tensor): A tensor of shape (N, H, W) containing the ground truth masks.
        num_classes (int): The number of classes in the dataset.

    Returns:
        float: The mean IoU of the predicted masks.
    """
    with torch.no_grad():
        pred_masks = F.softmax(logits, dim=1)
        pred_masks = torch.argmax(pred_masks, dim=1)

        iou_per_class = []
        for cls in range(num_classes):
            pred_inds = (pred_masks == cls)
            target_inds = (masks == cls)

            union = (pred_inds | target_inds).sum().item()
            if union == 0:
                iou_per_class.append(np.nan)
            else:
                iou_per_class.append((pred_inds & union).sum().item() / union)

        return np.nanmean(iou_per_class)
            

### Loss function

As well as *Cross entropy or Binary cross entropy* depends on what kind of segmentation, **Dice loss** also a common function which is often used for segmentation problme

In [12]:
class DiceLoss(nn.Module):
    """
    DiceLoss class calculates the Dice coefficient loss, which is often used for 
    image segmentation tasks. This implementation supports both binary and 
    multiclass segmentation.

    Args:
        smooth (float): A smoothing constant to avoid division by zero errors. Default is 1e-10.

    Methods:
        forward(logits, masks):
            Computes the Dice loss between the predicted logits and the ground truth masks.

            Args:
                logits (Tensor): A tensor of shape (N, C, H, W) containing the predicted logits for each class.
                masks (Tensor): A tensor of shape (N, H, W) containing the ground truth masks.

            Returns:
                Tensor: The calculated Dice loss.
    """
    def __init__(self, smooth: float = 1e-10):
        super().__init__()
        self.smooth = smooth

    def forward(self, logits: Tensor, masks: Tensor):
        # calculate probability for both logits and masks
        probs = F.softmax(logits, dim=1)

        one_hot_masks = F.one_hot(masks, num_classes=probs.shape[1])
        one_hot_masks = one_hot_masks.permute(0, 3, 1, 2).float()

        # flatten for element-wise operations
        probs = probs.view(probs.shape[0], probs.shape[1], -1)
        one_hot_masks = one_hot_masks.view(one_hot_masks.shape[0], one_hot_masks.shape[1], -1)
        # compute loss
        intersection = torch.sum(probs * one_hot_masks, dim=2)
        total = probs.sum(dim=2) + one_hot_masks.sum(dim=2)

        dice_coef = 2 * intersection / total
        avg_class_dice_coef = dice_coef.mean(dim=1)
        loss = 1 - avg_class_dice_coef.mean() # mean for batch

        return loss        


## Train

### Load data

* Dataset

In [13]:
train_image_path = os.path.abspath('../data/bcss/train')
train_mask_path = os.path.abspath('../data/bcss/train_mask')
val_image_path = os.path.abspath('../data/bcss/val')
val_mask_path = os.path.abspath('../data/bcss/val_mask')

NUM_CLASSES = 22
train_dataset = BCSSDataset(image_path=train_image_path, mask_path=train_mask_path)
val_dataset = BCSSDataset(image_path=val_image_path, mask_path=val_mask_path)

# Use Subset for testing
train_dataset = Subset(train_dataset, list(range(50)))
val_dataset = Subset(val_dataset, list(range(50)))

* Dataloader

In [14]:
batch = 16
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=batch, shuffle=False)

### Training

* Setup training utils

In [15]:
device = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')
epochs = 1

In [16]:
model = Unet(
    in_channels=3,
    output_classes=22,
    down_conv_kwargs={'kernel_size': 3, 'padding': 1},
    down_sample_kwargs={'kernel_size': 2, 'stride': 2},
    up_conv_kwargs={'kernel_size': 3, 'padding': 1},
    up_sample_kwargs={'kernel_size': 2, 'stride': 2}
)

ce_loss = nn.CrossEntropyLoss().to(device)
dice_loss = DiceLoss().to(device)

max_lr = 1e-3
weight_decay = 1e-4

optimizer = optim.AdamW(params=model.parameters(), lr=1e-5, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=max_lr, epochs=epochs, steps_per_epoch=len(train_dataloader))

* Setup tracking with Wandb

In [17]:
import wandb
from dotenv import load_dotenv

load_dotenv()
wandb.login()

wandb.init(
    name='[Unet] BCSS segmentation',
    config={
        'epoch': epochs,
        'batch_size': batch
    },
)

STEP_PER_LOG = 10

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mbuitanphuong10c13[0m ([33mbtp712[0m). Use [1m`wandb login --relogin`[0m to force relogin


* Train

In [18]:
# TODO: save checkpoint

In [None]:
# Initialize tracking variables
train_loss, val_loss, train_acc, val_acc, train_iou, val_iou, lrs = [], [], [], [], [], [], []

model.to(device)
for epoch in range(epochs):
    running_loss, iou_score, accuracy = 0, 0, 0
    batch_count, num_log = 0, 1
    last_train_data = None

    # Training loop
    model.train()
    train_loop = tqdm(train_dataloader, desc=f'Training Epoch {epoch+1}/{epochs}', leave=True)
    for i, data in enumerate(train_loop):
        X, y = (_.to(device) for _ in data)

        # Forward
        y_pred = model(X)

        # compute loss
        loss = dice_loss(y_pred, y) + ce_loss(y_pred, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # update metrics
        running_loss += loss.item()
        iou_score += mean_iou(y_pred, y, num_classes=NUM_CLASSES)
        accuracy += pixel_accuracy(y_pred, y)

        # update progress bar
        logging_dict = {
            'loss': running_loss / (i + 1),
            'mean IoU': iou_score / (i + 1),
            'accuracy': accuracy / (i + 1)
        }
        train_loop.set_postfix(logging_dict)

        # step learning rate scheduler
        lrs.append(get_lr(optimizer))
        scheduler.step()

        # update wandb
        batch_count += 1
        if batch_count // STEP_PER_LOG == num_log or i == len(train_dataloader) - 1:
            logging_dict['epoch'] = batch_count / len(train_dataloader)
            wandb.log({f'train/{k}': v for k, v in logging_dict.items()}, step=batch_count)
            
            num_log += 1

    # Validation loop
    model.eval()
    val_running_loss, val_iou_score, val_accuracy = 0, 0, 0
    val_loop = tqdm(val_dataloader, desc='Validation', leave=True)
    with torch.no_grad():
        for i, data in enumerate(val_loop):
            X, y = (_.to(device) for _ in data)

            # Forward
            y_pred = model(X)

            # compute loss
            loss = dice_loss(y_pred, y) + ce_loss(y_pred, y)

            # update metrics
            val_running_loss += loss.item()
            val_iou_score += mean_iou(y_pred, y, num_classes=NUM_CLASSES)
            val_accuracy += pixel_accuracy(y_pred, y)

            # update progress bar
            logging_dict = {
                'loss': val_running_loss / (i + 1),
                'mean IoU': val_iou_score / (i + 1),
                'accuracy': val_accuracy / (i + 1)
            }
            val_loop.set_postfix(logging_dict)

    # Log the evaluation data together with train data
    wandb.log({
        'train/epoch': epoch + 1,
        'eval/loss': val_running_loss / len(val_dataloader),
        'eval/mean IoU': val_iou_score / len(val_dataloader),
        'eval/accuracy': val_accuracy / len(val_dataloader)
    })
