# Glaucoma Segmentation


## Imports

In [None]:
import albumentations as A
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
from albumentations.pytorch import ToTensorV2
from functools import partial
from tqdm import tqdm

from networks import *
from training import *
from utils import *

# prepare_origa_dataset('../data/ORIGA')
# prepare_drishti_dataset('../data/DRISHTI')
# prepare_rimone_dataset('../data/RIMONE', 1.0)

## Setup

In [None]:
IMAGE_DIR = '../data/Kaggle-ORIGA/Images_CenterNet_Cropped'
MASK_DIR = '../data/Kaggle-ORIGA/Masks_CenterNet_Cropped'
LOGS_DIR = '../logs/'
CHECKPOINT_DIR = '../checkpoints/'

NETWORK_NAME = 'refunet3+cbam'  # raunet++, refunet3+cbam, swinunet
OPTIMIZER = 'adam'
LOSS_FUNCTION = 'combo'
ARCHITECTURE = 'binary'  # multiclass, multilabel, binary, dual, cascade
BINARY_TARGET_CLASSES = [1, 2]
SCHEDULER = 'plateau'
SCALER = 'none'
DATASET = 'ORIGA'
BASE_CASCADE_MODEL = ''

IMAGE_HEIGHT, IMAGE_WIDTH = 128, 128
# IMAGE_HEIGHT, IMAGE_WIDTH = 224, 224
IN_CHANNELS, OUT_CHANNELS = 3, 1
LEARNING_RATE = 1e-4
BATCH_SIZE = 4
EPOCHS = 5
LAYERS = [16, 32, 48, 64, 80]
CLASS_WEIGHTS = None
SET_SIZES = [0.8, 0.1, 0.1]
DROPOUT_2D = 0.2
EARLY_STOPPING_PATIENCE = 11
LOG_INTERVAL = 5
SAVE_INTERVAL = 10
NUM_WORKERS = 0

USE_WANDB = False
POLAR_TRANSFORM = True
MULTI_SCALE_INPUT = False
DEEP_SUPERVISION = False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
PIN_MEMORY = True if DEVICE == 'cuda' else False

POSTPROCESSING = [
    to_numpy,
    unpack,
    lambda x: fill_holes(x, binary=True),
    lambda x: keep_largest_component(x, binary=True),
    lambda x: dilate(x, kernel_size=5, iterations=1),
    pack,
    to_tensor,
]

os.makedirs(LOGS_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f'''CONFIGURATION:
    PyTorch version: {torch.__version__}
    NumPy version: {np.__version__}
    OpenCV version: {cv.__version__}

    Network: {NETWORK_NAME}
    Architecture: {ARCHITECTURE}
    Optimizer: {OPTIMIZER}
    Loss function: {LOSS_FUNCTION}
    Scheduler: {SCHEDULER}
    Scaler: {SCALER}
    Using device: {DEVICE}

    Dataset: {DATASET}
    Dataset proportions: {SET_SIZES}
    Image directory: {IMAGE_DIR}
    Mask directory: {MASK_DIR}
    Input image height & width: {IMAGE_HEIGHT}x{IMAGE_WIDTH}
    Number of input channels: {IN_CHANNELS}
    Number of output channels: {OUT_CHANNELS}

    Layers: {LAYERS}
    Batch size: {BATCH_SIZE}
    Learning rate: {LEARNING_RATE}
    Epochs: {EPOCHS}
    Class weights: {CLASS_WEIGHTS}
    Dropout: {DROPOUT_2D}
    Early stopping patience: {EARLY_STOPPING_PATIENCE}

    Save interval: {SAVE_INTERVAL}
    Log interval: {LOG_INTERVAL}
    Number of workers: {NUM_WORKERS}
    Pin memory: {PIN_MEMORY}

    Weight & Biases: {USE_WANDB}
    Polar transform: {POLAR_TRANSFORM}
    Multi-scale input: {MULTI_SCALE_INPUT}
    Deep supervision: {DEEP_SUPERVISION}''')

## Dataset

In [None]:
polar_transform_partial = partial(polar_transform, radius_ratio=1.0)

train_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, interpolation=cv.INTER_AREA),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=1.0),
    A.CLAHE(p=1.0, clip_limit=2.0, tile_grid_size=(8, 8), always_apply=True),
    A.RandomBrightnessContrast(p=0.5),
    # A.GridDistortion(p=0.5, border_mode=cv.BORDER_CONSTANT),
    # A.MedianBlur(p=0.5),
    # A.RandomToneCurve(p=0.5),
    # A.MultiplicativeNoise(p=0.5),
    # A.Lambda(image=sharpen, p=1.0),
    A.Lambda(image=polar_transform_partial, mask=polar_transform_partial) if POLAR_TRANSFORM else A.Lambda(),
    # A.Lambda(image=keep_gray_channel),
    # A.Lambda(image=keep_red_channel),
    # A.Lambda(image=keep_green_channel),
    # A.Lambda(image=keep_blue_channel),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, interpolation=cv.INTER_AREA),
    A.CLAHE(p=1.0, clip_limit=2.0, tile_grid_size=(8, 8), always_apply=True),
    A.Lambda(image=polar_transform_partial, mask=polar_transform_partial) if POLAR_TRANSFORM else A.Lambda(),
    # A.Lambda(image=keep_gray_channel),
    # A.Lambda(image=keep_red_channel),
    # A.Lambda(image=keep_green_channel),
    # A.Lambda(image=keep_blue_channel),
    ToTensorV2(),
])

train_loader, val_loader, test_loader = load_dataset(
    [IMAGE_DIR], [MASK_DIR],
    None, None, *SET_SIZES,
    train_transform, val_transform, val_transform,
    BATCH_SIZE, PIN_MEMORY, NUM_WORKERS,
)

## Model

In [None]:
model = None
binary_model = None
hist = None

if NETWORK_NAME == 'raunet++':
    model = RAUnetPlusPlus(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, features=LAYERS,
        multi_scale_input=MULTI_SCALE_INPUT, deep_supervision=DEEP_SUPERVISION, dropout=DROPOUT_2D,
    )

if NETWORK_NAME == 'refunet3+cbam':
    model = RefUnet3PlusCBAM(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, features=LAYERS,
        multi_scale_input=MULTI_SCALE_INPUT, dropout=DROPOUT_2D,
    )

if NETWORK_NAME == 'swinunet':
    model = SwinUnet(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, img_size=224, patch_size=4,
    )

if NETWORK_NAME == 'dual-raunet++':
    model = DualRAUnetPlusPlus(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, features=LAYERS,
        multi_scale_input=MULTI_SCALE_INPUT, deep_supervision=DEEP_SUPERVISION, dropout=DROPOUT_2D,
    )

if NETWORK_NAME == 'dual-refunet3+cbam':
    model = DualRefUnet3PlusCBAM(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, features=LAYERS,
        multi_scale_input=MULTI_SCALE_INPUT, dropout=DROPOUT_2D,
    )

if NETWORK_NAME == 'dual-swinunet':
    model = DualSwinUnet(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, img_size=224, patch_size=4,
    )

assert model is not None, 'Invalid network name'

model = model.to(DEVICE)
init_model_weights(model)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = ComboLoss(num_classes=OUT_CHANNELS if ARCHITECTURE == 'multiclass' else 1, class_weights=CLASS_WEIGHTS)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5, verbose=True)
scaler = None

if BASE_CASCADE_MODEL:
    checkpoint = load_checkpoint(BASE_CASCADE_MODEL)
    binary_model = checkpoint['model']

## Training

In [None]:
if ARCHITECTURE == 'multiclass':
    hist = train_multiclass(
        model, criterion, optimizer, EPOCHS, DEVICE, train_loader, val_loader, scheduler, scaler,
        save_interval=SAVE_INTERVAL, early_stopping_patience=EARLY_STOPPING_PATIENCE,
        log_to_wandb=USE_WANDB, log_dir=LOGS_DIR, log_interval=LOG_INTERVAL, checkpoint_dir=CHECKPOINT_DIR,
        save_best_model=True, plot_examples='none', show_plots=False,
        inverse_transform=undo_polar_transform if POLAR_TRANSFORM else None,
    )

if ARCHITECTURE == 'multilabel':
    hist = train_multilabel(
        model, criterion, optimizer, EPOCHS, DEVICE, train_loader, val_loader, scheduler, scaler,
        save_interval=SAVE_INTERVAL, early_stopping_patience=EARLY_STOPPING_PATIENCE,
        log_to_wandb=USE_WANDB, log_dir=LOGS_DIR, log_interval=LOG_INTERVAL, checkpoint_dir=CHECKPOINT_DIR,
        save_best_model=True, plot_examples='none', show_plots=False,
        inverse_transform=undo_polar_transform if POLAR_TRANSFORM else None,
    )

if ARCHITECTURE == 'binary':
    hist = train_binary(
        model, criterion, optimizer, EPOCHS, DEVICE, train_loader, val_loader, scheduler, scaler,
        save_interval=SAVE_INTERVAL, early_stopping_patience=EARLY_STOPPING_PATIENCE,
        log_to_wandb=USE_WANDB, log_dir=LOGS_DIR, log_interval=LOG_INTERVAL, checkpoint_dir=CHECKPOINT_DIR,
        save_best_model=True, plot_examples='none', show_plots=False, target_ids=BINARY_TARGET_CLASSES,
        inverse_transform=undo_polar_transform if POLAR_TRANSFORM else None,
    )

if ARCHITECTURE == 'cascade':
    assert binary_model is not None, 'Base model not specified'
    hist = train_cascade(
        binary_model, model, criterion, optimizer, EPOCHS, DEVICE, train_loader, val_loader, scheduler, scaler,
        save_interval=SAVE_INTERVAL, early_stopping_patience=EARLY_STOPPING_PATIENCE,
        log_to_wandb=USE_WANDB, log_dir=LOGS_DIR, log_interval=LOG_INTERVAL, checkpoint_dir=CHECKPOINT_DIR,
        save_best_model=True, plot_examples='none', show_plots=False, postprocess=POSTPROCESSING,
        inverse_transform=undo_polar_transform if POLAR_TRANSFORM else None,
    )

if ARCHITECTURE == 'dual':
    hist = train_dual(
        model, criterion, criterion, optimizer, EPOCHS, DEVICE, train_loader, val_loader, scheduler, scaler,
        save_interval=SAVE_INTERVAL, early_stopping_patience=EARLY_STOPPING_PATIENCE,
        log_to_wandb=USE_WANDB, log_dir=LOGS_DIR, log_interval=LOG_INTERVAL, checkpoint_dir=CHECKPOINT_DIR,
        save_best_model=True, plot_examples='none', show_plots=False,
        inverse_transform=undo_polar_transform if POLAR_TRANSFORM else None,
    )


In [None]:
plot_history(hist)

## Testing

In [None]:
if ARCHITECTURE == 'multiclass':
    results = evaluate('multiclass', model, test_loader, criterion, DEVICE)

if ARCHITECTURE == 'multilabel':
    results = evaluate('multilabel', model, test_loader, criterion, DEVICE)

if ARCHITECTURE == 'binary':
    results = evaluate('binary', model, test_loader, criterion, DEVICE, class_ids=BINARY_TARGET_CLASSES)

if ARCHITECTURE == 'dual':
    results = evaluate('dual', model, test_loader, criterion, DEVICE)

if ARCHITECTURE == 'cascade':
    results = evaluate('cascade', model, test_loader, criterion, DEVICE, model0=binary_model)

In [None]:
if ARCHITECTURE == 'multiclass':
    plot_results_from_loader(
        'multiclass', test_loader, model, DEVICE, n_samples=4,
        save_path=f'{LOGS_DIR}/evaluation.png',
    )

if ARCHITECTURE == 'multilabel':
    plot_results_from_loader(
        'multilabel', test_loader, model, DEVICE, n_samples=4,
        save_path=f'{LOGS_DIR}/evaluation.png',
    )

if ARCHITECTURE == 'binary':
    plot_results_from_loader(
        'binary', test_loader, model, DEVICE, n_samples=4,
        save_path=f'{LOGS_DIR}/evaluation.png', class_ids=BINARY_TARGET_CLASSES,
    )

if ARCHITECTURE == 'dual':
    plot_results_from_loader(
        'dual', test_loader, model, DEVICE,
        n_samples=4, save_path=f'{LOGS_DIR}/evaluation.png',
    )

if ARCHITECTURE == 'cascade':
    plot_results_from_loader(
        'cascade', test_loader, model, DEVICE, n_samples=4,
        save_path=f'{LOGS_DIR}/evaluation.png', model0=binary_model,
    )

## Work in progress

In [None]:
# TODO
WEIGHT_DECAY = 1e-4
model = SwinUnet(in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, img_size=224, patch_size=4).to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
criterion = ComboLoss(OUT_CHANNELS)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1, verbose=True)



In [None]:
NUM_EPOCHS = 1

# Loop over epochs
for epoch in range(NUM_EPOCHS):
    # Training phase
    model.train()  # Set the model to training mode
    train_loss = 0.0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.float().to(DEVICE), labels.long().to(DEVICE)

        # Binarize labels
        labels[labels > 0] = 1

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Update training loss
        train_loss += loss.item() * inputs.size(0)

    # Calculate average loss over an epoch
    train_loss = train_loss / len(train_loader.dataset)

    # Validation phase
    model.eval()  # Set the model to evaluation mode
    val_loss = 0.0
    with torch.no_grad():  # No gradients needed for validation
        for inputs, labels in tqdm(val_loader):
            inputs, labels = inputs.float().to(DEVICE), labels.long().to(DEVICE)

            # Binarize labels
            labels[labels > 0] = 1

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Update validation loss
            val_loss += loss.item() * inputs.size(0)

    # Calculate average loss over an epoch
    val_loss = val_loss / len(val_loader.dataset)

    # Print epoch statistics
    print(f'Epoch {epoch + 1}/{NUM_EPOCHS}')
    print(f'Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

    # Step the scheduler
    scheduler.step()

torch.save(model, '.models/swinunet.pth')

In [None]:
# Show some examples
model.eval()
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.float().to(DEVICE), labels.long().to(DEVICE)

        # Binarize labels
        labels[labels > 0] = 1

        # Forward pass
        outputs = model(inputs)

        images = inputs.detach().cpu().numpy().transpose(0, 2, 3, 1)
        masks = labels.detach().cpu().numpy()
        preds = outputs.detach().cpu().numpy().squeeze()
        break

for img, mask, pred in zip(images, masks, preds):
    img = (img / img.max() * 255).astype(np.uint8)
    plt.figure(figsize=(16, 8))
    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.subplot(1, 3, 2)
    plt.imshow(mask)
    plt.subplot(1, 3, 3)
    plt.imshow(pred)
    plt.show()