# Ensemble Techniques

## Environment Setup

### Imports

In [1]:
import os
os.sys.path.append("utils")

import torch
import numpy as np
import pandas as pd
from functools import reduce
from monai.transforms import AsDiscrete
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression
# Imports
import torch
import random
import numpy as np
import pandas as pd
import nibabel as nib
from tqdm import tqdm
from Transforms import Transforms
from monai.metrics import DiceMetric
from monai.utils import set_determinism
from monai.inferers import sliding_window_inference
from skimage.measure import regionprops, label
from monai.transforms import Activations, AsDiscrete

### Config

In [2]:
seed = 33
pd.set_option("display.max_columns", None)
pd.set_option('display.max_colwidth', None)

channels = ['TC', 'WT', 'ET']

In [3]:
# Ensemble Inference
def ensemble_inference(dataframe, ensemble_function, threshold = 0.5, include_label = False, model = None, store_npz = False, model_name = None):

    # Transforms
    trans = AsDiscrete(threshold=threshold)

    # Dice Params
    dice_values, dice_values_tc, dice_values_wt, dice_values_et = [], [], [], []
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

    # Biometrics Params
    ids = []
    pred_paths = []
    gt_nm, pred_nm = {'TC': [], 'WT': [], 'ET': []}, {'TC': [], 'WT': [], 'ET': []}
    gt_v, pred_v = {'TC': [], 'WT': [], 'ET': []}, {'TC': [], 'WT': [], 'ET': []}

    # Iterate over the dataframe
    for i in range(len(dataframe)):

        # Subject & Label
        subject_id = dataframe['SubjectID'][i]
        load_label = dataframe['GT'][i]
        
        # Images
        load_ahnet = dataframe['AHNet'][i]
        load_segresnet = dataframe['SegResNet'][i]
        load_unet = dataframe['UNet'][i]
        load_unetr = dataframe['UNETR'][i]
        
        # Params 
        image_voxel_volume = np.prod((1,1,1))
        label_voxel_volume = np.prod((1,1,1))

        # Load Images and Labels
        ahnet_image = [np.load(x)['arr_0'] for x in load_ahnet]
        segresnet_image = [np.load(x)['arr_0'] for x in load_segresnet]
        unetr_image = [np.load(x)['arr_0'] for x in load_unetr]
        img_label = [np.load(x)['arr_0'] for x in load_label] 

        # To Tensor
        ahnet_image = [torch.from_numpy(x) for x in ahnet_image]
        segresnet_image = [torch.from_numpy(x) for x in segresnet_image]
        unetr_image = [torch.from_numpy(x) for x in unetr_image]
        img_label = [torch.from_numpy(x) for x in img_label]
            
        # Stack Images and Label
        ahnet_image = torch.stack(ahnet_image, dim = 0).unsqueeze(0)
        segresnet_image = torch.stack(segresnet_image, dim = 0).unsqueeze(0)
        unetr_image = torch.stack(unetr_image, dim = 0).unsqueeze(0)
        img_label = torch.stack(img_label, dim = 0).unsqueeze(0)

        # Ensemble Function
        img = None
        if include_label:
            img = ensemble_function([ahnet_image, segresnet_image, unetr_image], img_label)
        elif model is not None:
            img = ensemble_function([ahnet_image, segresnet_image, unetr_image], model)
        else:
            img = ensemble_function([ahnet_image, segresnet_image, unetr_image])

        # Save NPZ
        if store_npz:
            np.savez_compressed(f'./outputs/Ensemble/pred_{model_name}/pred_{subject_id}_TC.npz', img[0][0])
            np.savez_compressed(f'./outputs/Ensemble/pred_{model_name}/pred_{subject_id}_WT.npz', img[0][1])
            np.savez_compressed(f'./outputs/Ensemble/pred_{model_name}/pred_{subject_id}_ET.npz', img[0][2])
            pred_paths.append([f'./outputs/Ensemble/pred_{model_name}/pred_{subject_id}_TC.npz', 
                               f'./outputs/Ensemble/pred_{model_name}/pred_{subject_id}_WT.npz',
                               f'./outputs/Ensemble/pred_{model_name}/pred_{subject_id}_ET.npz'])

        # Discretizise
        img = trans(img)
        img_label = trans(img_label)

        # Dice Metric
        dice_metric(y_pred=img, y=img_label)
        dice_score = dice_metric.aggregate()
        dice_values.append(dice_score.item())
        dice_metric.reset()
            
		# Batch Dice
        dice_metric_batch(y_pred=img, y=img_label)
        dice_batch = dice_metric_batch.aggregate()
        dice_values_tc.append(dice_batch[0].item())
        dice_values_wt.append(dice_batch[1].item())
        dice_values_et.append(dice_batch[2].item())
        dice_metric_batch.reset()     

        # Biometrics
        for j, channel in enumerate(channels):
            # Image
            props = regionprops(label(nib.Nifti1Image(img[0][j].cpu().numpy(), np.eye(4)).get_fdata()))
            volumes = [prop.area * image_voxel_volume for prop in props]
            pred_nm[channel].append(int(len(volumes)))
            pred_v[channel].append(int(np.sum(volumes)))
            # Label
            props = regionprops(label(nib.Nifti1Image(img_label[0][j].cpu().numpy(), np.eye(4)).get_fdata()))
            volumes = [prop.area * label_voxel_volume for prop in props]
            gt_nm[channel].append(int(len(volumes)))
            gt_v[channel].append(int(np.sum(volumes)))

        # Subject ID
        ids.append(subject_id)
                
    # Excel
    df = pd.DataFrame({
        'SubjectID': ids,
		'Dice': dice_values,
		'Dice TC': dice_values_tc,
		'Dice WT': dice_values_wt,
		'Dice ET': dice_values_et,
		'Pred NM TC': pred_nm['TC'],
		'Pred NM WT': pred_nm['WT'],
		'Pred NM ET': pred_nm['ET'],
		'GT NM TC': gt_nm['TC'],
		'GT NM WT': gt_nm['WT'],
		'GT NM ET': gt_nm['ET'],
		'Pred V TC': pred_v['TC'],
		'Pred V WT': pred_v['WT'],
		'Pred V ET': pred_v['ET'],
		'GT V TC': gt_v['TC'],
		'GT V WT': gt_v['WT'],
		'GT V ET': gt_v['ET'],
	})
    
    if store_npz:
        df['Pred Paths'] = pred_paths
    
    return df

## Load Test Data

In [4]:
subject_ids = pd.read_csv('./data/TEST.csv')['SubjectID'].values

ah_segs, unet_segs, segresnet_segs, untr_segs, gt_segs = [], [], [], [], []
for sid in subject_ids:
    ah_channels, unet_channels, segresnet_channels, untr_channels, gt_channels = [], [], [], [], []
    for channel in ['TC', 'WT', 'ET']:
        ah_channels.append(f'./outputs/AHNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.npz')
        unet_channels.append(f'./outputs/UNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.npz')
        segresnet_channels.append(f'./outputs/SegResNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.npz')
        untr_channels.append(f'./outputs/UNETR/pred_segs/test_pred_segs/pred_{sid}_{channel}.npz')
        gt_channels.append(f'./outputs/gt_segs/test_gt_segs/gt_{sid}_{channel}.npz')
    
    ah_segs.append(ah_channels)
    unet_segs.append(unet_channels)
    segresnet_segs.append(segresnet_channels)
    untr_segs.append(untr_channels)
    gt_segs.append(gt_channels)

# Dataframe
test_df = pd.DataFrame()
test_df['SubjectID'] = subject_ids
test_df['AHNet'] = ah_segs
test_df['UNet'] = unet_segs
test_df['SegResNet'] = segresnet_segs
test_df['UNETR'] = untr_segs
test_df['GT'] = gt_segs

test_df.head()

Unnamed: 0,SubjectID,AHNet,UNet,SegResNet,UNETR,GT
0,100214B,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_ET.npz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100214B_TC.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100214B_WT.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100214B_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100214B_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100214B_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100214B_ET.npz]"
1,100340A,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_ET.npz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100340A_TC.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100340A_WT.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100340A_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100340A_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100340A_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100340A_ET.npz]"
2,100391A,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_ET.npz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100391A_TC.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100391A_WT.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100391A_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100391A_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100391A_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100391A_ET.npz]"
3,100190B,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_ET.npz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100190B_TC.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100190B_WT.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100190B_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100190B_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100190B_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100190B_ET.npz]"
4,100142A,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_ET.npz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100142A_TC.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100142A_WT.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100142A_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100142A_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100142A_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100142A_ET.npz]"


## Ensemble 1 - Average

In [5]:
def avg(predictions):
    stacked_predictions = torch.stack(predictions, dim=0)
    return torch.mean(stacked_predictions, dim=0)

#---------------------------------------------------------#

best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    avg_df = ensemble_inference(test_df, avg, threshold)
    mean = avg_df['Dice'].mean()
    print(f'Average Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = avg_df
print(f'BEST THRESHOLD: {best_threshold}')
df.describe()

Average Mean Dice for threshold 0.1: 0.5863614200103667
Average Mean Dice for threshold 0.2: 0.6211351408593117
Average Mean Dice for threshold 0.3: 0.6383202962967898
Average Mean Dice for threshold 0.4: 0.6428995512185558
Average Mean Dice for threshold 0.5: 0.6311559115927066
Average Mean Dice for threshold 0.6: 0.613352604450718
Average Mean Dice for threshold 0.7: 0.5494273100649157
Average Mean Dice for threshold 0.8: 0.5011968788120055
Average Mean Dice for threshold 0.9: 0.44571167771374026
BEST THRESHOLD: 0.4


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.6429,0.675335,0.684806,0.547999,7.870968,7.032258,16.096774,8.709677,8.709677,5.129032,16631.483871,16666.580645,13891.870968,17105.387097,17985.870968,13810.806452
std,0.252928,0.241508,0.246799,0.303004,7.830462,6.374344,21.357832,9.103822,9.103822,4.951376,25129.788233,25262.45967,22386.06211,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.523253,0.570072,0.575907,0.320578,3.0,3.0,3.0,2.0,2.0,1.5,982.5,960.0,551.5,802.0,802.0,410.0
50%,0.637359,0.703949,0.753131,0.577692,5.0,5.0,8.0,6.0,6.0,4.0,4718.0,4724.0,2852.0,3847.0,4138.0,1944.0
75%,0.855782,0.873027,0.898497,0.793281,9.5,9.5,21.5,13.5,13.5,7.0,30380.0,30354.0,28023.0,26305.0,27036.5,23520.5
max,0.949085,0.955557,0.958394,0.935284,34.0,26.0,99.0,43.0,43.0,21.0,95617.0,95689.0,84916.0,102770.0,105729.0,92221.0


## Ensemble 2 - Weighted Average (Mean Dice Weights)

In [7]:
def weigthed_avg(predictions):

    # Weights form validation set AHNet, SegResNet, UNet, UNETR
    weights = np.array([0.546402, 0.648790, 0.547957])
    weights = weights / np.sum(weights)
    weights = torch.tensor(weights).float()
    # Weighted average
    stacked_predictions = torch.stack(predictions, dim=0)
    return torch.sum(stacked_predictions * weights[:, None, None, None, None, None], dim=0)
     
#---------------------------------------------------------#

best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    weigthed_avg_df = ensemble_inference(test_df, weigthed_avg, threshold)
    mean = weigthed_avg_df['Dice'].mean()
    print(f'Weighted Average Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = weigthed_avg_df
print(f'BEST THRESHOLD: {best_threshold}')
df.describe()

Weighted Average Mean Dice for threshold 0.1: 0.5867425597483112
Weighted Average Mean Dice for threshold 0.2: 0.6235906039514849
Weighted Average Mean Dice for threshold 0.3: 0.6423478781455948
Weighted Average Mean Dice for threshold 0.4: 0.6453976748931792
Weighted Average Mean Dice for threshold 0.5: 0.6335489165879065
Weighted Average Mean Dice for threshold 0.6: 0.6176399252107067
Weighted Average Mean Dice for threshold 0.7: 0.5546291386408191
Weighted Average Mean Dice for threshold 0.8: 0.5038006154279555
Weighted Average Mean Dice for threshold 0.9: 0.4465532794354423
BEST THRESHOLD: 0.4


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.645398,0.676805,0.687453,0.552012,7.193548,6.967742,14.0,8.709677,8.709677,5.129032,16555.032258,16586.225806,13788.322581,17105.387097,17985.870968,13810.806452
std,0.251027,0.24021,0.243862,0.301197,6.935029,6.635681,17.189144,9.103822,9.103822,4.951376,24977.822062,25104.340697,22205.619714,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.534398,0.5762,0.580152,0.33164,3.5,3.0,3.0,2.0,2.0,1.5,995.5,970.0,566.0,802.0,802.0,410.0
50%,0.631351,0.713736,0.762651,0.576646,5.0,5.0,7.0,6.0,6.0,4.0,4724.0,4729.0,2863.0,3847.0,4138.0,1944.0
75%,0.856444,0.873076,0.898498,0.794745,8.5,8.5,19.0,13.5,13.5,7.0,29923.5,29944.5,27641.0,26305.0,27036.5,23520.5
max,0.949815,0.956011,0.959204,0.935853,30.0,27.0,71.0,43.0,43.0,21.0,95005.0,95181.0,84273.0,102770.0,105729.0,92221.0


## Ensemble 3 - Median 

In [8]:
def median(predictions):
    stacked_predictions = torch.stack(predictions, dim=0)
    return torch.median(stacked_predictions, dim=0).values

#---------------------------------------------------------#

best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    median_df = ensemble_inference(test_df, median, threshold)
    mean = median_df['Dice'].mean()
    print(f'Median Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = median_df
print(f'BEST THRESHOLD: {best_threshold}')
df.describe()

Median Mean Dice for threshold 0.1: 0.6239719347607705
Median Mean Dice for threshold 0.2: 0.6403637986269689
Median Mean Dice for threshold 0.3: 0.6415600694055038
Median Mean Dice for threshold 0.4: 0.6378753589045617
Median Mean Dice for threshold 0.5: 0.630699080324942
Median Mean Dice for threshold 0.6: 0.6228696912406914
Median Mean Dice for threshold 0.7: 0.6113415539204594
Median Mean Dice for threshold 0.8: 0.5896395473710953
Median Mean Dice for threshold 0.9: 0.5505636014765308
BEST THRESHOLD: 0.3


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.64156,0.674658,0.684467,0.54681,8.806452,8.225806,9.387097,8.709677,8.709677,5.129032,16842.516129,17002.580645,14252.354839,17105.387097,17985.870968,13810.806452
std,0.248368,0.236819,0.240307,0.300191,7.717596,7.419388,11.062481,9.103822,9.103822,4.951376,25211.120421,25523.302593,22620.338032,25565.209783,26912.616906,22937.645993
min,0.00381,0.0,0.011429,0.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,44.0,44.0,0.0
25%,0.524193,0.56468,0.565742,0.332082,3.5,3.0,2.0,2.0,2.0,1.5,1077.5,1073.5,637.0,802.0,802.0,410.0
50%,0.624035,0.710088,0.759918,0.548019,6.0,5.0,5.0,6.0,6.0,4.0,4886.0,4910.0,3168.0,3847.0,4138.0,1944.0
75%,0.853133,0.87441,0.892536,0.788852,10.0,11.5,11.5,13.5,13.5,7.0,30742.0,30948.5,28709.0,26305.0,27036.5,23520.5
max,0.9448,0.952258,0.954205,0.931873,30.0,29.0,45.0,43.0,43.0,21.0,96593.0,97200.0,86239.0,102770.0,105729.0,92221.0


## Ensemble 4 - Maximum Confidence 

In [9]:
def max_confidence(predictions):
    stacked_predictions = torch.stack(predictions, dim=0)
    return torch.max(stacked_predictions, dim=0).values

#---------------------------------------------------------#

best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    max_confidence_df = ensemble_inference(test_df, max_confidence, threshold)
    mean = max_confidence_df['Dice'].mean()
    print(f'Max Confidence Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = max_confidence_df
print(f'BEST THRESHOLD: {best_threshold}')
df.describe()

Max Confidence Mean Dice for threshold 0.1: 0.5307200302279764
Max Confidence Mean Dice for threshold 0.2: 0.5740139740369012
Max Confidence Mean Dice for threshold 0.3: 0.5972267687320709
Max Confidence Mean Dice for threshold 0.4: 0.6107847342087377
Max Confidence Mean Dice for threshold 0.5: 0.6194333650412098
Max Confidence Mean Dice for threshold 0.6: 0.6250772557912334
Max Confidence Mean Dice for threshold 0.7: 0.6274599907619338
Max Confidence Mean Dice for threshold 0.8: 0.6256679824343131
Max Confidence Mean Dice for threshold 0.9: 0.6132325974082754
BEST THRESHOLD: 0.7


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.62746,0.664593,0.670343,0.537957,14.225806,12.354839,155.516129,8.709677,8.709677,5.129032,17815.870968,17703.354839,15239.096774,17105.387097,17985.870968,13810.806452
std,0.253009,0.237301,0.254247,0.294088,9.96564,9.11244,253.996177,9.103822,9.103822,4.951376,25973.511098,25893.051202,23651.232029,25565.209783,26912.616906,22937.645993
min,0.037298,0.082192,0.029703,0.0,1.0,2.0,0.0,1.0,1.0,0.0,45.0,28.0,0.0,44.0,44.0,0.0
25%,0.4438,0.528188,0.501666,0.300136,8.5,7.0,14.5,2.0,2.0,1.5,1056.0,1017.5,593.5,802.0,802.0,410.0
50%,0.668926,0.747748,0.738101,0.56473,11.0,9.0,49.0,6.0,6.0,4.0,4778.0,4772.0,3160.0,3847.0,4138.0,1944.0
75%,0.844789,0.860551,0.896599,0.772265,15.0,16.0,145.5,13.5,13.5,7.0,31523.0,31307.0,29318.0,26305.0,27036.5,23520.5
max,0.943834,0.953132,0.955554,0.922817,41.0,37.0,1118.0,43.0,43.0,21.0,98422.0,97834.0,89299.0,102770.0,105729.0,92221.0


## Ensemble 5 - Majority Voting

In [10]:
def majority_voting(predictions):
    transforms =[AsDiscrete(threshold=0.3), AsDiscrete(threshold=0.5), AsDiscrete(threshold=0.2), AsDiscrete(threshold=0.4)]
    predictions = [transform(prediction) for transform, prediction in zip(transforms, predictions)]
    stacked_predictions = torch.stack(predictions, dim=0)  
    return torch.mode(stacked_predictions, dim=0).values

#---------------------------------------------------------#

majority_vote_df = ensemble_inference(test_df, majority_voting, threshold)
majority_vote_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.641846,0.673063,0.684523,0.548373,8.774194,7.967742,9.677419,8.709677,8.709677,5.129032,16753.774194,16843.483871,14103.16129,17105.387097,17985.870968,13810.806452
std,0.246541,0.235789,0.237658,0.297688,8.321497,7.190196,10.734328,9.103822,9.103822,4.951376,25181.015711,25374.714886,22529.979089,25565.209783,26912.616906,22937.645993
min,0.018875,0.022727,0.033898,0.0,1.0,1.0,0.0,1.0,1.0,0.0,2.0,3.0,0.0,44.0,44.0,0.0
25%,0.530874,0.575143,0.587454,0.342862,3.0,3.0,3.0,2.0,2.0,1.5,1012.5,984.5,580.0,802.0,802.0,410.0
50%,0.639261,0.718425,0.767663,0.560427,5.0,5.0,6.0,6.0,6.0,4.0,4725.0,4746.0,3012.0,3847.0,4138.0,1944.0
75%,0.854089,0.873574,0.891474,0.79225,11.0,10.0,11.5,13.5,13.5,7.0,30727.0,30834.0,28597.0,26305.0,27036.5,23520.5
max,0.941908,0.948914,0.950959,0.928411,33.0,30.0,45.0,43.0,43.0,21.0,96470.0,96708.0,86145.0,102770.0,105729.0,92221.0


## Ensemble 6 - Logistic Regression dimension wise

In [12]:
def logreg_ensemble(predictions, targets):
    stacked_predictions = torch.stack(predictions, dim=0)
    num_models, batch, channels, depth, height, width = stacked_predictions.shape

    # Flatten
    flattened_predictions = stacked_predictions.reshape(num_models, -1).T 
    flattened_targets = targets.reshape(-1)
    
    # Logistic Regression
    meta_model = LogisticRegression(max_iter=1000)
    meta_model.fit(flattened_predictions, flattened_targets)
    ensemble_predictions = meta_model.predict_proba(flattened_predictions)
    
    # Reshape to original shape
    ensemble_predictions_reshaped = ensemble_predictions[:, 1].reshape(batch, channels, depth, height, width)
    return ensemble_predictions_reshaped

#---------------------------------------------------------#

log_df = ensemble_inference(test_df, logreg_ensemble, 0.1, include_label=True)
log_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.692221,0.727702,0.739015,0.584764,8.322581,6.806452,72.0,8.709677,8.709677,5.129032,17272.129032,17337.903226,14476.16129,17105.387097,17985.870968,13810.806452
std,0.22474,0.202622,0.209009,0.294194,7.161876,5.688698,153.011546,9.103822,9.103822,4.951376,26235.492829,26392.848424,23900.673742,25565.209783,26912.616906,22937.645993
min,0.055067,0.098214,0.066986,0.0,1.0,1.0,0.0,1.0,1.0,0.0,50.0,35.0,0.0,44.0,44.0,0.0
25%,0.564896,0.617037,0.644426,0.43615,3.5,3.0,2.0,2.0,2.0,1.5,875.0,849.0,408.0,802.0,802.0,410.0
50%,0.771234,0.790497,0.798819,0.652061,6.0,5.0,6.0,6.0,6.0,4.0,3766.0,3653.0,2660.0,3847.0,4138.0,1944.0
75%,0.868173,0.884655,0.899783,0.813635,10.0,9.0,37.5,13.5,13.5,7.0,28947.0,28905.0,26589.0,26305.0,27036.5,23520.5
max,0.953243,0.958147,0.960415,0.941167,28.0,23.0,593.0,43.0,43.0,21.0,102397.0,102515.0,93874.0,102770.0,105729.0,92221.0


## Ensemble 7 - Logistic Regression voxel wise

In [13]:
def logreg_ensemble(predictions, targets):
    batch, num_channels, depth, height, width = predictions[0].shape

    # Flatten
    flattened_predictions = [pred.view(3, -1).t() for pred in predictions]
    flattened_predictions = torch.cat(flattened_predictions, dim=1)
    flattened_targets = targets.view(3, -1).t()

    # Logistic Regression
    meta_model = OneVsRestClassifier(LogisticRegression(max_iter=1000))
    meta_model.fit(flattened_predictions, flattened_targets)
    ensemble_predictions = meta_model.predict(flattened_predictions)
    ensemble_predictions = ensemble_predictions.T.reshape(batch, num_channels, depth, height, width)
    return ensemble_predictions

#---------------------------------------------------------#
log_df = ensemble_inference(test_df, logreg_ensemble, threshold=0.1, include_label=True)
log_df.describe()

Label not 2 is present in all training examples.


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.689662,0.725067,0.739073,0.578591,8.0,6.967742,11.870968,8.709677,8.709677,5.129032,15525.774194,16242.806452,12265.0,17105.387097,17985.870968,13810.806452
std,0.23698,0.218273,0.21969,0.307505,10.708252,6.295416,23.227486,9.103822,9.103822,4.951376,23812.824179,24797.35464,21187.159245,25565.209783,26912.616906,22937.645993
min,0.006981,0.010471,0.010471,0.0,1.0,1.0,0.0,1.0,1.0,0.0,17.0,17.0,0.0,44.0,44.0,0.0
25%,0.552065,0.628436,0.656629,0.417846,2.0,2.0,2.0,2.0,2.0,1.5,604.0,604.0,195.5,802.0,802.0,410.0
50%,0.774989,0.792078,0.813953,0.658773,5.0,5.0,6.0,6.0,6.0,4.0,3247.0,3452.0,1886.0,3847.0,4138.0,1944.0
75%,0.869333,0.886609,0.901684,0.831295,8.5,8.5,10.5,13.5,13.5,7.0,25896.0,25992.5,22288.0,26305.0,27036.5,23520.5
max,0.954906,0.959421,0.961846,0.943452,54.0,24.0,124.0,43.0,43.0,21.0,94180.0,96400.0,83050.0,102770.0,105729.0,92221.0


## Ensemble 8 - Dempster-Shafer Theory

In [None]:
def dempster_rule(mass1, mass2):

    K = 1 / (1 - mass1[..., 0] * mass2[..., 1] - mass1[..., 1] * mass2[..., 0])

    combined_mass = torch.zeros_like(mass1)
    combined_mass[..., 0] = K * (mass1[..., 0] * mass2[..., 0] + mass1[..., 0] * mass2[..., 2] + mass1[..., 2] * mass2[..., 0])
    combined_mass[..., 1] = K * (mass1[..., 1] * mass2[..., 1] + mass1[..., 1] * mass2[..., 2] + mass1[..., 2] * mass2[..., 1])
    combined_mass[..., 2] = K * (mass1[..., 0] * mass2[..., 1] + mass1[..., 1] * mass2[..., 0] + mass1[..., 2] * mass2[..., 2])
    
    return combined_mass

def dempster_shafer_inference(predictions):

    # Params
    predictions = [pred.squeeze(0) for pred in predictions]

    # Mass Beliefs for {tumor, non-tumor} - (3, 3, x, y, z) - (masses, channels, x, y, z)
    masses = []	
    for pred in predictions:
        epsilon = 1 - torch.max(pred) 
        tumor_mass = pred
        non_tumor_mass = 1 - pred
        uncertainty_mass = torch.full_like(pred, epsilon)
        masses.append(torch.stack([tumor_mass, non_tumor_mass, uncertainty_mass], dim=-1))

    combined_mass = reduce(dempster_rule, masses)

    # Normalized Mass
    combined_mass_sum = combined_mass.sum(dim=-1, keepdim=True)
    normalized_mass = combined_mass / combined_mass_sum
    
    return normalized_mass[..., 0].unsqueeze(0)   

#---------------------------------------------------------#

ds_df = ensemble_inference(test_df, dempster_shafer_inference, threshold=0.1)
ds_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.601745,0.633941,0.644781,0.505325,8.516129,8.419355,9.290323,8.709677,8.709677,5.129032,16802.903226,16908.709677,14467.83871,17105.387097,17985.870968,13810.806452
std,0.264539,0.254862,0.260568,0.308289,7.719546,7.873052,10.687044,9.103822,9.103822,4.951376,25970.063724,26162.84428,23447.012299,25565.209783,26912.616906,22937.645993
min,0.011385,0.022727,0.011429,0.0,2.0,1.0,0.0,1.0,1.0,0.0,2.0,1.0,0.0,44.0,44.0,0.0
25%,0.456363,0.526884,0.549712,0.230807,3.5,4.0,2.5,2.0,2.0,1.5,850.0,850.0,490.0,802.0,802.0,410.0
50%,0.585188,0.64,0.629866,0.503819,6.0,6.0,6.0,6.0,6.0,4.0,4183.0,4317.0,3003.0,3847.0,4138.0,1944.0
75%,0.84931,0.874172,0.884706,0.78592,11.0,10.5,10.0,13.5,13.5,7.0,31359.5,31485.0,29419.0,26305.0,27036.5,23520.5
max,0.937775,0.944772,0.94706,0.921493,31.0,32.0,47.0,43.0,43.0,21.0,98062.0,98116.0,88349.0,102770.0,105729.0,92221.0
