### running the evaluation for the challenge dataset

In [1]:
%load_ext autoreload

In [2]:
%autoreload

In [3]:
import torch
import torch.nn as nn
import numpy as np
import os
import matplotlib.pyplot as plt
from torchinfo import summary
from collections import defaultdict, namedtuple
import pandas as pd
# import proplot as pplt
import warnings
warnings.filterwarnings("ignore") # I should only put this in place when I am doing plotting and using proplot....
from twaibrain.braintorch.utils.resize import crop_or_pad_dims
from twaibrain.brainpreprep.utils.image_io import load_image
import scipy
import scipy.stats

# model architecture
import json
from twaibrain.braintorch.models.nnUNet.nnUNetV2_model_loader import get_network_from_plans
from twaibrain.braintorch.models.ssn import SSN_Wrapped_Deep_Supervision, SSN_Wrapped_Deep_Supervision_LLO, Hierarchical_SSN_with_ConvRefine, Hierarchical_SSN_with_ConvSpatialAttention

# fitting code
from twaibrain.braintorch.fitting_and_inference.get_trainer import get_trainer
from twaibrain.braintorch.fitting_and_inference.get_scratch_dir import scratch_dir
from twaibrain.braintorch.fitting_and_inference.optimizer_constructor import OptimizerConfigurator
from twaibrain.braintorch.fitting_and_inference.lightning_fitter import StandardLitModelWrapper

# loss function
from twaibrain.braintorch.losses.ssn_losses import DeepSupervisionSSN, SSNCombinedDiceXent_and_MC_loss, SSNCombinedDiceXent_and_MC_loss_FromSamples
from twaibrain.braintorch.losses.generic_deep_supervision import MultiDeepSupervisionLoss, DeepSupervisionLoss
from twaibrain.braintorch.losses.dice_loss import SoftDiceV2
from twaibrain.braintorch.losses.xent import dice_xent_loss

# data
from twaibrain.brainexperiments.run_nnUNet_v2.old_dataloading.dataset_pipelines import load_data
from twaibrain.braintorch.data.legacy_dataset_types.dataset_wrappers import MonaiAugmentedDataset
from twaibrain.braintorch.augmentation.nnunet_augmentations import get_nnunet_transforms, get_val_transforms
from torch.utils.data import ConcatDataset
from twaibrain.braintorch.data.legacy_dataset_types.mri_dataset_inram import MRISegmentation3DDataset
from twaibrain.braintorch.data.legacy_dataset_types.mri_dataset_from_file import MRISegmentationDatasetFromFile, ArrayMRISegmentationDatasetFromFile
from twaibrain.braintorch.data.legacy_dataset_types.mri_dataset_directory_parsers import *

# evaluation code
from twaibrain.brainexperiments.run_nnUNet_v2.evaluation.eval_helper_functions import *
from twaibrain.brainexperiments.run_nnUNet_v2.evaluation.model_predictions import get_means_and_samples_2D, ssn_ensemble_mean_and_samples
from twaibrain.brainexperiments.run_nnUNet_v2.evaluation.model_predictions import *

### setting up params for the rest of the code to run

In [4]:
ARGS = namedtuple("args", "dataset test_split val_split seed " + 
                  "empty_slice_retention batch_size cross_validate cv_split cv_test_fold_smooth no_test_fold " +
                  "num_workers dice_factor xent_factor xent_reweighting eval_split uncertainty_type ckpt_dir"
)

In [5]:
VOXELS_TO_WMH_RATIO = 382
VOXELS_TO_WMH_RATIO_EXCLUDING_EMPTY_SLICES = 140
ESR = 0.5


# setup xent reweighting factor
XENT_VOXEL_RESCALE = VOXELS_TO_WMH_RATIO - (1-ESR) * (VOXELS_TO_WMH_RATIO - VOXELS_TO_WMH_RATIO_EXCLUDING_EMPTY_SLICES)

XENT_WEIGHTING = XENT_VOXEL_RESCALE/2

In [18]:
args = ARGS(
    dataset = "ed",
    
    test_split=0.15,
    val_split=0.15,
    eval_split='test',

    seed=5,

    empty_slice_retention=ESR,

    uncertainty_type="deterministic",

    batch_size=2,
    cross_validate=True,
    cv_split=0,
    cv_test_fold_smooth=1,
    no_test_fold="false",
    num_workers=16,

    dice_factor=1,
    xent_factor=1,
    xent_reweighting=XENT_WEIGHTING,
    ckpt_dir= "/home/s2208943/projects/twaibrain/twaibrain/brainexperiments/run_nnUNet_v2/training/ssn_model_ckpts/"
)

### loading eval data

In [7]:
def load_cv_test_set(cv_split):

    data_dict = load_data(
        dataset=args.dataset, 
        test_proportion=args.test_split, 
        validation_proportion=args.val_split,
        seed=args.seed,
        empty_proportion_retained=args.empty_slice_retention,
        batch_size=args.batch_size,
        dataloader2d_only=False,
        cross_validate=True,
        cv_split=cv_split,
        cv_test_fold_smooth=args.cv_test_fold_smooth,
        merge_val_test=args.no_test_fold
    )
    
    if args.eval_split == "all":
        eval_ds = ConcatDataset([data_dict['train_dataset3d'], data_dict['val_dataset3d'], data_dict['test_dataset3d']])
    else:
        eval_ds = data_dict[f'{args.eval_split}_dataset3d']
    
    # get the xs and ys
    xs3d_test = []
    ys3d_test = []
    
    for i, data in enumerate(eval_ds):
        ys3d_test.append(data[1].squeeze())
        xs3d_test.append(data[0])
    
    ys3d_test = [y * (y==1).type(y.dtype) for y in ys3d_test] # fix bug with challenge data having 3 classes on cluster only?
    gt_vols = GT_volumes(ys3d_test)
    
    # to run on the nnUnet version of the model,
    # I should crop the x and y to the min and max points
    # this will also effectively centre the images for evaluation.
    for i in range(len(xs3d_test)):
        x = xs3d_test[i]
        y = ys3d_test[i]
        wheres = torch.where(x[-1])
        zs = (wheres[0].min().item(), wheres[0].max().item())
        xs = (wheres[1].min().item(), wheres[1].max().item())
        ys = (wheres[2].min().item(), wheres[2].max().item())
        # print(zs, xs, ys)
        # print(x.shape, y.shape)
    
        x = x[:, zs[0]:zs[1]+1, xs[0]:xs[1]+1, ys[0]:ys[1]+1]
        y = y[zs[0]:zs[1]+1, xs[0]:xs[1]+1, ys[0]:ys[1]+1]
    
        # print(x.shape, y.shape)
        
        x = crop_or_pad_dims(x, [1,2,3], [48, 192, 192])
        y = crop_or_pad_dims(y, [0,1,2], [48, 192, 192])
    
        x = x.unsqueeze(0)
    
        # print(x.shape, y.shape)
        # print()
        xs3d_test[i] = x
        ys3d_test[i] = y

    return xs3d_test, ys3d_test, gt_vols

### load a model checkpoint and get relevant predictions

In [8]:
model_config = "/home/s2208943/projects/twaibrain/twaibrain/braintorch/models/nnUNet/cvd_configs/nnUNetResEncUNetMPlans.json"

with open(model_config) as f:
    model_config = json.load(f)

In [9]:
dims = "2d"
config = model_config['configurations'][dims]['architecture']
network_name = config['network_class_name']
kw_requires_import = config['_kw_requires_import']

def get_model():
    return get_network_from_plans(
        arch_class_name=network_name,
        arch_kwargs=config['arch_kwargs'],
        arch_kwargs_req_import=kw_requires_import,
        input_channels=3,
        output_channels=2,
        allow_init=True,
        deep_supervision=True,
    )

model = get_model()

In [15]:
def get_model_preds(cv_fold, model=model, model_name="nnunet", ckpt_path=None, model_func=deterministic_mean, do_reorder_samples=False, load_ckpt=True, out_domain=True):
    if load_ckpt:
        model = load_best_checkpoint(model, None, '.', cv_fold, ckpt_path).model

    xs3d_test, ys3d_test, gt_vols = load_cv_test_set(int(cv_fold))

    means, samples, miscs = get_means_and_samples_2D(model, zip(xs3d_test, ys3d_test), 10, model_func, args=args)

    means = [m.squeeze(0).swapaxes(0, 1) for m in means]
    means = [m[:,:2] for m in means]
    if samples[0] is not None:
        samples = [s.swapaxes(2, 3) for s in samples]
        samples = [s[:,:,:,:2].squeeze() for s in samples]

    chal_results = per_model_chal_stats(means, ys3d_test)

    rmses = []
    for m, y in zip(means, ys3d_test):
        m = m.cuda()
        m = m.softmax(dim=1)[:,:2]
        rmses.append(fast_rmse(m, y.cuda()).cpu())
    rmses = torch.Tensor(rmses)
    chal_results['rmse'] = rmses

    chal_results['gt_vols'] = gt_vols

    # run the evaluation on the samples
    print("GETTING PER SAMPLE RESULTS")
    if samples[0] is not None:
        if do_reorder_samples:
            samples = [reorder_samples(s) for s in samples]
        sample_top_dices, sample_dices = per_sample_metric(samples, ys3d_test, f=fast_dice, do_argmax=True, do_softmax=False, minimise=False)
        sample_best_avds, sample_avds = per_sample_metric(samples, ys3d_test, f=fast_avd, do_argmax=True, do_softmax=False, minimise=True)
        sample_best_rmses, sample_rmses = per_sample_metric(samples, ys3d_test, f=fast_rmse, do_argmax=False, do_softmax=True, minimise=True)
    
        # best dice, avd, rmse
        chal_results['best_dice'] = sample_top_dices
        chal_results['best_avd'] = sample_best_avds
        chal_results['best_rmse'] = sample_best_rmses
        
        _, sample_vds = per_sample_metric(samples, ys3d_test, f=fast_vd, do_argmax=True, do_softmax=False, minimise=True, take_abs=False)
        sample_vd_skew = torch.from_numpy(scipy.stats.skew(sample_vds, axis=1, bias=True))
    
        # vd of the sample distribution
        chal_results['sample_vd_skew'] = sample_vd_skew
        for s in range(sample_vds.shape[1]):
            chal_results[f'sample_{s}_vd'] = sample_vds[:,s]
            
        # ged score
        geds = iou_GED(means, ys3d_test, samples)
        chal_results['GED^2'] = geds

    # get the uncertainty maps
    print("GENREATING UNCERTAINTY MAPS")
    uncertainty_thresholds = torch.arange(0, 0.7, 0.01)
    ent_maps = get_uncertainty_maps(means, samples, miscs, args)

    # pavpu
    print("PAVPU")
    all_acc_cert, all_uncert_inacc,all_pavpu = all_individuals_pavpu(means, ent_maps, ys3d_test, 4, 0.8, uncertainty_thresholds)
    
    for i, tau in enumerate(uncertainty_thresholds):
        chal_results[f'p_acc_cert_{tau:.2f}'] = all_acc_cert[:,i]
        chal_results[f'p_uncert_inacc_{tau:.2f}'] = all_uncert_inacc[:,i]
        chal_results[f'pavpu_{tau:.2f}'] = all_pavpu[:,i]
    
    # sUEO score and UEO per threshold
    print("UEO")
    sUEOs = get_sUEOs(means, ys3d_test, ent_maps)
    chal_results['sUEO'] = sUEOs
    ueos = UEO_per_threshold_analysis(uncertainty_thresholds, ys3d_test, ent_maps, means, 0.7)
    for i, tau in enumerate(uncertainty_thresholds):
        chal_results[f'UEO_{tau:.2f}'] = ueos[i]
    
    # 3D connected component analysis
    print("3D CC ANALYSIS")
    num_lesions_all, sizes_all, mean_missed_area3d_all, mean_size_missed_lesions3d_all, mean_cov_mean_missed_lesions3d_all, prop_lesions_missed3d_all = do_3d_cc_analysis_per_individual(means, ys3d_test, ent_maps, uncertainty_thresholds)
    for i, tau in enumerate(uncertainty_thresholds):
        chal_results[f'mean_missed_area3d_all_{tau:.2f}'] = torch.stack(mean_missed_area3d_all)[:,i]
        chal_results[f'mean_cov_mean_missed_lesions3d_all_{tau:.2f}'] = torch.stack(mean_cov_mean_missed_lesions3d_all)[:,i]
        chal_results[f'mean_size_missed_lesions3d_all_{tau:.2f}'] = torch.stack(mean_size_missed_lesions3d_all)[:,i]
        chal_results[f'prop_lesions_missed3d_all_{tau:.2f}'] = torch.stack(prop_lesions_missed3d_all)[:,i]
        

    # save the results
    print("SAVING RESULTS")
    domain_folder = "out_domain_results" if out_domain else "in_domain_results"
    results_out_dir = f"/home/s2208943/ipdis/WMH_UQ_assessment/trustworthai/journal_run/evaluation/results/cross_validated_results/{domain_folder}/"
    write_per_model_channel_stats(results_out_dir, model_name, f"{model_name}0_cv{cv_fold}", preds=None, ys3d_test=None, args=args, chal_results=chal_results)
    
    print("DONE")

In [19]:
base_folder = "/home/s2208943/projects/twaibrain/twaibrain/brainexperiments/run_nnUNet_v2/training/model_ckpts/"
for folder in os.listdir(base_folder):
    if "3D" in folder:
        continue
    print(folder)
    cv_split = folder[-1]

    if cv_split == "3":
        continue
    
    # try:
    ckpt = sorted([f for f in os.listdir(os.path.join(base_folder, folder)) if f.endswith(".ckpt")])[-1]
    get_model_preds(cv_split, model=model, model_name="nnunet2D", ckpt_path=os.path.join(base_folder, folder, ckpt), model_func=deterministic_mean, out_domain=False)
    # except Exception as e:
    #     print(e)
    #     continue

nnunet2D_ens0_cv4
173 35 42


42it [00:08,  5.13it/s]
100%|███████████████████████████████████████████████████████████████| 42/42 [00:38<00:00,  1.08it/s]


GETTING PER SAMPLE RESULTS
GENREATING UNCERTAINTY MAPS
deterministic
generating uncertainty maps


100%|██████████████████████████████████████████████████████████████| 42/42 [00:00<00:00, 101.31it/s]


PAVPU


100%|███████████████████████████████████████████████████████████████| 42/42 [00:02<00:00, 17.85it/s]


UEO


100%|███████████████████████████████████████████████████████████████| 42/42 [00:07<00:00,  5.32it/s]
100%|███████████████████████████████████████████████████████████████| 42/42 [00:09<00:00,  4.49it/s]


3D CC ANALYSIS


100%|███████████████████████████████████████████████████████████████| 42/42 [01:30<00:00,  2.16s/it]


SAVING RESULTS
DONE
nnunet2D_ens0_cv3
nnunet2D_ens0_cv0
173 35 42


42it [00:08,  5.07it/s]
100%|███████████████████████████████████████████████████████████████| 42/42 [00:39<00:00,  1.07it/s]


GETTING PER SAMPLE RESULTS
GENREATING UNCERTAINTY MAPS
deterministic
generating uncertainty maps


100%|██████████████████████████████████████████████████████████████| 42/42 [00:00<00:00, 112.63it/s]


PAVPU


100%|███████████████████████████████████████████████████████████████| 42/42 [00:02<00:00, 18.81it/s]


UEO


100%|███████████████████████████████████████████████████████████████| 42/42 [00:08<00:00,  5.13it/s]
100%|███████████████████████████████████████████████████████████████| 42/42 [00:09<00:00,  4.52it/s]


3D CC ANALYSIS


100%|███████████████████████████████████████████████████████████████| 42/42 [01:20<00:00,  1.92s/it]


SAVING RESULTS
DONE
nnunet2D_ens0_cv2
173 35 42


42it [00:08,  5.06it/s]
100%|███████████████████████████████████████████████████████████████| 42/42 [00:39<00:00,  1.07it/s]


GETTING PER SAMPLE RESULTS
GENREATING UNCERTAINTY MAPS
deterministic
generating uncertainty maps


100%|██████████████████████████████████████████████████████████████| 42/42 [00:00<00:00, 130.91it/s]


PAVPU


100%|███████████████████████████████████████████████████████████████| 42/42 [00:02<00:00, 18.82it/s]


UEO


100%|███████████████████████████████████████████████████████████████| 42/42 [00:08<00:00,  5.25it/s]
100%|███████████████████████████████████████████████████████████████| 42/42 [00:09<00:00,  4.52it/s]


3D CC ANALYSIS


100%|███████████████████████████████████████████████████████████████| 42/42 [01:12<00:00,  1.72s/it]


SAVING RESULTS
DONE
nnunet2D_ens0_cv5
175 35 40


40it [00:07,  5.07it/s]
100%|███████████████████████████████████████████████████████████████| 40/40 [00:37<00:00,  1.07it/s]


GETTING PER SAMPLE RESULTS
GENREATING UNCERTAINTY MAPS
deterministic
generating uncertainty maps


100%|██████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 110.36it/s]


PAVPU


100%|███████████████████████████████████████████████████████████████| 40/40 [00:02<00:00, 18.81it/s]


UEO


100%|███████████████████████████████████████████████████████████████| 40/40 [00:07<00:00,  5.13it/s]
100%|███████████████████████████████████████████████████████████████| 40/40 [00:08<00:00,  4.56it/s]


3D CC ANALYSIS


100%|███████████████████████████████████████████████████████████████| 40/40 [01:19<00:00,  1.99s/it]


SAVING RESULTS
DONE
nnunet2D_ens0_cv1
173 35 42


42it [00:08,  5.07it/s]
100%|███████████████████████████████████████████████████████████████| 42/42 [00:38<00:00,  1.08it/s]


GETTING PER SAMPLE RESULTS
GENREATING UNCERTAINTY MAPS
deterministic
generating uncertainty maps


100%|██████████████████████████████████████████████████████████████| 42/42 [00:00<00:00, 115.44it/s]


PAVPU


100%|███████████████████████████████████████████████████████████████| 42/42 [00:02<00:00, 18.82it/s]


UEO


100%|███████████████████████████████████████████████████████████████| 42/42 [00:08<00:00,  5.12it/s]
100%|███████████████████████████████████████████████████████████████| 42/42 [00:09<00:00,  4.56it/s]


3D CC ANALYSIS


100%|███████████████████████████████████████████████████████████████| 42/42 [01:23<00:00,  1.98s/it]


SAVING RESULTS
DONE


In [20]:
base_folder = "/home/s2208943/projects/twaibrain/twaibrain/brainexperiments/run_nnUNet_v2/training/model_ckpts/"
for folder in os.listdir(base_folder):
    if "3D" in folder:
        continue
    print(folder)
    cv_split = folder[-1]

    if cv_split != "3":
        continue
    
    # try:
    ckpt = sorted([f for f in os.listdir(os.path.join(base_folder, folder)) if f.endswith(".ckpt")])[-1]
    get_model_preds(cv_split, model=model, model_name="nnunet2D", ckpt_path=os.path.join(base_folder, folder, ckpt), model_func=deterministic_mean, out_domain=False)
    # except Exception as e:
    #     print(e)
    #     continue

nnunet2D_ens0_cv4
nnunet2D_ens0_cv3
173 35 42


42it [00:08,  4.90it/s]
100%|███████████████████████████████████████████████████████████████| 42/42 [00:39<00:00,  1.07it/s]


GETTING PER SAMPLE RESULTS
GENREATING UNCERTAINTY MAPS
deterministic
generating uncertainty maps


100%|██████████████████████████████████████████████████████████████| 42/42 [00:00<00:00, 105.20it/s]


PAVPU


100%|███████████████████████████████████████████████████████████████| 42/42 [00:02<00:00, 18.33it/s]


UEO


100%|███████████████████████████████████████████████████████████████| 42/42 [00:08<00:00,  5.17it/s]
100%|███████████████████████████████████████████████████████████████| 42/42 [00:09<00:00,  4.54it/s]


3D CC ANALYSIS


100%|███████████████████████████████████████████████████████████████| 42/42 [01:20<00:00,  1.91s/it]


SAVING RESULTS
DONE
nnunet2D_ens0_cv0
nnunet2D_ens0_cv2
nnunet2D_ens0_cv5
nnunet2D_ens0_cv1
