### select best parameter setup
Use this notebook to select the best parameter setup from a list of settings based on Dice performance as well as BWT and FWT.

#### Set paths

In [1]:
rwalk_exps = [
              '/home/aranem_locale/Desktop/mnts/local/scratch/aranem/MICCAI_2023/MICCAI_2023_predictions/RWalk/UNet_VxM_rwalk_0-4_0-9',
              '/home/aranem_locale/Desktop/mnts/local/scratch/aranem/MICCAI_2023/MICCAI_2023_predictions/RWalk/UNet_VxM_rwalk_1-1_0-9',
              '/home/aranem_locale/Desktop/mnts/local/scratch/aranem/MICCAI_2023/MICCAI_2023_predictions/RWalk/UNet_VxM_rwalk_1-7_0-9',
              '/home/aranem_locale/Desktop/mnts/local/scratch/aranem/MICCAI_2023/MICCAI_2023_predictions/RWalk/UNet_VxM_rwalk_2-2_0-9'
             ]

ewc_exps = [
            '/home/aranem_locale/Desktop/mnts/local/scratch/aranem/MICCAI_2023/MICCAI_2023_predictions/EWC/UNet_VxM_ewc_0-4',
            '/home/aranem_locale/Desktop/mnts/local/scratch/aranem/MICCAI_2023/MICCAI_2023_predictions/EWC/UNet_VxM_ewc_1-1',
            '/home/aranem_locale/Desktop/mnts/local/scratch/aranem/MICCAI_2023/MICCAI_2023_predictions/EWC/UNet_VxM_ewc_1-7',
            '/home/aranem_locale/Desktop/mnts/local/scratch/aranem/MICCAI_2023/MICCAI_2023_predictions/EWC/UNet_VxM_ewc_2-2'
           ]

#### Import necessary libraries

In [2]:
from math import pi
import numpy as np
import pandas as pd
import seaborn as sns
from time import sleep
import os, pystrum, copy
import SimpleITK as sitk
from tqdm.notebook import trange, tqdm
import matplotlib.pyplot as plt
from torch.autograd import Variable

os.environ['NEURITE_BACKEND'] = "pytorch"
import neurite as ne

  from .autonotebook import tqdm as notebook_tqdm


#### Helpful functions and other stuffs

In [3]:
def mean_dice_coef(y_true, y_pred_bin, num_classes=1, do_torch=False):
    # from: https://www.codegrepper.com/code-examples/python/dice+similarity+coefficient+python
    # shape of y_true and y_pred_bin: (n_samples, height, width, n_channels)
    batch_size = y_true.shape[0]
    depth = y_true.shape[-1]
    # channel_num = y_true.shape[-1]
    mean_dice_channel = 0.
    # dict contains label: dice per batch
    channel_dices_per_batch = {i+1:list() for i in range(num_classes)}
    for i in range(batch_size):
        # for n in range(depth):
        for j in range(1, num_classes+1):
            y_t = y_true[i, ...].clone() if do_torch else copy.deepcopy(y_true[i, ...])
            y_p = y_pred_bin[i, ...].clone() if do_torch else copy.deepcopy(y_pred_bin[i, ...])
            y_t[y_t != j] = 0
            y_t[y_t == j] = 1
            y_p[y_p != j] = 0
            y_p[y_p == j] = 1
            channel_dice = single_dice_coef(y_t, y_p, do_torch)
            channel_dices_per_batch[j].append(channel_dice)
            # channel_dice = single_dice_coef(y_true[i, :, :, j], y_pred_bin[i, :, :, j], num_classes, do_torch)
            mean_dice_channel += channel_dice/(num_classes*batch_size)
    return mean_dice_channel, channel_dices_per_batch

def single_dice_coef(y_true, y_pred_bin, do_torch=False):
    # shape of y_true and y_pred_bin: (height, width)
    intersection = np.sum(y_true * y_pred_bin) if not do_torch else torch.sum(y_true * y_pred_bin)
    if do_torch:
        if (torch.sum(y_true)==0) and (torch.sum(y_pred_bin)==0):
            return 1
        return ((2*intersection) / (torch.sum(y_true) + torch.sum(y_pred_bin))).item()
    else:
        if (np.sum(y_true)==0) and (np.sum(y_pred_bin)==0):
            return 1
        return (2*intersection) / (np.sum(y_true) + np.sum(y_pred_bin))

def trunc(values, decs=0):
    return np.trunc(values*10**decs)/(10**decs)
    
val_keys = {'Task110_RUNMC': ['Case03', 'Case08', 'Case12', 'Case15', 'Case18', 'Case26'],
            'Task111_BMC': ['Case03', 'Case08', 'Case12', 'Case15', 'Case18', 'Case26'],
            'Task112_I2CVB': ['Case03', 'Case08', 'Case13', 'Case15'],
            'Task113_UCL': ['Case01', 'Case32', 'Case34'],
            'Task114_BIDMC': ['Case00', 'Case04', 'Case09'],
            'Task115_HK': ['Case38', 'Case41', 'Case46'],
            'Task116_DecathProst': ['prostate_00', 'prostate_04', 'prostate_14', 'prostate_20', 'prostate_25', 'prostate_31', 'prostate_42']}

### EWC results

In [4]:
# EWC
i = 0
dices_e_all = dict()
for path in ewc_exps:
    models = [x for x in os.listdir(path) if 'unet' in x and 'joint' not in x]
    models.sort()
    dices_e = dict()
    for model in models:
        dices_e[model[:-7].replace('unet_torch_250_', '')] = dict()
        preds = os.path.join(path, model, 'predictions')
        # print(f'Dice for model {model}:')
        # -- Load the data -- #
        ds = [x for x in os.listdir(preds)]
        ds.sort()
        for datas in ds:
            dices_ = list()
            dices_e[model[:-7].replace('unet_torch_250_', '')][datas.split('_')[0].replace('Task', '')] = dict()
            cases = val_keys[datas]
            for case in cases:
                dices_e[model[:-7].replace('unet_torch_250_', '')][datas.split('_')[0].replace('Task', '')][case] = dict()
                gt = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(preds, datas, case, 'seg_gt.nii.gz')))
                y_p = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(preds, datas, case, 'pred_seg.nii.gz')))
                _, channel_dices_per_batch = mean_dice_coef(gt, y_p, 1, False)
                dice = [np.mean(v) for _, v in channel_dices_per_batch.items()] # Dice between moved and fixed segmentation
                dices_e[model[:-7].replace('unet_torch_250_', '')][datas.split('_')[0].replace('Task', '')][case]['moved'] = np.round(dice, 4)
                dices_.append(dice)

            # print(f"{datas}: Mean dice +/- std: {trunc(np.round(np.mean(dices_), decimals=4)*100, 2)} % +/- {trunc(np.round(np.std(dices_), decimals=4)*100, 2)} %.") 

            dices_e[model[:-7].replace('unet_torch_250_', '')][datas.split('_')[0].replace('Task', '')]['mean_dice_moved'] = np.mean(dices_)
            dices_e[model[:-7].replace('unet_torch_250_', '')][datas.split('_')[0].replace('Task', '')]['mean_std_moved'] = np.std(dices_)
            dices_e[model[:-7].replace('unet_torch_250_', '')][datas.split('_')[0].replace('Task', '')]['mean_dice_std_moved'] = str(np.round(np.mean(dices_), 4)) + '+/-' + str(np.round(np.std(dices_), 4))
        # print()
    
    dices_e_all[i] = dices_e
    i += 1

### RWalk results

In [5]:
# RWalk
i = 0
dices_r_all = dict()
for path in rwalk_exps:
    models = [x for x in os.listdir(path) if 'unet' in x and 'joint' not in x]
    models.sort()
    dices_r = dict()
    for model in models:
        dices_r[model[:-9].replace('unet_torch_250_', '')] = dict()
        preds = os.path.join(path, model, 'predictions')
        # print(f'Dice for model {model}:')
        # -- Load the data -- #
        ds = [x for x in os.listdir(preds)]
        ds.sort()
        for datas in ds:
            dices_ = list()
            dices_r[model[:-9].replace('unet_torch_250_', '')][datas.split('_')[0].replace('Task', '')] = dict()
            cases = val_keys[datas]
            for case in cases:
                dices_r[model[:-9].replace('unet_torch_250_', '')][datas.split('_')[0].replace('Task', '')][case] = dict()
                gt = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(preds, datas, case, 'seg_gt.nii.gz')))
                y_p = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(preds, datas, case, 'pred_seg.nii.gz')))
                _, channel_dices_per_batch = mean_dice_coef(gt, y_p, 1, False)
                dice = [np.mean(v) for _, v in channel_dices_per_batch.items()] # Dice between moved and fixed segmentation
                dices_r[model[:-9].replace('unet_torch_250_', '')][datas.split('_')[0].replace('Task', '')][case]['moved'] = np.round(dice, 4)
                dices_.append(dice)

            # print(f"{datas}: Mean dice +/- std: {trunc(np.round(np.mean(dices_), decimals=4)*100, 2)} % +/- {trunc(np.round(np.std(dices_), decimals=4)*100, 2)} %.") 

            dices_r[model[:-9].replace('unet_torch_250_', '')][datas.split('_')[0].replace('Task', '')]['mean_dice_moved'] = np.mean(dices_)
            dices_r[model[:-9].replace('unet_torch_250_', '')][datas.split('_')[0].replace('Task', '')]['mean_std_moved'] = np.std(dices_)
            dices_r[model[:-9].replace('unet_torch_250_', '')][datas.split('_')[0].replace('Task', '')]['mean_dice_std_moved'] = str(np.round(np.mean(dices_), 4)) + '+/-' + str(np.round(np.std(dices_), 4))
        # print()
        
    dices_r_all[i] = dices_r
    i += 1

In [6]:
data = dices_e_all # dices_r_all

#### Forgetting, Positive Backward Transfer, Remembering and Forward Transfer

In [7]:
tasks = ['110', '111', '112', '113', '114', '115', '116']
methods = [i for i in range(len(ewc_exps))]
B_T, F_T, F_NEG, F_POS, B_TP, FOR, REM, Dice, Dice_F, Dice_L, STD = dict(), dict(), dict(), dict(), dict(), dict(), dict(), dict(), dict(), dict(), dict()

all_t_j = '_'.join(tasks)
for m in methods:
    B_T[m], F_T[m], F_NEG[m], F_POS[m], B_TP[m], FOR[m], REM[m], Dice[m], Dice_F[m], Dice_L[m], STD[m] = dict(), dict(), dict(), dict(), dict(), dict(), dict(), dict(), dict(), dict(), dict()
    t_list = list()
    for t in tasks:
        t_list.append(t)
        try:
            Dice_t_all = (data[m][all_t_j][t]['mean_dice_moved'],
                          data[m][all_t_j][t]['mean_std_moved'])
            
            Dice_t = (data[m]['_'.join(t_list)][t]['mean_dice_moved'],
                      data[m]['_'.join(t_list)][t]['mean_std_moved'])
            
            if len(t_list) > 1:
                Dice_t_prev = (data[m]['_'.join(t_list[:-1])][t]['mean_dice_moved'],
                               data[m]['_'.join(t_list[:-1])][t]['mean_std_moved'])
                F_T[m][t] = Dice_t_prev[0] - data[m][t][t]['mean_dice_moved']
                F_NEG[m][t] = abs(min(F_T[m][t], 0))
                F_POS[m][t] = 1 - abs(min(F_T[m][t], 0))
            
            if '_'.join(t_list) == all_t_j: # Only add this once
                Dice_F[m][t] = data[m][all_t_j][tasks[0]]['mean_dice_moved']
                Dice_L[m][t] = data[m][all_t_j][tasks[-1]]['mean_dice_moved']
            
            Dice[m][t] = data[m][all_t_j][t]['mean_dice_moved']

            if '_'.join(t_list) != all_t_j:
                B_T[m][t] = Dice_t_all[0] - Dice_t[0]
                B_TP[m][t] = max(B_T[m][t], 0)
                FOR[m][t] = abs(min(B_T[m][t], 0))
                REM[m][t] = 1 - abs(min(B_T[m][t], 0))
        except:
            pass

#### Select best setting

In [8]:
# Calculate mean Dices, BWT and FWT
print('Mean BWT, REM, FWT, Dice, Forgetting [%]')
best_m = - np.inf
best_setup = None
for m in methods:
    # print(m)
    print(np.mean(list(B_T[m].values()))*100, np.std(list(B_T[m].values()))*100)
    print(np.mean(list(REM[m].values()))*100, np.std(list(REM[m].values()))*100)
    print(np.mean(list(F_T[m].values()))*100, np.std(list(F_T[m].values()))*100)
    print(np.mean(list(Dice[m].values()))*100, np.std(list(Dice[m].values()))*100)
    print(np.mean(list(FOR[m].values()))*100, np.std(list(FOR[m].values()))*100)
    m_ = np.mean(list(B_T[m].values()))*100 + np.mean(list(F_T[m].values()))*100 + np.mean(list(Dice[m].values()))*100
    print(f"{m}: Mean over Dices, BWT and FWT: {m_}")
    if m_ > best_m:
        best_m = m_
        best_setup = m
print(f"The best parameters are from {best_setup}: {ewc_exps[best_setup]}.")
# print(f"The best parameters are from {best_setup}: {rwalk_exps[best_setup]}.")

Mean BWT, REM, FWT, Dice, Forgetting [%]
-28.467163746309826 12.34699631455628
71.53283625369018 12.346996314556279
-28.87330350939934 19.482110998035072
42.43886571857907 20.67217666252717
28.467163746309826 12.34699631455628
0: Mean over Dices, BWT and FWT: -14.901601537130091
-29.275575257004533 14.815028007884608
70.72442474299547 14.815028007884608
-27.473033231369488 16.28978944820951
44.82734995175964 20.855827905917298
29.275575257004533 14.815028007884608
1: Mean over Dices, BWT and FWT: -11.921258536614374
-30.959548135947294 15.333171893247815
69.0404518640527 15.333171893247815
-30.593663096020325 16.66339244126253
44.02862974661069 20.485931102830307
30.959548135947294 15.333171893247815
2: Mean over Dices, BWT and FWT: -17.524581485356933
-27.47769966669711 14.996302357665229
72.52230033330291 14.996302357665229
-29.65065769253637 18.31779875409807
46.05998528939446 18.45626251600051
27.47769966669711 14.996302357665229
3: Mean over Dices, BWT and FWT: -11.068372069839015