# Ensemble Techniques

## Environment Setup

### Imports

In [1]:
import os
os.sys.path.append("utils")

import torch
import pandas as pd
from monai.metrics import DiceMetric
from utils.Inference import ensemble_inference

### Config

In [2]:
seed = 33
pd.set_option("display.max_columns", None)
pd.set_option('display.max_colwidth', None)

## Load Test Data

In [3]:
test_pred_segs_paths = './outputs/AHNet/pred_segs/test_pred_segs'
test_gt_segs_paths = './outputs/gt_segs/test_gt_segs'
subject_ids = pd.read_csv('./data/TEST.csv')['SubjectID'].values

ah_segs, unet_segs, segresnet_segs, untr_segs, gt_segs = [], [], [], [], []
for sid in subject_ids:
    ah_channels, unet_channels, segresnet_channels, untr_channels, gt_channels = [], [], [], [], []
    for channel in ['TC', 'WT', 'ET']:
        ah_channels.append(f'./outputs/AHNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.nii.gz')
        unet_channels.append(f'./outputs/UNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.nii.gz')
        segresnet_channels.append(f'./outputs/SegResNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.nii.gz')
        untr_channels.append(f'./outputs/UNETR/pred_segs/test_pred_segs/pred_{sid}_{channel}.nii.gz')
        gt_channels.append(f'./outputs/gt_segs/test_gt_segs/gt_{sid}_{channel}.nii.gz')
    
    ah_segs.append(ah_channels)
    unet_segs.append(unet_channels)
    segresnet_segs.append(segresnet_channels)
    untr_segs.append(untr_channels)
    gt_segs.append(gt_channels)

# Dataframe
data_df = pd.DataFrame()
data_df['SubjectID'] = subject_ids
data_df['AHNet'] = ah_segs
data_df['UNet'] = unet_segs
data_df['SegResNet'] = segresnet_segs
data_df['UNETR'] = untr_segs
data_df['GT'] = gt_segs

data_df.head()

Unnamed: 0,SubjectID,AHNet,UNet,SegResNet,UNETR,GT
0,100214B,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_TC.nii.gz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_WT.nii.gz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_ET.nii.gz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100214B_TC.nii.gz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100214B_WT.nii.gz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100214B_ET.nii.gz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_TC.nii.gz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_WT.nii.gz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_ET.nii.gz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_TC.nii.gz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_WT.nii.gz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_ET.nii.gz]","[./outputs/gt_segs/test_gt_segs/gt_100214B_TC.nii.gz, ./outputs/gt_segs/test_gt_segs/gt_100214B_WT.nii.gz, ./outputs/gt_segs/test_gt_segs/gt_100214B_ET.nii.gz]"
1,100340A,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_TC.nii.gz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_WT.nii.gz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_ET.nii.gz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100340A_TC.nii.gz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100340A_WT.nii.gz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100340A_ET.nii.gz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_TC.nii.gz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_WT.nii.gz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_ET.nii.gz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_TC.nii.gz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_WT.nii.gz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_ET.nii.gz]","[./outputs/gt_segs/test_gt_segs/gt_100340A_TC.nii.gz, ./outputs/gt_segs/test_gt_segs/gt_100340A_WT.nii.gz, ./outputs/gt_segs/test_gt_segs/gt_100340A_ET.nii.gz]"
2,100391A,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_TC.nii.gz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_WT.nii.gz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_ET.nii.gz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100391A_TC.nii.gz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100391A_WT.nii.gz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100391A_ET.nii.gz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_TC.nii.gz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_WT.nii.gz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_ET.nii.gz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_TC.nii.gz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_WT.nii.gz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_ET.nii.gz]","[./outputs/gt_segs/test_gt_segs/gt_100391A_TC.nii.gz, ./outputs/gt_segs/test_gt_segs/gt_100391A_WT.nii.gz, ./outputs/gt_segs/test_gt_segs/gt_100391A_ET.nii.gz]"
3,100190B,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_TC.nii.gz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_WT.nii.gz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_ET.nii.gz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100190B_TC.nii.gz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100190B_WT.nii.gz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100190B_ET.nii.gz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_TC.nii.gz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_WT.nii.gz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_ET.nii.gz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_TC.nii.gz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_WT.nii.gz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_ET.nii.gz]","[./outputs/gt_segs/test_gt_segs/gt_100190B_TC.nii.gz, ./outputs/gt_segs/test_gt_segs/gt_100190B_WT.nii.gz, ./outputs/gt_segs/test_gt_segs/gt_100190B_ET.nii.gz]"
4,100142A,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_TC.nii.gz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_WT.nii.gz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_ET.nii.gz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100142A_TC.nii.gz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100142A_WT.nii.gz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100142A_ET.nii.gz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_TC.nii.gz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_WT.nii.gz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_ET.nii.gz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_TC.nii.gz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_WT.nii.gz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_ET.nii.gz]","[./outputs/gt_segs/test_gt_segs/gt_100142A_TC.nii.gz, ./outputs/gt_segs/test_gt_segs/gt_100142A_WT.nii.gz, ./outputs/gt_segs/test_gt_segs/gt_100142A_ET.nii.gz]"


## Ensemble 1 - Average

In [4]:
def avg(predictions, _):
    stacked_predictions = torch.stack(predictions, dim=0)
    return torch.mean(stacked_predictions, dim=0)

#---------------------------------------------------------#

avg_df = ensemble_inference(data_df, avg)
avg_df.to_csv('./outputs/Ensemble/avg_test_scores.csv', index=False)
avg_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.522651,0.560597,0.572364,0.434448,3.516129,3.258065,4.193548,8.709677,8.709677,5.129032,11607.709677,11683.129032,9211.83871,17103.677419,17984.16129,13810.870968
std,0.316486,0.321923,0.330971,0.306287,3.46286,3.316301,5.237171,9.103822,9.103822,4.951376,18957.730429,19169.474793,16349.038039,25564.191374,26911.703415,22937.67913
min,0.005555,0.009937,0.006728,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.218448,0.25224,0.253042,0.13048,1.0,1.0,1.0,2.0,2.0,1.5,278.5,271.0,103.0,802.0,802.0,410.0
50%,0.640889,0.67405,0.706515,0.492227,3.0,3.0,4.0,6.0,6.0,4.0,1928.0,2088.0,1306.0,3847.0,4138.0,1944.0
75%,0.790898,0.845353,0.866648,0.678609,5.0,4.5,5.0,13.5,13.5,7.0,16246.0,16055.5,11403.0,26305.0,27036.5,23520.5
max,0.923536,0.94244,0.943803,0.884366,16.0,16.0,27.0,43.0,43.0,21.0,65343.0,65754.0,61903.0,102770.0,105729.0,92221.0


## Ensemble 2 - Weighted Average (Mean Dice Weights)

In [5]:
def weigthed_avg(predictions, label):
    # Weights 
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    weights = []
    for pred in predictions:
        dice_metric(y_pred=pred, y=label)
        weights.append(dice_metric.aggregate().item())
        dice_metric.reset()
    weights = torch.tensor(weights)
    weights = weights / weights.sum()

    # Weighted average
    stacked_predictions = torch.stack(predictions, dim=0)
    t = torch.sum(stacked_predictions * weights[:, None, None, None, None, None], dim=0)
    return t

#---------------------------------------------------------#

weigthed_avg_df = ensemble_inference(data_df, weigthed_avg)
weigthed_avg_df.to_csv('./outputs/Ensemble/weighted_avg_test_scores.csv', index=False)
weigthed_avg_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.533306,0.576017,0.584808,0.434448,1.709677,1.483871,1.612903,8.709677,8.709677,5.129032,4778.387097,4786.225806,3900.354839,17103.677419,17984.16129,13810.870968
std,0.300442,0.298409,0.311887,0.306287,2.465678,2.142755,2.894563,9.103822,9.103822,4.951376,13192.788883,13253.044273,11209.350649,25564.191374,26911.703415,22937.67913
min,0.08422,0.114951,0.101266,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.218448,0.25224,0.253042,0.13048,0.0,0.0,0.0,2.0,2.0,1.5,0.0,0.0,0.0,802.0,802.0,410.0
50%,0.640889,0.67405,0.706515,0.492227,1.0,1.0,0.0,6.0,6.0,4.0,5.0,2.0,0.0,3847.0,4138.0,1944.0
75%,0.790898,0.845353,0.866648,0.678609,3.0,3.0,2.0,13.5,13.5,7.0,602.0,608.5,256.5,26305.0,27036.5,23520.5
max,0.923536,0.94244,0.943803,0.884366,10.0,9.0,13.0,43.0,43.0,21.0,63969.0,64497.0,54662.0,102770.0,105729.0,92221.0


## Ensemble 3 - Majority Voting

In [6]:
def majority_vote(predictions, _):
    stacked_tensors = torch.stack(predictions, dim=0)
    majority_votes, _ = torch.mode(stacked_tensors, dim=0)
    return majority_votes

majority_vote_df = ensemble_inference(data_df, majority_vote)
majority_vote_df.to_csv('./outputs/Ensemble/majority_vote_test_scores.csv', index=False)
majority_vote_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.580455,0.61474,0.6171,0.50483,6.16129,5.419355,7.903226,8.709677,8.709677,5.129032,14172.387097,14257.225806,11632.774194,17103.677419,17984.16129,13810.870968
std,0.292811,0.283028,0.294667,0.327607,6.361849,5.451447,8.553186,9.103822,9.103822,4.951376,21705.019766,21917.715107,19121.380415,25564.191374,26911.703415,22937.67913
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.385514,0.445795,0.438938,0.168871,2.0,2.0,2.0,2.0,2.0,1.5,666.5,653.5,413.5,802.0,802.0,410.0
50%,0.657048,0.68589,0.706022,0.599136,4.0,4.0,5.0,6.0,6.0,4.0,3423.0,3427.0,2357.0,3847.0,4138.0,1944.0
75%,0.829055,0.860853,0.871061,0.793708,7.0,7.0,11.0,13.5,13.5,7.0,23699.5,23822.0,20584.5,26305.0,27036.5,23520.5
max,0.950684,0.955906,0.958353,0.937794,24.0,22.0,32.0,43.0,43.0,21.0,80558.0,80731.0,71546.0,102770.0,105729.0,92221.0
