# Glaucoma Segmentation


## Imports

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2 as cv
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from networks import *
from training import *
from utils import *

## Setup

In [None]:
IMAGE_DIR = '../data/ORIGA/Images_CenterNet_Cropped'
MASK_DIR = '../data/ORIGA/Masks_CenterNet_Cropped'
LOGS_DIR = '../logs/'
CHECKPOINT_DIR = '../checkpoints/'

NETWORK_NAME = 'refunet3+cbam'  # unet, raunet++, refunet3+cbam, swinunet
OPTIMIZER = 'adam'
LOSS_FUNCTION = 'combo'
ARCHITECTURE = 'binary'  # multiclass, multilabel, binary, dual, cascade
BINARY_TARGET_CLASSES = [1, 2]
SCHEDULER = 'plateau'
SCALER = 'none'
DATASET = 'ORIGA'
BASE_CASCADE_MODEL = ''

IMAGE_HEIGHT, IMAGE_WIDTH = 128, 128
IN_CHANNELS, OUT_CHANNELS = 3, 1
LEARNING_RATE = 1e-4
BATCH_SIZE = 4
EPOCHS = 5
LAYERS = [16, 32, 48, 64, 80]
CLASS_WEIGHTS = None
SET_SIZES = [0.8, 0.1, 0.1]
DROPOUT_2D = 0.2
EARLY_STOPPING_PATIENCE = 11
LOG_INTERVAL = 5
SAVE_INTERVAL = 10
NUM_WORKERS = 0

USE_WANDB = False
POLAR_TRANSFORM = True
MULTI_SCALE_INPUT = False
DEEP_SUPERVISION = False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
PIN_MEMORY = True if DEVICE == 'cuda' else False

os.makedirs(LOGS_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f'''CONFIGURATION:
    PyTorch version: {torch.__version__}
    NumPy version: {np.__version__}
    OpenCV version: {cv.__version__}

    Network: {NETWORK_NAME}
    Architecture: {ARCHITECTURE}
    Optimizer: {OPTIMIZER}
    Loss function: {LOSS_FUNCTION}
    Scheduler: {SCHEDULER}
    Scaler: {SCALER}
    Using device: {DEVICE}

    Dataset: {DATASET}
    Dataset proportions: {SET_SIZES}
    Image directory: {IMAGE_DIR}
    Mask directory: {MASK_DIR}
    Input image height & width: {IMAGE_HEIGHT}x{IMAGE_WIDTH}
    Number of input channels: {IN_CHANNELS}
    Number of output channels: {OUT_CHANNELS}

    Layers: {LAYERS}
    Batch size: {BATCH_SIZE}
    Learning rate: {LEARNING_RATE}
    Epochs: {EPOCHS}
    Class weights: {CLASS_WEIGHTS}
    Dropout: {DROPOUT_2D}
    Early stopping patience: {EARLY_STOPPING_PATIENCE}

    Save interval: {SAVE_INTERVAL}
    Log interval: {LOG_INTERVAL}
    Number of workers: {NUM_WORKERS}
    Pin memory: {PIN_MEMORY}

    Weight & Biases: {USE_WANDB}
    Polar transform: {POLAR_TRANSFORM}
    Multi-scale input: {MULTI_SCALE_INPUT}
    Deep supervision: {DEEP_SUPERVISION}''')

## Dataset

In [None]:
polar_transform_partial = partial(polar_transform, radius_ratio=1.0)

train_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, interpolation=cv.INTER_AREA),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=1.0),
    A.CLAHE(p=1.0, clip_limit=2.0, tile_grid_size=(8, 8), always_apply=True),
    A.RandomBrightnessContrast(p=0.5),
    # A.GridDistortion(p=0.5, border_mode=cv.BORDER_CONSTANT),
    # A.MedianBlur(p=0.5),
    # A.RandomToneCurve(p=0.5),
    # A.MultiplicativeNoise(p=0.5),
    # A.Lambda(image=sharpen, p=1.0),
    A.Lambda(image=polar_transform_partial, mask=polar_transform_partial) if POLAR_TRANSFORM else A.Lambda(),
    # A.Lambda(image=keep_gray_channel),
    # A.Lambda(image=keep_red_channel),
    # A.Lambda(image=keep_green_channel),
    # A.Lambda(image=keep_blue_channel),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, interpolation=cv.INTER_AREA),
    A.CLAHE(p=1.0, clip_limit=2.0, tile_grid_size=(8, 8), always_apply=True),
    A.Lambda(image=polar_transform_partial, mask=polar_transform_partial) if POLAR_TRANSFORM else A.Lambda(),
    # A.Lambda(image=keep_gray_channel),
    # A.Lambda(image=keep_red_channel),
    # A.Lambda(image=keep_green_channel),
    # A.Lambda(image=keep_blue_channel),
    ToTensorV2(),
])

train_loader, val_loader, test_loader = load_dataset(
    IMAGE_DIR, MASK_DIR, None, None, *SET_SIZES,
    train_transform, val_transform, val_transform, BATCH_SIZE, PIN_MEMORY, NUM_WORKERS,
)

In [None]:
images, masks = next(iter(train_loader))
for image, mask in zip(images, masks):
    image = image.permute(1, 2, 0).numpy()
    mask = mask.numpy()
    _, ax = plt.subplots(1, 2, figsize=(8, 4))
    ax[0].imshow(image)
    ax[1].imshow(mask)
    plt.show()
    break

## Model

In [None]:
model = None
binary_model = None
hist = None

if NETWORK_NAME == 'unet':
    model = Unet(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, features=LAYERS,
        multi_scale_input=MULTI_SCALE_INPUT,
    )

if NETWORK_NAME == 'raunet++':
    model = RAUnetPlusPlus(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, features=LAYERS,
        multi_scale_input=MULTI_SCALE_INPUT, deep_supervision=DEEP_SUPERVISION, dropout=DROPOUT_2D,
    )

if NETWORK_NAME == 'refunet3+cbam':
    model = RefUnet3PlusCBAM(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, features=LAYERS,
        multi_scale_input=MULTI_SCALE_INPUT, dropout=DROPOUT_2D,
    )

if NETWORK_NAME == 'swinunet':
    model = SwinUnet(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, img_size=224, patch_size=4,
    )

if NETWORK_NAME == 'dual-unet':
    model = DualUnet(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, features=LAYERS,
        multi_scale_input=MULTI_SCALE_INPUT,
    )

if NETWORK_NAME == 'dual-raunet++':
    model = DualRAUnetPlusPlus(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, features=LAYERS,
        multi_scale_input=MULTI_SCALE_INPUT, deep_supervision=DEEP_SUPERVISION, dropout=DROPOUT_2D,
    )

if NETWORK_NAME == 'dual-refunet3+cbam':
    model = DualRefUnet3PlusCBAM(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, features=LAYERS,
        multi_scale_input=MULTI_SCALE_INPUT, dropout=DROPOUT_2D,
    )

if NETWORK_NAME == 'dual-swinunet':
    model = DualSwinUnet(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, img_size=224, patch_size=4,
    )

assert model is not None, 'Invalid network name'

model = model.to(DEVICE)
init_model_weights(model)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = ComboLoss(num_classes=OUT_CHANNELS if ARCHITECTURE == 'multiclass' else 1, class_weights=CLASS_WEIGHTS)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5, verbose=True)
scaler = None

if BASE_CASCADE_MODEL:
    checkpoint = load_checkpoint(BASE_CASCADE_MODEL)
    binary_model = checkpoint['model']
    print('Loaded base model')

In [None]:
# WEIGHT_DECAY = 1e-4
# 
# model = SwinUnet(in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, img_size=224, patch_size=4).to(DEVICE)
# optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
# criterion = nn.CrossEntropyLoss()
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
#     optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1, verbose=True)

In [None]:
# hist = train_multiclass(
#     model, criterion, optimizer, EPOCHS, DEVICE, train_loader, val_loader, scheduler,
#     save_interval=SAVE_INTERVAL, early_stopping_patience=EARLY_STOPPING_PATIENCE,
#     log_to_wandb=USE_WANDB, log_dir=LOGS_DIR, log_interval=LOG_INTERVAL, checkpoint_dir=CHECKPOINT_DIR,
#     save_best_model=True, plot_examples='all', show_plots=True,
# )
# plot_history(hist)

## Training

In [None]:
if ARCHITECTURE == 'multiclass':
    hist = train_multiclass(
        model, criterion, optimizer, EPOCHS, DEVICE, train_loader, val_loader, scheduler, scaler,
        save_interval=SAVE_INTERVAL, early_stopping_patience=EARLY_STOPPING_PATIENCE,
        log_to_wandb=USE_WANDB, log_dir=LOGS_DIR, log_interval=LOG_INTERVAL, checkpoint_dir=CHECKPOINT_DIR,
        save_best_model=True, plot_examples='none', show_plots=False,
        inverse_transform=undo_polar_transform if POLAR_TRANSFORM else None,
    )

if ARCHITECTURE == 'multilabel':
    hist = train_multilabel(
        model, criterion, optimizer, EPOCHS, DEVICE, train_loader, val_loader, scheduler, scaler,
        save_interval=SAVE_INTERVAL, early_stopping_patience=EARLY_STOPPING_PATIENCE,
        log_to_wandb=USE_WANDB, log_dir=LOGS_DIR, log_interval=LOG_INTERVAL, checkpoint_dir=CHECKPOINT_DIR,
        save_best_model=True, plot_examples='none', show_plots=False,
        inverse_transform=undo_polar_transform if POLAR_TRANSFORM else None,
    )

if ARCHITECTURE == 'binary':
    hist = train_binary(
        model, criterion, optimizer, EPOCHS, DEVICE, train_loader, val_loader, scheduler, scaler,
        save_interval=SAVE_INTERVAL, early_stopping_patience=EARLY_STOPPING_PATIENCE,
        log_to_wandb=USE_WANDB, log_dir=LOGS_DIR, log_interval=LOG_INTERVAL, checkpoint_dir=CHECKPOINT_DIR,
        save_best_model=True, plot_examples='none', show_plots=False, target_ids=BINARY_TARGET_CLASSES,
        inverse_transform=undo_polar_transform if POLAR_TRANSFORM else None,
    )

if ARCHITECTURE == 'cascade':
    assert binary_model is not None, 'Base model not specified'
    hist = train_cascade(
        binary_model, model, criterion, optimizer, EPOCHS, DEVICE, train_loader, val_loader, scheduler, scaler,
        save_interval=SAVE_INTERVAL, early_stopping_patience=EARLY_STOPPING_PATIENCE,
        log_to_wandb=USE_WANDB, log_dir=LOGS_DIR, log_interval=LOG_INTERVAL, checkpoint_dir=CHECKPOINT_DIR,
        save_best_model=True, plot_examples='none', show_plots=False,
        inverse_transform=undo_polar_transform if POLAR_TRANSFORM else None,
    )

if ARCHITECTURE == 'dual':
    hist = train_dual(
        model, criterion, criterion, optimizer, EPOCHS, DEVICE, train_loader, val_loader, scheduler, scaler,
        save_interval=SAVE_INTERVAL, early_stopping_patience=EARLY_STOPPING_PATIENCE,
        log_to_wandb=USE_WANDB, log_dir=LOGS_DIR, log_interval=LOG_INTERVAL, checkpoint_dir=CHECKPOINT_DIR,
        save_best_model=True, plot_examples='none', show_plots=False,
        inverse_transform=undo_polar_transform if POLAR_TRANSFORM else None,
    )


In [None]:
plot_history(hist)

## Testing

In [None]:
if ARCHITECTURE == 'multiclass':
    results = evaluate('multiclass', model, test_loader, criterion, DEVICE)

if ARCHITECTURE == 'multilabel':
    results = evaluate('multilabel', model, test_loader, criterion, DEVICE)

if ARCHITECTURE == 'binary':
    results = evaluate('binary', model, test_loader, criterion, DEVICE, class_ids=BINARY_TARGET_CLASSES)

if ARCHITECTURE == 'dual':
    results = evaluate('dual', model, test_loader, criterion, DEVICE)

if ARCHITECTURE == 'cascade':
    results = evaluate('cascade', model, test_loader, criterion, DEVICE, model0=binary_model)

In [None]:
if ARCHITECTURE == 'multiclass':
    plot_results_from_loader(
        'multiclass', test_loader, model, DEVICE, n_samples=4,
        save_path=f'{LOGS_DIR}/evaluation.png',
    )

if ARCHITECTURE == 'multilabel':
    plot_results_from_loader(
        'multilabel', test_loader, model, DEVICE, n_samples=4,
        save_path=f'{LOGS_DIR}/evaluation.png',
    )

if ARCHITECTURE == 'binary':
    plot_results_from_loader(
        'binary', test_loader, model, DEVICE, n_samples=4,
        save_path=f'{LOGS_DIR}/evaluation.png', class_ids=BINARY_TARGET_CLASSES,
    )

if ARCHITECTURE == 'dual':
    plot_results_from_loader(
        'dual', test_loader, model, DEVICE,
        n_samples=4, save_path=f'{LOGS_DIR}/evaluation.png',
    )

if ARCHITECTURE == 'cascade':
    plot_results_from_loader(
        'cascade', test_loader, model, DEVICE, n_samples=4,
        save_path=f'{LOGS_DIR}/evaluation.png', model0=binary_model,
    )

## Apply CRF post-processing

In [None]:
images, masks = next(iter(test_loader))
images = images.float().to(DEVICE)
masks = masks.long().to(DEVICE)

outputs = model(images)
probs = torch.softmax(outputs, dim=1)
preds = torch.argmax(probs, dim=1)

images = images.detach().cpu().numpy().transpose(0, 2, 3, 1) / 255
masks = masks.detach().cpu().numpy()
preds = preds.detach().cpu().numpy()
probs = probs.detach().cpu().numpy()

idx = 0
image, mask, pred, prob = images[idx], masks[idx], preds[idx], probs[idx]

print(f'{image.shape = }')
print(f'{mask.shape = }')
print(f'{pred.shape = }')
print(f'{prob.shape = }')

pred_crf = dense_crf(image, prob)

m1 = get_metrics(mask, pred, [[1, 2], [2]])
m2 = get_metrics(mask, pred_crf, [[1, 2], [2]])

print('Before CRF:', m1['dice_OC'], m1['dice_OD'])
print('After CRF:', m2['dice_OC'], m2['dice_OD'])
print('Improvement:', m2['dice_OC'] - m1['dice_OC'], m2['dice_OD'] - m1['dice_OD'])

_, ax = plt.subplots(1, 4, figsize=(12, 4))
ax[0].imshow(image)
ax[1].imshow(mask)
ax[2].imshow(pred)
ax[3].imshow(pred_crf)
plt.tight_layout()
plt.show()

In [None]:
model = Unet(3, 3, LAYERS).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = DiceLoss(1)

trainer = MultilabelTrainer(model, criterion, optimizer, DEVICE)

trainer.train_one_epoch(train_loader)

## Work in progress

In [None]:
images, masks = next(iter(test_loader))
images = images.float().to(DEVICE)
masks = masks.long().to(DEVICE)

outputs = model(images)
probs = torch.softmax(outputs, dim=1)
preds = torch.argmax(probs, dim=1)

met1 = get_metrics(masks, preds, [[1, 2], [2]])

images, masks, preds = undo_polar_transform(images, masks, preds)

met2 = get_metrics(masks, preds, [[1, 2], [2]])

images = images.detach().cpu().numpy().transpose(0, 2, 3, 1) / 255
masks = masks.detach().cpu().numpy()
preds = preds.detach().cpu().numpy()

fig, ax = plt.subplots(4, 3, figsize=(8, 12))
ax = ax.ravel()
for i in range(4):
    ax[3 * i].imshow(images[i])
    ax[3 * i + 1].imshow(masks[i])
    ax[3 * i + 2].imshow(preds[i])
plt.tight_layout()
plt.show()

{k: (met1[k], met2[k]) for k in met1.keys()}

In [None]:
# torch.save(model.state_dict(), CHECKPOINT_DIR + 'model.pth')

checkpoint = torch.load(CHECKPOINT_DIR + 'multiclass-unet-model.pth')
model = Unet(3, 3, LAYERS).to(DEVICE)
# model = DualUnet(3, 1, LAYERS).to(DEVICE)
criterion = DiceLoss(3)
model.load_state_dict(checkpoint)

# binary_model = Unet(3, 1, LAYERS).to(DEVICE)
# checkpoint = torch.load(CHECKPOINT_DIR + 'binary-model.pth')
# binary_model.load_state_dict(checkpoint)

In [None]:
# Non-ellipse: 5, 12
# Holes: 0, 20

target_batch = 12

for batch_idx, (images, masks) in enumerate(test_loader):
    if batch_idx != target_batch:
        continue
    images = images.float().to(DEVICE)
    masks = masks.long().to(DEVICE)

    outputs = model(images)
    probs = torch.softmax(outputs, dim=1)
    preds = torch.argmax(probs, dim=1)

    met = get_metrics(masks, preds, [[1, 2], [2]])

    images = images.detach().cpu().numpy().transpose(0, 2, 3, 1) / 255
    masks = masks.detach().cpu().numpy()
    preds = preds.detach().cpu().numpy()

    print(f'Batch {batch_idx} metrics: {met}')

    # Plot results
    fig, ax = plt.subplots(4, 3, figsize=(8, 12))
    ax = ax.ravel()
    for i in range(4):
        ax[3 * i].imshow(images[i])
        ax[3 * i + 1].imshow(masks[i])
        ax[3 * i + 2].imshow(preds[i])
    plt.tight_layout()
    plt.show()

    break

In [None]:
idx = 1

image = images[idx]
mask = masks[idx]
pred = preds[idx]

mask_od, mask_oc = separate_disc_and_cup_mask(mask)
pred_od, pred_oc = separate_disc_and_cup_mask(pred)

# plot image, OD mask and OC mask
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax = ax.ravel()
ax[0].imshow(image)
ax[1].imshow(pred_od)
ax[2].imshow(pred_oc)
plt.tight_layout()
plt.show()

get_metrics(mask_oc, pred_oc, [[1, 2]])

In [None]:
mask1 = apply_largest_component_selection(pred)
mask2 = apply_hole_filling(pred)
mask3 = apply_largest_component_selection(apply_hole_filling(pred))
mask4 = apply_hole_filling(apply_largest_component_selection(pred))
mask5 = apply_ellipse_fitting(pred)

# Plot results
_, ax = plt.subplots(2, 4, figsize=(12, 7))
ax = ax.ravel()
for a in ax:
    a.axis('off')

ax[0].set_title('Image')
ax[0].imshow(image)
ax[1].set_title('Ground truth')
ax[1].imshow(mask)
ax[2].set_title('Prediction')
ax[2].imshow(pred)
ax[3].set_title('Largest component')
ax[3].imshow(mask1)
ax[4].set_title('Hole filling')
ax[4].imshow(mask2)
ax[5].set_title('Largest + hole')
ax[5].imshow(mask3)
ax[6].set_title('Hole + largest')
ax[6].imshow(mask4)
ax[7].set_title('Ellipse fitting')
ax[7].imshow(mask5)

plt.tight_layout()
plt.show()

In [None]:
smoothed_od_mask = smooth_contours(pred_od)
smoothed_oc_mask = smooth_contours(pred_oc)
smoothed_mask = smoothed_od_mask + smoothed_oc_mask

# Plot results
_, ax = plt.subplots(1, 4, figsize=(12, 5))
ax[0].imshow(image)
ax[1].imshow(pred)
ax[2].imshow(smoothed_od_mask)
ax[3].imshow(smoothed_oc_mask)
plt.tight_layout()
plt.show()


def on_trackbar(s):
    s = s / 10
    smoothed_mask = smooth_contours(pred_oc, s)

    m = get_metrics(mask_oc, smoothed_mask, [[1, 2]])
    print(m['dice_OD'])

    # Display the original and smoothed masks side by side
    side_by_side = np.hstack((mask_oc * 255, pred_oc * 255, smoothed_mask * 255))
    cv.imshow('Contour Smoothing', side_by_side)


cv.namedWindow('Contour Smoothing')
cv.createTrackbar('S Parameter', 'Contour Smoothing', 5, 100, on_trackbar)
initial_s = 1.0
on_trackbar(initial_s)

cv.waitKey(0)
cv.destroyAllWindows()

In [None]:
snake_img = snakes(pred_oc)

# Plot results
fig, ax = plt.subplots(1, 4, figsize=(12, 5))
ax[0].imshow(image)
ax[1].imshow(mask_oc)
ax[2].imshow(pred_oc)
ax[3].imshow(snake_img)
plt.tight_layout()
plt.show()


def on_trackbar(a, b, c):
    a = a / 10
    b = b / 10
    c = c / 10

    snake_mask = snakes(pred_oc, a, b, c)

    m = get_metrics(mask_oc, snake_mask, [[1, 2]])
    print(m['dice_OD'])

    # Display the original and smoothed masks side by side
    side_by_side = np.hstack((mask_oc * 255, pred_oc * 255, snake_mask * 255))
    cv.imshow('Snake', side_by_side)


def on_trackbar1(a):
    global initial_alpha
    initial_alpha = a
    on_trackbar(initial_alpha, initial_beta, initial_gamma)


def on_trackbar2(b):
    global initial_beta
    initial_beta = b
    on_trackbar(initial_alpha, initial_beta, initial_gamma)


def on_trackbar3(c):
    global initial_gamma
    initial_gamma = c
    on_trackbar(initial_alpha, initial_beta, initial_gamma)


initial_alpha = 0.1
initial_beta = 2.0
initial_gamma = 5.0

cv.namedWindow('Snake')
cv.createTrackbar('alpha', 'Snake', 1, 50, on_trackbar1)
cv.createTrackbar('beta', 'Snake', 20, 50, on_trackbar2)
cv.createTrackbar('gamma', 'Snake', 50, 100, on_trackbar3)
on_trackbar(initial_alpha, initial_beta, initial_gamma)

cv.waitKey(0)
cv.destroyAllWindows()

## Contour detection method

In [None]:
from skimage import segmentation
from scipy.ndimage import distance_transform_edt

images, masks = next(iter(val_loader))
images = images.float().to(DEVICE)
masks = masks.long().to(DEVICE)

image = images[0].cpu().numpy().transpose(1, 2, 0) / 255.0
mask = masks[0].cpu().numpy()
prediction = np.zeros_like(mask)
prediction[16:112, 16:112] = 1
prediction[32:96, 32:96] = 2

boundaries = segmentation.find_boundaries(mask, mode='inner').astype(np.uint8)
marked = image.copy()
marked[boundaries == 1] = [0, 0, 0]
dist_map = distance_transform_edt(1 - boundaries)
dist_map = dist_map / dist_map.max()

_, ax = plt.subplots(2, 4, figsize=(15, 7))
ax[0, 0].imshow(mask)
ax[0, 1].imshow(boundaries)
ax[0, 2].imshow(dist_map)
ax[0, 3].imshow(dist_map.max() - dist_map)
ax[1, 0].imshow(prediction)
ax[1, 1].imshow(marked)
plt.show()

In [None]:
model = Unet(in_channels=3, out_channels=1).to(DEVICE)
loss = DiceLoss(num_classes=1)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(5):
    acc_loss = 0
    for images, masks in val_loader:

        edge_masks = np.zeros((masks.shape[0], masks.shape[1], masks.shape[2]))
        for b in range(masks.shape[0]):
            mask = masks[b].cpu().numpy()
            boundaries = segmentation.find_boundaries(mask, mode='thick').astype(np.uint8)
            edge_masks[b] = boundaries
        masks = torch.from_numpy(edge_masks)

        images = images.float().to(DEVICE)
        masks = masks.long().to(DEVICE)

        outputs = model(images)
        loss_value = loss(outputs, masks)
        acc_loss += loss_value.item()

        optimizer.zero_grad()
        loss_value.backward()
        optimizer.step()

    # plot example
    images = images.cpu().numpy()
    masks = masks.cpu().numpy()
    # probs = F.softmax(outputs, dim=1)
    # preds = torch.argmax(probs, dim=1).cpu().numpy()
    probs = torch.sigmoid(outputs)
    preds = (probs > 0.5).float().cpu().numpy().transpose(0, 2, 3, 1)

    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(images[0].transpose(1, 2, 0) / 255.0)
    ax[1].imshow(masks[0])
    ax[2].imshow(preds[0])
    plt.show()

    print(f'Epoch {epoch + 1} loss:', acc_loss / len(val_loader))


## Interpretability

In [None]:
model = Unet(in_channels=3, out_channels=3, features=[32, 64, 128, 256, 512]).to(DEVICE)
checkpoint = torch.load(CHECKPOINT_DIR + 'best-multiclass-unet-model.pth')
model.load_state_dict(checkpoint['model'])
model.eval()
model = model.to(DEVICE)

images, masks = next(iter(val_loader))
images = images.float().to(DEVICE)
masks = masks.long().to(DEVICE)
# print(model)

### Guided Backpropagation

In [None]:
class GuidedBackpropagation:
    def __init__(self, model):
        self.model = model
        self.hooks = []
        self.hook_layers()

    def hook_layers(self):
        def relu_hook_function(module, grad_in, grad_out):
            if isinstance(module, torch.nn.ReLU):
                return (torch.clamp(grad_in[0], min=0.),)

        for module in self.model.modules():
            if isinstance(module, nn.ReLU):
                hook = module.register_backward_hook(relu_hook_function)
                self.hooks.append(hook)

    def unhook_layers(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()

    def guided_backward(self, image, class_idx=None, to_numpy=True):
        inputs = image.clone()
        inputs.requires_grad = True

        outputs = self.model(inputs)
        if class_idx is None:
            class_idx = torch.argmax(outputs, dim=1, keepdim=True)

        onehot = torch.zeros_like(outputs)
        onehot[0][class_idx] = 1

        outputs.backward(gradient=onehot)
        gradients = inputs.grad

        # Keep only positive gradients
        gradients = F.relu(gradients)

        # Normalize
        min_grad, max_grad = torch.min(gradients), torch.max(gradients)
        gradients = (gradients - min_grad) / (max_grad - min_grad)
        gradients = gradients.squeeze()

        if to_numpy:
            # Move channel axis to the last dimension
            gradients = gradients.permute(1, 2, 0).cpu().numpy()

        return gradients


input_image = cv.imread(r"C:\Users\ASUS\PycharmProjects\DP-GlaucomaSegmentation\data\ORIGA\Images_Cropped\049.jpg")
input_image = cv.cvtColor(input_image, cv.COLOR_BGR2RGB)
input_image = cv.resize(input_image, (IMAGE_WIDTH, IMAGE_HEIGHT))
input_image = torch.tensor(input_image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)
input_image = input_image.to(DEVICE)

guided_bp = GuidedBackpropagation(model)
guided_grads = guided_bp.guided_backward(input_image, class_idx=2)

print(guided_grads.shape, guided_grads.min(), guided_grads.max())

input_image = input_image.cpu().numpy()[0].transpose(1, 2, 0) / 255.0

_, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(input_image)
ax[1].imshow(guided_grads)
plt.show()

### Grad-CAM

In [None]:
class GradCAM:
    def __init__(self, model, target_layer):
        self.modules = dict(model.named_modules())
        if isinstance(target_layer, str):
            assert target_layer in self.modules.keys(), \
                f'Invalid target layer: {target_layer}, available layers: {self.modules.keys()}'
            self.target_layer = self.modules[target_layer]
        else:
            assert target_layer in self.modules.values(), \
                f'Invalid target layer: {target_layer}, available layers: {self.modules.keys()}'
            self.target_layer = target_layer
        self.model = model
        self.gradients = None
        self.activations = None
        self.hooks = []
        self.register_hooks()

    def register_hooks(self):
        def forward_hook(module, module_input, module_output):
            self.activations = module_output

        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0]

        hook = self.target_layer.register_forward_hook(forward_hook)
        self.hooks.append(hook)
        hook = self.target_layer.register_full_backward_hook(backward_hook)
        self.hooks.append(hook)

    def unregister_hooks(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()

    def __call__(self, inputs, class_idx=None, to_numpy=True):
        self.model.zero_grad()

        # Forward pass
        outputs = self.model(inputs)
        if class_idx is None:
            class_idx = torch.argmax(outputs, dim=1)

        onehot = torch.zeros(outputs.size(), dtype=torch.float32, device=inputs.device)
        onehot[0][class_idx] = 1
        outputs.backward(gradient=onehot)

        # Global average pooling to obtain the pooled gradients
        weights = torch.mean(self.gradients, dim=(0, 2, 3), keepdim=True)
        # Weight the channels by corresponding gradients
        cam = torch.sum(weights * self.activations, dim=1, keepdim=True)
        # Clamp negative values to zero
        cam = F.relu(cam)
        # Normalize to [0, 1]
        cam /= torch.max(cam)
        # Remove batch dimension
        cam = cam.squeeze()

        if to_numpy:
            cam = cam.detach().cpu().numpy()

        return cam

In [None]:
model_layers = dict(model.named_modules())
# print(model_layers.keys())
target_layer = model_layers['decoder.conv2.conv']

image = images[3:4]
mask = masks[3:4]

outputs = model(image)
probs = torch.softmax(outputs, dim=1)
preds = torch.argmax(probs, dim=1)

gradcam = GradCAM(model, target_layer)
heatmap = gradcam(image, class_idx=2)
heatmap = cv.resize(heatmap, (IMAGE_WIDTH, IMAGE_HEIGHT))
heatmap = np.uint8(255 * heatmap / np.max(heatmap))

plt.imshow(heatmap)

image = image.squeeze().detach().cpu().numpy().transpose(1, 2, 0) / 255
mask = mask.squeeze().detach().cpu().numpy()
pred = preds.squeeze().detach().cpu().numpy()

overlay = cv.applyColorMap(heatmap, cv.COLORMAP_JET)
overlay = cv.cvtColor(overlay, cv.COLOR_BGR2RGB) / 255

alpha = 0.5
combined = alpha * overlay + (1 - alpha) * image

_, ax = plt.subplots(1, 5, figsize=(15, 5))
ax[0].imshow(image)
ax[1].imshow(mask)
ax[2].imshow(pred)
ax[3].imshow(overlay)
ax[4].imshow(combined)
plt.show()

gradcam.unregister_hooks()

### Guided Grad-CAM

In [None]:
class GuidedGradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.gradcam = GradCAM(model, target_layer)
        self.guided_bp = GuidedBackpropagation(model)

    def __call__(self, inputs, class_idx=None):
        cam = self.gradcam(inputs, class_idx, to_numpy=True)
        grads = self.guided_bp.guided_backward(inputs, class_idx, to_numpy=True)

        # Resize cam to guided gradients' shape
        cam = cv.resize(cam, (grads.shape[0], grads.shape[1]))
        cam = np.expand_dims(cam, axis=2)

        return grads * cam


model_layers = dict(model.named_modules())
target_layer = model_layers['decoder.conv2.conv']

image = images[3:4]
mask = masks[3:4]

outputs = model(image)
probs = torch.softmax(outputs, dim=1)
preds = torch.argmax(probs, dim=1)

guided_gradcam = GuidedGradCAM(model, target_layer)
guided_gradcam_mask = guided_gradcam(image, class_idx=2)
guided_gradcam_mask = cv.resize(guided_gradcam_mask, (IMAGE_WIDTH, IMAGE_HEIGHT))
guided_gradcam_mask = np.uint8(255 * guided_gradcam_mask / np.max(guided_gradcam_mask))

_, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image.squeeze().detach().cpu().numpy().transpose(1, 2, 0) / 255)
ax[1].imshow(guided_gradcam_mask)
plt.show()