# Model training and evaluation

This notebook includes the model training and evaluation for DeepLabv3 with ResNet101 and MobileNetV3 backbone layers.

In [None]:
import pickle
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torchvision import models

from utils import (
    get_device, pixel_accuracy, iou_score, dice_score,
    f1_score, batch_files, files_to_tensors
)

## Importing dataset

Data we use consist of POEM images and a corresponding 2D tensor comprising of the each pixel's label. The images are imported from a local directory and the tensors are store in a pickled dictionary containing the image file name as keys and tensors as values.

In [None]:
im_dir = "/Users/naman/Workspace/Data/BM5020-POEM/Annotated"

pkl_path = f"{im_dir}/annotations.pkl"
# pkl_path = "segmented_images.pkl"

with open(pkl_path, "rb") as file:
    data_dict = pickle.load(file)

data_dict

## Data preprocessing

Here 

In [None]:
all_files = list(data_dict.keys())

batch_size = 4
batched_files = batch_files(all_files, batch_size)

batched_files

## Loading the models

We load the models with classifier head as DeepLabv3 with 2 different backbones: Resnet101 and MobileNetV3. The classifier head in our model is used to classify each pixel and the backbone layer is used for feature extraction. Backbone layers are initialised with pretrained weights whereas the classifier head is initialised with random weights and 4 number of classes.

In [None]:
device = get_device()

In [None]:
classes = 4

deeplabv3_weights = models.segmentation.DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1
resnet101_weights = models.ResNet101_Weights.IMAGENET1K_V2
mobilenetv3_weights = models.mobilenet.MobileNet_V3_Large_Weights.IMAGENET1K_V2

deeplabv3_resnet101 = models.segmentation.deeplabv3.deeplabv3_resnet101(
    num_classes=classes, weights_backbone=resnet101_weights
).to(device)
deeplabv3_mobilenetv3 = models.segmentation.deeplabv3.deeplabv3_mobilenet_v3_large(
    num_classes=classes, weights_backbone=mobilenetv3_weights
).to(device)

In [None]:
deeplabv3_resnet101

In [None]:
deeplabv3_mobilenetv3

## Loading the optimizer and scheduler

We use the Adam optimizer with exponential scheduling. $\gamma$ in our scheduler is the factor multiplied after each step of the scheduler (which is taken after every epoch). Hence, the learning rate at $i\text{th}$ epoch will be $\text{initial\_lr} * \gamma^i$.

In [None]:
initial_lr = 1e-2
optimizer = torch.optim.Adam(deeplabv3_mobilenetv3.parameters(), lr=initial_lr)

gamma = 0.8
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

## Model training

In [None]:
loss_list = []
accuracy_list = []
iou_list = [[] for _ in range(classes)]
dice_list = [[] for _ in range(classes)]
f1_list = [[] for _ in range(classes)]

lr_vals = []

epochs = 40
batches = len(batched_files)
for epoch in range(epochs):

    epoch_loss = 0
    epoch_accuracy = 0
    epoch_iou_list = [0] * classes
    epoch_dice_list = [0] * classes
    epoch_f1_list = [0] * classes

    for batch in batched_files:

        inputs, labels = files_to_tensors(batch, im_dir, data_dict)
        inputs = inputs.to(device)
        labels = labels.to(device)
        output = deeplabv3_mobilenetv3(inputs)

        logits = output["out"]

        loss = F.cross_entropy(logits, labels)
        epoch_loss += loss.item()

        loss.backward()
        optimizer.step()

        predictions = logits.argmax(dim=1).cpu()
        labels = labels.cpu()
        epoch_accuracy += pixel_accuracy(predictions, labels)

        for i in range(classes):
            epoch_iou_list[i] += iou_score(predictions == i, labels == i)
            epoch_dice_list[i] += dice_score(predictions == i, labels == i)
            epoch_f1_list[i] += f1_score(predictions == i, labels == i)

    lr_val = optimizer.state_dict()["param_groups"][0]["lr"]
    lr_vals.append(lr_val)
    scheduler.step()

    print(f"Epoch {epoch + 1}/{epochs} average loss: {epoch_loss / batches}")

    loss_list.append(epoch_loss / batches)
    accuracy_list.append(epoch_accuracy / batches)
    for i in range(classes):
        iou_list[i].append(epoch_iou_list[i] / batches)
        dice_list[i].append(epoch_dice_list[i] / batches)
        f1_list[i].append(epoch_f1_list[i] / batches)

In [None]:
plt.plot(loss_list)
plt.xlabel("Epochs")
plt.ylabel("Train loss")
plt.show()

In [None]:
plt.plot(lr_vals)
plt.xlabel("Epochs")
plt.ylabel("Learning rate")
plt.show()

In [None]:
plt.plot(accuracy_list)
plt.xlabel("Epochs")
plt.ylabel("Pixel accuracy")
plt.show()

In [None]:
for class_ in range(classes):
    plt.plot(iou_list[class_], label=f"Class {class_ + 1}")
plt.xlabel("Epochs")
plt.ylabel(f"IoU score")
plt.legend()
plt.show()

In [None]:
for class_ in range(classes):
    plt.plot(dice_list[class_], label=f"Class {class_ + 1}")
plt.xlabel("Epochs")
plt.ylabel(f"Dice score")
plt.legend()
plt.show()

In [None]:
for class_ in range(classes):
    plt.plot(f1_list[class_], label=f"Class {class_ + 1}")
plt.xlabel("Epochs")
plt.ylabel(f"F1 score")
plt.legend()
plt.show()