In [None]:
import os
import tqdm
import time
from collections import defaultdict
from sklearn.metrics import cohen_kappa_score

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
import torch
import torch.nn as nn
import torchvision.transforms as T
import torch.nn.functional as F
import torch.optim as optim

import torchio as tio
import torchvision
from torchvision.models import resnet50

import optuna
from optuna.trial import TrialState

from coral_pytorch.dataset import levels_from_labelbatch
from coral_pytorch.dataset import proba_to_label, corn_label_from_logits

from losses import get_loss
from models import get_model
from datasets import CMRxMOTION2DEval

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
MODEL_TRAIN_DATE = "2022-09-01-01-25-56"         ### DATE OF THE TRAINING
MODEL_FILENAME = "0.8283.pth"                    ### MODEL FILENAME WHICH IS DEFINED BY THE VAL SPLIT CK SCORE
MODEL_NAME = "EfficientNet_B5"

MODE_3D_2D = '2D'                                ### 3D/2D CONFIG? WILL AFFECT DATALOADER AND NETWORK ARCHITECTURE
DATA_PATH = "path_test"        ### TO BE RECONFIGURED
DEBUG = False                                    ### SWITCHING TO DEBUG MODE

DEVICE = torch.device("cuda")
BATCH_SIZE = 16
CLASSES = 3
DIR = os.getcwd()
# EPOCHS = 10
# TRAIN_VAL_RATIO = 0.8
# LOSS_NAME = "FocalLoss"
# ALPHA = 0.25
# GAMMA = 2.0
# LR = 0.0008
# OPTIMIZER_NAME = "RMSprop"
USE_CORAL = False
USE_CORN = True

if MODE_3D_2D == '2D':
    EPOCHS = 300


In [None]:
if MODE_3D_2D == "3D":
    pass
elif MODE_3D_2D == "2D":
    # TODO: Test-time Augmentation??
    # transforms_val_aug = T.Compose([
    #     T.RandomAffine(degrees=120, translate=(0.,0.1), scale=(0.85, 1.15)),
    #     T.CenterCrop(256),
    #     T.ToTensor(),
    #     T.ConvertImageDtype(torch.float),
    # ])
    transforms_val = T.Compose([
        T.CenterCrop(128),
        T.ToTensor(),
        T.ConvertImageDtype(torch.float),
    ])
    tio_transforms = [
        # tio.RandomMotion(degrees=3, translation=1, num_transforms=1),
        # tio.RescaleIntensity(out_min_max=(0, 1)),
        tio.RescaleIntensity(out_min_max=(0, 1), percentiles=(0.5, 99.5)),
    ]

    # TODO: Test-time Augmentation??
    # val_dataset_aug = CMRxMOTION2DEval(DATA_PATH, transforms=transforms_tr)
    val_dataset = CMRxMOTION2DEval(DATA_PATH,
                                   transforms=transforms_val,
                                   tio_transforms=tio_transforms)


In [None]:
# TODO: Test-time Augmentation??
# val_aug_loader = torch.utils.data.DataLoader(
#     val_dataset_aug,
#     batch_size=BATCH_SIZE,
#     shuffle=True,
#     num_workers=4
# )
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4
)

In [None]:
def show(imgs):
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(16, 16))
    for i, img in enumerate(imgs):
        img = T.ToPILImage()(img.to('cpu'))
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
# Samples from the train and the val splits
# show([(val_dataset_aug.__getitem__(d))[0] for d in range(0, 4)])
show([(val_dataset.__getitem__(d))[0] for d in range(0, 4)])
# show([(val_dataset.__getitem__(d))[0] for d in range(4, 8)])
# show([(val_dataset.__getitem__(d))[0] for d in range(8, 12)])
# show([(val_dataset.__getitem__(d))[0] for d in range(12, 16)])
# show([(val_dataset.__getitem__(d))[0] for d in range(16, 20)])


In [None]:
def eval_loop(model, val_loader, use_corn=False, use_coral=False):
    model.eval()
    
    all_votes = defaultdict(lambda: [0, 0, 0])

    preds = defaultdict(lambda: -1)
    
    with torch.no_grad():
        for batch_idx, (data, pid, phase, depth, max_depth) in (enumerate(val_loader)):
            data = data.to(DEVICE)

            if USE_CORAL:
                _, probas = model(data)
            else:
                output = model(data)

            if USE_CORAL:
                pred = proba_to_label(probas).unsqueeze(dim=1)
            elif USE_CORN:
                pred = corn_label_from_logits(output).unsqueeze(dim=1)
            else:
                # Get the index of the max log-probability.
                pred = output.argmax(dim=1, keepdim=True)

            # TODO: Max importance for mid, min importance
            # for apex and basal slices for final decision
            for i, (id, ph, d) in enumerate(zip(pid, phase, depth)):
                filename = "-".join([id, ph])
                # all_votes[filename][pred[i]] += torch.abs((max_depth[i] // 2 - d) / (max_depth[i] // 2))
                all_votes[filename][pred[i]] += 1

        for fn, vote in all_votes.items():
            final_pred = torch.argmax(torch.Tensor([vote]), dim=1)
            preds[fn] = final_pred.cpu().item() + 1
            if vote[2] > 0:
                preds[fn] = 3

        all_votes = sorted(all_votes.items())
        preds = sorted(preds.items())
        
        return all_votes, preds

In [None]:
def evaluate(val_loader):
    # Model Initializer
    model = get_model(MODEL_NAME, CLASSES, use_corn=USE_CORN, use_coral=USE_CORAL).to(DEVICE)
    model.load_state_dict(torch.load(os.path.join("saved", MODEL_TRAIN_DATE, MODEL_FILENAME)))

    # Evaluation of the model.
    all_votes, preds = eval_loop(model, val_loader, use_corn=USE_CORN, use_coral=USE_CORAL)

    print(all_votes)
    
    # Save the predictions of the model
    df = pd.DataFrame.from_dict(preds)
    df.columns = ["Image", "Label"]
    df.to_csv(os.path.join("saved", MODEL_TRAIN_DATE, "preds.csv"), index=False)

In [None]:
evaluate(val_loader)