# Setup

## Add dependencies

In [None]:
from pathlib import Path
import sys
sys.path.append(".")


## Detect Device

In [None]:
import torch

device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load & Pre-process Data

In [None]:
from data import get_img_dataset
from masked_dataset import MaskedDataset
from torchvision import transforms
from project3Lib import EnhanceContrast, ScaleToFill, SubtractPCA

train_dataset, val_dataset, test_dataset = get_img_dataset(
      device=device, 
      data_path=Path("/content/data/unique_images"), 
      folder_type=MaskedDataset, 
      mask_folder=Path("/content/data/masks"),
      common_transforms = [
          EnhanceContrast(reduce_dim = True)
      ]
)

train_dataset_unlabeled, val_dataset_unlabeled, test_dataset_unlabeled = get_img_dataset(
      device=device, 
      data_path=Path("/content/data/tl_dataset"), 
      folder_type=MaskedDataset, 
      mask_folder=None,
      use_empty_mask = True,
      common_transforms = [
          EnhanceContrast(reduce_dim = True)
      ]
)

# Train Models

## Helper functions for 

In [None]:
from project3Lib import UNet, MaxClassifier, dice_loss
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
import numpy as np

def get_model(train_dataset, val_dataset, device, train_dataset_unlabeled=None, **kwargs):
    
    # Create segmentation model
    model = UNet(1, 1)
    model.to(device=device)
    if train_dataset_unlabeled is None:
        model.train_supervised(list(train_dataset), list(val_dataset), **kwargs)
    else:
        model.train_semisupervised(list(train_dataset), list(val_dataset), list(train_dataset_unlabeled), **kwargs)
    
    # Construct MaxClassifier
    classifier = MaxClassifier(model)
    classifier.fit(val_dataset)

    return classifier


## Select Hyperparameters

In [None]:
import optuna
from project3Lib import dice_loss

# Objective to optimize dice loss over 
# validation set for different hyper parameters
def objective(trial, train_dataset, val_dataset, device, train_dataset_unlabeled=None, **kwargs):

    # Sugest parameters
    if train_dataset_unlabeled is None:
        kwargs["alpha"] = trial.suggest_float("alpha", 0, 1)
        kwargs["epochs"] = trial.suggest_int("epochs", 5, 20)
        kwargs["lr"] = trial.suggest_float("lr", 1e-6, 1e-4)
    else:
        kwargs["beta"] = trial.suggest_float("beta", 0.1, 10)

    # Train model
    model = get_model(train_dataset, val_dataset, device, train_dataset_unlabeled, **kwargs)

    # Compute validation dice loss
    with torch.no_grad():
        val_dice = 0
        for x, target, _ in val_dataset:
            pred = model.model(x)
            val_dice += dice_loss(pred, target.unsqueeze(0))

    return val_dice.item()

# Optimize the hyperparameters for supervised model
study_supervised = optuna.create_study(direction="minimize") 
best_trial = study_supervised.optimize(
    lambda trial: objective(trial, train_dataset, val_dataset, device), 
    n_trials=50
)

# Use found parameters
alpha = study_supervised.best_params["alpha"]
epochs = study_supervised.best_params["epochs"]
lr = study_supervised.best_params["lr"]

# Optimize the hyperparameters for semi-supervised 
# model with consistancy regularization
study_semisupervised = optuna.create_study(direction="minimize") 
best_trial = study_semisupervised.optimize(
    lambda trial: objective(trial, train_dataset, val_dataset, device, train_dataset_unlabeled, 
                            alpha=alpha, epochs=epochs, lr=lr), 
    n_trials=10
)

# Use found parameters
beta = study_semisupervised.best_params["beta"]


## Train and test the final models

In [None]:
# Helper function for testing the quality of the different models
def test_model(model_fn, test_dataset, n_expreiments = 10):

    # Values to compute
    dice_values, f1_values, acc_values, cm_values = [], [], [], []

    # Retrain model multiple times and report metrics
    for _ in range(n_expreiments):

        classifier = model_fn()

        dice = 0
        preds = []
        truth = []
        for x, target, label in test_dataset:
            dice += 1 - dice_loss(classifier.model(x), target.unsqueeze(0)).item()
            preds.append(classifier(x).item())
            truth.append(label)
        dice /= len(test_dataset)

        dice_values.append(dice)
        f1_values.append(f1_score(truth, preds))
        acc_values.append(accuracy_score(truth, preds))
        cm_values.append(confusion_matrix(truth, preds))

    # Return metrics
    return np.array(dice_values), np.array(f1_values), np.array(acc_values), np.array(cm_values)


# Compute metrics for unregularized model
dice_supervised, f1_supervised, acc_supervised, cm_supervised = test_model(
    lambda: get_model(train_dataset, val_dataset, device, 
                      alpha=alpha, epochs=epochs, lr=lr), 
    test_dataset, n_expreiments = 10
)
np.save("dice_supervised", dice_supervised)
np.save("f1_supervised", f1_supervised)
np.save("acc_supervised", acc_supervised)
np.save("cm_supervised", cm_supervised)

# Compute metrics for regularized model
dice_values, f1_values, acc_values, cm_values = test_model(
    lambda: get_model(train_dataset, val_dataset, device, train_dataset_unlabeled,
                      alpha=alpha, epochs=epochs, lr=lr, beta=beta), 
    test_dataset, n_expreiments = 10
)
np.save("dice_semi", dice_values)
np.save("f1_semi", f1_values)
np.save("acc_semi", acc_values)
np.save("cm_semi", cm_values)


# Interpret Trained Models 

In [None]:
from project3Lib import gradcam_unet, integrad_unet, shap_unet, evaluate_interpretability
import matplotlib.pyplot as plt
import numpy as np

def interpret_model(model, test_dataset, plot = True):

    gc_scores, ig_scores, shap_scores = [], [], []
    background = torch.concat([x for x, _, _ in val_dataset])
    for im, target, label in test_dataset:

        # GradCam
        attr_gc = gradcam_unet(model, model.down3, im, target, label).cpu().detach()
        gc_scores.append(evaluate_interpretability(-attr_gc[0, 0], target[0].cpu(), torch.count_nonzero(target).item()))

        # Integrated Gradients
        attr_ig = integrad_unet(model, im, target, label).cpu().detach()
        ig_scores.append(evaluate_interpretability(attr_ig[0, 0], target[0].cpu(), torch.count_nonzero(target).item()))

        # Shap 
        attr_shap = shap_unet(model, background, im, target, label)
        shap_scores.append(evaluate_interpretability(attr_shap[0, 0], target[0].cpu(), torch.count_nonzero(target).item()))

        if plot:

            pred = model(im).cpu().detach()

            f, axarr = plt.subplots(1, 5)
            im = im.cpu().detach()
            axarr[0].imshow(im[0, 0], cmap='gray')
            axarr[0].imshow(target[0].cpu(), alpha=0.5, cmap="jet")
            axarr[1].imshow(pred[0, 0], cmap='gray')
            axarr[2].imshow(attr_gc[0, 0], alpha=0.5, cmap="gray")
            axarr[3].imshow(attr_ig[0, 0], alpha=0.5, cmap="gray")
            axarr[4].imshow(attr_shap[0, 0], alpha=0.5, cmap="gray")
            plt.show()

    return np.array(gc_scores), np.array(ig_scores), np.array(shap_scores)


# Interpret unregularized model
gc_supervised, ig_supervised, shap_supervised = interpret_model(
    get_model(train_dataset, val_dataset, device, alpha=alpha, epochs=epochs, lr=lr).model, 
    test_dataset, plot=True
)
np.save("gc_supervised", gc_supervised)
np.save("ig_supervised", ig_supervised)
np.save("shap_supervised", shap_supervised)

# Interpret regularized model
gc_semi, ig_semi, shap_semi = interpret_model(
    get_model(train_dataset, val_dataset, device, train_dataset_unlabeled, alpha=alpha, epochs=epochs, lr=lr, beta=beta).model, 
    test_dataset, plot=True
)
np.save("gc_semi", gc_semi)
np.save("ig_semi", ig_semi)
np.save("shap_semi", shap_semi)
