# Glaucoma Segmentation

**Author:** Bc. Ákos Kappel
**Year:** 2022 - 2024


In [None]:
import albumentations as A
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from albumentations.pytorch import ToTensorV2
from functools import partial

from modules import *
from networks import *
from training import *

## Config

In [None]:
# Main
NETWORK_NAME = 'refunet3+cbam'  # raunet++, refunet3+cbam, swinunet
ARCHITECTURE = 'cascade'  # multiclass, multilabel, binary, cascade, dual
USE_WANDB = False

# Dataset
IMAGE_SIZE = 256
BATCH_SIZE = 4
POLAR_TRANSFORM = False
_BASE_DIR = '../data'
TRAIN_IMAGES_DIR = [
    # f'{_BASE_DIR}/ORIGA/ROI/TrainImages',
    f'{_BASE_DIR}/DRISHTI/ROI/TrainImages',
]
TRAIN_MASKS_DIR = [
    # f'{_BASE_DIR}/ORIGA/ROI/TrainMasks',
    f'{_BASE_DIR}/DRISHTI/ROI/TrainMasks',
]
VAL_IMAGES_DIR = [
    # f'{_BASE_DIR}/ORIGA/ROI/TestImages',
    f'{_BASE_DIR}/DRISHTI/ROI/TestImages',
]
VAL_MASKS_DIR = [
    # f'{_BASE_DIR}/ORIGA/ROI/TestMasks',
    f'{_BASE_DIR}/DRISHTI/ROI/TestMasks',
]

NUM_WORKERS = 0
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
PIN_MEMORY = True if DEVICE == 'cuda' else False
DEVICE = torch.device(DEVICE)

# Model
IN_CHANNELS, OUT_CHANNELS = 3, 1
LEARNING_RATE = 1e-4
LAYERS = [16, 32, 48, 64, 80]
DROPOUT_2D = 0.2
CLASS_WEIGHTS = None
BASE_CASCADE_MODEL = '../models/normal/binary-RefUnet3PlusCBAM-model.pth'
MULTI_SCALE_INPUT = False
DEEP_SUPERVISION = False

# Training
EPOCHS = 3
EARLY_STOPPING_PATIENCE = 11
LOG_INTERVAL = 10
SAVE_INTERVAL = 10
OD_LOSS_WEIGHT = 1.0
OC_LOSS_WEIGHT = 5.0
LOGS_DIR = '../logs/'
CHECKPOINT_DIR = '../checkpoints/'
BINARY_LABELS = [1, 2]

# Extra notes
OPTIMIZER = 'adam'
LOSS_FUNCTION = 'combo'
SCHEDULER = 'plateau'
SCALER = 'none'
DATASET = 'ORIGA'

In [None]:
if USE_WANDB:
    wandb.login()

config = {
    'image_size': (IMAGE_SIZE, IMAGE_SIZE),
    'input_channels': IN_CHANNELS,
    'output_channels': OUT_CHANNELS,
    'layers': LAYERS,
    'batch_size': BATCH_SIZE,
    'learning_rate': LEARNING_RATE,
    'mode': ARCHITECTURE,
    'model': NETWORK_NAME,
    'loss': LOSS_FUNCTION,
    'optimizer': OPTIMIZER,
    'scheduler': SCHEDULER,
    'epochs': EPOCHS,
    'class_weights': CLASS_WEIGHTS,
    'dropout': DROPOUT_2D,
    'early_stopping_patience': EARLY_STOPPING_PATIENCE,
    'dataset': DATASET,
    'polar_transform': POLAR_TRANSFORM,
    'multi_scale_input': MULTI_SCALE_INPUT,
    'deep_supervision': DEEP_SUPERVISION,
}

if USE_WANDB:
    wandb.init(project='DP-Glaucoma', config=config)
#     wandb.init(project='DP-Glaucoma', config=config, resume=True, id='')

## Dataset

In [None]:
polar_transform_partial = partial(polar_transform, radius_ratio=1.0)

train_transform = A.Compose([
    A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE, interpolation=cv.INTER_AREA),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=1.0),
    A.RandomBrightnessContrast(p=0.5),
    A.RandomToneCurve(p=0.5),
    A.MultiplicativeNoise(p=0.5),
    A.Lambda(image=sharpen, p=1.0),
    A.Lambda(image=polar_transform_partial, mask=polar_transform_partial) if POLAR_TRANSFORM else A.Lambda(),
    A.Normalize(),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE, interpolation=cv.INTER_AREA),
    A.Lambda(image=sharpen, p=1.0),
    A.Lambda(image=polar_transform_partial, mask=polar_transform_partial) if POLAR_TRANSFORM else A.Lambda(),
    A.Normalize(),
    ToTensorV2(),
])

train_loader = load_dataset(
    TRAIN_IMAGES_DIR, TRAIN_MASKS_DIR, train_transform,
    BATCH_SIZE, NUM_WORKERS, PIN_MEMORY, shuffle=True,
)
val_loader = load_dataset(
    VAL_IMAGES_DIR, VAL_MASKS_DIR, val_transform,
    BATCH_SIZE, NUM_WORKERS, PIN_MEMORY, shuffle=False,
)

images, masks = next(iter(train_loader))
images = images[:4].float()
masks = masks[:4].long()
fig, ax = plt.subplots(2, 4, figsize=(12, 6))
ax = ax.flatten()
for i, (image, mask) in enumerate(zip(images, masks)):
    image -= image.min()
    image /= image.max()
    ax[i].imshow(image.permute(1, 2, 0))
    ax[i + 4].imshow(mask)
plt.tight_layout()
plt.show()

## Model

In [None]:
model = None
binary_model = None
hist = None

if ARCHITECTURE == 'dual' and 'dual-' not in NETWORK_NAME:
    NETWORK_NAME = 'dual-' + NETWORK_NAME

if NETWORK_NAME == 'raunet++':
    model = RAUnetPlusPlus(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, features=LAYERS,
        multi_scale_input=MULTI_SCALE_INPUT, deep_supervision=DEEP_SUPERVISION, dropout=DROPOUT_2D,
    )

if NETWORK_NAME == 'refunet3+cbam':
    model = RefUnet3PlusCBAM(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, features=LAYERS,
        multi_scale_input=MULTI_SCALE_INPUT, dropout=DROPOUT_2D,
    )

if NETWORK_NAME == 'swinunet':
    model = SwinUnet(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, img_size=224, patch_size=4,
    )

if NETWORK_NAME == 'dual-raunet++':
    model = DualRAUnetPlusPlus(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, features=LAYERS,
        multi_scale_input=MULTI_SCALE_INPUT, deep_supervision=DEEP_SUPERVISION, dropout=DROPOUT_2D,
    )

if NETWORK_NAME == 'dual-refunet3+cbam':
    model = DualRefUnet3PlusCBAM(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, features=LAYERS,
        multi_scale_input=MULTI_SCALE_INPUT, dropout=DROPOUT_2D,
    )

if NETWORK_NAME == 'dual-swinunet':
    model = DualSwinUnet(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, img_size=224, patch_size=4,
    )

assert model is not None, f'Invalid network name: {NETWORK_NAME}'

model = model.to(DEVICE)
init_model_weights(model)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)
# optimizer = optim.RMSprop(model.parameters(), lr=LEARNING_RATE)
# optimizer = optim.Adadelta(model.parameters(), lr=LEARNING_RATE)
# optimizer = optim.Adagrad(model.parameters(), lr=LEARNING_RATE)
# optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)

num_classes = OUT_CHANNELS if ARCHITECTURE == 'multiclass' else 1
criterion = ComboLoss(num_classes=num_classes, class_weights=CLASS_WEIGHTS)
# criterion = DiceLoss(num_classes=num_classes, class_weights=CLASS_WEIGHTS)
# criterion = GeneralizedDice(num_classes=num_classes, class_weights=CLASS_WEIGHTS)
# criterion = IoULoss(num_classes=num_classes, class_weights=CLASS_WEIGHTS)
# criterion = FocalLoss(num_classes=num_classes)
# criterion = TverskyLoss(num_classes=num_classes, class_weights=CLASS_WEIGHTS, alpha=0.7, beta=0.3)
# criterion = FocalTverskyLoss(num_classes=num_classes, class_weights=CLASS_WEIGHTS, alpha=0.3, beta=0.7)
# criterion = BoundaryLoss(num_classes=num_classes, class_weights=CLASS_WEIGHTS)
# criterion = HausdorffLoss(num_classes=num_classes, class_weights=CLASS_WEIGHTS)
# criterion = EdgeLoss(num_classes=num_classes, class_weights=CLASS_WEIGHTS)
# criterion = CrossEntropyLoss(num_classes=num_classes)
# criterion = SensitivitySpecificityLoss(num_classes=num_classes, class_weights=CLASS_WEIGHTS, alpha=1.0, beta=1.0)
# criterion = CompositeLoss([
#     ComboLoss(num_classes=num_classes, class_weights=CLASS_WEIGHTS),
#     FocalTverskyLoss(num_classes=num_classes, class_weights=CLASS_WEIGHTS),
#     BoundaryLoss(num_classes=num_classes, class_weights=CLASS_WEIGHTS),
# ], weights=[0.5, 1.5, 1.0])

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5, verbose=True)
# scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
#     optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1, verbose=True
# )
scaler = None

if BASE_CASCADE_MODEL:
    checkpoint = load_checkpoint(BASE_CASCADE_MODEL, map_location=DEVICE)
    binary_model = checkpoint['model']

## Training

In [None]:
os.makedirs(LOGS_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

hist = train(
    ARCHITECTURE, model, criterion, optimizer, EPOCHS, DEVICE, train_loader, val_loader, scheduler, scaler,
    binary_labels=BINARY_LABELS,  # binary
    binary_model=binary_model, inter_processing=interprocess,  # cascade
    od_loss_weight=OD_LOSS_WEIGHT, oc_loss_weight=OC_LOSS_WEIGHT,  # dual
    save_interval=SAVE_INTERVAL, early_stopping_patience=EARLY_STOPPING_PATIENCE,
    log_to_wandb=USE_WANDB, log_dir=LOGS_DIR, log_interval=LOG_INTERVAL, checkpoint_dir=CHECKPOINT_DIR,
    save_best_model=True, plot_examples='none', show_plots=False,
    inverse_transform=undo_polar_transform if POLAR_TRANSFORM else None,
)

In [None]:
plot_history(hist, figsize=(14, 12))

## Testing

In [None]:
results = evaluate(
    ARCHITECTURE, model, val_loader, DEVICE, criterion,
    binary_labels=BINARY_LABELS, base_model=binary_model,
    inverse_transform=undo_polar_transform if POLAR_TRANSFORM else None,
    inter_process_fn=interprocess, post_process_fn=postprocess, tta=False,
)

In [None]:
plot_results_from_loader(
    ARCHITECTURE, val_loader, model, DEVICE,
    n_samples=4, save_path=f'{LOGS_DIR}/evaluation.png',
    base_model=binary_model, binary_labels=BINARY_LABELS,
    inverse_transform=undo_polar_transform if POLAR_TRANSFORM else None,
    inter_process_fn=interprocess, post_process_fn=postprocess, tta=False,
)

In [None]:
if USE_WANDB:
    wandb.finish()