# Ensemble - ALL MODELS

## Environment Setup

### Imports

In [1]:
import os
os.sys.path.append("utils")

import torch
import numpy as np
import pandas as pd
from functools import reduce
import torch.nn.functional as F
from torch.utils.data import DataLoader
from utils.Transforms import Transforms
from monai.transforms import AsDiscrete
from utils.Datasets import EnsembleDataset
from utils.Models import LogisticRegression
from torch.utils.data import SequentialSampler
from utils.Inference import ensemble_inference, model_ensemble_inference

### Config

In [2]:
seed = 33
pd.set_option("display.max_columns", None)
pd.set_option('display.max_colwidth', None)

## Load Test Data

In [3]:
subject_ids = pd.read_csv('./data/TEST.csv')['SubjectID'].values

ah_segs, unet_segs, segresnet_segs, untr_segs, gt_segs = [], [], [], [], []
for sid in subject_ids:
    ah_channels, unet_channels, segresnet_channels, untr_channels, gt_channels = [], [], [], [], []
    for channel in ['TC', 'WT', 'ET']:
        ah_channels.append(f'./outputs/AHNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.npz')
        unet_channels.append(f'./outputs/UNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.npz')
        segresnet_channels.append(f'./outputs/SegResNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.npz')
        untr_channels.append(f'./outputs/UNETR/pred_segs/test_pred_segs/pred_{sid}_{channel}.npz')
        gt_channels.append(f'./outputs/gt_segs/test_gt_segs/gt_{sid}_{channel}.npz')
    
    ah_segs.append(ah_channels)
    unet_segs.append(unet_channels)
    segresnet_segs.append(segresnet_channels)
    untr_segs.append(untr_channels)
    gt_segs.append(gt_channels)

# Dataframe
test_df = pd.DataFrame()
test_df['SubjectID'] = subject_ids
test_df['AHNet'] = ah_segs
test_df['UNet'] = unet_segs
test_df['SegResNet'] = segresnet_segs
test_df['UNETR'] = untr_segs
test_df['GT'] = gt_segs

test_df.head()

Unnamed: 0,SubjectID,AHNet,UNet,SegResNet,UNETR,GT
0,100214B,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_ET.npz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100214B_TC.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100214B_WT.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100214B_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100214B_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100214B_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100214B_ET.npz]"
1,100340A,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_ET.npz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100340A_TC.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100340A_WT.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100340A_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100340A_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100340A_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100340A_ET.npz]"
2,100391A,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_ET.npz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100391A_TC.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100391A_WT.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100391A_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100391A_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100391A_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100391A_ET.npz]"
3,100190B,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_ET.npz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100190B_TC.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100190B_WT.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100190B_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100190B_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100190B_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100190B_ET.npz]"
4,100142A,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_ET.npz]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100142A_TC.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100142A_WT.npz, ./outputs/UNet/pred_segs/test_pred_segs/pred_100142A_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100142A_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100142A_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100142A_ET.npz]"


## Ensemble Functions

In [4]:
# Average
def avg(predictions):
    stacked_predictions = torch.stack(predictions, dim=0)
    return torch.mean(stacked_predictions, dim=0)

#---------------------------------#

# Weighted Average
def weigthed_avg(predictions):

    # Weights
    weights = np.ones(len(predictions))
    if len(predictions) == 4: # All models
        weights = np.array([0.546402, 0.648790, 0.428639, 0.547957])
    elif len(predictions) == 3: # 3 models
        weights = np.array([0.546402, 0.648790, 0.547957])
    elif len(predictions) == 2: # 2 models
        weights = np.array([0.546402, 0.547957])
    weights = weights / np.sum(weights)
    weights = torch.tensor(weights).float()
    
    stacked_predictions = torch.stack(predictions, dim=0)
    return torch.sum(stacked_predictions * weights[:, None, None, None, None, None], dim=0)

#---------------------------------#

# Median
def median(predictions):
    stacked_predictions = torch.stack(predictions, dim=0)
    return torch.median(stacked_predictions, dim=0).values

#---------------------------------#

# Maximun Confidence
def max_confidence(predictions):
    stacked_predictions = torch.stack(predictions, dim=0)
    return torch.max(stacked_predictions, dim=0).values

#---------------------------------#

# Majority Voting
def majority_voting(predictions):
    transforms =[AsDiscrete(threshold=0.3), AsDiscrete(threshold=0.3), AsDiscrete(threshold=0.3), AsDiscrete(threshold=0.3)]
    predictions = [transform(prediction) for transform, prediction in zip(transforms, predictions)]
    stacked_predictions = torch.stack(predictions, dim=0)  
    return torch.mode(stacked_predictions, dim=0).values

In [5]:
# Dempter-Shafer
def dempster_rule(mass1, mass2):

    K = 1 / (1 - mass1[..., 0] * mass2[..., 1] - mass1[..., 1] * mass2[..., 0])

    combined_mass = torch.zeros_like(mass1)
    combined_mass[..., 0] = K * (mass1[..., 0] * mass2[..., 0] + mass1[..., 0] * mass2[..., 2] + mass1[..., 2] * mass2[..., 0])
    combined_mass[..., 1] = K * (mass1[..., 1] * mass2[..., 1] + mass1[..., 1] * mass2[..., 2] + mass1[..., 2] * mass2[..., 1])
    combined_mass[..., 2] = K * (mass1[..., 0] * mass2[..., 1] + mass1[..., 1] * mass2[..., 0] + mass1[..., 2] * mass2[..., 2])
    
    return combined_mass

def dempster_shafer_inference(predictions):

    # Params
    predictions = [pred.squeeze(0) for pred in predictions]

    # Mass Beliefs for {tumor, non-tumor} - (3, 3, x, y, z) - (masses, channels, x, y, z)
    masses = []	
    for pred in predictions:
        epsilon = 1 - torch.max(pred) 
        tumor_mass = pred
        non_tumor_mass = 1 - pred
        uncertainty_mass = torch.full_like(pred, epsilon)
        masses.append(torch.stack([tumor_mass, non_tumor_mass, uncertainty_mass], dim=-1))

    combined_mass = reduce(dempster_rule, masses)

    # Normalized Mass
    combined_mass_sum = combined_mass.sum(dim=-1, keepdim=True)
    normalized_mass = combined_mass / combined_mass_sum
    
    return normalized_mass[..., 0].unsqueeze(0)   

In [16]:
# CRF - Mean Field Approximation
def crf_mean(predictions, sigma_spatial=3.0, sigma_bilateral=1.0, iterations=10, alpha=0.7):
    # Weighted Average
    avg_pred = weigthed_avg(predictions)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    avg_pred = avg_pred.to(device)
    Q = avg_pred.clone()

    for _ in range(iterations):
        # Spatial
        Q_spatial = F.avg_pool3d(Q, kernel_size=3, stride=1, padding=1)

        # Bilateral
        diff = (Q.unsqueeze(2) - Q.unsqueeze(1)).pow(2)
        bilateral_weight = torch.exp(-diff / (2 * sigma_bilateral**2))
        Q_bilateral = (Q.unsqueeze(2) * bilateral_weight).sum(1) / bilateral_weight.sum(1)

        # Combine spatial and bilateral
        Q_new = Q * torch.exp((Q_spatial - Q) / sigma_spatial + (Q_bilateral - Q) / sigma_bilateral)
        Q_new = Q_new / Q_new.sum(dim=1, keepdim=True)

        Q = Q_new
        
    final_pred = alpha * avg_pred + (1 - alpha) * Q

    return final_pred.cpu()

In [17]:
# CRF - Energy Minimization
def crf_energy(predictions, iterations=5, sigma_spatial=3.0, sigma_bilateral=1.0, learning_rate=1e-3, alpha=0.7):
    # Weighted Average
    avg_pred = weigthed_avg(predictions)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    avg_pred = avg_pred.to(device)

    Q = avg_pred.clone().requires_grad_(True)

    for _ in range(iterations):
        # Compute energy
        unary_energy = -torch.sum(avg_pred * torch.log(Q + 1e-10))
        
        # Spatial pairwise energy
        Q_spatial = F.avg_pool3d(Q, kernel_size=3, stride=1, padding=1)
        spatial_energy = torch.sum((Q - Q_spatial).pow(2)) / (2 * sigma_spatial**2)
        
        # Bilateral pairwise energy
        diff = (Q.unsqueeze(2) - Q.unsqueeze(1)).pow(2)
        bilateral_weight = torch.exp(-diff / (2 * sigma_bilateral**2))
        Q_bilateral = (Q.unsqueeze(2) * bilateral_weight).sum(1) / bilateral_weight.sum(1)
        bilateral_energy = torch.sum((Q - Q_bilateral).pow(2)) / (2 * sigma_bilateral**2)
        
        # Total energy
        energy = unary_energy + spatial_energy + bilateral_energy
        
        # Update Q
        energy.backward()
        with torch.no_grad():
            Q -= learning_rate * Q.grad
            Q.grad.zero_()
        Q = F.softmax(Q, dim=1)
        Q = Q.detach().requires_grad_(True)

    final_pred = alpha * avg_pred + (1 - alpha) * Q
    
    return final_pred.detach().cpu()

## Ensembles - All Models

In [8]:
path = './outputs/Ensemble/'
models = ['AHNet', 'SegResNet', 'UNet', 'UNETR']

In [9]:
# Average
df = ensemble_inference(models, test_df, avg, threshold=0.4)
df.to_csv(f'{path}ensemble_avg.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.630636,0.662461,0.673749,0.533071,6.677419,6.516129,87.129032,8.709677,8.709677,5.129032,16592.322581,16591.774194,14575.709677,17105.387097,17985.870968,13810.806452
std,0.259768,0.25531,0.260801,0.293511,5.906985,5.977017,166.985177,9.103822,9.103822,4.951376,24168.923954,24207.681579,22133.776319,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,7.0,6.0,0.0,44.0,44.0,0.0
25%,0.484863,0.554968,0.560128,0.282297,3.0,2.5,5.5,2.0,2.0,1.5,1057.0,1048.0,662.5,802.0,802.0,410.0
50%,0.67684,0.717632,0.698113,0.542388,5.0,5.0,14.0,6.0,6.0,4.0,4993.0,4986.0,3298.0,3847.0,4138.0,1944.0
75%,0.842653,0.869613,0.885937,0.769723,8.0,8.0,65.0,13.5,13.5,7.0,29342.5,29308.0,27872.0,26305.0,27036.5,23520.5
max,0.946255,0.955846,0.958104,0.924816,25.0,25.0,684.0,43.0,43.0,21.0,90746.0,90706.0,83437.0,102770.0,105729.0,92221.0


In [10]:
# Weighted Average
df = ensemble_inference(models, test_df, weigthed_avg, threshold=0.4)
df.to_csv(f'{path}/ensemble_weighted_avg.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.6359,0.667215,0.679828,0.53887,6.903226,6.419355,109.774194,8.709677,8.709677,5.129032,16629.322581,16633.903226,14432.903226,17105.387097,17985.870968,13810.806452
std,0.257415,0.252058,0.256418,0.295345,5.985287,5.863868,204.315884,9.103822,9.103822,4.951376,24313.317312,24377.260244,22128.085428,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,6.0,5.0,0.0,44.0,44.0,0.0
25%,0.498647,0.566927,0.570827,0.290649,2.5,2.0,6.0,2.0,2.0,1.5,1042.0,1032.5,635.0,802.0,802.0,410.0
50%,0.675439,0.723906,0.727524,0.549588,5.0,5.0,21.0,6.0,6.0,4.0,4940.0,4945.0,3194.0,3847.0,4138.0,1944.0
75%,0.846335,0.871845,0.890754,0.776012,8.5,8.0,103.5,13.5,13.5,7.0,29476.5,29444.5,27817.0,26305.0,27036.5,23520.5
max,0.947493,0.956289,0.958427,0.927762,25.0,25.0,922.0,43.0,43.0,21.0,91540.0,91606.0,83508.0,102770.0,105729.0,92221.0


In [11]:
# Median
df = ensemble_inference(models, test_df, median, threshold=0.4)
df.to_csv(f'{path}/ensemble_median.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.592251,0.625845,0.626673,0.516698,5.903226,5.83871,8.16129,8.709677,8.709677,5.129032,14633.16129,14741.741935,12157.387097,17105.387097,17985.870968,13810.806452
std,0.286974,0.276813,0.291834,0.320492,5.752419,6.061335,9.612827,9.103822,9.103822,4.951376,22319.908553,22558.45792,19780.442551,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.39889,0.471294,0.466704,0.192902,2.0,2.0,2.5,2.0,2.0,1.5,732.0,718.5,465.5,802.0,802.0,410.0
50%,0.653768,0.684043,0.717003,0.59392,4.0,4.0,5.0,6.0,6.0,4.0,3549.0,3546.0,2550.0,3847.0,4138.0,1944.0
75%,0.834203,0.867059,0.87476,0.793132,7.5,7.0,10.0,13.5,13.5,7.0,24729.0,24901.5,22015.0,26305.0,27036.5,23520.5
max,0.951454,0.957341,0.959509,0.937512,22.0,22.0,39.0,43.0,43.0,21.0,83221.0,83512.0,74275.0,102770.0,105729.0,92221.0


In [12]:
# Maximum Confidence
df = ensemble_inference(models, test_df, max_confidence, threshold=0.4)
df.to_csv(f'{path}ensemble_max_confidence.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.519687,0.55769,0.56923,0.431639,34.16129,31.677419,34.580645,8.709677,8.709677,5.129032,21617.032258,21632.225806,21275.935484,17105.387097,17985.870968,13810.806452
std,0.312949,0.318218,0.327271,0.302892,54.264228,48.573235,58.82617,9.103822,9.103822,4.951376,27887.914179,27994.480121,27943.434225,25565.209783,26912.616906,22937.645993
min,0.00768,0.012817,0.010222,0.0,4.0,3.0,2.0,1.0,1.0,0.0,1986.0,2005.0,1813.0,44.0,44.0,0.0
25%,0.214851,0.249721,0.248241,0.130362,11.5,10.5,8.0,2.0,2.0,1.5,4126.5,4106.5,3889.5,802.0,802.0,410.0
50%,0.638677,0.676597,0.697987,0.496722,21.0,19.0,15.0,6.0,6.0,4.0,7842.0,7704.0,6868.0,3847.0,4138.0,1944.0
75%,0.783556,0.841341,0.858065,0.674008,33.5,30.5,25.0,13.5,13.5,7.0,35506.5,35595.0,35823.5,26305.0,27036.5,23520.5
max,0.916666,0.936059,0.937123,0.876816,263.0,244.0,284.0,43.0,43.0,21.0,107651.0,107975.0,106480.0,102770.0,105729.0,92221.0


In [13]:
# Majority Voting
majority_vote_df = ensemble_inference(models, test_df, majority_voting)
majority_vote_df.to_csv(f'{path}ensemble_majority_vote.csv', index=False)
majority_vote_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.602996,0.633261,0.638629,0.525343,6.225806,5.709677,8.129032,8.709677,8.709677,5.129032,15176.096774,15288.451613,12763.709677,17105.387097,17985.870968,13810.806452
std,0.280308,0.273121,0.282172,0.314649,6.540691,5.734071,10.111848,9.103822,9.103822,4.951376,23072.3079,23298.294801,20540.39066,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.41986,0.49123,0.492867,0.213326,2.0,2.0,2.0,2.0,2.0,1.5,806.0,809.5,527.0,802.0,802.0,410.0
50%,0.646907,0.677307,0.725038,0.587416,4.0,4.0,5.0,6.0,6.0,4.0,3673.0,3696.0,2722.0,3847.0,4138.0,1944.0
75%,0.842977,0.869454,0.87987,0.789589,8.0,7.0,8.0,13.5,13.5,7.0,25870.5,26098.0,23754.5,26305.0,27036.5,23520.5
max,0.950614,0.957061,0.959217,0.936127,29.0,25.0,43.0,43.0,43.0,21.0,86124.0,86616.0,77293.0,102770.0,105729.0,92221.0


In [14]:
# Dempter-Shafer
ds_df = ensemble_inference(models, test_df, dempster_shafer_inference, threshold=0.1)
ds_df.to_csv(f'{path}ensemble_dempster_shafer.csv', index=False)
ds_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.577118,0.60608,0.618176,0.490236,6.290323,6.225806,6.741935,8.709677,8.709677,5.129032,15446.258065,15558.225806,13528.290323,17105.387097,17985.870968,13810.806452
std,0.279771,0.276253,0.281809,0.313583,5.768845,5.649246,6.093536,9.103822,9.103822,4.951376,23810.414574,23963.213483,21562.27938,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,16.0,12.0,0.0,44.0,44.0,0.0
25%,0.442785,0.495962,0.520909,0.213849,4.0,4.0,3.0,2.0,2.0,1.5,698.0,724.0,533.5,802.0,802.0,410.0
50%,0.56542,0.608258,0.612368,0.492505,5.0,5.0,5.0,6.0,6.0,4.0,3781.0,3802.0,3008.0,3847.0,4138.0,1944.0
75%,0.834014,0.862115,0.876348,0.775048,7.0,7.0,8.5,13.5,13.5,7.0,27914.5,28144.5,26598.0,26305.0,27036.5,23520.5
max,0.948727,0.954471,0.957102,0.934608,28.0,28.0,27.0,43.0,43.0,21.0,89020.0,89231.0,81336.0,102770.0,105729.0,92221.0


In [15]:
# CRF - Mean Field Approximation
df = ensemble_inference(models, test_df, crf_mean, threshold=0.4)
df.to_csv(f'{path}ensemble_crf_mean.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.635137,0.666215,0.678172,0.53882,6.645161,6.548387,123.709677,8.709677,8.709677,5.129032,16324.032258,16321.354839,14036.870968,17105.387097,17985.870968,13810.806452
std,0.259541,0.253901,0.258408,0.298055,5.689454,6.048905,232.605129,9.103822,9.103822,4.951376,23963.957376,24025.909997,21674.041599,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,5.0,2.0,0.0,44.0,44.0,0.0
25%,0.500498,0.559208,0.571193,0.293311,3.5,2.0,7.0,2.0,2.0,1.5,992.0,974.0,584.0,802.0,802.0,410.0
50%,0.688387,0.722104,0.740951,0.561484,5.0,5.0,26.0,6.0,6.0,4.0,4791.0,4748.0,3074.0,3847.0,4138.0,1944.0
75%,0.846332,0.870643,0.889466,0.781494,8.5,8.0,126.5,13.5,13.5,7.0,29037.0,29010.0,27191.0,26305.0,27036.5,23520.5
max,0.948949,0.957131,0.959748,0.929968,25.0,25.0,1103.0,43.0,43.0,21.0,90066.0,90116.0,81821.0,102770.0,105729.0,92221.0


In [16]:
# CRF - Energy Minimization
df = ensemble_inference(models, test_df, crf_energy, threshold=0.4)
df.to_csv(f'{path}ensemble_crf_energy.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.635099,0.666059,0.678025,0.538904,6.612903,6.387097,123.451613,8.709677,8.709677,5.129032,16321.967742,16317.0,14038.354839,17105.387097,17985.870968,13810.806452
std,0.25975,0.253963,0.259043,0.297947,5.643152,6.059029,231.764512,9.103822,9.103822,4.951376,23962.280433,24021.201451,21676.087175,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,5.0,2.0,0.0,44.0,44.0,0.0
25%,0.500684,0.558539,0.571357,0.293646,3.5,2.0,7.0,2.0,2.0,1.5,992.5,973.5,583.5,802.0,802.0,410.0
50%,0.691589,0.722104,0.740583,0.561484,5.0,5.0,26.0,6.0,6.0,4.0,4791.0,4746.0,3075.0,3847.0,4138.0,1944.0
75%,0.846302,0.870695,0.889422,0.781489,8.0,8.0,125.5,13.5,13.5,7.0,29038.5,29009.0,27190.5,26305.0,27036.5,23520.5
max,0.948939,0.957131,0.959761,0.929925,25.0,25.0,1097.0,43.0,43.0,21.0,90051.0,90081.0,81829.0,102770.0,105729.0,92221.0


In [17]:
# Data
transforms = Transforms(seed=33)
test_dataset = EnsembleDataset(test_df.drop(columns = ['SubjectID']), transform=transforms.val_ensemble(), size = None)
test_sampler = SequentialSampler(test_dataset)
test_loader = DataLoader(test_dataset, batch_size = 1, shuffle = False, sampler = test_sampler)
subjects = pd.read_csv('./data/TEST.csv')['SubjectID'].values

# Model
model = LogisticRegression(12, 3)
model.load_state_dict(torch.load(f'{path}LogRegCheckpoints/LogisticRegression_2.pt'))

#---------------------------------------------------------#

# Logistic Regression - No TTA
df = model_ensemble_inference(subjects, test_loader, model, spatial_size=12, threshold=0.9)
df.to_csv(f'{path}ensemble_logreg.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.591372,0.633501,0.645053,0.488979,6.451613,5.967742,8.129032,8.709677,8.709677,5.129032,16603.354839,15922.129032,16926.645161,17105.387097,17985.870968,13810.806452
std,0.269072,0.275739,0.281359,0.275928,6.152716,6.590315,7.710348,9.103822,9.103822,4.951376,24234.257254,23657.322502,24612.208108,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,1.0,0.0,1.0,1.0,1.0,0.0,1.0,0.0,3.0,44.0,44.0,0.0
25%,0.42468,0.518179,0.513945,0.288068,2.0,2.0,2.5,2.0,2.0,1.5,979.0,796.0,1090.5,802.0,802.0,410.0
50%,0.628824,0.682927,0.715011,0.501049,4.0,4.0,5.0,6.0,6.0,4.0,4922.0,4432.0,5163.0,3847.0,4138.0,1944.0
75%,0.81361,0.869493,0.883609,0.702449,8.0,7.0,12.0,13.5,13.5,7.0,29586.5,28904.5,29824.0,26305.0,27036.5,23520.5
max,0.938326,0.955768,0.959606,0.900631,27.0,28.0,29.0,43.0,43.0,21.0,90315.0,87750.0,91928.0,102770.0,105729.0,92221.0


## Ensemble - 3 Models - AHNet, SegResNet & UNETR

In [8]:
path = './outputs/Ensemble_3/'
models = ['AHNet', 'SegResNet', 'UNETR']

In [9]:
# Average
df = ensemble_inference(models, test_df, avg, threshold=0.4)
df.to_csv(f'{path}ensemble_avg.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.6429,0.675335,0.684806,0.547999,7.870968,7.032258,16.096774,8.709677,8.709677,5.129032,16631.483871,16666.580645,13891.870968,17105.387097,17985.870968,13810.806452
std,0.252928,0.241508,0.246799,0.303004,7.830462,6.374344,21.357832,9.103822,9.103822,4.951376,25129.788233,25262.45967,22386.06211,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.523253,0.570072,0.575907,0.320578,3.0,3.0,3.0,2.0,2.0,1.5,982.5,960.0,551.5,802.0,802.0,410.0
50%,0.637359,0.703949,0.753131,0.577692,5.0,5.0,8.0,6.0,6.0,4.0,4718.0,4724.0,2852.0,3847.0,4138.0,1944.0
75%,0.855782,0.873027,0.898497,0.793281,9.5,9.5,21.5,13.5,13.5,7.0,30380.0,30354.0,28023.0,26305.0,27036.5,23520.5
max,0.949085,0.955557,0.958394,0.935284,34.0,26.0,99.0,43.0,43.0,21.0,95617.0,95689.0,84916.0,102770.0,105729.0,92221.0


In [10]:
# Weighted Average
df = ensemble_inference(models, test_df, weigthed_avg, threshold=0.4)
df.to_csv(f'{path}/ensemble_weighted_avg.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.645398,0.676805,0.687453,0.552012,7.193548,6.967742,14.0,8.709677,8.709677,5.129032,16555.032258,16586.225806,13788.322581,17105.387097,17985.870968,13810.806452
std,0.251027,0.24021,0.243862,0.301197,6.935029,6.635681,17.189144,9.103822,9.103822,4.951376,24977.822062,25104.340697,22205.619714,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.534398,0.5762,0.580152,0.33164,3.5,3.0,3.0,2.0,2.0,1.5,995.5,970.0,566.0,802.0,802.0,410.0
50%,0.631351,0.713736,0.762651,0.576646,5.0,5.0,7.0,6.0,6.0,4.0,4724.0,4729.0,2863.0,3847.0,4138.0,1944.0
75%,0.856444,0.873076,0.898498,0.794745,8.5,8.5,19.0,13.5,13.5,7.0,29923.5,29944.5,27641.0,26305.0,27036.5,23520.5
max,0.949815,0.956011,0.959204,0.935853,30.0,27.0,71.0,43.0,43.0,21.0,95005.0,95181.0,84273.0,102770.0,105729.0,92221.0


In [11]:
# Median
df = ensemble_inference(models, test_df, median, threshold=0.4)
df.to_csv(f'{path}/ensemble_median.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.637875,0.672492,0.680239,0.541232,8.709677,7.967742,9.612903,8.709677,8.709677,5.129032,16071.645161,16227.548387,13444.032258,17105.387097,17985.870968,13810.806452
std,0.254386,0.241583,0.246623,0.308399,7.751101,7.467636,11.238853,9.103822,9.103822,4.951376,24189.381509,24527.321182,21631.56032,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.529374,0.570772,0.579417,0.291528,3.5,3.5,2.5,2.0,2.0,1.5,974.0,953.0,556.0,802.0,802.0,410.0
50%,0.613806,0.693872,0.745872,0.567229,5.0,5.0,6.0,6.0,6.0,4.0,4626.0,4656.0,2851.0,3847.0,4138.0,1944.0
75%,0.853084,0.873338,0.894805,0.791979,12.5,9.0,12.0,13.5,13.5,7.0,28907.5,29085.5,26872.0,26305.0,27036.5,23520.5
max,0.949471,0.956106,0.957815,0.936082,30.0,28.0,47.0,43.0,43.0,21.0,92858.0,93420.0,82590.0,102770.0,105729.0,92221.0


In [12]:
# Maximum Confidence
df = ensemble_inference(models, test_df, max_confidence, threshold=0.4)
df.to_csv(f'{path}ensemble_max_confidence.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.610785,0.649291,0.657078,0.522498,20.032258,18.935484,84.741935,8.709677,8.709677,5.129032,20064.064516,20117.387097,17951.483871,17105.387097,17985.870968,13810.806452
std,0.251513,0.237557,0.251231,0.28307,11.870085,12.179041,145.158986,9.103822,9.103822,4.951376,28396.293512,28495.471263,26651.570137,25565.209783,26912.616906,22937.645993
min,0.108131,0.120713,0.095548,0.0,4.0,3.0,0.0,1.0,1.0,0.0,117.0,88.0,0.0,44.0,44.0,0.0
25%,0.46681,0.519662,0.517151,0.288739,10.5,9.0,11.5,2.0,2.0,1.5,1402.5,1414.5,1001.0,802.0,802.0,410.0
50%,0.668579,0.704439,0.765412,0.593307,18.0,14.0,23.0,6.0,6.0,4.0,6221.0,6254.0,4554.0,3847.0,4138.0,1944.0
75%,0.828359,0.850638,0.868865,0.759921,26.5,28.5,79.5,13.5,13.5,7.0,35372.5,35460.5,33510.5,26305.0,27036.5,23520.5
max,0.924435,0.937159,0.938148,0.897998,50.0,53.0,584.0,43.0,43.0,21.0,107428.0,107767.0,100362.0,102770.0,105729.0,92221.0


In [13]:
# Majority Voting
majority_vote_df = ensemble_inference(models, test_df, majority_voting)
majority_vote_df.to_csv(f'{path}ensemble_majority_vote.csv', index=False)
majority_vote_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.64156,0.674658,0.684467,0.54681,8.806452,8.225806,9.387097,8.709677,8.709677,5.129032,16842.516129,17002.580645,14252.354839,17105.387097,17985.870968,13810.806452
std,0.248368,0.236819,0.240307,0.300191,7.717596,7.419388,11.062481,9.103822,9.103822,4.951376,25211.120421,25523.302593,22620.338032,25565.209783,26912.616906,22937.645993
min,0.00381,0.0,0.011429,0.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,44.0,44.0,0.0
25%,0.524193,0.56468,0.565742,0.332082,3.5,3.0,2.0,2.0,2.0,1.5,1077.5,1073.5,637.0,802.0,802.0,410.0
50%,0.624035,0.710088,0.759918,0.548019,6.0,5.0,5.0,6.0,6.0,4.0,4886.0,4910.0,3168.0,3847.0,4138.0,1944.0
75%,0.853133,0.87441,0.892536,0.788852,10.0,11.5,11.5,13.5,13.5,7.0,30742.0,30948.5,28709.0,26305.0,27036.5,23520.5
max,0.9448,0.952258,0.954205,0.931873,30.0,29.0,45.0,43.0,43.0,21.0,96593.0,97200.0,86239.0,102770.0,105729.0,92221.0


In [14]:
# Dempter-Shafer
ds_df = ensemble_inference(models, test_df, dempster_shafer_inference, threshold=0.1)
ds_df.to_csv(f'{path}ensemble_dempster_shafer.csv', index=False)
ds_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.601745,0.633941,0.644781,0.505325,8.516129,8.419355,9.290323,8.709677,8.709677,5.129032,16802.903226,16908.709677,14467.83871,17105.387097,17985.870968,13810.806452
std,0.264539,0.254862,0.260568,0.308289,7.719546,7.873052,10.687044,9.103822,9.103822,4.951376,25970.063724,26162.84428,23447.012299,25565.209783,26912.616906,22937.645993
min,0.011385,0.022727,0.011429,0.0,2.0,1.0,0.0,1.0,1.0,0.0,2.0,1.0,0.0,44.0,44.0,0.0
25%,0.456363,0.526884,0.549712,0.230807,3.5,4.0,2.5,2.0,2.0,1.5,850.0,850.0,490.0,802.0,802.0,410.0
50%,0.585188,0.64,0.629866,0.503819,6.0,6.0,6.0,6.0,6.0,4.0,4183.0,4317.0,3003.0,3847.0,4138.0,1944.0
75%,0.84931,0.874172,0.884706,0.78592,11.0,10.5,10.0,13.5,13.5,7.0,31359.5,31485.0,29419.0,26305.0,27036.5,23520.5
max,0.937775,0.944772,0.94706,0.921493,31.0,32.0,47.0,43.0,43.0,21.0,98062.0,98116.0,88349.0,102770.0,105729.0,92221.0


In [18]:
# CRF - Mean Field Approximation
df = ensemble_inference(models, test_df, crf_mean, threshold=0.4)
df.to_csv(f'{path}ensemble_crf_mean.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.643267,0.676059,0.684608,0.548649,7.193548,6.645161,11.483871,8.709677,8.709677,5.129032,16231.225806,16279.806452,13430.064516,17105.387097,17985.870968,13810.806452
std,0.25389,0.242339,0.247343,0.304838,6.920594,6.457803,13.22087,9.103822,9.103822,4.951376,24534.594601,24696.136302,21763.150106,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.533568,0.578342,0.581405,0.32396,3.5,3.0,3.0,2.0,2.0,1.5,965.0,935.0,529.5,802.0,802.0,410.0
50%,0.635036,0.703195,0.752616,0.58286,5.0,4.0,6.0,6.0,6.0,4.0,4661.0,4643.0,2750.0,3847.0,4138.0,1944.0
75%,0.855568,0.871944,0.898139,0.796715,8.5,8.5,16.5,13.5,13.5,7.0,28962.0,28988.0,26655.0,26305.0,27036.5,23520.5
max,0.951052,0.956996,0.959822,0.93741,30.0,26.0,51.0,43.0,43.0,21.0,93512.0,93718.0,82698.0,102770.0,105729.0,92221.0


In [19]:
# CRF - Energy Minimization
df = ensemble_inference(models, test_df, crf_energy, threshold=0.4)
df.to_csv(f'{path}ensemble_crf_energy.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.642601,0.675014,0.683434,0.548793,7.580645,6.903226,11.806452,8.709677,8.709677,5.129032,16208.903226,16254.0,13433.935484,17105.387097,17985.870968,13810.806452
std,0.254197,0.242991,0.247759,0.304729,7.29737,6.410173,13.731762,9.103822,9.103822,4.951376,24532.480847,24693.675019,21767.654772,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.533431,0.577694,0.581337,0.323906,2.5,3.0,3.0,2.0,2.0,1.5,961.5,934.5,530.5,802.0,802.0,410.0
50%,0.637398,0.703398,0.752826,0.582716,5.0,5.0,6.0,6.0,6.0,4.0,4655.0,4638.0,2751.0,3847.0,4138.0,1944.0
75%,0.855494,0.872075,0.898147,0.796974,9.0,9.0,17.0,13.5,13.5,7.0,28955.5,28980.5,26658.0,26305.0,27036.5,23520.5
max,0.951048,0.956983,0.959822,0.937292,30.0,25.0,53.0,43.0,43.0,21.0,93493.0,93684.0,82721.0,102770.0,105729.0,92221.0


In [20]:
# Data
transforms = Transforms(seed=33)
test_dataset = EnsembleDataset(test_df.drop(columns = ['SubjectID']), transform=transforms.val_ensemble(), size = None, include_models=[True, True, False, True])
test_sampler = SequentialSampler(test_dataset)
test_loader = DataLoader(test_dataset, batch_size = 1, shuffle = False, sampler = test_sampler)
subjects = pd.read_csv('./data/TEST.csv')['SubjectID'].values

# Model
model = LogisticRegression(9, 3)
model.load_state_dict(torch.load(f'{path}LogRegCheckpoints/LogisticRegression_2.pt'))

#---------------------------------------------------------#

# Logistic Regression - No TTA
df = model_ensemble_inference(subjects, test_loader, model, spatial_size=9, threshold=0.9)
df.to_csv(f'{path}ensemble_logreg.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.612213,0.663321,0.669514,0.489864,8.774194,9.516129,9.83871,8.709677,8.709677,5.129032,18290.258065,19003.064516,18774.0,17105.387097,17985.870968,13810.806452
std,0.248036,0.251599,0.25671,0.271166,8.208571,8.286016,8.426533,9.103822,9.103822,4.951376,27038.498033,27869.187866,27409.691265,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,0.0,0.0,1.0,10.0,44.0,44.0,0.0
25%,0.462423,0.554807,0.542016,0.302372,3.5,4.0,5.0,2.0,2.0,1.5,1056.0,1168.0,1187.5,802.0,802.0,410.0
50%,0.654818,0.729425,0.753564,0.51798,6.0,7.0,8.0,6.0,6.0,4.0,4991.0,5138.0,5175.0,3847.0,4138.0,1944.0
75%,0.818042,0.87939,0.884497,0.689917,9.0,10.0,12.0,13.5,13.5,7.0,33123.5,34406.5,33904.5,26305.0,27036.5,23520.5
max,0.926134,0.946958,0.942607,0.888838,31.0,33.0,36.0,43.0,43.0,21.0,102134.0,105173.0,103049.0,102770.0,105729.0,92221.0


## Ensemble - 2 Models - AHNet & UNETR

In [26]:
path = './outputs/Ensemble_2/'
models = ['AHNet', 'UNETR']

In [27]:
# Average
df = ensemble_inference(models, test_df, avg, threshold=0.4)
df.to_csv(f'{path}ensemble_avg.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.605769,0.643667,0.65036,0.508975,13.677419,10.612903,150.774194,8.709677,8.709677,5.129032,17663.193548,17646.387097,15357.774194,17105.387097,17985.870968,13810.806452
std,0.265101,0.256697,0.265615,0.306157,11.341332,9.376131,243.93916,9.103822,9.103822,4.951376,26408.863503,26472.33078,24240.955873,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,44.0,44.0,0.0
25%,0.4485,0.509076,0.484161,0.232632,6.0,4.0,15.5,2.0,2.0,1.5,880.0,851.5,517.0,802.0,802.0,410.0
50%,0.630689,0.637417,0.718654,0.551874,11.0,8.0,49.0,6.0,6.0,4.0,4740.0,4738.0,2982.0,3847.0,4138.0,1944.0
75%,0.843353,0.869894,0.893236,0.770017,18.0,13.5,153.5,13.5,13.5,7.0,32329.0,32220.0,30417.0,26305.0,27036.5,23520.5
max,0.93977,0.948976,0.951038,0.919298,44.0,35.0,1049.0,43.0,43.0,21.0,100554.0,100044.0,91481.0,102770.0,105729.0,92221.0


In [28]:
# Weighted Average
df = ensemble_inference(models, test_df, weigthed_avg, threshold=0.4)
df.to_csv(f'{path}/ensemble_weighted_avg.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.605617,0.643362,0.650314,0.508993,13.677419,10.677419,150.806452,8.709677,8.709677,5.129032,17661.580645,17644.580645,15354.903226,17105.387097,17985.870968,13810.806452
std,0.26519,0.256849,0.265586,0.306192,11.232059,9.47413,244.045818,9.103822,9.103822,4.951376,26406.623136,26470.523231,24237.338339,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,44.0,44.0,0.0
25%,0.446578,0.504957,0.484027,0.232324,6.0,4.0,15.0,2.0,2.0,1.5,879.5,852.5,517.0,802.0,802.0,410.0
50%,0.630594,0.637417,0.71701,0.551845,11.0,8.0,49.0,6.0,6.0,4.0,4741.0,4737.0,2979.0,3847.0,4138.0,1944.0
75%,0.843416,0.86985,0.893239,0.770212,18.0,14.0,154.0,13.5,13.5,7.0,32332.0,32218.0,30411.5,26305.0,27036.5,23520.5
max,0.939784,0.948976,0.951051,0.919326,43.0,36.0,1048.0,43.0,43.0,21.0,100543.0,100034.0,91456.0,102770.0,105729.0,92221.0


In [29]:
# Median
df = ensemble_inference(models, test_df, median, threshold=0.4)
df.to_csv(f'{path}/ensemble_median.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.538917,0.571657,0.574137,0.463418,5.870968,5.258065,8.129032,8.709677,8.709677,5.129032,14391.806452,14530.451613,12065.064516,17105.387097,17985.870968,13810.806452
std,0.303437,0.29496,0.304586,0.338225,5.920259,5.409669,9.969092,9.103822,9.103822,4.951376,23191.244583,23455.686142,20697.480201,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.317495,0.38559,0.374568,0.12327,2.0,2.0,2.5,2.0,2.0,1.5,465.5,463.0,256.0,802.0,802.0,410.0
50%,0.509097,0.561654,0.563958,0.432812,5.0,4.0,5.0,6.0,6.0,4.0,2746.0,2909.0,2021.0,3847.0,4138.0,1944.0
75%,0.83634,0.861178,0.881641,0.770279,7.0,5.5,10.0,13.5,13.5,7.0,26027.0,26167.0,23096.0,26305.0,27036.5,23520.5
max,0.950051,0.955172,0.957395,0.937587,26.0,23.0,42.0,43.0,43.0,21.0,87741.0,87904.0,77850.0,102770.0,105729.0,92221.0


In [30]:
# Maximum Confidence
df = ensemble_inference(models, test_df, max_confidence, threshold=0.4)
df.to_csv(f'{path}ensemble_max_confidence.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.591759,0.629597,0.637825,0.504159,19.483871,17.580645,89.0,8.709677,8.709677,5.129032,19367.322581,19508.0,17485.83871,17105.387097,17985.870968,13810.806452
std,0.26876,0.261602,0.272585,0.292832,12.935406,13.180223,152.3986,9.103822,9.103822,4.951376,28103.705865,28352.972036,26499.924684,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,5.0,2.0,0.0,44.0,44.0,0.0
25%,0.371002,0.466939,0.458415,0.260921,9.0,8.0,11.0,2.0,2.0,1.5,1158.0,1172.5,745.5,802.0,802.0,410.0
50%,0.657639,0.691783,0.743555,0.578174,16.0,12.0,25.0,6.0,6.0,4.0,5192.0,5246.0,3982.0,3847.0,4138.0,1944.0
75%,0.827439,0.856369,0.866257,0.758419,29.0,26.5,82.0,13.5,13.5,7.0,35242.0,35306.0,33239.0,26305.0,27036.5,23520.5
max,0.924995,0.937072,0.938124,0.899788,53.0,52.0,633.0,43.0,43.0,21.0,106717.0,107006.0,99950.0,102770.0,105729.0,92221.0


In [31]:
# Majority Voting
majority_vote_df = ensemble_inference(models, test_df, majority_voting)
majority_vote_df.to_csv(f'{path}ensemble_majority_vote.csv', index=False)
majority_vote_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.554642,0.583924,0.590844,0.477403,6.193548,5.645161,8.354839,8.709677,8.709677,5.129032,14988.645161,15144.580645,12698.645161,17105.387097,17985.870968,13810.806452
std,0.29449,0.288909,0.29334,0.330313,6.4209,5.885283,10.731724,9.103822,9.103822,4.951376,24014.155578,24282.418701,21525.714756,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.367777,0.442381,0.404659,0.149619,2.0,2.0,2.0,2.0,2.0,1.5,507.5,507.5,317.5,802.0,802.0,410.0
50%,0.510321,0.567679,0.57028,0.454534,5.0,4.0,4.0,6.0,6.0,4.0,3098.0,3272.0,2246.0,3847.0,4138.0,1944.0
75%,0.840903,0.866662,0.884677,0.775159,7.5,6.0,9.0,13.5,13.5,7.0,27495.0,27667.0,25094.5,26305.0,27036.5,23520.5
max,0.947089,0.952529,0.955163,0.933574,30.0,27.0,43.0,43.0,43.0,21.0,91055.0,91149.0,81208.0,102770.0,105729.0,92221.0


In [32]:
# Dempter-Shafer
ds_df = ensemble_inference(models, test_df, dempster_shafer_inference, threshold=0.1)
ds_df.to_csv(f'{path}ensemble_dempster_shafer.csv', index=False)
ds_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.577296,0.603725,0.616029,0.494138,6.774194,5.967742,11.451613,8.709677,8.709677,5.129032,16001.16129,16154.677419,13782.0,17105.387097,17985.870968,13810.806452
std,0.283266,0.279498,0.281408,0.319218,6.571198,5.805077,15.037373,9.103822,9.103822,4.951376,25349.161554,25590.406769,22925.056378,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.40939,0.514237,0.511225,0.181478,2.5,2.5,3.5,2.0,2.0,1.5,575.5,582.5,424.5,802.0,802.0,410.0
50%,0.551461,0.598369,0.601221,0.493097,5.0,5.0,6.0,6.0,6.0,4.0,3697.0,3632.0,2629.0,3847.0,4138.0,1944.0
75%,0.845099,0.874625,0.886767,0.776769,7.5,6.5,11.5,13.5,13.5,7.0,30352.0,30570.0,28244.5,26305.0,27036.5,23520.5
max,0.940381,0.946734,0.94931,0.9251,25.0,26.0,58.0,43.0,43.0,21.0,96632.0,96629.0,86888.0,102770.0,105729.0,92221.0


In [33]:
# CRF - Mean Field Approximation
df = ensemble_inference(models, test_df, crf_mean, threshold=0.4)
df.to_csv(f'{path}ensemble_crf_mean.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.605497,0.641875,0.651633,0.506557,13.741935,9.83871,162.16129,8.709677,8.709677,5.129032,17362.645161,17317.677419,14967.612903,17105.387097,17985.870968,13810.806452
std,0.265536,0.257429,0.263442,0.308732,12.521895,8.165361,263.428054,9.103822,9.103822,4.951376,26069.468031,26089.228239,23814.192373,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,44.0,44.0,0.0
25%,0.460574,0.521829,0.512217,0.224473,4.5,4.5,15.5,2.0,2.0,1.5,834.5,809.5,477.0,802.0,802.0,410.0
50%,0.622315,0.642066,0.717188,0.550936,10.0,7.0,47.0,6.0,6.0,4.0,4672.0,4601.0,2860.0,3847.0,4138.0,1944.0
75%,0.842275,0.870411,0.895477,0.772928,18.5,13.0,163.5,13.5,13.5,7.0,31703.5,31590.0,29868.0,26305.0,27036.5,23520.5
max,0.942068,0.950584,0.952914,0.922705,50.0,31.0,1176.0,43.0,43.0,21.0,99290.0,98733.0,89919.0,102770.0,105729.0,92221.0


In [34]:
# CRF - Energy Minimization
df = ensemble_inference(models, test_df, crf_energy, threshold=0.4)
df.to_csv(f'{path}ensemble_crf_energy.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.605683,0.641933,0.652164,0.506371,13.580645,9.903226,160.741935,8.709677,8.709677,5.129032,17350.580645,17293.774194,14975.451613,17105.387097,17985.870968,13810.806452
std,0.265386,0.257282,0.263292,0.308636,12.510727,8.580423,260.579862,9.103822,9.103822,4.951376,26070.477501,26083.192592,23822.433151,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,44.0,44.0,0.0
25%,0.465194,0.526412,0.5259,0.224549,4.5,3.5,15.5,2.0,2.0,1.5,831.5,803.5,476.5,802.0,802.0,410.0
50%,0.622171,0.643253,0.713107,0.550534,10.0,7.0,48.0,6.0,6.0,4.0,4647.0,4595.0,2864.0,3847.0,4138.0,1944.0
75%,0.842798,0.870454,0.895478,0.772731,18.0,13.0,161.0,13.5,13.5,7.0,31703.5,31588.0,29882.5,26305.0,27036.5,23520.5
max,0.942006,0.950542,0.95294,0.922537,50.0,31.0,1151.0,43.0,43.0,21.0,99268.0,98665.0,89945.0,102770.0,105729.0,92221.0


In [35]:
# Data
transforms = Transforms(seed=33)
test_dataset = EnsembleDataset(test_df.drop(columns = ['SubjectID']), transform=transforms.val_ensemble(), size = None, include_models=[True, False, False, True])
test_sampler = SequentialSampler(test_dataset)
test_loader = DataLoader(test_dataset, batch_size = 1, shuffle = False, sampler = test_sampler)
subjects = pd.read_csv('./data/TEST.csv')['SubjectID'].values

# Model
model = LogisticRegression(6, 3)
model.load_state_dict(torch.load(f'{path}LogRegCheckpoints/LogisticRegression_4.pt'))

#---------------------------------------------------------#

# Logistic Regression - No TTA
df = model_ensemble_inference(subjects, test_loader, model, spatial_size=6, threshold=0.9)
df.to_csv(f'{path}ensemble_logreg.csv', index=False)
df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.583115,0.62804,0.640094,0.477012,13.870968,12.612903,12.483871,8.709677,8.709677,5.129032,19690.322581,19310.193548,18150.483871,17105.387097,17985.870968,13810.806452
std,0.264822,0.267097,0.274947,0.276754,11.194469,11.692383,15.244389,9.103822,9.103822,4.951376,28819.670283,28435.020747,27167.19043,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,3.0,3.0,0.0,44.0,44.0,0.0
25%,0.392455,0.470563,0.473526,0.255282,7.5,4.5,4.0,2.0,2.0,1.5,1108.0,1056.5,755.0,802.0,802.0,410.0
50%,0.649034,0.673607,0.717494,0.511213,10.0,9.0,9.0,6.0,6.0,4.0,5276.0,5139.0,4812.0,3847.0,4138.0,1944.0
75%,0.8106,0.858849,0.878384,0.698885,16.0,15.0,14.0,13.5,13.5,7.0,36119.5,35745.0,34043.0,26305.0,27036.5,23520.5
max,0.916859,0.928759,0.93452,0.887298,52.0,55.0,81.0,43.0,43.0,21.0,108875.0,107411.0,102438.0,102770.0,105729.0,92221.0
