In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2 as cv
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim

from networks import *
from training import *
from modules import *

In [None]:
IMAGE_DIR = '../data/ORIGA/Images_CenterNet_Cropped'
MASK_DIR = '../data/ORIGA/Masks_CenterNet_Cropped'
LOGS_DIR = '../logs/'
CHECKPOINT_DIR = '../checkpoints/'

NETWORK_NAME = 'refunet3+cbam'  # raunet++, refunet3+cbam, swinunet
OPTIMIZER = 'adam'
LOSS_FUNCTION = 'combo'
ARCHITECTURE = 'binary'  # multiclass, multilabel, binary, dual, cascade
BINARY_TARGET_CLASSES = [1, 2]
SCHEDULER = 'plateau'
SCALER = 'none'
DATASET = 'ORIGA'
BASE_CASCADE_MODEL = ''

IMAGE_HEIGHT, IMAGE_WIDTH = 128, 128
IN_CHANNELS, OUT_CHANNELS = 3, 1
LEARNING_RATE = 1e-4
BATCH_SIZE = 4
EPOCHS = 5
LAYERS = [16, 32, 48, 64, 80]
CLASS_WEIGHTS = None
SET_SIZES = [0.8, 0.1, 0.1]
DROPOUT_2D = 0.2
EARLY_STOPPING_PATIENCE = 11
LOG_INTERVAL = 5
SAVE_INTERVAL = 10
NUM_WORKERS = 0

USE_WANDB = False
POLAR_TRANSFORM = True
MULTI_SCALE_INPUT = False
DEEP_SUPERVISION = False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
PIN_MEMORY = True if DEVICE == 'cuda' else False

polar_transform_partial = partial(polar_transform, radius_ratio=1.0)

train_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, interpolation=cv.INTER_AREA),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=1.0),
    A.CLAHE(p=1.0, clip_limit=2.0, tile_grid_size=(8, 8), always_apply=True),
    A.RandomBrightnessContrast(p=0.5),
    A.Lambda(image=polar_transform_partial, mask=polar_transform_partial) if POLAR_TRANSFORM else A.Lambda(),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, interpolation=cv.INTER_AREA),
    A.CLAHE(p=1.0, clip_limit=2.0, tile_grid_size=(8, 8), always_apply=True),
    A.Lambda(image=polar_transform_partial, mask=polar_transform_partial) if POLAR_TRANSFORM else A.Lambda(),
    ToTensorV2(),
])

train_loader, val_loader, test_loader = load_dataset(
    IMAGE_DIR, MASK_DIR, None, None, *SET_SIZES,
    train_transform, val_transform, val_transform, BATCH_SIZE, PIN_MEMORY, NUM_WORKERS,
)

In [None]:
model = RefUnet3PlusCBAM(3, 1, LAYERS).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = DiceLoss(1)

trainer = MultilabelTrainer(model, criterion, optimizer, DEVICE)
trainer.train_one_epoch(train_loader)

## Apply CRF post-processing

In [None]:
images, masks = next(iter(test_loader))
images = images.float().to(DEVICE)
masks = masks.long().to(DEVICE)

outputs = model(images)
probs = torch.softmax(outputs, dim=1)
preds = torch.argmax(probs, dim=1)

images = images.detach().cpu().numpy().transpose(0, 2, 3, 1) / 255
masks = masks.detach().cpu().numpy()
preds = preds.detach().cpu().numpy()
probs = probs.detach().cpu().numpy()

idx = 0
image, mask, pred, prob = images[idx], masks[idx], preds[idx], probs[idx]

print(f'{image.shape = }')
print(f'{mask.shape = }')
print(f'{pred.shape = }')
print(f'{prob.shape = }')

pred_crf = dense_crf(image, prob)

m1 = get_metrics(mask, pred, [[1, 2], [2]])
m2 = get_metrics(mask, pred_crf, [[1, 2], [2]])

print('Before CRF:', m1['dice_OC'], m1['dice_OD'])
print('After CRF:', m2['dice_OC'], m2['dice_OD'])
print('Improvement:', m2['dice_OC'] - m1['dice_OC'], m2['dice_OD'] - m1['dice_OD'])

_, ax = plt.subplots(1, 4, figsize=(12, 4))
ax[0].imshow(image)
ax[1].imshow(mask)
ax[2].imshow(pred)
ax[3].imshow(pred_crf)
plt.tight_layout()
plt.show()

## Work in progress

In [None]:
images, masks = next(iter(test_loader))
images = images.float().to(DEVICE)
masks = masks.long().to(DEVICE)

outputs = model(images)
probs = torch.softmax(outputs, dim=1)
preds = torch.argmax(probs, dim=1)

met1 = get_metrics(masks, preds, [[1, 2], [2]])

images, masks, preds = undo_polar_transform(images, masks, preds)

met2 = get_metrics(masks, preds, [[1, 2], [2]])

images = images.detach().cpu().numpy().transpose(0, 2, 3, 1) / 255
masks = masks.detach().cpu().numpy()
preds = preds.detach().cpu().numpy()

fig, ax = plt.subplots(4, 3, figsize=(8, 12))
ax = ax.ravel()
for i in range(4):
    ax[3 * i].imshow(images[i])
    ax[3 * i + 1].imshow(masks[i])
    ax[3 * i + 2].imshow(preds[i])
plt.tight_layout()
plt.show()

{k: (met1[k], met2[k]) for k in met1.keys()}

In [None]:
# torch.save(model.state_dict(), CHECKPOINT_DIR + 'model.pth')

checkpoint = torch.load(CHECKPOINT_DIR + 'multiclass-unet-model.pth')
model = RefUnet3PlusCBAM(3, 3, LAYERS).to(DEVICE)
# model = DualUnet(3, 1, LAYERS).to(DEVICE)
criterion = DiceLoss(3)
model.load_state_dict(checkpoint)

# binary_model = Unet(3, 1, LAYERS).to(DEVICE)
# checkpoint = torch.load(CHECKPOINT_DIR + 'binary-model.pth')
# binary_model.load_state_dict(checkpoint)

In [None]:
# Non-ellipse: 5, 12
# Holes: 0, 20

target_batch = 12

for batch_idx, (images, masks) in enumerate(test_loader):
    if batch_idx != target_batch:
        continue
    images = images.float().to(DEVICE)
    masks = masks.long().to(DEVICE)

    outputs = model(images)
    probs = torch.softmax(outputs, dim=1)
    preds = torch.argmax(probs, dim=1)

    met = get_metrics(masks, preds, [[1, 2], [2]])

    images = images.detach().cpu().numpy().transpose(0, 2, 3, 1) / 255
    masks = masks.detach().cpu().numpy()
    preds = preds.detach().cpu().numpy()

    print(f'Batch {batch_idx} metrics: {met}')

    # Plot results
    fig, ax = plt.subplots(4, 3, figsize=(8, 12))
    ax = ax.ravel()
    for i in range(4):
        ax[3 * i].imshow(images[i])
        ax[3 * i + 1].imshow(masks[i])
        ax[3 * i + 2].imshow(preds[i])
    plt.tight_layout()
    plt.show()

    break

In [None]:
idx = 1

image = images[idx]
mask = masks[idx]
pred = preds[idx]

mask_od, mask_oc = separate_disc_and_cup_mask(mask)
pred_od, pred_oc = separate_disc_and_cup_mask(pred)

# plot image, OD mask and OC mask
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax = ax.ravel()
ax[0].imshow(image)
ax[1].imshow(pred_od)
ax[2].imshow(pred_oc)
plt.tight_layout()
plt.show()

get_metrics(mask_oc, pred_oc, [[1, 2]])

In [None]:
mask1 = apply_largest_component_selection(pred)
mask2 = apply_hole_filling(pred)
mask3 = apply_largest_component_selection(apply_hole_filling(pred))
mask4 = apply_hole_filling(apply_largest_component_selection(pred))
mask5 = apply_ellipse_fitting(pred)

# Plot results
_, ax = plt.subplots(2, 4, figsize=(12, 7))
ax = ax.ravel()
for a in ax:
    a.axis('off')

ax[0].set_title('Image')
ax[0].imshow(image)
ax[1].set_title('Ground truth')
ax[1].imshow(mask)
ax[2].set_title('Prediction')
ax[2].imshow(pred)
ax[3].set_title('Largest component')
ax[3].imshow(mask1)
ax[4].set_title('Hole filling')
ax[4].imshow(mask2)
ax[5].set_title('Largest + hole')
ax[5].imshow(mask3)
ax[6].set_title('Hole + largest')
ax[6].imshow(mask4)
ax[7].set_title('Ellipse fitting')
ax[7].imshow(mask5)

plt.tight_layout()
plt.show()

In [None]:
smoothed_od_mask = smooth_contours(pred_od)
smoothed_oc_mask = smooth_contours(pred_oc)
smoothed_mask = smoothed_od_mask + smoothed_oc_mask

# Plot results
_, ax = plt.subplots(1, 4, figsize=(12, 5))
ax[0].imshow(image)
ax[1].imshow(pred)
ax[2].imshow(smoothed_od_mask)
ax[3].imshow(smoothed_oc_mask)
plt.tight_layout()
plt.show()


def on_trackbar(s):
    s = s / 10
    smoothed_mask = smooth_contours(pred_oc, s)

    m = get_metrics(mask_oc, smoothed_mask, [[1, 2]])
    print(m['dice_OD'])

    # Display the original and smoothed masks side by side
    side_by_side = np.hstack((mask_oc * 255, pred_oc * 255, smoothed_mask * 255))
    cv.imshow('Contour Smoothing', side_by_side)


cv.namedWindow('Contour Smoothing')
cv.createTrackbar('S Parameter', 'Contour Smoothing', 5, 100, on_trackbar)
initial_s = 1.0
on_trackbar(initial_s)

cv.waitKey(0)
cv.destroyAllWindows()

In [None]:
snake_img = snakes(pred_oc)

# Plot results
fig, ax = plt.subplots(1, 4, figsize=(12, 5))
ax[0].imshow(image)
ax[1].imshow(mask_oc)
ax[2].imshow(pred_oc)
ax[3].imshow(snake_img)
plt.tight_layout()
plt.show()


def on_trackbar(a, b, c):
    a = a / 10
    b = b / 10
    c = c / 10

    snake_mask = snakes(pred_oc, a, b, c)

    m = get_metrics(mask_oc, snake_mask, [[1, 2]])
    print(m['dice_OD'])

    # Display the original and smoothed masks side by side
    side_by_side = np.hstack((mask_oc * 255, pred_oc * 255, snake_mask * 255))
    cv.imshow('Snake', side_by_side)


def on_trackbar1(a):
    global initial_alpha
    initial_alpha = a
    on_trackbar(initial_alpha, initial_beta, initial_gamma)


def on_trackbar2(b):
    global initial_beta
    initial_beta = b
    on_trackbar(initial_alpha, initial_beta, initial_gamma)


def on_trackbar3(c):
    global initial_gamma
    initial_gamma = c
    on_trackbar(initial_alpha, initial_beta, initial_gamma)


initial_alpha = 0.1
initial_beta = 2.0
initial_gamma = 5.0

cv.namedWindow('Snake')
cv.createTrackbar('alpha', 'Snake', 1, 50, on_trackbar1)
cv.createTrackbar('beta', 'Snake', 20, 50, on_trackbar2)
cv.createTrackbar('gamma', 'Snake', 50, 100, on_trackbar3)
on_trackbar(initial_alpha, initial_beta, initial_gamma)

cv.waitKey(0)
cv.destroyAllWindows()

## Contour detection method

In [None]:
from skimage import segmentation
from scipy.ndimage import distance_transform_edt

images, masks = next(iter(val_loader))
images = images.float().to(DEVICE)
masks = masks.long().to(DEVICE)

image = images[0].cpu().numpy().transpose(1, 2, 0) / 255.0
mask = masks[0].cpu().numpy()
prediction = np.zeros_like(mask)
prediction[16:112, 16:112] = 1
prediction[32:96, 32:96] = 2

boundaries = segmentation.find_boundaries(mask, mode='inner').astype(np.uint8)
marked = image.copy()
marked[boundaries == 1] = [0, 0, 0]
dist_map = distance_transform_edt(1 - boundaries)
dist_map = dist_map / dist_map.max()

_, ax = plt.subplots(2, 4, figsize=(15, 7))
ax[0, 0].imshow(mask)
ax[0, 1].imshow(boundaries)
ax[0, 2].imshow(dist_map)
ax[0, 3].imshow(dist_map.max() - dist_map)
ax[1, 0].imshow(prediction)
ax[1, 1].imshow(marked)
plt.show()

In [None]:
model = RefUnet3PlusCBAM(in_channels=3, out_channels=1).to(DEVICE)
loss = DiceLoss(num_classes=1)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(5):
    acc_loss = 0
    for images, masks in val_loader:

        edge_masks = np.zeros((masks.shape[0], masks.shape[1], masks.shape[2]))
        for b in range(masks.shape[0]):
            mask = masks[b].cpu().numpy()
            boundaries = segmentation.find_boundaries(mask, mode='thick').astype(np.uint8)
            edge_masks[b] = boundaries
        masks = torch.from_numpy(edge_masks)

        images = images.float().to(DEVICE)
        masks = masks.long().to(DEVICE)

        outputs = model(images)
        loss_value = loss(outputs, masks)
        acc_loss += loss_value.item()

        optimizer.zero_grad()
        loss_value.backward()
        optimizer.step()

    # plot example
    images = images.cpu().numpy()
    masks = masks.cpu().numpy()
    # probs = F.softmax(outputs, dim=1)
    # preds = torch.argmax(probs, dim=1).cpu().numpy()
    probs = torch.sigmoid(outputs)
    preds = (probs > 0.5).float().cpu().numpy().transpose(0, 2, 3, 1)

    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(images[0].transpose(1, 2, 0) / 255.0)
    ax[1].imshow(masks[0])
    ax[2].imshow(preds[0])
    plt.show()

    print(f'Epoch {epoch + 1} loss:', acc_loss / len(val_loader))
