# 🔬 Larynx segmentation using transformers

### 📦 Imports

In [None]:
import sys
sys.path.append('utils')
from utils.notebook_utils import generate_model_name, get_nb_filename
from utils.train_logs import display_metrics, pretty_time
from utils.metrics import Metrics
from utils.preprocessing import preprocess_masks, split_dataset
from utils.plotting import plot_image_with_mask
from utils.dataset import CleDataset, custom_collate

from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import tabulate
from time import perf_counter
from IPython.display import clear_output

import torch
import torch.optim as optim
from torch.utils.data import DataLoader  
from tqdm.auto import tqdm as tq
from torchvision.transforms import v2 as T
from torchvision.transforms import InterpolationMode
from transformers import SegformerForSemanticSegmentation, SegformerConfig

### 🔢 Constants

Define constants

In [None]:
data_type = torch.float32
img_size = (424, 530)
num_classes = 3
batch_size = 2

background_class = 0
trachea_class = 1
supraglottis_class = 2

use_batch_transforms = True
mix_batch_transforms = False

actual_batch_size = 2 * batch_size if use_batch_transforms else batch_size

model_filename = generate_model_name(get_nb_filename())

Use CUDA/GPU if available

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

### 📂 Data

Define paths for the data

In [None]:
folder_name = '/path/to/data/folder'
path = f'data/{folder_name}'

data_path = Path(path)
images = data_path/'images'
mask_data = data_path/'result.json'

Connecting images with their corresponding masks

In [None]:
dataset = preprocess_masks(mask_data)

Split dataset

In [None]:
train, valid, test = split_dataset(dataset)

### 🛠️ Helper functions

In [None]:
'''
Creates a DataLoader object from a dataset
'''
def create_dataloader(dataset):
  return DataLoader(
    CleDataset(
      images_with_annotations=dataset,
      data_folder=images,
      img_size=img_size,
      batch_size=batch_size,
      use_batch_transforms=use_batch_transforms,
      mix_batch_transforms=mix_batch_transforms
    ),
    batch_size=batch_size,
    shuffle=True,
    collate_fn=custom_collate
  )

In [None]:
'''
Upsample/resizes the logits to the original image size
'''
def upsamle_logits(logits):
  return T.Resize(size=img_size, interpolation=InterpolationMode.BILINEAR)(logits)

### 🔄 DataLoaders

Create data loaders

In [None]:
train_dataloader = create_dataloader(train)
valid_dataloader = create_dataloader(valid)
test_dataloader = create_dataloader(test)

Print the length of the data loaders

In [None]:
print(f"Train batches: {len(train_dataloader)}")
print(f"Valid batches: {len(valid_dataloader)}")
print(f"Test batches: {len(test_dataloader)}")

### 👁 Display images

Display sample training image and mask

In [None]:
train_iterator = iter(train_dataloader)
train_features, train_labels = next(train_iterator)

for i in range(len(train_features)):
  img = train_features[i].squeeze()
  label = train_labels[i]
  plot_image_with_mask(img, label)

### 🧠 Model

Define the configuration of the transformer model

In [None]:
# MiT-b2
config = SegformerConfig(
    image_size=img_size,
    num_channels=3,
    num_labels=3,
    depths=[3, 4, 6, 3],
    hidden_sizes=[64, 128, 320, 512],
    decoder_hidden_size=768
)

Create the model

In [None]:
model = SegformerForSemanticSegmentation(config).to(device)

### 🎯 Training

Configure the optimizer function

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)

Train the model

In [None]:
n_epochs = 2000
early_stop_limit = 10
epochs_since_improvement = 0

epochs_list, train_loss_list, valid_loss_list, valid_loss_decreased_list, dice_list, precision_list, recall_list, f1_list, lr_rate_list, duration_list, iou_list = [], [], [], [], [], [], [], [], [], [], []
valid_loss_min = np.Inf

for epoch in range(1, n_epochs + 1):
    clear_output(wait=False)
    display_metrics([epochs_list, valid_loss_decreased_list, train_loss_list, valid_loss_list, dice_list, iou_list, f1_list, precision_list, recall_list, duration_list])
    train_loss, valid_loss = 0., 0.
    metrics = Metrics(device)

    model.train()
    bar = tq(train_dataloader, postfix={"train_loss":0.})
    epochs_list.append(epoch)
    start_time = perf_counter()

    for data, target in bar:
        data.requires_grad = True
        target.requires_grad = True
        optimizer.zero_grad()

        data = data.to(device=device)
        target = target.to(device=device)
        output = model(data, labels=target.argmax(dim=1))

        loss = output.loss
        loss.backward()

        optimizer.step()
        train_loss += loss.item()

        bar.set_postfix({"train_loss":loss.item()})

    model.eval()
    del data, target
    with torch.no_grad():
        bar = tq(valid_dataloader, postfix={"valid_loss":0.0, "dice_score":0.0})
        for data, target in bar:
            data = data.to(device=device)
            target = target.to(device=device)
            output = model(data, labels=target.argmax(dim=1))

            loss = output.loss
            valid_loss += loss.item()

            logits_upsampled = upsamle_logits(output.logits)
            logits_soft = torch.softmax(logits_upsampled, dim=1)
            
            metrics.accumulate(logits_soft, target)

            bar.set_postfix(ordered_dict={"valid_loss": loss.item()})

    # Calculate average losses and metrics
    train_loss = train_loss/len(train_dataloader)
    valid_loss = valid_loss/len(valid_dataloader)
    valid_loss_decreased = valid_loss <= valid_loss_min
    dice_score, iou_score, precision, recall, f1 = metrics.get_value_and_reset(n_batches=len(valid_dataloader))

    # Append losses to lists
    train_loss_list.append(train_loss)
    valid_loss_list.append(valid_loss)
    valid_loss_decreased_list.append(valid_loss_decreased)
    dice_list.append(dice_score)
    iou_list.append(iou_score)
    precision_list.append(precision)
    recall_list.append(recall)
    f1_list.append(f1)
    lr_rate_list.append([param_group['lr'] for param_group in optimizer.param_groups])
    duration_list.append(perf_counter() - start_time)
    
    # Save model if validation loss has decreased
    if valid_loss_decreased:
        torch.save(model.state_dict(), model_filename)
        valid_loss_min = valid_loss
        epochs_since_improvement = 0
    else:
        epochs_since_improvement += 1
        if epochs_since_improvement >= early_stop_limit:
            total_duration_to_stop = sum(duration_list[:-early_stop_limit])
            clear_output(wait=False)
            print(f"🛑 Early stopping. No improvement since epoch {epoch - early_stop_limit} epochs. Duration: {pretty_time(total_duration_to_stop)}")
            display_metrics([epochs_list, valid_loss_decreased_list, train_loss_list, valid_loss_list, dice_list, iou_list, f1_list, precision_list, recall_list, duration_list])
            break

### 📊 Loss graphs

This section displays graphs for loss and the metrics

#### Loss

In [None]:
epochs_to_exclude = 0

plt.figure(figsize=(6,6))
plt.plot(train_loss_list[epochs_to_exclude:],  marker='o', label="Training Loss", color='blue')
plt.plot(valid_loss_list[epochs_to_exclude:],  marker='o', label="Validation Loss", color='orange')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Training Loss', 'Validation Loss'])
plt.show()

#### Precision, Recall, F1

In [None]:
epochs_to_exclude = 0
index = 0

precision_list_cpu = [tensor.cpu()[index] for tensor in precision_list]
recall_list_cpu = [tensor.cpu()[index] for tensor in recall_list]
f1_list_cpu = [tensor.cpu()[index] for tensor in f1_list]

plt.figure(figsize=(6,6))
plt.title('Tranchea')
plt.plot(precision_list_cpu[epochs_to_exclude:],  marker='o', label="Precision", color='green')
plt.plot(recall_list_cpu[epochs_to_exclude:],  marker='o', label="Recall", color='red')
plt.plot(f1_list_cpu[epochs_to_exclude:],  marker='o', label="F1", color='black')
plt.legend(['Precision', 'Recall', 'F1'])
plt.show()

In [None]:
epochs_to_exclude = 0
index = 1

precision_list_cpu = [tensor.cpu()[index] for tensor in precision_list]
recall_list_cpu = [tensor.cpu()[index] for tensor in recall_list]
f1_list_cpu = [tensor.cpu()[index] for tensor in f1_list]

plt.figure(figsize=(6,6))
plt.title('Supraglottis')
plt.plot(precision_list_cpu[epochs_to_exclude:],  marker='o', label="Precision", color='green')
plt.plot(recall_list_cpu[epochs_to_exclude:],  marker='o', label="Recall", color='red')
plt.plot(f1_list_cpu[epochs_to_exclude:],  marker='o', label="F1", color='black')
plt.legend(['Precision', 'Recall', 'F1'])
plt.show()

#### Dice

In [None]:
index = 0
dice_list_1_cpu = [tensor.cpu()[index] for tensor in dice_list]

plt.figure(figsize=(6,6))
plt.plot(dice_list_1_cpu)
plt.title('Tranchea Dice score')
plt.ylabel('Dice')
plt.legend(['Dice score'])
plt.show()

In [None]:
index = 1
dice_list_2_cpu = [tensor.cpu()[index] for tensor in dice_list]

plt.figure(figsize=(6,6))
plt.plot(dice_list_2_cpu)
plt.title('Supraglottis Dice score')
plt.ylabel('Dice')
plt.legend(['Dice score'])
plt.show()

### 🔍 Evaluation

In [None]:
'''
Calculates the dice score for a single prediction
'''
def dice_score_single(pred, y, dim):
  prediction, target = torch.softmax(pred, dim=0).argmax(dim=0), y[dim, :]
  prediction = torch.where(prediction == dim, 1, 0)
  inter = torch.sum(prediction * target).item()
  union = torch.sum(prediction).item() + torch.sum(target).item()
  return 2. * inter/union if union > 0 else None

In [None]:
plot_test_images = True

Evaluate the model by displaying the predicted mask and calculating the metrics for the test data

In [None]:
dataloader_to_test = test_dataloader
iterator = iter(dataloader_to_test)
metrics = Metrics(device)

for batch in range(len(dataloader_to_test)):
  img_batch, label_batch = next(iterator)
  if len(img_batch) != actual_batch_size: break
  logits = model(img_batch.to(device=device)).logits.cpu()
  logits_upsampled = upsamle_logits(logits)
  logits_soft = torch.softmax(logits_upsampled, dim=1)

  metrics.accumulate(logits_soft, label_batch)

  if plot_test_images:
    for i in range(len(img_batch)):
      # Plot
        print(f"Batch {batch}, image {i}")
        pred_argmax = logits_soft[i].argmax(dim=0)
        background = torch.where(pred_argmax == background_class, 1, 0)
        trachea_pixels = torch.where(pred_argmax == trachea_class, 1, 0)
        supraglottis_pixels = torch.where(pred_argmax == supraglottis_class, 1, 0)
        plot_image_with_mask(img_batch[i], [background, trachea_pixels, supraglottis_pixels])

        # Single dice
        single_dice_1 = dice_score_single(logits_soft[i], label_batch[i], dim=1)
        single_dice_2 = dice_score_single(logits_soft[i], label_batch[i], dim=2)
        print(f"Single dice (tranchea): {single_dice_1}")
        print(f"Single dice (supraglottis): {single_dice_2}\n")

dice_score, iou_score, precision, recall, f1 = metrics.get_value_and_reset(n_batches=len(dataloader_to_test))

data = [
        ["Metric", "Trachea", "Supraglottis"],
        ["Dice score", dice_score[0], dice_score[1]],
        ["IoU", iou_score[0], iou_score[1]],
        ["F1", f1[0], f1[1]],
        ["Precision", precision[0], precision[1]],
        ["Recall", recall[0], recall[1]]
       ]
table = tabulate.tabulate(data, tablefmt='html', headers='firstrow', floatfmt='0.4f')
table