# Ensemble Techniques

## Environment Setup

### Imports

In [1]:
import os
os.sys.path.append("utils")

import torch
import numpy as np
import pandas as pd
from monai.transforms import AsDiscrete
from utils.Inference import ensemble_inference

### Config

In [2]:
seed = 33
pd.set_option("display.max_columns", None)
pd.set_option('display.max_colwidth', None)

## Load Test Data

In [3]:
test_pred_segs_paths = './outputs/AHNet/pred_segs/test_pred_segs'
test_gt_segs_paths = './outputs/gt_segs/test_gt_segs'
subject_ids = pd.read_csv('./data/TEST.csv')['SubjectID'].values

ah_segs, unet_segs, segresnet_segs, untr_segs, gt_segs = [], [], [], [], []
for sid in subject_ids:
    ah_channels, unet_channels, segresnet_channels, untr_channels, gt_channels = [], [], [], [], []
    for channel in ['TC', 'WT', 'ET']:
        ah_channels.append(f'./outputs/AHNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.pt')
        unet_channels.append(f'./outputs/UNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.pt')
        segresnet_channels.append(f'./outputs/SegResNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.pt')
        untr_channels.append(f'./outputs/UNETR/pred_segs/test_pred_segs/pred_{sid}_{channel}.pt')
        gt_channels.append(f'./outputs/gt_segs/test_gt_segs/gt_{sid}_{channel}.pt')
    
    ah_segs.append(ah_channels)
    unet_segs.append(unet_channels)
    segresnet_segs.append(segresnet_channels)
    untr_segs.append(untr_channels)
    gt_segs.append(gt_channels)

# Dataframe
data_df = pd.DataFrame()
data_df['SubjectID'] = subject_ids
data_df['AHNet'] = ah_segs
data_df['UNet'] = unet_segs
data_df['SegResNet'] = segresnet_segs
data_df['UNETR'] = untr_segs
data_df['GT'] = gt_segs

data_df.head()

Unnamed: 0,SubjectID,AHNet,UNet,SegResNet,UNETR,GT
0,100214B,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_TC.pt, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_WT.pt, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_ET.pt]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100214B_TC.pt, ./outputs/UNet/pred_segs/test_pred_segs/pred_100214B_WT.pt, ./outputs/UNet/pred_segs/test_pred_segs/pred_100214B_ET.pt]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_TC.pt, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_WT.pt, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_ET.pt]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_TC.pt, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_WT.pt, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_ET.pt]","[./outputs/gt_segs/test_gt_segs/gt_100214B_TC.pt, ./outputs/gt_segs/test_gt_segs/gt_100214B_WT.pt, ./outputs/gt_segs/test_gt_segs/gt_100214B_ET.pt]"
1,100340A,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_TC.pt, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_WT.pt, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_ET.pt]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100340A_TC.pt, ./outputs/UNet/pred_segs/test_pred_segs/pred_100340A_WT.pt, ./outputs/UNet/pred_segs/test_pred_segs/pred_100340A_ET.pt]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_TC.pt, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_WT.pt, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_ET.pt]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_TC.pt, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_WT.pt, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_ET.pt]","[./outputs/gt_segs/test_gt_segs/gt_100340A_TC.pt, ./outputs/gt_segs/test_gt_segs/gt_100340A_WT.pt, ./outputs/gt_segs/test_gt_segs/gt_100340A_ET.pt]"
2,100391A,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_TC.pt, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_WT.pt, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_ET.pt]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100391A_TC.pt, ./outputs/UNet/pred_segs/test_pred_segs/pred_100391A_WT.pt, ./outputs/UNet/pred_segs/test_pred_segs/pred_100391A_ET.pt]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_TC.pt, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_WT.pt, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_ET.pt]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_TC.pt, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_WT.pt, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_ET.pt]","[./outputs/gt_segs/test_gt_segs/gt_100391A_TC.pt, ./outputs/gt_segs/test_gt_segs/gt_100391A_WT.pt, ./outputs/gt_segs/test_gt_segs/gt_100391A_ET.pt]"
3,100190B,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_TC.pt, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_WT.pt, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_ET.pt]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100190B_TC.pt, ./outputs/UNet/pred_segs/test_pred_segs/pred_100190B_WT.pt, ./outputs/UNet/pred_segs/test_pred_segs/pred_100190B_ET.pt]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_TC.pt, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_WT.pt, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_ET.pt]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_TC.pt, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_WT.pt, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_ET.pt]","[./outputs/gt_segs/test_gt_segs/gt_100190B_TC.pt, ./outputs/gt_segs/test_gt_segs/gt_100190B_WT.pt, ./outputs/gt_segs/test_gt_segs/gt_100190B_ET.pt]"
4,100142A,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_TC.pt, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_WT.pt, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_ET.pt]","[./outputs/UNet/pred_segs/test_pred_segs/pred_100142A_TC.pt, ./outputs/UNet/pred_segs/test_pred_segs/pred_100142A_WT.pt, ./outputs/UNet/pred_segs/test_pred_segs/pred_100142A_ET.pt]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_TC.pt, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_WT.pt, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_ET.pt]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_TC.pt, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_WT.pt, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_ET.pt]","[./outputs/gt_segs/test_gt_segs/gt_100142A_TC.pt, ./outputs/gt_segs/test_gt_segs/gt_100142A_WT.pt, ./outputs/gt_segs/test_gt_segs/gt_100142A_ET.pt]"


## Ensemble 1 - Average

In [4]:
def avg(predictions):
    stacked_predictions = torch.stack(predictions, dim=0)
    return torch.mean(stacked_predictions, dim=0)

#---------------------------------------------------------#

best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    avg_df = ensemble_inference(data_df, avg, threshold)
    mean = avg_df['Dice'].mean()
    print(f'Average Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = avg_df
print(f'BEST THRESHOLD: {best_threshold}')
df.to_csv('./outputs/Ensemble/avg_test_scores.csv', index=False)
df.describe()

Average Mean Dice for threshold 0.1: 0.5142114165449335
Average Mean Dice for threshold 0.2: 0.5262180840401279
Average Mean Dice for threshold 0.3: 0.6198014113451203
Average Mean Dice for threshold 0.4: 0.6305191401061753
Average Mean Dice for threshold 0.5: 0.6116523122595202
Average Mean Dice for threshold 0.6: 0.5820567079128758
Average Mean Dice for threshold 0.7: 0.5514330773223792
Average Mean Dice for threshold 0.8: 0.4948232361385899
Average Mean Dice for threshold 0.9: 0.44838402900964985
BEST THRESHOLD: 0.4


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.630519,0.662353,0.673632,0.532945,6.645161,6.451613,86.935484,8.709677,8.709677,5.129032,16601.322581,16599.741935,14583.129032,17103.677419,17984.16129,13810.870968
std,0.259899,0.255405,0.260923,0.293641,5.8966,5.971257,166.053593,9.103822,9.103822,4.951376,24177.018311,24214.837841,22140.608537,25564.191374,26911.703415,22937.67913
min,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,7.0,6.0,0.0,44.0,44.0,0.0
25%,0.48447,0.554301,0.558237,0.282204,2.5,2.5,6.5,2.0,2.0,1.5,1069.5,1056.0,672.0,802.0,802.0,410.0
50%,0.67684,0.717632,0.698113,0.542388,5.0,5.0,14.0,6.0,6.0,4.0,4993.0,4986.0,3298.0,3847.0,4138.0,1944.0
75%,0.842653,0.869613,0.885937,0.76938,8.0,8.0,65.0,13.5,13.5,7.0,29460.5,29426.0,27984.5,26305.0,27036.5,23520.5
max,0.946255,0.955846,0.958104,0.924816,25.0,25.0,684.0,43.0,43.0,21.0,90746.0,90706.0,83437.0,102770.0,105729.0,92221.0


## Ensemble 2 - Weighted Average (Mean Dice Weights)

In [6]:
def weigthed_avg(predictions):

    # Weights form validation set AHNet, SegResNet, UNet, UNETR
    weights = np.array([0.546402, 0.648790, 0.428639, 0.547957])
    weights = weights / np.sum(weights)
    weights = torch.tensor(weights).float().cuda()

    # Weighted average
    stacked_predictions = torch.stack(predictions, dim=0)
    return torch.sum(stacked_predictions * weights[:, None, None, None, None, None], dim=0)
     
#---------------------------------------------------------#

best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    weigthed_avg_df = ensemble_inference(data_df, weigthed_avg, threshold)
    mean = weigthed_avg_df['Dice'].mean()
    print(f'Weighted Average Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = weigthed_avg_df
print(f'BEST THRESHOLD: {best_threshold}')
df.to_csv('./outputs/Ensemble/weighted_avg_test_scores.csv', index=False)
df.describe()

Weighted Average Mean Dice for threshold 0.1: 0.5136893866884131
Weighted Average Mean Dice for threshold 0.2: 0.5755864612517818
Weighted Average Mean Dice for threshold 0.3: 0.6318546302856938
Weighted Average Mean Dice for threshold 0.4: 0.6359359584147891
Weighted Average Mean Dice for threshold 0.5: 0.6332755531034162
Weighted Average Mean Dice for threshold 0.6: 0.5931168280301555
Weighted Average Mean Dice for threshold 0.7: 0.5560377530993954
Weighted Average Mean Dice for threshold 0.8: 0.49385631973704985
Weighted Average Mean Dice for threshold 0.9: 0.4451204692163775
BEST THRESHOLD: 0.4


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.635936,0.667285,0.679758,0.538886,6.935484,6.451613,109.612903,8.709677,8.709677,5.129032,16637.387097,16643.096774,14440.064516,17103.677419,17984.16129,13810.870968
std,0.257522,0.252134,0.256506,0.295454,6.021824,5.886927,203.173928,9.103822,9.103822,4.951376,24321.068977,24386.495176,22136.931103,25564.191374,26911.703415,22937.67913
min,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,6.0,5.0,0.0,44.0,44.0,0.0
25%,0.498541,0.564774,0.570827,0.290279,2.5,2.0,7.5,2.0,2.0,1.5,1049.0,1037.5,637.5,802.0,802.0,410.0
50%,0.67824,0.723906,0.727524,0.549588,5.0,5.0,21.0,6.0,6.0,4.0,4940.0,4945.0,3194.0,3847.0,4138.0,1944.0
75%,0.846335,0.871845,0.890754,0.775534,8.5,8.0,103.5,13.5,13.5,7.0,29594.0,29571.5,27916.5,26305.0,27036.5,23520.5
max,0.947493,0.956289,0.958427,0.927762,25.0,25.0,922.0,43.0,43.0,21.0,91540.0,91606.0,83508.0,102770.0,105729.0,92221.0


## Ensemble 3 - Median 

In [7]:
def median(predictions):
    stacked_predictions = torch.stack(predictions, dim=0)
    return torch.median(stacked_predictions, dim=0).values

#---------------------------------------------------------#

best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    median_df = ensemble_inference(data_df, median, threshold)
    mean = median_df['Dice'].mean()
    print(f'Median Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = median_df
print(f'BEST THRESHOLD: {best_threshold}')
df.to_csv('./outputs/Ensemble/median_test_scores.csv', index=False)
df.describe()

Median Mean Dice for threshold 0.1: 0.6094920510245908
Median Mean Dice for threshold 0.2: 0.6115314936445605
Median Mean Dice for threshold 0.3: 0.6017218155245627
Median Mean Dice for threshold 0.4: 0.5915184443996798
Median Mean Dice for threshold 0.5: 0.5804547960719755
Median Mean Dice for threshold 0.6: 0.5693406931334927
Median Mean Dice for threshold 0.7: 0.5551250854567173
Median Mean Dice for threshold 0.8: 0.5364319516886626
Median Mean Dice for threshold 0.9: 0.5016437357952518
BEST THRESHOLD: 0.2


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.611531,0.638636,0.649176,0.53103,6.225806,5.967742,7.096774,8.709677,8.709677,5.129032,15885.903226,16009.16129,13580.322581,17103.677419,17984.16129,13810.870968
std,0.274064,0.269728,0.274676,0.306877,5.903161,5.770522,8.150071,9.103822,9.103822,4.951376,24052.275952,24273.456627,21565.561,25564.191374,26911.703415,22937.67913
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.48573,0.508654,0.520046,0.242661,2.5,2.0,2.0,2.0,2.0,1.5,955.5,949.0,599.0,802.0,802.0,410.0
50%,0.635331,0.679783,0.72,0.571091,4.0,4.0,5.0,6.0,6.0,4.0,3859.0,3888.0,2995.0,3847.0,4138.0,1944.0
75%,0.850039,0.868642,0.88466,0.785967,8.0,7.5,9.0,13.5,13.5,7.0,27416.0,27688.0,25991.5,26305.0,27036.5,23520.5
max,0.947509,0.954614,0.956809,0.932726,25.0,25.0,35.0,43.0,43.0,21.0,90012.0,90697.0,81341.0,102770.0,105729.0,92221.0


## Ensemble 4 - Maximum Confidence 

In [8]:
def max_confidence(predictions):
    stacked_predictions = torch.stack(predictions, dim=0)
    return torch.max(stacked_predictions, dim=0).values

#---------------------------------------------------------#

best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    max_confidence_df = ensemble_inference(data_df, max_confidence, threshold)
    mean = max_confidence_df['Dice'].mean()
    print(f'Max Confidence Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = max_confidence_df
print(f'BEST THRESHOLD: {best_threshold}')
df.to_csv('./outputs/Ensemble/max_confidence_test_scores.csv', index=False)
df.describe()

Max Confidence Mean Dice for threshold 0.1: 0.4763980285234509
Max Confidence Mean Dice for threshold 0.2: 0.5011439663808672
Max Confidence Mean Dice for threshold 0.3: 0.5124773197777329
Max Confidence Mean Dice for threshold 0.4: 0.5188757432174058
Max Confidence Mean Dice for threshold 0.5: 0.5226507567830624
Max Confidence Mean Dice for threshold 0.6: 0.5243801138544034
Max Confidence Mean Dice for threshold 0.7: 0.5248255786080394
Max Confidence Mean Dice for threshold 0.8: 0.5232918485266818
Max Confidence Mean Dice for threshold 0.9: 0.5186530648310098
BEST THRESHOLD: 0.7


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.524826,0.562644,0.574504,0.43674,27.225806,22.870968,39.741935,8.709677,8.709677,5.129032,19742.032258,19628.967742,19364.16129,17103.677419,17984.16129,13810.870968
std,0.319932,0.325454,0.334434,0.309956,50.699579,40.858897,63.172762,9.103822,9.103822,4.951376,25787.326831,25712.940527,25646.224809,25564.191374,26911.703415,22937.67913
min,0.002576,0.005687,0.00204,0.0,1.0,2.0,1.0,1.0,1.0,0.0,1821.0,1813.0,1750.0,44.0,44.0,0.0
25%,0.218913,0.248949,0.249116,0.133874,8.5,7.0,10.0,2.0,2.0,1.5,3700.5,3664.0,3465.5,802.0,802.0,410.0
50%,0.632381,0.661653,0.712883,0.479532,13.0,11.0,19.0,6.0,6.0,4.0,6957.0,6894.0,5875.0,3847.0,4138.0,1944.0
75%,0.794857,0.850547,0.87444,0.692742,22.0,19.5,30.0,13.5,13.5,7.0,32451.0,32274.5,32699.5,26305.0,27036.5,23520.5
max,0.932704,0.950421,0.953007,0.894683,266.0,218.0,307.0,43.0,43.0,21.0,99301.0,98700.0,97619.0,102770.0,105729.0,92221.0


## Ensemble 5 - Majority Voting

In [10]:
def majority_voting(predictions):
    transforms =[AsDiscrete(threshold=0.3), AsDiscrete(threshold=0.5), AsDiscrete(threshold=0.3), AsDiscrete(threshold=0.4)]
    predictions = [transform(prediction) for transform, prediction in zip(transforms, predictions)]
    stacked_predictions = torch.stack(predictions, dim=0)  
    return torch.mode(stacked_predictions, dim=0).values

#---------------------------------------------------------#

majority_vote_df = ensemble_inference(data_df, majority_voting, threshold)
majority_vote_df.to_csv('./outputs/Ensemble/majority_vote_test_scores.csv', index=False)
majority_vote_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.596554,0.629098,0.63158,0.51832,5.83871,5.612903,7.032258,8.709677,8.709677,5.129032,14509.258065,14605.225806,12013.935484,17103.677419,17984.16129,13810.870968
std,0.282648,0.273798,0.28481,0.32015,5.721869,5.595697,7.968621,9.103822,9.103822,4.951376,22083.541983,22288.719447,19507.19584,25564.191374,26911.703415,22937.67913
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.405486,0.476419,0.476057,0.193824,2.0,2.0,2.0,2.0,2.0,1.5,740.0,731.5,465.5,802.0,802.0,410.0
50%,0.654706,0.685359,0.715706,0.594623,4.0,4.0,6.0,6.0,6.0,4.0,3544.0,3557.0,2569.0,3847.0,4138.0,1944.0
75%,0.832263,0.866253,0.873796,0.795604,8.0,7.5,9.0,13.5,13.5,7.0,24550.0,24748.5,21871.5,26305.0,27036.5,23520.5
max,0.951277,0.957172,0.959179,0.937479,22.0,23.0,37.0,43.0,43.0,21.0,82028.0,82219.0,73217.0,102770.0,105729.0,92221.0


## Ensemble 6 - 

In [11]:
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral, create_pairwise_gaussian

def apply_crf(image, predictions):

    iterations=10

    combined_probs = torch.stack(predictions, dim=0)

    num_models, _, num_classes, x, y, z = combined_probs.shape

    # Convert image and combined_probs to numpy
    image_np = image.cpu().numpy().astype(np.float32)
    combined_probs_np = combined_probs.mean(dim=0).cpu().numpy()  # Shape: (num_classes, x, y, z)

    # Initialize CRF model
    d = dcrf.DenseCRF2D(x, y * z, num_classes)  # Width, Height*Depth, NumClasses
    
    # Reshape combined_probs for unary energy
    combined_probs_np = combined_probs_np.reshape((num_classes, -1))  # Shape: (num_classes, x*y*z)
    
    # Unary potentials
    unary = unary_from_softmax(combined_probs_np)
    d.setUnaryEnergy(unary)
    
    # Pairwise potentials
    feats = create_pairwise_gaussian(sdims=(10, 10), shape=(x, y, z))
    d.addPairwiseEnergy(feats, compat=3)
    
    feats = create_pairwise_bilateral(sdims=(10, 10), schan=0.01, img=image_np, chdim=0)
    d.addPairwiseEnergy(feats, compat=10)
    
    # Run inference
    Q = d.inference(iterations)
    result = np.argmax(Q, axis=0).reshape((x, y, z))
    
    return torch.tensor(result)


#---------------------------------------------------------#

best_mean = 0
best_threshold = 0
df = None
for threshold in [0.5]:
    crf_df = ensemble_inference(data_df, apply_crf, threshold, apply_crf=True)
    mean = crf_df['Dice'].mean()
    print(f'CRF Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = crf_df
print(f'BEST THRESHOLD: {best_threshold}')
df.to_csv('./outputs/Ensemble/crf_test_scores.csv', index=False)
df.describe()

ValueError: Bad shape for pairwise energy (Need (?, 7550400), got (2, 11325600))

In [None]:
from sklearn.ensemble import RandomForestClassifier

def random_forest_ensemble(predictions, rf_model):
    """
    Uses a Random Forest to combine predictions voxel-wise.

    Args:
    predictions (list of torch.Tensor): List of prediction tensors of shape (3, x, y, z).
    rf_model: A trained Random Forest model.

    Returns:
    torch.Tensor: Ensemble prediction of shape (3, x, y, z).
    """
    pred_shape = predictions[0].shape
    flat_predictions = torch.cat([pred.view(3, -1).T for pred in predictions], dim=1).cpu().numpy()
    rf_prediction = rf_model.predict(flat_predictions)
    return torch.tensor(rf_prediction).view(pred_shape[1:])
