# **Hyperparameter Tuning notebook**
- This notebook is to tune the parameters of DeepLabV3+ with MobileNetV2 backbone
- Utilize Weight and Biases Sweep for hyperparamter tuning
- The hyperparameters that are going to be tuned are `learning_rate`, `batch_size`, `threshold`, `bce_weights`, `augmentation_probability`, and `optimizer`



## **Import**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install wandb

In [None]:
!pip install torchinfo

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Dataset
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from PIL import Image
import cv2
import wandb
from torchinfo import summary
from torchvision import models
import torch.nn.functional as F

import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
wandb.login()

### Move data to local disk

**Note:** Replace your own paths here

In [None]:
zip_train_source_path = "/content/drive/MyDrive/FYP/Datasets/zipped/train.zip"
zip_val_source_path = "/content/drive/MyDrive/FYP/Datasets/zipped/validation.zip"

local_data_dir = "/content/data"

!mkdir -p "$local_data_dir"

print(f"Copying {zip_train_source_path} to {local_data_dir}")
!cp "$zip_train_source_path" "$local_data_dir/"

print(f"Copying {zip_val_source_path} to {local_data_dir}")
!cp "$zip_val_source_path" "$local_data_dir/"

print("Copying complete.")

In [None]:
local_zip_train_path = f"{local_data_dir}/train.zip"
local_zip_val_path = f"{local_data_dir}/validation.zip"

unzip_destination_path = local_data_dir

print(f"Unzipping {local_zip_train_path} to {unzip_destination_path}")
!unzip -q "$local_zip_train_path" -d "$unzip_destination_path"

print(f"Unzipping {local_zip_val_path} to {unzip_destination_path}")
!unzip -q "$local_zip_val_path" -d "$unzip_destination_path"

print("Unzipping complete.")

In [None]:
local_train_image_path = os.path.join(local_data_dir, "train", "images")
local_train_mask_path = os.path.join(local_data_dir, "train", "masks")

# Check if directories exist
if os.path.exists(local_train_image_path):
  num_images = len(os.listdir(local_train_image_path))
  print(f"Number of images in {local_train_image_path}: {num_images}")
else:
  print(f"Directory {local_train_image_path} does not exist.")

if os.path.exists(local_train_mask_path):
  num_masks = len(os.listdir(local_train_mask_path))
  print(f"Number of masks in {local_train_mask_path}: {num_masks}")
else:
  print(f"Directory {local_train_mask_path} does not exist.")

In [None]:
local_val_image_path = os.path.join(local_data_dir, "validation", "images")
local_val_mask_path = os.path.join(local_data_dir, "validation", "masks")

# Check if directories exist
if os.path.exists(local_val_image_path):
  num_images = len(os.listdir(local_val_image_path))
  print(f"Number of images in {local_val_image_path}: {num_images}")
else:
  print(f"Directory {local_val_image_path} does not exist.")

if os.path.exists(local_val_mask_path):
  num_masks = len(os.listdir(local_val_mask_path))
  print(f"Number of masks in {local_val_mask_path}: {num_masks}")
else:
  print(f"Directory {local_val_mask_path} does not exist.")

In [None]:
!rm "$local_zip_train_path"
!rm "$local_zip_val_path"

## **Defining the sweep**

In [None]:
sweep_config = {
    'method': 'bayes'
}

In [None]:
metric = {
    'name': 'val_iou',
    'goal': 'maximize',
}

sweep_config['metric'] = metric

In [None]:
parameters_dict = {
    'epochs' : {
      'value': 5
    },
    'optimizer': {
        'value': 'adam'
    },
    'batch_size': {
        'value': 64
    },
    'learning_rate': {
        'distribution': 'log_uniform_values',
        'min': 0.0001,
        'max': 0.01,
    },
    'threshold': {
        'distribution': 'uniform',
        'min': 0.40,
        'max': 0.60,
    },
    'bce_weight': {
        'distribution': 'uniform',
        'min': 0.5,
        'max': 0.6,
    },
    'augmentation_p': {
      'distribution': 'uniform',
      'min': 0.6,
      'max': 0.8,
    },
}

sweep_config['parameters'] = parameters_dict

In [None]:
import pprint
pprint.pprint(sweep_config)

## **Relevant code**

### Custom SegmentationDataset

In [None]:
dataset_mean = (0.55943902, 0.50729719, 0.48297841)
dataset_std = (0.25904843, 0.25247732, 0.25680549)

In [None]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
import albumentations as A

class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, img_size=(224, 224), mean=dataset_mean, std=dataset_std):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform # Only for geometric transformations
        self.images = sorted(os.listdir(image_dir))
        self.masks = sorted(os.listdir(mask_dir))
        self.img_size = img_size
        self.mean = np.array(mean)
        self.std = np.array(std)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        mask_name = self.masks[idx]

        image_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, mask_name)

        # Albumentations expects a NumPy array with uint8 data type
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        # Apply the geometric transformations
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Manual steps for resize, standardization, and conversion
        # 1. Resize if necessary
        if image.shape[0] != self.img_size[0] or image.shape[1] != self.img_size[1]:
            image = cv2.resize(image, self.img_size)
        if mask.shape[0] != self.img_size[0] or mask.shape[1] != self.img_size[1]:
            mask = cv2.resize(mask, self.img_size, interpolation=cv2.INTER_NEAREST)

        # 2. Convert image to float and standardize
        image = image.astype("float32") / 255.0
        image = (image - self.mean) / self.std

        # 3. Convert image to torch tensor and permute (HWC -> CHW)
        image = torch.from_numpy(image).permute(2, 0, 1)

        # 4. Convert mask to float and add a channel dimension
        mask = (mask > 0).astype("float32") # Convert to binary (0.0 or 1.0)
        mask = torch.from_numpy(mask).unsqueeze(0) # HW -> 1HW

        return image, mask

### Function to get the datasets and dataloaders

In [None]:
def get_datasets_dataloaders(augmentation_probability, batch_size):
  """
  Sets up and returns the DataLoader objects with configurable hyperparameters.

  Args:
    augmentation_probability (float): The probability 'p' for data augmentations.
    batch_size (int): The number of samples per batch.

  Returns:
    tuple: A tuple containing the train_loader and val_loader
  """

  # Data augmentation setup
  train_transform = A.Compose([
      A.HorizontalFlip(p=augmentation_probability),
      A.VerticalFlip(p=augmentation_probability),
      A.RandomRotate90(p=augmentation_probability),
      A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
      A.GaussNoise(p=augmentation_probability),
      A.GridDistortion(num_steps=5, distort_limit=0.3, p=augmentation_probability),
      A.ElasticTransform(alpha=1, sigma=50, p=augmentation_probability),
      A.CoarseDropout(p=augmentation_probability),
  ])

  original_train_image_path = local_train_image_path
  original_train_label_path = local_train_mask_path
  original_val_image_path = local_val_image_path
  original_val_label_path = local_val_mask_path

  # Create the datasets and data loaders using the on-the-fly transformations.
  train_dataset = SegmentationDataset(original_train_image_path, original_train_label_path, transform=train_transform)
  val_dataset = SegmentationDataset(original_val_image_path, original_val_label_path)

  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
  val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

  return train_dataset, train_loader, val_dataset, val_loader

### DeepLabv3+ with MobileNetV2

In [None]:
class SeparableConv2d(nn.Module):
    """
    Implements Depthwise Separable Convolution, which is a depthwise convolution
    followed by a pointwise (1x1) convolution.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1):
        super(SeparableConv2d, self).__init__()

        # Depthwise convolution: Applies a separate filter to each input channel
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=False)
        self.bn_depth = nn.BatchNorm2d(in_channels)
        self.relu_depth = nn.ReLU(inplace=True)

        # Pointwise convolution: A 1x1 convolution to mix the channels
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False)
        self.bn_point = nn.BatchNorm2d(out_channels)
        self.relu_point = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.bn_depth(x)
        x = self.relu_depth(x)

        x = self.pointwise(x)
        x = self.bn_point(x)
        x = self.relu_point(x)

        return x

In [None]:
class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ASPP, self).__init__()

        # 1x1 convolution branch (This is always a standard 1x1 convolution)
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        # Atrous separable convolution with rate=6
        self.atrous_block6 = SeparableConv2d(in_channels, out_channels, kernel_size=3, padding=12, dilation=12)

        # Atrous separable convolution with rate=12
        self.atrous_block12 = SeparableConv2d(in_channels, out_channels, kernel_size=3, padding=24, dilation=24)

        # Atrous separable convolution with rate=18
        self.atrous_block18 = SeparableConv2d(in_channels, out_channels, kernel_size=3, padding=36, dilation=36)

        # Global Average Pooling branch
        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(in_channels, out_channels, 1, stride=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        # Final 1x1 convolution to fuse all 5 branches
        self.final_conv = nn.Sequential(
            nn.Conv2d(out_channels * 5, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5)
        )

    def forward(self, x):
        size = x.size()[2:]

        x1 = self.conv1x1(x)
        x2 = self.atrous_block6(x)
        x3 = self.atrous_block12(x)
        x4 = self.atrous_block18(x)
        x5 = self.global_avg_pool(x)

        x5 = F.interpolate(x5, size=size, mode='bilinear', align_corners=False)

        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        x = self.final_conv(x)

        return x

In [None]:
class DeepLabV3Plus(nn.Module):
    def __init__(self, num_classes=1):
        super(DeepLabV3Plus, self).__init__()
        backbone = models.mobilenet_v2(pretrained=True)
        self.backbone = backbone.features  # Get all layers except classifier

        # Using stride of 16
        self.backbone[14].conv[1][0].stride = (1, 1)

        # Low-level features come from early layer (for decoder)
        self.low_level_idx = 3
        self.low_level_channels = 24

        # ASPP expects 1280 channels from the last MobileNetV2 layer
        self.aspp = ASPP(in_channels=1280, out_channels=256)

        # Decoder
        self.low_level_project = nn.Sequential(
            nn.Conv2d(self.low_level_channels, 48, kernel_size=1),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )

        self.decoder = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1)
        )

    def forward(self, x):
      input_size = x.size()[2:]

      # Extract low-level and high-level features
      low_level_feat = None
      feat = x
      for i, layer in enumerate(self.backbone):
          feat = layer(feat)
          if i == self.low_level_idx:
              low_level_feat = feat  # Save for decoder

      high_level_feat = feat  # Final output of backbone (usually [B, 1280, H/32, W/32])

      # ASPP on high-level features
      x = self.aspp(high_level_feat)
      x = F.interpolate(x, size=low_level_feat.shape[2:], mode='bilinear', align_corners=False)

      # Decoder
      low_level = self.low_level_project(low_level_feat)
      x = torch.cat([x, low_level], dim=1)
      x = self.decoder(x)
      x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=False)
      return x

### Function to initialize and get the model

In [None]:
def get_model():
  model = DeepLabV3Plus(num_classes=1)
  return model.to(device)

### Criterion class

In [None]:
class DiceBCELoss(nn.Module):
    def __init__(self, bce_weight=0.6, size_average=True):
        super(DiceBCELoss, self).__init__()
        self.bce_loss = nn.BCEWithLogitsLoss(reduction='mean')
        self.bce_weight = bce_weight

    def forward(self, inputs, targets, smooth=1):
        # Apply sigmoid to convert logits to probabilities
        inputs_sig = torch.sigmoid(inputs)

        # Flatten label and prediction tensors for Dice Loss
        inputs_flat = inputs_sig.view(-1)
        targets_flat = targets.view(-1)

        # Calculate Dice Loss
        intersection = (inputs_flat * targets_flat).sum()
        dice_loss = 1 - (2. * intersection + smooth) / (inputs_flat.sum() + targets_flat.sum() + smooth)

        # Calculate BCE Loss on original logits
        bce_loss_val = self.bce_loss(inputs, targets)

        # Combine the two losses
        combined_loss = self.bce_weight * bce_loss_val + (1 - self.bce_weight) * dice_loss

        return combined_loss

### Functions to calculate the metrics

In [None]:
def get_confusion_matrix_components(y_true, y_pred):
    """
    Calculates the confusion matrix components (TP, FP, FN) for a batch.

    Args:
        y_true (torch.Tensor): Ground truth masks, a tensor of 0s and 1s.
        y_pred (torch.Tensor): Binary tensor after applying the sigmoid function and threshold.

    Returns:
        tuple: A tuple containing True Positives, False Positives, False Negatives, and True Negatives.
    """

    # Flatten tensors for easier calculation
    y_true_flat = y_true.view(-1)
    y_pred_flat = y_pred.view(-1)

    # Calculate confusion matrix components
    true_positives = ((y_pred_flat == 1) & (y_true_flat == 1)).sum().item()
    false_positives = ((y_pred_flat == 1) & (y_true_flat == 0)).sum().item()
    false_negatives = ((y_pred_flat == 0) & (y_true_flat == 1)).sum().item()
    true_negatives = ((y_pred_flat == 0) & (y_true_flat == 0)).sum().item()

    return true_positives, false_positives, false_negatives, true_negatives


def calculate_final_metrics(tp, fp, fn, tn, smooth=1e-6):
    """
    Calculates final metrics from accumulated confusion matrix components.
    """
    # IoU
    intersection = tp
    union = tp + fp + fn
    iou = intersection / (union + smooth)

    # Recall (Sensitivity)
    recall = tp / (tp + fn + smooth)

    # Precision (Positive Predictive Value)
    precision = tp / (tp + fp + smooth)

    # Dice Coefficient / F1
    dice = (2 * precision * recall) / (precision + recall + smooth)

    # Accuracy
    accuracy = (tp + tn) / (tp + tn + fp + fn + smooth)

    return iou, dice, recall, precision, accuracy

### Function to get the optimizer

In [None]:
def get_optimizer(model, optimizer, learning_rate):
  if optimizer == "sgd":
    optimizer = torch.optim.SGD(model.parameters(), lr= learning_rate, momentum=0.9)
  elif optimizer == "adam":
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)

  return optimizer

### Function to train and validate the model for one epoch

In [None]:
def train_epoch(model, train_dataset, train_loader, val_dataset, val_loader, criterion, optimizer, threshold, device):
  # --- Training Phase ---
  model.train()
  train_running_loss = 0.0
  train_total_tp, train_total_fp, train_total_fn, train_total_tn = 0, 0, 0, 0

  for X_batch, y_batch in train_loader:
    X_batch, y_batch = X_batch.to(device).float(), y_batch.to(device).float()
    output = model(X_batch)

    # BCEwithLogitsLoss expects raw output (the logits)
    loss = criterion(output, y_batch)

    # Accumulate the loss (loss * batch_size)
    train_running_loss += loss.item() * X_batch.size(0)

    optimizer.zero_grad() # clear the gradients
    loss.backward() # Backward pass
    optimizer.step() # Update weights

    # Convert to binary predictions (either 0 or 1)
    train_probs = torch.sigmoid(output)
    train_preds = (train_probs > threshold).float()

    # Accumulate confusion matrix components
    tp, fp, fn, tn = get_confusion_matrix_components(y_batch, train_preds)
    train_total_tp += tp
    train_total_fp += fp
    train_total_fn += fn
    train_total_tn += tn

  # Calculate final epoch metrics
  train_loss = train_running_loss / len(train_dataset)
  train_iou, train_dice, train_recall, train_precision, train_acc = calculate_final_metrics(train_total_tp, train_total_fp, train_total_fn, train_total_tn)

  # --- Validation Phase ---
  model.eval()

  # Initialize validation loss and confusion matrix components
  val_running_loss = 0.0
  val_total_tp, val_total_fp, val_total_fn, val_total_tn = 0, 0, 0, 0

  with torch.no_grad():
    for X_val_batch, y_val_batch in val_loader:
      X_val_batch, y_val_batch = X_val_batch.to(device).float(), y_val_batch.to(device).float()
      val_output = model(X_val_batch)

      # Get the validation loss
      val_loss = criterion(val_output, y_val_batch)
      val_running_loss += val_loss.item() * X_val_batch.size(0)

      # Convert to binary tensors
      val_probs = torch.sigmoid(val_output)
      val_preds = (val_probs > threshold).float()

      tp, fp, fn, tn = get_confusion_matrix_components(y_val_batch, val_preds)
      val_total_tp += tp
      val_total_fp += fp
      val_total_fn += fn
      val_total_tn += tn

  val_loss = val_running_loss / len(val_dataset)
  val_iou, val_dice, val_recall, val_precision, val_acc = calculate_final_metrics(val_total_tp, val_total_fp, val_total_fn, val_total_tn)

  return train_loss, train_iou, train_dice, val_loss, val_iou, val_dice


### Getting the device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

### training logic

In [None]:
def train(config=None):

  # Initialize a new wandb run
  with wandb.init(config=config):

    # this config will be set by Sweep Controller
    config = wandb.config

    train_dataset, train_loader, val_dataset, val_loader = get_datasets_dataloaders(config.augmentation_p, config.batch_size)
    model = get_model()
    criterion = DiceBCELoss(config.bce_weight)
    optimizer = get_optimizer(model, config.optimizer, config.learning_rate)

    for epoch in range(config.epochs + 1):
      train_loss, train_iou, train_dice, val_loss, val_iou, val_dice = train_epoch(model, train_dataset, train_loader, val_dataset, val_loader, criterion, optimizer, config.threshold, device)

      # Log all metrics to W&B
      wandb.log({
        "epoch": epoch,
        "train_loss": train_loss,
        "train_iou": train_iou,
        "train_dice": train_dice,
        "val_loss": val_loss,
        "val_iou": val_iou,
        "val_dice": val_dice,
    })

    print("Training for 5 epochs has completed!")


## **Initialize Sweep**

In [None]:
sweep_id = wandb.sweep(sweep_config, project="DeepLabv3+_Sweeps")

## **Activate Sweep Agents**

In [None]:
wandb.agent(sweep_id, train, count=5)