### Import libraries

In [None]:
import sys
sys.path.append("../")

In [None]:
import os
from os import environ
import numpy as np
from random import choices
import pandas as pd
from tqdm.notebook import tqdm
import torch

In [None]:
import skimage.io as io

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.transforms as mtrans
%matplotlib inline

In [None]:
from preprocess.common import load_nii
import copy

In [None]:
from torch.utils.data import DataLoader
from torchcontrib.optim import SWA

# ---- My utils ----
from utils.data_augmentation import data_augmentation_selector
from utils.dataload import *
from utils.training import *

In [None]:
from pylab import rcParams
rcParams['figure.figsize'] = 8, 5
plt.rc('grid', linestyle="--", color='gray')

# https://learnui.design/tools/data-color-picker.html#palette
colors = ['#33508f', '#ff5d68', '#ffa600','#af4f9b']

--------------------------------

### Load Data

In [None]:
train_aug, train_aug_img, val_aug = data_augmentation_selector("none", 224, 224)

In [None]:
data_partition = "validation"
general_aug, img_aug = train_aug, train_aug_img
normalization = "standardize"
fold_system = "patient"
label_type = "vendor_label_full"
data_fold = 0
add_depth=False
in_channels = 3 if add_depth else 1
data_fold_validation=None

discriminator_val_dataset = MMsDataset(
    mode=data_partition, transform=train_aug, img_transform=train_aug_img,
    folding_system=fold_system, normalization=normalization, label_type=label_type,
    train_fold=data_fold, val_fold=data_fold_validation, add_depth=add_depth
)

discriminator_loader = DataLoader(discriminator_val_dataset, batch_size=1, shuffle=False, drop_last=False)

In [None]:
data_partition = "validation"
general_aug, img_aug = train_aug, train_aug_img
normalization = "standardize"
fold_system = "vendor"
label_type = "mask"


segmentation_val_dataset = MMsDataset(
    mode=data_partition, transform=general_aug, img_transform=img_aug,
    folding_system=fold_system, normalization=normalization, label_type=label_type,
    train_fold="A", val_fold="B",
)

segmentation_loader = DataLoader(segmentation_val_dataset, batch_size=1, shuffle=False, drop_last=False)

In [None]:
val_same_patients = np.intersect1d(discriminator_val_dataset.df["External code"], segmentation_val_dataset.df["External code"])
print(f"Pacientes en común ({len(val_same_patients)}): {val_same_patients}")

------------------------------------------

### Load Models

In [None]:
from models import *

In [None]:
num_classes, crop_size, model_name = 3, 224, "resnet34_unet_scratch_classification"

discriminator = model_selector(model_name, num_classes=num_classes, in_channels=in_channels)
model_total_params = sum(p.numel() for p in discriminator.parameters())
print("Model total number of parameters: {}".format(model_total_params))
discriminator = torch.nn.DataParallel(discriminator, device_ids=range(torch.cuda.device_count()))

###########################################################################################

model_checkpoint = "../checkpoints/full_discriminator_{}channel_fold{}.pt".format(in_channels, data_fold)
discriminator.load_state_dict(torch.load(model_checkpoint))
print("Discriminator checkpoint loaded correctly!")

In [None]:
num_classes, crop_size, model_name = 4, 224, "resnet34_unet_scratch"

segmentator = model_selector(model_name, num_classes=num_classes, in_channels=in_channels)
model_total_params = sum(p.numel() for p in segmentator.parameters())
print("Model total number of parameters: {}".format(model_total_params))
segmentator = torch.nn.DataParallel(segmentator, device_ids=range(torch.cuda.device_count()))

###########################################################################################

segmentation_train_fold = 'A'
segmentation_val_fold = 'B'
model_checkpoint = "../checkpoints/segmentator_{}vs{}_{}channel.pt".format(segmentation_train_fold, segmentation_val_fold, in_channels)
segmentator.load_state_dict(torch.load(model_checkpoint))
print("Segmentator checkpoint loaded correctly!")

-----------------------------------

## Label Modification

#### Check discriminator accuracy

In [None]:
criterion, weights_criterion = "ce", "default"
criterion, weights_criterion, multiclass_criterion = get_criterion(criterion, weights_criterion)
task = "classification" # binary_classification or classification

In [None]:
accuracy, val_loss = val_step_accuracy(
    discriminator_loader, discriminator, criterion, weights_criterion, multiclass_criterion, task=task
)

In [None]:
print(f"Discriminator accuracy: {accuracy}")

#### Check segmentator metrics

In [None]:
train_csv = pd.read_csv("../utils/data/train.csv")

In [None]:
os.makedirs("testa", exist_ok=True)
iou, dice, val_loss, stats = val_step(
    segmentation_loader, segmentator, criterion, weights_criterion, multiclass_criterion, 0.5,
    generate_stats=True, save_path="testa",
    generate_overlays=False,
)

In [None]:
def clean_stats(df, train_df):
    df = df.fillna(1)
    df["Vendor"] = "Z"
    df["Centre"] = 999
    df["Type"] = "XX"

    for i, row in df.iterrows():

        patient = row["patient"]
        c_phase = row["phase"]

        centre = train_df.loc[train_df["External code"]==patient].iloc[0]["Centre"]
        vendor = train_df.loc[train_df["External code"]==patient].iloc[0]["Vendor"]
        c_type = train_df.loc[(train_df["External code"]==patient) & (train_df["Phase"]==int(c_phase))].iloc[0]["Type"]

        df.at[i,'Vendor'] = vendor
        df.at[i,'Centre'] = centre
        df.at[i,'Type'] = c_type
    
    return df

In [None]:
stats = clean_stats(stats, train_csv)
same_stats = stats[stats['patient'].isin(val_same_patients)]
same_stats["Vendor"].value_counts()

In [None]:
same_stats.groupby("Vendor")["IOU_MEAN"].mean()

In [None]:
same_stats.groupby("Vendor")["IOU_MEAN"].mean().plot.bar(color=colors)
# -------------------------------------------------------------- #
plt.ylabel("Mean IOU")
plt.xticks(rotation='horizontal')
plt.yticks(np.arange(0, same_stats.groupby("Vendor")["IOU_MEAN"].mean().max()+0.05, .05))
plt.title("Mean IOU by Vendor")
plt.grid()
#plt.savefig(os.path.join(save_dir, 'iou_vendor.png'), bbox_inches='tight', dpi=160)

# Image modification using entropy

In [None]:
def CXE(predicted, target):
    return -(target * torch.log(predicted)).sum(dim=1).mean()

### Without Image modification

In [None]:
stats = val_step_experiments(segmentation_loader, segmentator, val_same_patients, train_csv,
                         num_classes=4, generate_imgs=False, image_modificator_fn=None)

In [None]:
stats.groupby("Vendor")["IOU_MEAN"].mean()

In [None]:
stats.groupby("Vendor")["IOU_MEAN"].mean().plot.bar(color=colors)
# -------------------------------------------------------------- #
plt.ylabel("Mean IOU")
plt.xticks(rotation='horizontal')
plt.yticks(np.arange(0, stats.groupby("Vendor")["IOU_MEAN"].mean().max()+0.05, .05))
plt.title("Mean IOU by Vendor")
plt.grid()
#plt.savefig(os.path.join(save_dir, 'iou_vendor.png'), bbox_inches='tight', dpi=160)

### Simple CXE

In [None]:
class ImageBackwardEntropy:
    """
    
    """

    def __init__(self, discriminator_model, target, max_epochs=500, 
                 out_threshold=0.01, grad_gamma=0.9, add_l1=False, l1_lambda=10, verbose=False):
        """
        Parameters:
            
        """

        self.discriminator_model = discriminator_model
        self.target = target
        self.max_epochs = max_epochs
        self.out_threshold = out_threshold
        self.grad_gamma = grad_gamma
        self.verbose = verbose
        self.add_l1 = add_l1
        self.l1_lambda = l1_lambda
        

    def apply(self, image):
        """
        Parameters:
            
        """

        x = copy.deepcopy(image).detach()

        with torch.no_grad():
            initial_y = torch.nn.functional.softmax(self.discriminator_model(x.detach()), dim=1)


        for k in range(self.max_epochs):

            x.requires_grad_(True)

            y = torch.nn.functional.softmax(self.discriminator_model(x), dim=1)

            # https://discuss.pytorch.org/t/catrogircal-cross-entropy-with-soft-classes/50871
            error = CXE(y.cuda(), target.cuda())
            if self.add_l1:
                error = error + (torch.nn.L1Loss()(image.detach(), x) * self.l1_lambda)
            error.backward()

            x = x.detach() - self.grad_gamma*x.grad

            if (y.cuda()-target.cuda()).abs().max() <= self.out_threshold: 
                break

        if self.verbose:
            print("")
            if (k+1) < self.max_epochs:
                print(f"----- Early stopping at iteration {k} -----")
            print("Target: {}".format(target))
            print("Initial y: {}".format(['%.4f' % elem for elem in initial_y.tolist()[0]]))
            print("Final y: {}".format(['%.4f' % elem for elem in y.tolist()[0]]))
            print("")

        return x, y

In [None]:
target = torch.from_numpy(np.array([1.0, 0.0, 0.0]))
out_threshold = 0.01
grad_gamma=0.99
max_epochs=50

image_modificator_fn = ImageBackwardEntropy(
    discriminator, target, max_epochs=max_epochs, 
    out_threshold=out_threshold, grad_gamma=grad_gamma, verbose=False,
    add_l1=True, l1_lambda=10
)

entropy_descriptor = "simple"

In [None]:
stats = val_step_experiments(
    segmentation_loader, segmentator, val_same_patients, train_csv,
    num_classes=4, generate_imgs=True, image_modificator_fn=image_modificator_fn,
    save_dir="entropy_images/{}vs{}/{}/outThreshold{}_gradGamma{}_maxEpochs{}".format(segmentation_train_fold, segmentation_val_fold, entropy_descriptor, out_threshold, grad_gamma, max_epochs)
)

In [None]:
stats.groupby("Vendor")["IOU_MEAN"].mean()

In [None]:
stats.groupby("Vendor")["IOU_MEAN"].mean().plot.bar(color=colors)
# -------------------------------------------------------------- #
plt.ylabel("Mean IOU")
plt.xticks(rotation='horizontal')
plt.yticks(np.arange(0, stats.groupby("Vendor")["IOU_MEAN"].mean().max()+0.05, .05))
plt.title("Mean IOU by Vendor")
plt.grid()
#plt.savefig(os.path.join(save_dir, 'iou_vendor.png'), bbox_inches='tight', dpi=160)