In [None]:
import os
import tqdm
import time
from collections import defaultdict
from sklearn.metrics import cohen_kappa_score

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
import torch
import torch.nn as nn
import torchvision.transforms as T
import torch.nn.functional as F
import torch.optim as optim

import torchio as tio
import torchvision
# from torchvision.models import resnet50

import optuna
from optuna.trial import TrialState

from coral_pytorch.dataset import levels_from_labelbatch
from coral_pytorch.dataset import proba_to_label, corn_label_from_logits

from losses import get_loss
from models import get_model
from datasets import CMRxMOTION2D

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
STUDY_NAME = "EB5_RANDOMAFFINE_CENTRALCROP_CORN_CK_PC"
# MODEL_NAME = "ResNet50"
MODEL_NAME = "EfficientNet_B5"
MODE_3D_2D = '2D'                                ### 3D/2D CONFIG? WILL AFFECT DATALOADER AND NETWORK ARCHITECTURE
PIDS_VAL = [3, 9, 14, 19]                        ### IDENTIFY THE PIDS FOR VALIDATION
PIDS_TR = set(range(1, 21))                      ### THE REST OF PIDS'LL BE USED FOR TRAINING 
for pid in PIDS_VAL:
    PIDS_TR.remove(pid)
PIDS_TR = list(PIDS_TR)
DATA_PATH = r"path_in"                           ### TO BE RECONFIGURED 
DEBUG = False                                    ### SWITCHING TO DEBUG MODE

DEVICE = torch.device("cuda")
BATCH_SIZE = 16
CLASSES = 3
DIR = os.getcwd()
EPOCHS = 10
TRAIN_VAL_RATIO = 0.8

LOSS_NAME = "FocalLoss"
ALPHA = 0.25
GAMMA = 2 #1.5

LR = 3e-4#0.00021
OPTIMIZER_NAME = "Adam"
SCHEDULER_NAME = "CyclicLR"
USE_CORAL = False
USE_CORN = True

if MODE_3D_2D == '2D':
    # EPOCHS = 75
    EPOCHS = 500
if USE_CORAL:
    LOSS_NAME = "CORAL"
if USE_CORN:
    LOSS_NAME = "CORN"


In [None]:
if MODE_3D_2D == "3D":
    # 3D CMRxMOTION class has not been implemented.
    pass
elif MODE_3D_2D == "2D":
    transforms_tr = T.Compose([
        T.RandomAffine(degrees=120, translate=(0.,0.1), scale=(0.85, 1.15)), # translate=(0.,0.1) scale=(0.85, 1.15)
        T.CenterCrop(128), #128
        T.ToTensor(),
        T.ConvertImageDtype(torch.float),
    ])
    transforms_val = T.Compose([
        T.CenterCrop(128),
        T.ToTensor(),
        T.ConvertImageDtype(torch.float),
    ])
    tio_transforms = [
        # tio.RandomMotion(degrees=3, translation=1, num_transforms=1),
        # tio.RescaleIntensity(out_min_max=(0, 1)),
        tio.RescaleIntensity(out_min_max=(0, 1), percentiles=(0.5, 99.5)),
    ]

    train_dataset = CMRxMOTION2D(DATA_PATH, mode='train', pids=PIDS_TR, transforms=transforms_tr, tio_transforms=tio_transforms)
    val_dataset = CMRxMOTION2D(DATA_PATH, mode='val', pids=PIDS_VAL, transforms=transforms_val, tio_transforms=tio_transforms)


In [None]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4
)
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4
)

In [None]:
def show(imgs):
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(16, 16))
    for i, img in enumerate(imgs):
        img = T.ToPILImage()(img.to('cpu'))
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
# Samples from the train and the val splits
for i in train_dataset:
    continue
show([(train_dataset.__getitem__(d))[0] for d in range(4, 8)])
show([(val_dataset.__getitem__(d))[0] for d in range(1, 5)])
[(val_dataset.__getitem__(d))[1] for d in range(1, 5)]
# [(val_dataset.__getitem__(d))[1] for d in range(len(val_dataset))]

In [None]:
def train_loop(model, train_loader, criterion, optimizer, scheduler=None):
    model.train()
    correct = 0
    for batch_idx, (data, target, pid, phase, d, _) in (enumerate(train_loader)):
        # # Limiting training data for faster epochs.
        # if batch_idx * BATCH_SIZE >= N_TRAIN_EXAMPLES:
        #     break

        if USE_CORAL: 
            data = data.to(DEVICE)
            target = target.to(DEVICE)
            levels = levels_from_labelbatch(target, num_classes=CLASSES)
            levels = levels.to(DEVICE)
        else:
            data, target = data.to(DEVICE), target.to(DEVICE)

        optimizer.zero_grad()

        if USE_CORAL:
            logits, probas = model(data)
        else:
            output = model(data)

        if USE_CORAL:
            loss = criterion(logits, levels)
        elif USE_CORN:
            loss = criterion(output, target, CLASSES)
        else:
            loss = criterion(output, target)
        
        if USE_CORAL:
            pred = proba_to_label(probas).unsqueeze(dim=1)
        elif USE_CORN:
            pred = corn_label_from_logits(output).unsqueeze(dim=1)
        else:
            pred = output.argmax(dim=1, keepdim=True)

        correct += pred.eq(target.view_as(pred)).sum().item()

        accuracy = correct / len(train_loader.dataset)
        loss.backward()

        optimizer.step()
        if scheduler:
            scheduler.step()

    return accuracy


def val_loop(model, val_loader):
    model.eval()
    
    correct = 0
    correct_3d = 0

    all_votes = defaultdict(lambda: [0, 0, 0])
    targets = defaultdict(lambda: -1)

    preds = []
    gt = []
    
    with torch.no_grad():
        for batch_idx, (data, target, pid, phase, depth, max_depth) in (enumerate(val_loader)):
            data, target = data.to(DEVICE), target.to(DEVICE)

            if USE_CORAL:
                _, probas = model(data)
            else:
                output = model(data)

            if USE_CORAL:
                pred = proba_to_label(probas).unsqueeze(dim=1)
            elif USE_CORN:
                pred = corn_label_from_logits(output).unsqueeze(dim=1)
            else:
                # Get the index of the max log-probability.
                pred = output.argmax(dim=1, keepdim=True)

            correct += pred.eq(target.view_as(pred)).sum().item()

            # TODO: Max importance for mid, min importance
            # for apex and basal slices for final decision
            for i, (id, ph, d) in enumerate(zip(pid, phase, depth)):
                filename = "-".join([id, ph])
                # all_votes[filename][pred[i]] += 1
                all_votes[filename][pred[i]] += 1 - torch.abs((max_depth[i] // 2 - d) / (max_depth[i] // 2))
                targets[filename] = target[i]

        for fn, vote in all_votes.items():
            final_pred = torch.argmax(torch.Tensor([vote]), dim=1)
            correct_3d += final_pred.eq(targets[fn].view_as(final_pred).sum().item())
            preds.append(final_pred.cpu().item())
            gt.append(targets[fn].item())
            
        accuracy = correct / len(val_loader.dataset)
        accuracy_patient = torch.Tensor(preds).eq(torch.Tensor(gt)).sum().item() / len(gt)

        ck = cohen_kappa_score(preds, gt)

        return accuracy, accuracy_patient, ck

In [None]:
def trainer(train_loader, val_loader, optimizer_name):

    date = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    # date = "saved/2022-07-09-22-32-29"

    # Model Initializer
    model = get_model(MODEL_NAME, CLASSES, use_corn=USE_CORN, use_coral=USE_CORAL).to(DEVICE)
    # model.load_state_dict(torch.load(os.path.join(date, "0.6821.pth")))
    
    if LOSS_NAME == "FocalLoss":
        alpha = ALPHA
        gamma = GAMMA
        loss_args = {"alpha": alpha, "gamma": gamma, "reduction": "mean"}
    else:
        loss_args = {}
    criterion = get_loss(LOSS_NAME, **loss_args)
    optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=LR)

    if SCHEDULER_NAME == "CyclicLR":
        base_lr = LR / 10
        max_lr = LR
        cycle_momentum = False
        scheduler_args = {"base_lr": base_lr, "max_lr": max_lr,
                          "cycle_momentum": cycle_momentum}
    else:
        scheduler_args = None
    if SCHEDULER_NAME:
        scheduler = getattr(optim.lr_scheduler, SCHEDULER_NAME)(optimizer, **scheduler_args)
    else:
        scheduler = None

    # Training and Validation of the model.
    ck_max = 0.
    acc_max = 0.
    for epoch in range(EPOCHS):
        accuracy_tr = train_loop(model, train_loader, criterion, optimizer, scheduler=scheduler)
        accuracy, accuracy_patient, ck = val_loop(model, val_loader)
        print("Train Accuracy at Epoch " + str(epoch) + ": ", str(accuracy_tr))
        print("Val Accuracy at Epoch " + str(epoch) + ": ", str(accuracy))
        print("Val Patientwise Accuracy at Epoch " + str(epoch) + ": ", str(accuracy_patient))
        print("Val Patientwise Cohen's Kappa at Epoch " + str(epoch) + ": ", str(ck))

        if accuracy >= acc_max or ck > ck_max:
            os.makedirs(os.path.join("saved", date), exist_ok=True)
            
            filename = "{:.4f}".format(accuracy)+".pth"
            torch.save(model.state_dict(), os.path.join("saved", date, filename))

            acc_max = accuracy
            ck_max = ck

    # return accuracy
    return ck

In [None]:
trainer(train_loader, val_loader, OPTIMIZER_NAME)