# Glaucoma Segmentation


## Imports

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from collections import defaultdict
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

from models import *
from utils import *

## Setup

In [None]:
IMAGE_DIR = '../data/ORIGA/Images_Cropped'
MASK_DIR = '../data/ORIGA/Masks_Cropped'
LOGS_DIR = '../logs/'
CHECKPOINT_DIR = '../checkpoints/'
IMAGE_HEIGHT, IMAGE_WIDTH = 128, 128
BATCH_SIZE = 4
LEARNING_RATE = 1e-4
EPOCHS = 3
LAYERS = [32, 64, 128, 256, 512]
EARLY_STOPPING_PATIENCE = 10
SAVE_INTERVAL = 10
NUM_WORKERS = 4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
PIN_MEMORY = True if DEVICE == 'cuda' else False
LOAD_MODEL = ''
USE_WANDB = False
DEEP_SUPERVISION = False

os.makedirs(LOGS_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f'PyTorch version: {torch.__version__}')
print(f'Using device: {DEVICE}')

## Dataset

In [None]:
example_ds = OrigaDataset(IMAGE_DIR, MASK_DIR, os.listdir(IMAGE_DIR)[:1], A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=1.0),  # rotate by 0, 90, 180, or 270 degrees
    A.Rotate(limit=30, p=0.33, border_mode=cv.BORDER_CONSTANT),
    A.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
    ToTensorV2(),
]))
example_loader = DataLoader(example_ds, batch_size=1, shuffle=True)

example_image, example_mask = next(iter(example_loader))
print(f'Image shape: {example_image.shape}')
print(f'Mask shape: {example_mask.shape}')

unique, counts = np.unique(example_mask, return_counts=True)
print(f'Unique values and their counts in mask: {dict(zip(unique, counts))}')

# Plot example augmented images and masks
fig, ax = plt.subplots(3, 6, figsize=(12, 6))
ax = ax.ravel()
for i in range(0, 3 * 6, 2):
    batch = next(iter(example_loader))
    images, masks = batch
    image, mask = images[0], masks[0]
    image = image.permute(1, 2, 0).numpy()
    mask = mask.numpy()
    ax[i].imshow(image)
    ax[i].axis('off')
    ax[i + 1].imshow(mask, cmap='gray')
    ax[i + 1].axis('off')
plt.show()

In [None]:
train_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=1.0),
    # A.Rotate(limit=30, p=0.25, border_mode=cv.BORDER_CONSTANT),
    # A.Normalize(mean=ORIGA_MEANS, std=ORIGA_STDS),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    # A.Normalize(mean=ORIGA_MEANS, std=ORIGA_STDS),
    ToTensorV2()
])

train_ds, val_ds, test_ds = load_origa(
    IMAGE_DIR, MASK_DIR, train_transform, val_transform, val_transform,
    train_size=0.7, val_size=0.15, test_size=0.15,
    # train_size=0.01, val_size=0.01, test_size=0.98,
)

print(f'Train size: {len(train_ds)}')
print(f'Validation size: {len(val_ds)}')
print(f'Test size: {len(test_ds)}')

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS)

## Model

In [None]:
# initialize model, loss, optimizer, scheduler, scaler, ...
model = UNet(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)
# model = UNetPlusPlus(in_channels=3, out_channels=3, features=LAYERS, deep_supervision=DEEP_SUPERVISION).to(DEVICE)
# model = UNet3Plus(in_channels=3, out_channels=3, features=LAYERS, deep_supervision=DEEP_SUPERVISION).to(DEVICE)

# model = AttentionUNet(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)
# model = InceptionUNet(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)

# model = ResUNet(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)
# model = RUNet(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)
# model = R2UNet(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)

# model = SqueezeUNet(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)

# model = R2AttentionUNet(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)
# model = R2UNetPlusPlus(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)

# model = ResAttentionUNetPlusPlus(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)
# model = RefUNet3PlusCBAM(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# criterion = nn.CrossEntropyLoss()  # softmax layer is already included inside nn.CrossEntropyLoss()
criterion = DiceLoss(num_classes=3, device=DEVICE, class_weights=[1.0, 1.0, 1.0])
# criterion = IoULoss(num_classes=3, device=DEVICE, class_weights=[1.0, 1.0, 1.0])
# criterion = FocalLoss(alpha=0.25, gamma=2)
# criterion = TverskyLoss(num_classes=3, alpha=0.5, beta=0.5, class_weights=[1.0, 1.0, 1.0], device=DEVICE)
# criterion = FocalTverskyLoss(num_classes=3, alpha=0.5, beta=0.5, gamma=1.0,
#                              class_weights=[1.0, 1.0, 1.0], device=DEVICE)

# scheduler = None
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5, verbose=True)

scaler = None
# scaler = torch.cuda.amp.GradScaler()

if LOAD_MODEL:
    load_checkpoint(LOAD_MODEL, model, optimizer)

## Training

In [None]:
hist = train(
    model, criterion, optimizer, EPOCHS, DEVICE, train_loader, val_loader, scheduler, scaler,
    save_interval=SAVE_INTERVAL, early_stopping_patience=EARLY_STOPPING_PATIENCE,
    log_to_wandb=USE_WANDB, log_dir=LOGS_DIR, checkpoint_dir=CHECKPOINT_DIR,
    save_best_model=False,
)

In [None]:
# Plot metrics
used_metrics = sorted([m[6:] for m in hist.keys() if m.startswith('train_')])
fig, ax = plt.subplots(4, 4, figsize=(14, 8))
ax = ax.ravel()

for i, metric in enumerate(used_metrics):
    ax[i].plot(hist[f'train_{metric}'], label=f'train')
    ax[i].plot(hist[f'val_{metric}'], label=f'val')
    ax[i].set_title(metric[0].upper() + metric[1:].replace('_', ' '))
    if metric != 'loss':
        ax[i].set_ylim(top=1)
    ax[i].legend()

for ax in ax[len(used_metrics):]:
    ax.axis('off')

plt.tight_layout()
plt.show()

## Testing

In [None]:
results = evaluate(model, criterion, DEVICE, test_loader)

In [None]:
plot_results_from_loader(test_loader, model, DEVICE, types='all', n_samples=5, save_path=f'{LOGS_DIR}/evaluation.png')

## Work in progress

In [None]:
# torch.save(model.state_dict(), CHECKPOINT_DIR + 'model.pth')

checkpoint = torch.load(CHECKPOINT_DIR + 'model.pth')
model = UNet(in_channels=3, out_channels=3).to(DEVICE)
model.load_state_dict(checkpoint)

In [None]:
from scipy.ndimage import distance_transform_edt


class HausdorffLoss(nn.Module):
    def __init__(self, num_classes=3, ignore_index=0):
        super(HausdorffLoss, self).__init__()
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.alpha = 2

    def forward(self, logits, target):
        probs = F.softmax(logits, dim=1)
        preds = torch.argmax(logits, dim=1)
        loss = 0

        for class_idx in range(self.num_classes):
            if class_idx == self.ignore_index:
                continue

            pred_class = (preds >= class_idx).float().requires_grad_(True)
            target_class = (target >= class_idx).float().requires_grad_(True)

            pred_dist = self.compute_distance_map(pred_class)
            target_dist = self.compute_distance_map(target_class)

            forward_hd = torch.max(pred_dist * target_class)
            backward_hd = torch.max(target_dist * pred_class)

            hausdorff_dist = torch.pow(forward_hd, self.alpha) + torch.pow(backward_hd, self.alpha)
            hausdorff_loss = torch.abs(pred_class - target_class) * hausdorff_dist

            loss += hausdorff_loss.mean()

        loss /= (self.num_classes - 1)

        return loss

    def compute_distance_map(self, mask):
        mask_np = mask.squeeze().detach().cpu().numpy().astype(np.float32)
        dist_map = np.zeros_like(mask_np, dtype=np.float32)

        for i in range(mask_np.shape[0]):
            dist_map[i] += distance_transform_edt(mask_np[i])
            dist_map[i] += distance_transform_edt(1 - mask_np[i])

        # _, ax = plt.subplots(1, mask_np.shape[0], figsize=(10, 10))
        # for i in range(dist_map.shape[0]):
        #     ax[i].imshow(dist_map[i])
        # plt.show()

        dist_map = torch.from_numpy(dist_map).unsqueeze(0).to(mask.device)

        return dist_map


class HausdorffDTLoss(nn.Module):
    """Binary Hausdorff loss based on distance transform"""

    def __init__(self, alpha=2.0, **kwargs):
        super(HausdorffDTLoss, self).__init__()
        self.alpha = alpha

    @torch.no_grad()
    def distance_field(self, img: np.ndarray) -> np.ndarray:
        field = np.zeros_like(img)

        for i in range(len(img)):
            fg_mask = img[i] > 0.5

            if fg_mask.any():
                bg_mask = ~fg_mask

                fg_dist = distance_transform_edt(fg_mask)
                bg_dist = distance_transform_edt(bg_mask)

                field[i] = fg_dist + bg_dist

        return field

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Uses one binary channel: 1 - fg, 0 - bg
        pred: (b, 1, x, y, z) or (b, 1, x, y)
        target: (b, 1, x, y, z) or (b, 1, x, y)
        """

        pred_dt = torch.from_numpy(self.distance_field(pred.cpu().numpy())).float()
        target_dt = torch.from_numpy(self.distance_field(target.cpu().numpy())).float()

        pred_error = (pred - target) ** 2
        distance = pred_dt ** self.alpha + target_dt ** self.alpha

        dt_field = pred_error * distance
        return dt_field.mean()



In [None]:
hausdorff_loss = HausdorffLoss()

# images, masks = next(iter(train_loader))
# images = images.float().to(DEVICE)
# masks = masks.long().to(DEVICE)

print(f'{images.shape = }')
print(f'{masks.shape = }')

logits = model(images)
print(f'{logits.shape = }')
loss = hausdorff_loss(logits, masks)
print(f'Hausdorff loss: {loss.item()}, {loss.shape = }')

loss.backward()


In [None]:
class BoundaryLoss(nn.Module):
    def __init__(self):
        super(BoundaryLoss, self).__init__()

    def forward(self, logits, target):
        probs = F.softmax(logits, dim=1)
        predictions = torch.argmax(probs, dim=1)

        # Compute the boundary maps for both prediction and target
        pred_boundary = self.compute_boundary_map(predictions)
        target_boundary = self.compute_boundary_map(target)

        # Calculate the boundary loss
        loss = torch.mean((pred_boundary - target_boundary) ** 2)

        return loss

    def compute_boundary_map(self, mask):
        # Define the horizontal and vertical sobel filters
        sobel_x = torch.Tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).to(DEVICE)
        sobel_y = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).to(DEVICE)

        # Compute the horizontal and vertical gradients using the sobel filters
        G_x = F.conv2d(mask, sobel_x, padding=1)
        G_y = F.conv2d(mask, sobel_y, padding=1)

        # Calculate the magnitude of the gradient
        G = torch.sqrt(torch.pow(G_x, 2) + torch.pow(G_y, 2))

        # Normalize the gradient magnitude
        G = G / torch.max(G)

        return G

In [None]:
model = UNet(in_channels=3, out_channels=3, features=LAYERS).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
cross_entropy = nn.CrossEntropyLoss()
lambda_boundary = 0.5

# Instantiate the boundary loss
boundary_loss = BoundaryLoss()

# Example usage within the training loop
for images, masks in val_loader:
    images = images.float().to(DEVICE)
    masks = masks.long().to(DEVICE)

    # Forward pass to obtain the predicted segmentation mask
    logits = model(images)

    # Compute the segmentation loss (e.g., cross-entropy)
    seg_loss = cross_entropy(logits, masks)

    # Compute the boundary loss
    bnd_loss = boundary_loss(logits, masks)

    # Combine the losses (adjust the weights if necessary)
    total_loss = seg_loss + lambda_boundary * bnd_loss

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    break


In [None]:
# images = images.float().to(DEVICE)
# masks = masks.long().to(DEVICE)
# outputs = model(images)
# preds = torch.argmax(outputs, dim=1)

metrics = get_performance_metrics(masks, preds)
metrics

In [None]:
old_metrics