# Ensemble Techniques

## Environment Setup

### Imports

In [1]:
import os
os.sys.path.append("utils")

import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from monai.transforms import AsDiscrete
from utils.Inference import ensemble_inference
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression

### Config

In [2]:
seed = 33
pd.set_option("display.max_columns", None)
pd.set_option('display.max_colwidth', None)

## Load Test Data

In [3]:
subject_ids = pd.read_csv('./data/TEST.csv')['SubjectID'].values

ah_segs, unet_segs, segresnet_segs, untr_segs, gt_segs = [], [], [], [], []
for sid in subject_ids:
    ah_channels, unet_channels, segresnet_channels, untr_channels, gt_channels = [], [], [], [], []
    for channel in ['TC', 'WT', 'ET']:
        ah_channels.append(f'./outputs/AHNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.npz')
        unet_channels.append(f'./outputs/UNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.npz')
        segresnet_channels.append(f'./outputs/SegResNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.npz')
        untr_channels.append(f'./outputs/UNETR/pred_segs/test_pred_segs/pred_{sid}_{channel}.npz')
        gt_channels.append(f'./outputs/gt_segs/test_gt_segs/gt_{sid}_{channel}.npz')
    
    ah_segs.append(ah_channels)
    unet_segs.append(unet_channels)
    segresnet_segs.append(segresnet_channels)
    untr_segs.append(untr_channels)
    gt_segs.append(gt_channels)

# Dataframe
test_df = pd.DataFrame()
test_df['SubjectID'] = subject_ids
test_df['AHNet'] = ah_segs
test_df['UNet'] = unet_segs
test_df['SegResNet'] = segresnet_segs
test_df['UNETR'] = untr_segs
test_df['GT'] = gt_segs

test_df.head()

Unnamed: 0,SubjectID,AHNet,UNet,SegResNet,UNETR,GT
0,100214B,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_ET.npz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100214B_TC.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100214B_WT.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100214B_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100214B_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100214B_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100214B_ET.npz]"
1,100340A,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_ET.npz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100340A_TC.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100340A_WT.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100340A_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100340A_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100340A_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100340A_ET.npz]"
2,100391A,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_ET.npz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100391A_TC.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100391A_WT.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100391A_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100391A_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100391A_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100391A_ET.npz]"
3,100190B,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_ET.npz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100190B_TC.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100190B_WT.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100190B_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100190B_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100190B_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100190B_ET.npz]"
4,100142A,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_ET.npz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100142A_TC.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100142A_WT.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100142A_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100142A_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100142A_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100142A_ET.npz]"


## Ensemble 1 - Average

In [4]:
def avg(predictions):
    stacked_predictions = torch.stack(predictions, dim=0)
    return torch.mean(stacked_predictions, dim=0)

#---------------------------------------------------------#

best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    avg_df = ensemble_inference(test_df, avg, threshold)
    mean = avg_df['Dice'].mean()
    print(f'Average Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = avg_df
print(f'BEST THRESHOLD: {best_threshold}')
df.to_csv('./outputs/Ensemble/avg_test_scores.csv', index=False)
df.describe()

Average Mean Dice for threshold 0.1: 0.5150403138250113
Average Mean Dice for threshold 0.2: 0.5272680526043498
Average Mean Dice for threshold 0.3: 0.6201561954953978
Average Mean Dice for threshold 0.4: 0.6306360651919197
Average Mean Dice for threshold 0.5: 0.6117088165014021
Average Mean Dice for threshold 0.6: 0.582749746499523
Average Mean Dice for threshold 0.7: 0.5514685651587863
Average Mean Dice for threshold 0.8: 0.4944665648283497
Average Mean Dice for threshold 0.9: 0.44803049511486487
BEST THRESHOLD: 0.4


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.630636,0.662461,0.673749,0.533071,6.677419,6.516129,87.129032,8.709677,8.709677,5.129032,16592.322581,16591.774194,14575.709677,17105.387097,17985.870968,13810.806452
std,0.259768,0.25531,0.260801,0.293511,5.906985,5.977017,166.985177,9.103822,9.103822,4.951376,24168.923954,24207.681579,22133.776319,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,7.0,6.0,0.0,44.0,44.0,0.0
25%,0.484863,0.554968,0.560128,0.282297,3.0,2.5,5.5,2.0,2.0,1.5,1057.0,1048.0,662.5,802.0,802.0,410.0
50%,0.67684,0.717632,0.698113,0.542388,5.0,5.0,14.0,6.0,6.0,4.0,4993.0,4986.0,3298.0,3847.0,4138.0,1944.0
75%,0.842653,0.869613,0.885937,0.769723,8.0,8.0,65.0,13.5,13.5,7.0,29342.5,29308.0,27872.0,26305.0,27036.5,23520.5
max,0.946255,0.955846,0.958104,0.924816,25.0,25.0,684.0,43.0,43.0,21.0,90746.0,90706.0,83437.0,102770.0,105729.0,92221.0


## Ensemble 2 - Weighted Average (Mean Dice Weights)

In [5]:
def weigthed_avg(predictions):

    # Weights form validation set AHNet, SegResNet, UNet, UNETR
    weights = np.array([0.546402, 0.648790, 0.428639, 0.547957])
    weights = weights / np.sum(weights)
    weights = torch.tensor(weights).float()
    # Weighted average
    stacked_predictions = torch.stack(predictions, dim=0)
    return torch.sum(stacked_predictions * weights[:, None, None, None, None, None], dim=0)
     
#---------------------------------------------------------#

best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    weigthed_avg_df = ensemble_inference(test_df, weigthed_avg, threshold)
    mean = weigthed_avg_df['Dice'].mean()
    print(f'Weighted Average Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = weigthed_avg_df
print(f'BEST THRESHOLD: {best_threshold}')
df.to_csv('./outputs/Ensemble/weighted_avg_test_scores.csv', index=False)
df.describe()

Weighted Average Mean Dice for threshold 0.1: 0.5145711503442256
Weighted Average Mean Dice for threshold 0.2: 0.576161464856517
Weighted Average Mean Dice for threshold 0.3: 0.632053408891924
Weighted Average Mean Dice for threshold 0.4: 0.6359004551965383
Weighted Average Mean Dice for threshold 0.5: 0.6326853892495555
Weighted Average Mean Dice for threshold 0.6: 0.5931650337673002
Weighted Average Mean Dice for threshold 0.7: 0.5561348192634121
Weighted Average Mean Dice for threshold 0.8: 0.49360962188051594
Weighted Average Mean Dice for threshold 0.9: 0.4448244151088499
BEST THRESHOLD: 0.4


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.6359,0.667215,0.679828,0.53887,6.903226,6.419355,109.774194,8.709677,8.709677,5.129032,16629.322581,16633.903226,14432.903226,17105.387097,17985.870968,13810.806452
std,0.257415,0.252058,0.256418,0.295345,5.985287,5.863868,204.315884,9.103822,9.103822,4.951376,24313.317312,24377.260244,22128.085428,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,6.0,5.0,0.0,44.0,44.0,0.0
25%,0.498647,0.566927,0.570827,0.290649,2.5,2.0,6.0,2.0,2.0,1.5,1042.0,1032.5,635.0,802.0,802.0,410.0
50%,0.675439,0.723906,0.727524,0.549588,5.0,5.0,21.0,6.0,6.0,4.0,4940.0,4945.0,3194.0,3847.0,4138.0,1944.0
75%,0.846335,0.871845,0.890754,0.776012,8.5,8.0,103.5,13.5,13.5,7.0,29476.5,29444.5,27817.0,26305.0,27036.5,23520.5
max,0.947493,0.956289,0.958427,0.927762,25.0,25.0,922.0,43.0,43.0,21.0,91540.0,91606.0,83508.0,102770.0,105729.0,92221.0


## Ensemble 3 - Median 

In [6]:
def median(predictions):
    stacked_predictions = torch.stack(predictions, dim=0)
    return torch.median(stacked_predictions, dim=0).values

#---------------------------------------------------------#

best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    median_df = ensemble_inference(test_df, median, threshold)
    mean = median_df['Dice'].mean()
    print(f'Median Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = median_df
print(f'BEST THRESHOLD: {best_threshold}')
df.to_csv('./outputs/Ensemble/median_test_scores.csv', index=False)
df.describe()

Median Mean Dice for threshold 0.1: 0.6100782444400172
Median Mean Dice for threshold 0.2: 0.6120831245376218
Median Mean Dice for threshold 0.3: 0.6029955754356999
Median Mean Dice for threshold 0.4: 0.592251036436327
Median Mean Dice for threshold 0.5: 0.5806126330168017
Median Mean Dice for threshold 0.6: 0.5694510761287904
Median Mean Dice for threshold 0.7: 0.5551219413117054
Median Mean Dice for threshold 0.8: 0.536392334788557
Median Mean Dice for threshold 0.9: 0.5016900248104527
BEST THRESHOLD: 0.2


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.612083,0.63885,0.64994,0.531447,6.225806,5.741935,7.032258,8.709677,8.709677,5.129032,15876.483871,16000.548387,13569.354839,17105.387097,17985.870968,13810.806452
std,0.273661,0.269515,0.274134,0.306581,5.948163,5.483112,8.12192,9.103822,9.103822,4.951376,24037.063603,24258.577521,21550.894202,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.486434,0.508295,0.520046,0.242854,2.5,2.0,2.0,2.0,2.0,1.5,943.5,937.5,592.0,802.0,802.0,410.0
50%,0.635331,0.679783,0.72,0.571091,4.0,4.0,4.0,6.0,6.0,4.0,3859.0,3888.0,2995.0,3847.0,4138.0,1944.0
75%,0.850039,0.868642,0.88466,0.785967,8.0,7.5,9.0,13.5,13.5,7.0,27416.0,27688.0,25957.5,26305.0,27036.5,23520.5
max,0.947509,0.954614,0.956809,0.932726,25.0,25.0,35.0,43.0,43.0,21.0,90012.0,90697.0,81341.0,102770.0,105729.0,92221.0


## Ensemble 4 - Maximum Confidence 

In [7]:
def max_confidence(predictions):
    stacked_predictions = torch.stack(predictions, dim=0)
    return torch.max(stacked_predictions, dim=0).values

#---------------------------------------------------------#

best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    max_confidence_df = ensemble_inference(test_df, max_confidence, threshold)
    mean = max_confidence_df['Dice'].mean()
    print(f'Max Confidence Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = max_confidence_df
print(f'BEST THRESHOLD: {best_threshold}')
df.to_csv('./outputs/Ensemble/max_confidence_test_scores.csv', index=False)
df.describe()

Max Confidence Mean Dice for threshold 0.1: 0.4769555576325905
Max Confidence Mean Dice for threshold 0.2: 0.5018590381669421
Max Confidence Mean Dice for threshold 0.3: 0.5132624893899886
Max Confidence Mean Dice for threshold 0.4: 0.5196868994032904
Max Confidence Mean Dice for threshold 0.5: 0.5234780021372342
Max Confidence Mean Dice for threshold 0.6: 0.5252501968624851
Max Confidence Mean Dice for threshold 0.7: 0.5257121397167324
Max Confidence Mean Dice for threshold 0.8: 0.5242511283835366
Max Confidence Mean Dice for threshold 0.9: 0.5196892056584118
BEST THRESHOLD: 0.7


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.525712,0.563654,0.575627,0.437288,25.709677,21.967742,37.870968,8.709677,8.709677,5.129032,19678.903226,19567.967742,19298.677419,17105.387097,17985.870968,13810.806452
std,0.318874,0.324253,0.333119,0.309278,43.836205,36.408318,55.675094,9.103822,9.103822,4.951376,25807.250693,25731.65804,25660.328382,25565.209783,26912.616906,22937.645993
min,0.002576,0.005687,0.00204,0.0,1.0,2.0,1.0,1.0,1.0,0.0,1821.0,1813.0,1750.0,44.0,44.0,0.0
25%,0.218996,0.249058,0.249256,0.133651,8.5,7.0,10.0,2.0,2.0,1.5,3700.5,3664.0,3465.5,802.0,802.0,410.0
50%,0.632381,0.661653,0.712883,0.479532,13.0,11.0,19.0,6.0,6.0,4.0,6310.0,5889.0,5875.0,3847.0,4138.0,1944.0
75%,0.794857,0.850547,0.87358,0.694644,22.0,19.5,30.0,13.5,13.5,7.0,32187.0,32010.0,32388.0,26305.0,27036.5,23520.5
max,0.932704,0.950421,0.953007,0.894683,221.0,190.0,251.0,43.0,43.0,21.0,99301.0,98700.0,97619.0,102770.0,105729.0,92221.0


## Ensemble 5 - Majority Voting

In [8]:
def majority_voting(predictions):
    transforms =[AsDiscrete(threshold=0.3), AsDiscrete(threshold=0.5), AsDiscrete(threshold=0.2), AsDiscrete(threshold=0.4)]
    predictions = [transform(prediction) for transform, prediction in zip(transforms, predictions)]
    stacked_predictions = torch.stack(predictions, dim=0)  
    return torch.mode(stacked_predictions, dim=0).values

#---------------------------------------------------------#

majority_vote_df = ensemble_inference(test_df, majority_voting, threshold)
majority_vote_df.to_csv('./outputs/Ensemble/majority_vote_test_scores.csv', index=False)
majority_vote_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.598177,0.629584,0.634287,0.518907,5.774194,5.322581,6.967742,8.709677,8.709677,5.129032,14518.096774,14620.096774,12023.290323,17105.387097,17985.870968,13810.806452
std,0.281068,0.273253,0.28147,0.319628,5.529977,5.042401,8.130125,9.103822,9.103822,4.951376,22096.846216,22308.181914,19524.494051,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.40679,0.476757,0.47681,0.194828,2.0,2.0,2.0,2.0,2.0,1.5,735.0,726.5,461.0,802.0,802.0,410.0
50%,0.654574,0.685359,0.715392,0.594832,4.0,4.0,6.0,6.0,6.0,4.0,3544.0,3562.0,2571.0,3847.0,4138.0,1944.0
75%,0.832495,0.866259,0.871473,0.795037,7.5,7.0,8.5,13.5,13.5,7.0,24593.0,24801.5,21853.0,26305.0,27036.5,23520.5
max,0.951305,0.957194,0.95926,0.937462,21.0,21.0,38.0,43.0,43.0,21.0,82128.0,82342.0,73322.0,102770.0,105729.0,92221.0


## Ensemble 6 - Logistic Regression dimension wise

In [4]:
def logreg_ensemble(predictions, targets):
    stacked_predictions = torch.stack(predictions, dim=0)
    num_models, batch, channels, depth, height, width = stacked_predictions.shape

    # Flatten
    flattened_predictions = stacked_predictions.reshape(num_models, -1).T 
    flattened_targets = targets.reshape(-1)
    
    # Logistic Regression
    meta_model = LogisticRegression(max_iter=1000)
    meta_model.fit(flattened_predictions, flattened_targets)
    ensemble_predictions = meta_model.predict_proba(flattened_predictions)
    
    # Reshape to original shape
    ensemble_predictions_reshaped = ensemble_predictions[:, 1].reshape(batch, channels, depth, height, width)
    return ensemble_predictions_reshaped

#---------------------------------------------------------#

best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    log_df = ensemble_inference(test_df, logreg_ensemble, threshold, include_label=True, store_npz=True, model_name='dim_logreg')
    mean = log_df['Dice'].mean()
    print(f'Logistic Regression Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = log_df
print(f'BEST THRESHOLD: {best_threshold}')
df.to_csv('./outputs/Ensemble/test_dim_logreg_test_scores.csv', index=False)
df.describe()

Logistic Regression Mean Dice for threshold 0.1: 0.6961464016668258
Logistic Regression Mean Dice for threshold 0.2: 0.693949933494291
Logistic Regression Mean Dice for threshold 0.3: 0.6907285230775033
Logistic Regression Mean Dice for threshold 0.4: 0.6873722626797615
Logistic Regression Mean Dice for threshold 0.5: 0.6813228882909301
Logistic Regression Mean Dice for threshold 0.6: 0.670750834483413
Logistic Regression Mean Dice for threshold 0.7: 0.6421724648004578
Logistic Regression Mean Dice for threshold 0.8: 0.6066610277179749
Logistic Regression Mean Dice for threshold 0.9: 0.5064416873839593
BEST THRESHOLD: 0.1


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.696146,0.729733,0.743127,0.590396,8.064516,7.064516,71.322581,8.709677,8.709677,5.129032,17306.774194,17372.709677,14472.612903,17105.387097,17985.870968,13810.806452
std,0.221089,0.199089,0.205319,0.29059,7.882197,6.757393,161.865662,9.103822,9.103822,4.951376,26300.231843,26456.874389,23840.836203,25565.209783,26912.616906,22937.645993
min,0.05869,0.108108,0.067961,0.0,1.0,1.0,0.0,1.0,1.0,0.0,48.0,32.0,0.0,44.0,44.0,0.0
25%,0.565964,0.623152,0.644902,0.434253,2.0,2.0,1.5,2.0,2.0,1.5,874.5,845.5,411.0,802.0,802.0,410.0
50%,0.771663,0.784452,0.798665,0.652569,6.0,5.0,6.0,6.0,6.0,4.0,3700.0,3650.0,2633.0,3847.0,4138.0,1944.0
75%,0.865933,0.881005,0.898048,0.810923,9.5,9.0,25.5,13.5,13.5,7.0,29011.0,28953.0,26329.0,26305.0,27036.5,23520.5
max,0.953246,0.958113,0.9604,0.941226,31.0,29.0,675.0,43.0,43.0,21.0,102821.0,103018.0,93728.0,102770.0,105729.0,92221.0


## Ensemble 7 - Logistic Regression voxel wise

In [7]:
def logreg_ensemble(predictions, targets):
    batch, num_channels, depth, height, width = predictions[0].shape

    # Flatten
    flattened_predictions = [pred.view(3, -1).t() for pred in predictions]
    flattened_predictions = torch.cat(flattened_predictions, dim=1)
    flattened_targets = targets.view(3, -1).t()

    # Logistic Regression
    meta_model = OneVsRestClassifier(LogisticRegression(max_iter=1000))
    meta_model.fit(flattened_predictions, flattened_targets)
    ensemble_predictions = meta_model.predict(flattened_predictions)
    ensemble_predictions = ensemble_predictions.T.reshape(batch, num_channels, depth, height, width)
    return ensemble_predictions

#---------------------------------------------------------#
log_df = ensemble_inference(test_df, logreg_ensemble, threshold=0.1, include_label=True, store_npz=True, model_name='voxel_logreg')
log_df.to_csv('./outputs/Ensemble/test_voxel_logreg_test_scores.csv', index=False)
log_df.describe()

Label not 2 is present in all training examples.


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.692374,0.727502,0.741233,0.582129,7.387097,6.645161,11.903226,8.709677,8.709677,5.129032,15512.0,16272.580645,12278.064516,17105.387097,17985.870968,13810.806452
std,0.233793,0.214656,0.2164,0.305549,8.826768,6.390349,23.744971,9.103822,9.103822,4.951376,23804.061275,24819.26776,21193.697282,25565.209783,26912.616906,22937.645993
min,0.014035,0.021053,0.021053,0.0,1.0,1.0,0.0,1.0,1.0,0.0,16.0,16.0,0.0,44.0,44.0,0.0
25%,0.554798,0.632412,0.661439,0.433333,2.0,2.0,2.0,2.0,2.0,1.5,606.0,606.0,212.5,802.0,802.0,410.0
50%,0.775395,0.792676,0.813953,0.651661,5.0,5.0,7.0,6.0,6.0,4.0,3254.0,3467.0,1907.0,3847.0,4138.0,1944.0
75%,0.869011,0.886596,0.900485,0.832299,8.5,8.5,11.0,13.5,13.5,7.0,25946.0,26020.5,22304.0,26305.0,27036.5,23520.5
max,0.954892,0.959309,0.961732,0.943636,42.0,25.0,128.0,43.0,43.0,21.0,94297.0,96433.0,83083.0,102770.0,105729.0,92221.0
