In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from collections import defaultdict
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import os
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

from networks import *
from training import *
from utils import *

In [None]:
IMAGE_DIR = '../data/ORIGA/Images_Cropped'
MASK_DIR = '../data/ORIGA/Masks_Cropped'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

IMAGE_HEIGHT, IMAGE_WIDTH = 128, 128
BATCH_SIZE = 4
PIN_MEMORY = True
NUM_WORKERS = 4

IN_CHANNELS = 3
OUT_CHANNELS = 3

In [None]:
train_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=1.0),
    # A.Lambda(image=occlude),
    # A.Lambda(image=polar_transform, mask=polar_transform),
    # A.Lambda(image=keep_gray_channel),
    # A.Lambda(image=keep_red_channel),
    # A.Lambda(image=keep_green_channel),
    # A.Lambda(image=keep_blue_channel),
    # A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.255]),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    # A.Lambda(image=polar_transform, mask=polar_transform),
    # A.Lambda(image=keep_gray_channel),
    # A.Lambda(image=keep_red_channel),
    # A.Lambda(image=keep_green_channel),
    # A.Lambda(image=keep_blue_channel),
    # A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.255]),
    ToTensorV2(),
])

train_loader, val_loader, test_loader = load_origa(
    IMAGE_DIR, MASK_DIR, 0.7, 0.15, 0.15,
    train_transform, val_transform, val_transform, BATCH_SIZE, PIN_MEMORY, NUM_WORKERS,
)

In [None]:
BACKBONE = 'resnet34'
LR = 1e-4
EPOCHS = 5

model = smp.Unet(
    encoder_name=BACKBONE,
    encoder_weights='imagenet',
    in_channels=IN_CHANNELS,
    classes=OUT_CHANNELS,
    # decoder_attention_type='scse',
).to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
preprocess_input = get_preprocessing_fn(BACKBONE, pretrained='imagenet')

In [None]:
# Freeze encoder weights
for param in model.encoder.parameters():
    param.requires_grad = False

In [None]:
hist = defaultdict(list)
for epoch in range(EPOCHS):
    print(f'Epoch {epoch + 1}/{EPOCHS}:')

    model.train()
    metrics = defaultdict(list)
    for images, masks in train_loader:
        images = images.float().to(DEVICE)
        masks = masks.long().to(DEVICE)

        outputs = model(images)
        loss = criterion(outputs, masks)

        probs = F.softmax(outputs, dim=1)
        preds = torch.argmax(probs, dim=1)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        metrics['loss'].append(loss.item())
        update_metrics(masks, preds, metrics, [[1, 2], [2]])
    update_history(hist, {k: np.mean(v) for k, v in metrics.items()}, prefix='train')
    print(f'''Training:
    Loss: {hist['train_loss'][-1]:.4f}
    Accuracy: {hist['train_accuracy_OD'][-1]:.4f} (Disc), {hist['train_accuracy_OC'][-1]:.4f} (Cup)
    Dice: {hist['train_dice_OD'][-1]:.4f} (Disc), {hist['train_dice_OC'][-1]:.4f} (Cup)
    ''')

    model.eval()
    metrics = defaultdict(list)
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.float().to(DEVICE)
            masks = masks.long().to(DEVICE)

            outputs = model(images)
            loss = criterion(outputs, masks)

            probs = F.softmax(outputs, dim=1)
            preds = torch.argmax(probs, dim=1)

            metrics['loss'].append(loss.item())
            update_metrics(masks, preds, metrics, [[1, 2], [2]])
        update_history(hist, {k: np.mean(v) for k, v in metrics.items()}, prefix='val')
    print(f'''Validation:
    Loss: {hist['val_loss'][-1]:.4f}
    Accuracy: {hist['val_accuracy_OD'][-1]:.4f} (Disc), {hist['val_accuracy_OC'][-1]:.4f} (Cup)
    Dice: {hist['val_dice_OD'][-1]:.4f} (Disc), {hist['val_dice_OC'][-1]:.4f} (Cup)
    ''')


In [None]:
plot_history(hist)

In [None]:
# Plot some predictions
model.eval()
images, masks = next(iter(test_loader))
images = images.float().to(DEVICE)
masks = masks.long().to(DEVICE)

outputs = model(images)
probs = F.softmax(outputs, dim=1)
preds = torch.argmax(probs, dim=1)

images = images.detach().cpu().numpy().transpose(0, 2, 3, 1) / 255
masks = masks.detach().cpu().numpy()
preds = preds.detach().cpu().numpy()

_, ax = plt.subplots(3, 3, figsize=(15, 15))
for i in range(3):
    ax[i, 0].imshow(images[i])
    ax[i, 1].imshow(masks[i])
    ax[i, 2].imshow(preds[i])