# Glaucoma Segmentation


## Imports

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from collections import defaultdict
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

from models import *
from training import *
from utils import *

## Setup

In [None]:
IMAGE_DIR = '../data/ORIGA/Images_Cropped'
MASK_DIR = '../data/ORIGA/Masks_Cropped'
LOGS_DIR = '../logs/'
CHECKPOINT_DIR = '../checkpoints/'
IMAGE_HEIGHT, IMAGE_WIDTH = 128, 128
BATCH_SIZE = 4
LEARNING_RATE = 1e-4
EPOCHS = 3
LAYERS = [32, 64, 128, 256, 512]
EARLY_STOPPING_PATIENCE = 10
SAVE_INTERVAL = 10
NUM_WORKERS = 4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
PIN_MEMORY = True if DEVICE == 'cuda' else False
LOAD_MODEL = ''
USE_WANDB = False
DEEP_SUPERVISION = False

os.makedirs(LOGS_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f'PyTorch version: {torch.__version__}')
print(f'Using device: {DEVICE}')

## Dataset

In [None]:
example_ds = OrigaDataset(IMAGE_DIR, MASK_DIR, os.listdir(IMAGE_DIR)[:1], A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=1.0),  # rotate by 0, 90, 180, or 270 degrees
    A.Rotate(limit=30, p=0.33, border_mode=cv.BORDER_CONSTANT),
    A.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
    ToTensorV2(),
]))
example_loader = DataLoader(example_ds, batch_size=1, shuffle=True)

example_image, example_mask = next(iter(example_loader))
print(f'Image shape: {example_image.shape}')
print(f'Mask shape: {example_mask.shape}')

unique, counts = np.unique(example_mask, return_counts=True)
print(f'Unique values and their counts in mask: {dict(zip(unique, counts))}')

# Plot example augmented images and masks
fig, ax = plt.subplots(3, 6, figsize=(12, 6))
ax = ax.ravel()
for i in range(0, 3 * 6, 2):
    batch = next(iter(example_loader))
    images, masks = batch
    image, mask = images[0], masks[0]
    image = image.permute(1, 2, 0).numpy()
    mask = mask.numpy()
    ax[i].imshow(image)
    ax[i].axis('off')
    ax[i + 1].imshow(mask, cmap='gray')
    ax[i + 1].axis('off')
plt.show()

In [None]:
train_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=1.0),
    # A.Rotate(limit=30, p=0.25, border_mode=cv.BORDER_CONSTANT),
    # A.Normalize(mean=ORIGA_MEANS, std=ORIGA_STDS),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    # A.Normalize(mean=ORIGA_MEANS, std=ORIGA_STDS),
    ToTensorV2()
])

train_ds, val_ds, test_ds = load_origa(
    IMAGE_DIR, MASK_DIR, train_transform, val_transform, val_transform,
    train_size=0.7, val_size=0.15, test_size=0.15,
    # train_size=0.01, val_size=0.01, test_size=0.98,
)

print(f'Train size: {len(train_ds)}')
print(f'Validation size: {len(val_ds)}')
print(f'Test size: {len(test_ds)}')

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS)

## Model

In [None]:
# initialize model, loss, optimizer, scheduler, scaler, ...
model = UNet(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)
# model = UNetPlusPlus(in_channels=3, out_channels=3, features=LAYERS, deep_supervision=DEEP_SUPERVISION).to(DEVICE)
# model = UNet3Plus(in_channels=3, out_channels=3, features=LAYERS, deep_supervision=DEEP_SUPERVISION).to(DEVICE)

# model = AttentionUNet(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)
# model = InceptionUNet(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)

# model = ResUNet(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)
# model = RUNet(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)
# model = R2UNet(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)

# model = SqueezeUNet(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)

# model = R2AttentionUNet(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)
# model = R2UNetPlusPlus(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)

# model = ResAttentionUNetPlusPlus(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)
# model = RefUNet3PlusCBAM(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# criterion = nn.CrossEntropyLoss()  # softmax layer is already included inside nn.CrossEntropyLoss()
criterion = DiceLoss(num_classes=3, class_weights=[1.0, 1.0, 1.0])
# criterion = IoULoss(num_classes=3, class_weights=[1.0, 1.0, 1.0])
# criterion = FocalLoss(alpha=0.25, gamma=2)
# criterion = TverskyLoss(num_classes=3, alpha=0.5, beta=0.5, class_weights=[1.0, 1.0, 1.0])
# criterion = FocalTverskyLoss(num_classes=3, alpha=0.5, beta=0.5, gamma=1.0, class_weights=[1.0, 1.0, 1.0])

# scheduler = None
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5, verbose=True)

scaler = None
# scaler = torch.cuda.amp.GradScaler()

if LOAD_MODEL:
    load_checkpoint(LOAD_MODEL, model, optimizer)

## Training

In [None]:
hist = train_multiclass(
    model, criterion, optimizer, EPOCHS, DEVICE, train_loader, val_loader, scheduler, scaler,
    save_interval=SAVE_INTERVAL, early_stopping_patience=EARLY_STOPPING_PATIENCE,
    log_to_wandb=USE_WANDB, log_dir=LOGS_DIR, checkpoint_dir=CHECKPOINT_DIR,
    save_best_model=False,
)

In [None]:
# Plot metrics
used_metrics = sorted([m[6:] for m in hist.keys() if m.startswith('train_')])
fig, ax = plt.subplots(4, 4, figsize=(14, 8))
ax = ax.ravel()

for i, metric in enumerate(used_metrics):
    ax[i].plot(hist[f'train_{metric}'], label=f'train')
    ax[i].plot(hist[f'val_{metric}'], label=f'val')
    ax[i].set_title(metric[0].upper() + metric[1:].replace('_', ' '))
    if metric != 'loss':
        ax[i].set_ylim(top=1)
    ax[i].legend()

for ax in ax[len(used_metrics):]:
    ax.axis('off')

plt.tight_layout()
plt.show()

## Testing

In [None]:
results = evaluate(model, criterion, DEVICE, test_loader)

In [None]:
plot_results_from_loader(test_loader, model, DEVICE, types='all', n_samples=4, save_path=f'{LOGS_DIR}/evaluation.png')

## Work in progress

In [None]:
# torch.save(model.state_dict(), CHECKPOINT_DIR + 'model.pth')

checkpoint = torch.load(CHECKPOINT_DIR + 'model.pth')
model = UNet(in_channels=3, out_channels=3).to(DEVICE)
model.load_state_dict(checkpoint)

In [None]:
model = UNet(in_channels=3, out_channels=1).to(DEVICE)
loss = DiceLoss(num_classes=1)
# loss = GeneralizedDice(num_classes=1)
# loss = IoULoss(num_classes=1)
# loss = FocalLoss(num_classes=1)
# loss = TverskyLoss(num_classes=1)
# loss = FocalTverskyLoss(num_classes=1)
# loss = HausdorffLoss(idc=[0])
# loss = BoundaryLoss(idc=[0])
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(5):
    acc_loss = 0
    for images, masks in val_loader:
        images = images.float().to(DEVICE)
        masks = masks.long().to(DEVICE)
        masks = torch.where(masks > 0, torch.tensor(1).to(DEVICE), torch.tensor(0).to(DEVICE))

        outputs = model(images)
        loss_value = loss(outputs, masks)
        acc_loss += loss_value.item()

        optimizer.zero_grad()
        loss_value.backward()
        optimizer.step()

    # plot example
    images = images.cpu().numpy()
    masks = masks.cpu().numpy()
    # probs = F.softmax(outputs, dim=1)
    # preds = torch.argmax(probs, dim=1).cpu().numpy()
    probs = torch.sigmoid(outputs)
    preds = (probs > 0.5).float().cpu().numpy().transpose(0, 2, 3, 1)

    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(images[0].transpose(1, 2, 0) / 255.0)
    ax[1].imshow(masks[0])
    ax[2].imshow(preds[0])
    plt.show()

    print(f'Epoch {epoch + 1} loss:', acc_loss / len(val_loader))


In [None]:
from scipy.ndimage import distance_transform_edt as eucl_distance


def one_hot2hd_dist(seg: np.ndarray, resolution: tuple[float, float, float] = None,
                    dtype=None) -> np.ndarray:
    """
    Used for https://arxiv.org/pdf/1904.10030.pdf,
    implementation from https://github.com/JunMa11/SegWithDistMap
    """
    # Relasx the assertion to allow computation live on only a
    # subset of the classes
    # assert one_hot(torch.tensor(seg), axis=0)
    K: int = len(seg)

    res = np.zeros_like(seg, dtype=dtype)
    for k in range(K):
        posmask = seg[k].astype(np.uint8)

        if posmask.any():
            pos_edt = eucl_distance(posmask, sampling=resolution)
            res[k] = pos_edt

    return res


class HausdorffLoss(nn.Module):

    def __init__(self, idc):
        super(HausdorffLoss, self).__init__()
        self.idc: list[int] = idc

    def forward(self, logits, target):
        probs = F.softmax(logits, dim=1)
        target = F.one_hot(target, num_classes=probs.shape[1]).permute(0, 3, 1, 2)

        B, K, *xyz = probs.shape

        pc = probs[:, self.idc, ...].float()
        tc = target[:, self.idc, ...].float()
        assert pc.shape == tc.shape == (B, len(self.idc), *xyz)

        target_dm_npy: np.ndarray = np.stack([one_hot2hd_dist(tc[b].cpu().detach().numpy())
                                              for b in range(B)], axis=0)
        assert target_dm_npy.shape == tc.shape == pc.shape
        tdm = torch.tensor(target_dm_npy, device=probs.device, dtype=torch.float32)

        pred_segmentation = torch.argmax(probs, dim=1)
        pred_segmentation = F.one_hot(pred_segmentation, num_classes=K).permute(0, 3, 1, 2).cpu().detach().numpy()
        pred_dm_npy: np.nparray = np.stack([one_hot2hd_dist(pred_segmentation[b, self.idc, ...]) for b in range(B)],
                                           axis=0)
        assert pred_dm_npy.shape == tc.shape == pc.shape
        pdm = torch.tensor(pred_dm_npy, device=probs.device, dtype=torch.float32)

        delta = (pc - tc) ** 2
        dtm = tdm ** 2 + pdm ** 2

        multiplied = einsum("bkwh,bkwh->bkwh", delta, dtm)

        return multiplied.mean()

In [None]:
model = UNet(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss = HausdorffLoss(idc=[0, 1, 2])

# Example usage within the training loop
for epoch in range(5):
    acc_loss = 0
    for images, masks in val_loader:
        images = images.float().to(DEVICE)
        masks = masks.long().to(DEVICE)
        # print(f'{images.shape = }, {masks.shape = }')

        logits = model(images)
        # print(f'{logits.shape = }')
        seg_loss = loss(logits, masks)
        # print(f'{seg_loss = }')
        acc_loss += seg_loss.item()

        optimizer.zero_grad()
        seg_loss.backward()
        optimizer.step()
        # print()
        # break

    print(f'Epoch {epoch}: {acc_loss = }')

    # plot some images
    images = images.cpu().numpy()
    masks = masks.cpu().numpy()
    probs = F.softmax(logits, dim=1)
    preds = torch.argmax(probs, dim=1).cpu().numpy()
    # probs = torch.sigmoid(outputs)
    # preds = (probs > 0.5).float().cpu().numpy().transpose(0, 2, 3, 1)

    fig, ax = plt.subplots(1, 3, figsize=(9, 3))
    ax[0].imshow(images[0].transpose(1, 2, 0) / 255.0)
    ax[1].imshow(masks[0])
    ax[2].imshow(preds[0])
    plt.show()
    # break

## Boundary loss - Official code

In [None]:
from scipy.ndimage import distance_transform_edt as edt
from torch import einsum


def one_hot_to_dist_maps(one_hot):
    n_classes = len(one_hot)
    dist_maps = np.zeros_like(one_hot)

    for i in range(n_classes):
        pos_mask = one_hot[i].astype(np.uint8)

        if pos_mask.any():
            neg_mask = ~pos_mask
            pos_dist = edt(pos_mask)
            neg_dist = edt(neg_mask)
            dist_maps[i] = neg_dist - pos_dist

    return dist_maps


def labels_to_dist_maps(batched_labels, num_classes):
    batched_one_hot = F.one_hot(batched_labels, num_classes).permute(0, 3, 1, 2).cpu().numpy()

    batched_dist_maps = np.zeros_like(batched_one_hot)
    for i in range(len(batched_one_hot)):
        batched_dist_maps[i] = one_hot_to_dist_maps(batched_one_hot[i])

    return torch.from_numpy(batched_dist_maps).float()


class BoundaryLoss(nn.Module):

    def __init__(self, idc):
        super(BoundaryLoss, self).__init__()
        self.idc = idc
        self.num_classes = len(idc)

    def forward(self, logits, targets):
        probs = F.softmax(logits, dim=1)
        # print(f'{probs.shape = }')
        # print(f'{targets.shape = }')
        dist_maps = labels_to_dist_maps(targets, self.num_classes)
        # print(f'{dist_maps.shape = }')

        pc = probs[:, self.idc, ...].to(logits.device)
        # print(f'{pc.shape = }')
        dc = dist_maps[:, self.idc, ...].to(logits.device)
        # print(f'{dc.shape = }')

        multiplied = einsum("bkwh,bkwh->bkwh", pc, dc)
        return multiplied.mean()

In [None]:
model = UNet(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss = BoundaryLoss(idc=[0, 1, 2])

# Example usage within the training loop
for epoch in range(5):
    acc_loss = 0
    for images, masks in val_loader:
        images = images.float().to(DEVICE)
        masks = masks.long().to(DEVICE)
        # print(f'{images.shape = }, {masks.shape = }')

        logits = model(images)
        # print(f'{logits.shape = }')
        seg_loss = loss(logits, masks)
        # print(f'{seg_loss = }')
        acc_loss += seg_loss.item()

        optimizer.zero_grad()
        seg_loss.backward()
        optimizer.step()
        # print()
        # break

    print(f'Epoch {epoch}: {acc_loss = }')

    # plot some images
    images = images.cpu().numpy()
    masks = masks.cpu().numpy()
    probs = F.softmax(logits, dim=1)
    preds = torch.argmax(probs, dim=1).cpu().numpy()
    # probs = torch.sigmoid(outputs)
    # preds = (probs > 0.5).float().cpu().numpy().transpose(0, 2, 3, 1)

    fig, ax = plt.subplots(1, 3, figsize=(9, 3))
    ax[0].imshow(images[0].transpose(1, 2, 0) / 255.0)
    ax[1].imshow(masks[0])
    ax[2].imshow(preds[0])
    plt.show()
    # break