# Test Predictions Generator

## Environment Set Up

### Imports

In [1]:
# Data / MONAI
import pandas as pd

# Utils
from Inference import calculate_metrics

### Config

In [2]:
seed = 33
pd.set_option("display.max_columns", None)
pd.set_option('display.max_colwidth', None)

## Load Val Data

In [3]:
subject_ids = pd.read_csv('../data/VAL.csv')['SubjectID'].values

ah_segs, unet_segs, segresnet_segs, untr_segs, gt_segs = [], [], [], [], []
for sid in subject_ids:
    ah_channels, unet_channels, segresnet_channels, untr_channels, gt_channels = [], [], [], [], []
    for channel in ['TC', 'WT', 'ET']:
        ah_channels.append(f'../outputs/AHNet/pred_segs/val_pred_segs/pred_{sid}_{channel}.pt')
        unet_channels.append(f'../outputs/UNet/pred_segs/val_pred_segs/pred_{sid}_{channel}.pt')
        segresnet_channels.append(f'../outputs/SegResNet/pred_segs/val_pred_segs/pred_{sid}_{channel}.pt')
        untr_channels.append(f'../outputs/UNETR/pred_segs/val_pred_segs/pred_{sid}_{channel}.pt')
        gt_channels.append(f'../outputs/gt_segs/val_gt_segs/gt_{sid}_{channel}.pt')
    
    ah_segs.append(ah_channels)
    unet_segs.append(unet_channels)
    segresnet_segs.append(segresnet_channels)
    untr_segs.append(untr_channels)
    gt_segs.append(gt_channels)

# Dataframe
val_df = pd.DataFrame()
val_df['SubjectID'] = subject_ids
val_df['AHNet'] = ah_segs
val_df['UNet'] = unet_segs
val_df['SegResNet'] = segresnet_segs
val_df['UNETR'] = untr_segs
val_df['GT'] = gt_segs

val_df.head()

Unnamed: 0,SubjectID,AHNet,UNet,SegResNet,UNETR,GT
0,100237A,"[../outputs/AHNet/pred_segs/val_pred_segs/pred_100237A_TC.pt, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100237A_WT.pt, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100237A_ET.pt]","[../outputs/UNet/pred_segs/val_pred_segs/pred_100237A_TC.pt, ../outputs/UNet/pred_segs/val_pred_segs/pred_100237A_WT.pt, ../outputs/UNet/pred_segs/val_pred_segs/pred_100237A_ET.pt]","[../outputs/SegResNet/pred_segs/val_pred_segs/pred_100237A_TC.pt, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100237A_WT.pt, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100237A_ET.pt]","[../outputs/UNETR/pred_segs/val_pred_segs/pred_100237A_TC.pt, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100237A_WT.pt, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100237A_ET.pt]","[../outputs/gt_segs/val_gt_segs/gt_100237A_TC.pt, ../outputs/gt_segs/val_gt_segs/gt_100237A_WT.pt, ../outputs/gt_segs/val_gt_segs/gt_100237A_ET.pt]"
1,100219A,"[../outputs/AHNet/pred_segs/val_pred_segs/pred_100219A_TC.pt, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100219A_WT.pt, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100219A_ET.pt]","[../outputs/UNet/pred_segs/val_pred_segs/pred_100219A_TC.pt, ../outputs/UNet/pred_segs/val_pred_segs/pred_100219A_WT.pt, ../outputs/UNet/pred_segs/val_pred_segs/pred_100219A_ET.pt]","[../outputs/SegResNet/pred_segs/val_pred_segs/pred_100219A_TC.pt, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100219A_WT.pt, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100219A_ET.pt]","[../outputs/UNETR/pred_segs/val_pred_segs/pred_100219A_TC.pt, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100219A_WT.pt, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100219A_ET.pt]","[../outputs/gt_segs/val_gt_segs/gt_100219A_TC.pt, ../outputs/gt_segs/val_gt_segs/gt_100219A_WT.pt, ../outputs/gt_segs/val_gt_segs/gt_100219A_ET.pt]"
2,100363A,"[../outputs/AHNet/pred_segs/val_pred_segs/pred_100363A_TC.pt, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100363A_WT.pt, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100363A_ET.pt]","[../outputs/UNet/pred_segs/val_pred_segs/pred_100363A_TC.pt, ../outputs/UNet/pred_segs/val_pred_segs/pred_100363A_WT.pt, ../outputs/UNet/pred_segs/val_pred_segs/pred_100363A_ET.pt]","[../outputs/SegResNet/pred_segs/val_pred_segs/pred_100363A_TC.pt, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100363A_WT.pt, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100363A_ET.pt]","[../outputs/UNETR/pred_segs/val_pred_segs/pred_100363A_TC.pt, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100363A_WT.pt, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100363A_ET.pt]","[../outputs/gt_segs/val_gt_segs/gt_100363A_TC.pt, ../outputs/gt_segs/val_gt_segs/gt_100363A_WT.pt, ../outputs/gt_segs/val_gt_segs/gt_100363A_ET.pt]"
3,100354A,"[../outputs/AHNet/pred_segs/val_pred_segs/pred_100354A_TC.pt, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100354A_WT.pt, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100354A_ET.pt]","[../outputs/UNet/pred_segs/val_pred_segs/pred_100354A_TC.pt, ../outputs/UNet/pred_segs/val_pred_segs/pred_100354A_WT.pt, ../outputs/UNet/pred_segs/val_pred_segs/pred_100354A_ET.pt]","[../outputs/SegResNet/pred_segs/val_pred_segs/pred_100354A_TC.pt, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100354A_WT.pt, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100354A_ET.pt]","[../outputs/UNETR/pred_segs/val_pred_segs/pred_100354A_TC.pt, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100354A_WT.pt, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100354A_ET.pt]","[../outputs/gt_segs/val_gt_segs/gt_100354A_TC.pt, ../outputs/gt_segs/val_gt_segs/gt_100354A_WT.pt, ../outputs/gt_segs/val_gt_segs/gt_100354A_ET.pt]"
4,100303A,"[../outputs/AHNet/pred_segs/val_pred_segs/pred_100303A_TC.pt, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100303A_WT.pt, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100303A_ET.pt]","[../outputs/UNet/pred_segs/val_pred_segs/pred_100303A_TC.pt, ../outputs/UNet/pred_segs/val_pred_segs/pred_100303A_WT.pt, ../outputs/UNet/pred_segs/val_pred_segs/pred_100303A_ET.pt]","[../outputs/SegResNet/pred_segs/val_pred_segs/pred_100303A_TC.pt, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100303A_WT.pt, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100303A_ET.pt]","[../outputs/UNETR/pred_segs/val_pred_segs/pred_100303A_TC.pt, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100303A_WT.pt, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100303A_ET.pt]","[../outputs/gt_segs/val_gt_segs/gt_100303A_TC.pt, ../outputs/gt_segs/val_gt_segs/gt_100303A_WT.pt, ../outputs/gt_segs/val_gt_segs/gt_100303A_ET.pt]"


## Load Test Data

In [4]:
subject_ids = pd.read_csv('../data/TEST.csv')['SubjectID'].values

ah_segs, unet_segs, segresnet_segs, untr_segs, gt_segs = [], [], [], [], []
for sid in subject_ids:
    ah_channels, unet_channels, segresnet_channels, untr_channels, gt_channels = [], [], [], [], []
    for channel in ['TC', 'WT', 'ET']:
        ah_channels.append(f'../outputs/AHNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.pt')
        unet_channels.append(f'../outputs/UNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.pt')
        segresnet_channels.append(f'../outputs/SegResNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.pt')
        untr_channels.append(f'../outputs/UNETR/pred_segs/test_pred_segs/pred_{sid}_{channel}.pt')
        gt_channels.append(f'../outputs/gt_segs/test_gt_segs/gt_{sid}_{channel}.pt')
    
    ah_segs.append(ah_channels)
    unet_segs.append(unet_channels)
    segresnet_segs.append(segresnet_channels)
    untr_segs.append(untr_channels)
    gt_segs.append(gt_channels)

# Dataframe
test_df = pd.DataFrame()
test_df['SubjectID'] = subject_ids
test_df['AHNet'] = ah_segs
test_df['UNet'] = unet_segs
test_df['SegResNet'] = segresnet_segs
test_df['UNETR'] = untr_segs
test_df['GT'] = gt_segs

test_df.head()

Unnamed: 0,SubjectID,AHNet,UNet,SegResNet,UNETR,GT
0,100214B,"[../outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_TC.pt, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_WT.pt, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_ET.pt]","[../outputs/UNet/pred_segs/test_pred_segs/pred_100214B_TC.pt, ../outputs/UNet/pred_segs/test_pred_segs/pred_100214B_WT.pt, ../outputs/UNet/pred_segs/test_pred_segs/pred_100214B_ET.pt]","[../outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_TC.pt, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_WT.pt, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_ET.pt]","[../outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_TC.pt, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_WT.pt, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_ET.pt]","[../outputs/gt_segs/test_gt_segs/gt_100214B_TC.pt, ../outputs/gt_segs/test_gt_segs/gt_100214B_WT.pt, ../outputs/gt_segs/test_gt_segs/gt_100214B_ET.pt]"
1,100340A,"[../outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_TC.pt, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_WT.pt, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_ET.pt]","[../outputs/UNet/pred_segs/test_pred_segs/pred_100340A_TC.pt, ../outputs/UNet/pred_segs/test_pred_segs/pred_100340A_WT.pt, ../outputs/UNet/pred_segs/test_pred_segs/pred_100340A_ET.pt]","[../outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_TC.pt, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_WT.pt, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_ET.pt]","[../outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_TC.pt, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_WT.pt, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_ET.pt]","[../outputs/gt_segs/test_gt_segs/gt_100340A_TC.pt, ../outputs/gt_segs/test_gt_segs/gt_100340A_WT.pt, ../outputs/gt_segs/test_gt_segs/gt_100340A_ET.pt]"
2,100391A,"[../outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_TC.pt, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_WT.pt, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_ET.pt]","[../outputs/UNet/pred_segs/test_pred_segs/pred_100391A_TC.pt, ../outputs/UNet/pred_segs/test_pred_segs/pred_100391A_WT.pt, ../outputs/UNet/pred_segs/test_pred_segs/pred_100391A_ET.pt]","[../outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_TC.pt, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_WT.pt, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_ET.pt]","[../outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_TC.pt, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_WT.pt, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_ET.pt]","[../outputs/gt_segs/test_gt_segs/gt_100391A_TC.pt, ../outputs/gt_segs/test_gt_segs/gt_100391A_WT.pt, ../outputs/gt_segs/test_gt_segs/gt_100391A_ET.pt]"
3,100190B,"[../outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_TC.pt, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_WT.pt, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_ET.pt]","[../outputs/UNet/pred_segs/test_pred_segs/pred_100190B_TC.pt, ../outputs/UNet/pred_segs/test_pred_segs/pred_100190B_WT.pt, ../outputs/UNet/pred_segs/test_pred_segs/pred_100190B_ET.pt]","[../outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_TC.pt, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_WT.pt, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_ET.pt]","[../outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_TC.pt, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_WT.pt, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_ET.pt]","[../outputs/gt_segs/test_gt_segs/gt_100190B_TC.pt, ../outputs/gt_segs/test_gt_segs/gt_100190B_WT.pt, ../outputs/gt_segs/test_gt_segs/gt_100190B_ET.pt]"
4,100142A,"[../outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_TC.pt, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_WT.pt, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_ET.pt]","[../outputs/UNet/pred_segs/test_pred_segs/pred_100142A_TC.pt, ../outputs/UNet/pred_segs/test_pred_segs/pred_100142A_WT.pt, ../outputs/UNet/pred_segs/test_pred_segs/pred_100142A_ET.pt]","[../outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_TC.pt, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_WT.pt, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_ET.pt]","[../outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_TC.pt, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_WT.pt, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_ET.pt]","[../outputs/gt_segs/test_gt_segs/gt_100142A_TC.pt, ../outputs/gt_segs/test_gt_segs/gt_100142A_WT.pt, ../outputs/gt_segs/test_gt_segs/gt_100142A_ET.pt]"


## Generate Results CSV

### AHNet

In [5]:
best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    # VAL AHNet
    ahnet_df = calculate_metrics('AHNet', val_df, threshold)
    mean = ahnet_df['Dice'].mean()
    print(f'AHNet VAL Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = ahnet_df
print(f'BEST THRESHOLD: {best_threshold}')
df.to_csv('../outputs/AHNet/val_scores.csv', index=False)
df.describe()

AHNet VAL Mean Dice for threshold 0.1: 0.5489151429745459
AHNet VAL Mean Dice for threshold 0.2: 0.5602947120224276
AHNet VAL Mean Dice for threshold 0.3: 0.5604525240168216
AHNet VAL Mean Dice for threshold 0.4: 0.5557049424087088
AHNet VAL Mean Dice for threshold 0.5: 0.5464015600812291
AHNet VAL Mean Dice for threshold 0.6: 0.5358502230336589
AHNet VAL Mean Dice for threshold 0.7: 0.5210102844622827
AHNet VAL Mean Dice for threshold 0.8: 0.500065658722193
AHNet VAL Mean Dice for threshold 0.9: 0.4654711919926828
BEST THRESHOLD: 0.3


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.560453,0.594703,0.598837,0.481495,12.806452,10.354839,109.645161,7.741935,7.741935,3.032258,21333.225806,21510.903226,19285.870968,19099.290323,19349.16129,15783.483871
std,0.291469,0.294154,0.299704,0.304243,12.589465,10.656918,241.870289,10.708152,10.708152,2.4696,41956.106469,42272.915557,38737.951879,37454.248946,37733.912158,32731.406821
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,69.0,69.0,0.0
25%,0.333367,0.361916,0.361883,0.272072,4.5,4.0,6.0,2.5,2.5,1.5,597.5,599.0,416.5,606.5,606.5,271.0
50%,0.637068,0.703095,0.708434,0.482574,8.0,7.0,19.0,5.0,5.0,2.0,1542.0,1532.0,1161.0,2111.0,2149.0,949.0
75%,0.837684,0.855268,0.865622,0.783561,18.0,13.5,61.0,8.0,8.0,4.0,21153.5,21419.5,18357.5,20221.0,20732.5,16174.5
max,0.928312,0.933032,0.939695,0.912208,57.0,44.0,1095.0,55.0,55.0,9.0,171800.0,173245.0,158599.0,153858.0,153901.0,133325.0


In [6]:
# TEST AHNet
ahnet_df = calculate_metrics('AHNet', test_df, best_threshold)
ahnet_df.to_csv('../outputs/AHNet/test_scores.csv', index=False)
ahnet_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.582135,0.616634,0.625311,0.501514,10.935484,8.16129,68.677419,8.709677,8.709677,5.129032,18752.064516,18934.903226,17132.870968,17103.677419,17984.16129,13810.870968
std,0.275483,0.273861,0.28112,0.288308,8.425895,7.057369,128.556184,9.103822,9.103822,4.951376,28000.109033,28274.407608,26444.429711,25564.191374,26911.703415,22937.67913
min,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,5.0,0.0,0.0,44.0,44.0,0.0
25%,0.416782,0.489047,0.486695,0.27316,5.5,3.0,7.0,2.0,2.0,1.5,910.5,914.0,814.5,802.0,802.0,410.0
50%,0.648561,0.659333,0.694444,0.561328,7.0,7.0,16.0,6.0,6.0,4.0,4825.0,4866.0,3967.0,3847.0,4138.0,1944.0
75%,0.821729,0.85397,0.872481,0.731286,15.0,10.5,60.5,13.5,13.5,7.0,33219.5,33641.0,32436.0,26305.0,27036.5,23520.5
max,0.920359,0.934362,0.933894,0.89282,34.0,27.0,571.0,43.0,43.0,21.0,107917.0,108916.0,101786.0,102770.0,105729.0,92221.0


### SegResNet

In [7]:
best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    # SegResNet
    segresnet_df = calculate_metrics('SegResNet', val_df, threshold)
    mean = segresnet_df['Dice'].mean()
    print(f'SegResNet VAL Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = segresnet_df
print(f'BEST THRESHOLD: {best_threshold}')
df.to_csv('../outputs/SegResNet/val_scores.csv', index=False)
df.describe()

SegResNet VAL Mean Dice for threshold 0.1: 0.596730723706705
SegResNet VAL Mean Dice for threshold 0.2: 0.6288285100411984
SegResNet VAL Mean Dice for threshold 0.3: 0.6418485007278861
SegResNet VAL Mean Dice for threshold 0.4: 0.6467655073370664
SegResNet VAL Mean Dice for threshold 0.5: 0.6487895887224905
SegResNet VAL Mean Dice for threshold 0.6: 0.6455584521735868
SegResNet VAL Mean Dice for threshold 0.7: 0.6360755350080228
SegResNet VAL Mean Dice for threshold 0.8: 0.619358874556999
SegResNet VAL Mean Dice for threshold 0.9: 0.5848339720659198
BEST THRESHOLD: 0.5


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.64879,0.682159,0.689137,0.555028,7.064516,6.806452,4.129032,7.741935,7.741935,3.032258,18358.193548,18363.967742,15317.419355,19099.290323,19349.16129,15783.483871
std,0.237248,0.229526,0.233742,0.296625,5.434062,4.895027,3.344069,10.708152,10.708152,2.4696,34927.020202,35133.636532,30754.945666,37454.248946,37733.912158,32731.406821
min,0.017208,0.017117,0.017299,0.0,1.0,1.0,1.0,1.0,1.0,0.0,135.0,126.0,13.0,69.0,69.0,0.0
25%,0.474633,0.535691,0.5253,0.373331,4.0,4.0,2.0,2.5,2.5,1.5,660.5,684.5,243.5,606.5,606.5,271.0
50%,0.734192,0.770978,0.783077,0.595791,6.0,5.0,3.0,5.0,5.0,2.0,1190.0,1167.0,759.0,2111.0,2149.0,949.0
75%,0.836072,0.848044,0.877968,0.817787,9.0,8.5,5.5,8.0,8.0,4.0,16905.5,16828.5,13712.5,20221.0,20732.5,16174.5
max,0.949027,0.946284,0.955822,0.944976,29.0,25.0,17.0,55.0,55.0,9.0,135459.0,136617.0,121015.0,153858.0,153901.0,133325.0


In [8]:
# SegResNet
segresnet_df = calculate_metrics('SegResNet', test_df, best_threshold)
segresnet_df.to_csv('../outputs/SegResNet/test_scores.csv', index=False)
segresnet_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.646463,0.679766,0.684896,0.561792,8.483871,8.322581,5.419355,8.709677,8.709677,5.129032,14547.064516,14531.0,11533.322581,17103.677419,17984.16129,13810.870968
std,0.237764,0.217345,0.228654,0.295048,6.840424,6.842938,5.25848,9.103822,9.103822,4.951376,21314.629841,21391.868382,18371.986988,25564.191374,26911.703415,22937.67913
min,0.08422,0.151394,0.101266,0.0,1.0,1.0,0.0,1.0,1.0,0.0,71.0,63.0,0.0,44.0,44.0,0.0
25%,0.484024,0.571311,0.567735,0.324266,4.5,3.5,1.0,2.0,2.0,1.5,1101.5,1081.0,636.5,802.0,802.0,410.0
50%,0.723051,0.763834,0.768185,0.638653,6.0,6.0,4.0,6.0,6.0,4.0,4401.0,4441.0,2781.0,3847.0,4138.0,1944.0
75%,0.825656,0.855362,0.866187,0.796817,10.5,10.5,8.0,13.5,13.5,7.0,21443.5,21289.5,18183.0,26305.0,27036.5,23520.5
max,0.945202,0.950976,0.95245,0.935746,26.0,27.0,26.0,43.0,43.0,21.0,78216.0,79071.0,68497.0,102770.0,105729.0,92221.0


### UNet

In [11]:
best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    # UNETR
    unet_df = calculate_metrics('UNet', val_df, threshold)
    mean = unet_df['Dice'].mean()
    print(f'UNet VAL Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = unet_df
print(f'BEST THRESHOLD: {best_threshold}')
df.to_csv('../outputs/UNet/val_scores.csv', index=False)
df.describe()

UNet VAL Mean Dice for threshold 0.1: 0.42865688396587726
UNet VAL Mean Dice for threshold 0.2: 0.4287183751500467
UNet VAL Mean Dice for threshold 0.3: 0.42873568030359643
UNet VAL Mean Dice for threshold 0.4: 0.4286642036430778
UNet VAL Mean Dice for threshold 0.5: 0.4286389134959468
UNet VAL Mean Dice for threshold 0.6: 0.4285550280549233
UNet VAL Mean Dice for threshold 0.7: 0.42842960879281766
UNet VAL Mean Dice for threshold 0.8: 0.4284043742584125
UNet VAL Mean Dice for threshold 0.9: 0.4281843339201934
BEST THRESHOLD: 0.3


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.428736,0.453638,0.461929,0.370293,10.193548,9.290323,10.032258,7.741935,7.741935,3.032258,21208.16129,21197.741935,21206.387097,19099.290323,19349.16129,15783.483871
std,0.344426,0.35836,0.366375,0.32202,28.478553,24.039126,28.12174,10.708152,10.708152,2.4696,33949.330818,33933.385756,33939.67651,37454.248946,37733.912158,32731.406821
min,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0,1315.0,1321.0,1317.0,69.0,69.0,0.0
25%,0.065312,0.077523,0.076632,0.041782,3.0,3.0,3.0,2.5,2.5,1.5,3095.0,3091.0,3094.0,606.5,606.5,271.0
50%,0.544594,0.581749,0.581052,0.342887,4.0,4.0,4.0,5.0,5.0,2.0,7007.0,7028.0,7012.0,2111.0,2149.0,949.0
75%,0.728406,0.77135,0.833644,0.660151,6.5,6.5,6.5,8.0,8.0,4.0,19724.0,19736.5,19731.0,20221.0,20732.5,16174.5
max,0.912942,0.928051,0.937673,0.873101,162.0,137.0,160.0,55.0,55.0,9.0,138966.0,138895.0,138884.0,153858.0,153901.0,133325.0


In [12]:
# UNet
unet_df = calculate_metrics('UNet', test_df, best_threshold)
unet_df.to_csv('../outputs/UNet/test_scores.csv', index=False)
unet_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.48441,0.516615,0.528952,0.407664,22.0,19.451613,24.0,8.709677,8.709677,5.129032,16522.580645,16488.419355,16521.225806,17103.677419,17984.16129,13810.870968
std,0.328612,0.33997,0.346949,0.308005,65.407441,56.570804,75.696763,9.103822,9.103822,4.951376,20858.686809,20872.251606,20854.198478,25564.191374,26911.703415,22937.67913
min,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0,1783.0,1781.0,1788.0,44.0,44.0,0.0
25%,0.116837,0.142044,0.142032,0.087392,2.5,2.5,2.5,2.0,2.0,1.5,3640.0,3644.0,3642.0,802.0,802.0,410.0
50%,0.583553,0.647756,0.648623,0.44648,5.0,5.0,5.0,6.0,6.0,4.0,5495.0,5505.0,5510.0,3847.0,4138.0,1944.0
75%,0.767797,0.833568,0.835908,0.672236,8.5,8.0,9.0,13.5,13.5,7.0,26132.0,26135.0,26134.0,26305.0,27036.5,23520.5
max,0.930027,0.946852,0.94916,0.898922,314.0,276.0,387.0,43.0,43.0,21.0,75512.0,75540.0,75493.0,102770.0,105729.0,92221.0


### UNETR

In [13]:
best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    # UNETR
    untr_df = calculate_metrics('UNETR', val_df, threshold)
    mean = untr_df['Dice'].mean()
    print(f'UNETR VAL Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = untr_df
print(f'BEST THRESHOLD: {best_threshold}')
df.to_csv('../outputs/UNETR/val_scores.csv', index=False)
df.describe()

UNETR VAL Mean Dice for threshold 0.1: 0.5110577699998694
UNETR VAL Mean Dice for threshold 0.2: 0.5364357509740418
UNETR VAL Mean Dice for threshold 0.3: 0.5473616691907087
UNETR VAL Mean Dice for threshold 0.4: 0.5504147643523831
UNETR VAL Mean Dice for threshold 0.5: 0.5479566721065391
UNETR VAL Mean Dice for threshold 0.6: 0.5435637275297796
UNETR VAL Mean Dice for threshold 0.7: 0.5330707826801846
UNETR VAL Mean Dice for threshold 0.8: 0.5149553074170986
UNETR VAL Mean Dice for threshold 0.9: 0.48047986836923706
BEST THRESHOLD: 0.4


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.550415,0.573424,0.58091,0.485023,16.870968,17.032258,10.387097,7.741935,7.741935,3.032258,19711.935484,19931.483871,16429.193548,19099.290323,19349.16129,15783.483871
std,0.300526,0.296214,0.299794,0.335349,21.654471,21.350229,10.355924,10.708152,10.708152,2.4696,36508.199657,37093.678059,32046.188421,37454.248946,37733.912158,32731.406821
min,0.012344,0.012241,0.012447,0.0,2.0,2.0,1.0,1.0,1.0,0.0,84.0,96.0,1.0,69.0,69.0,0.0
25%,0.308373,0.316527,0.348451,0.171904,7.5,8.0,4.5,2.5,2.5,1.5,744.0,762.0,246.0,606.5,606.5,271.0
50%,0.579267,0.66007,0.673875,0.486154,11.0,11.0,7.0,5.0,5.0,2.0,2739.0,2738.0,1661.0,2111.0,2149.0,949.0
75%,0.800706,0.815115,0.841115,0.775039,15.0,16.5,10.5,8.0,8.0,4.0,18429.5,18473.5,13786.0,20221.0,20732.5,16174.5
max,0.937153,0.932987,0.953532,0.92494,120.0,118.0,45.0,55.0,55.0,9.0,147550.0,148563.0,129968.0,153858.0,153901.0,133325.0


In [14]:
# UNETR
untr_df = calculate_metrics('UNETR', test_df, best_threshold)
untr_df.to_csv('../outputs/UNETR/test_scores.csv', index=False)
untr_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.54553,0.579356,0.586186,0.466583,16.290323,15.806452,10.322581,8.709677,8.709677,5.129032,15709.387097,15931.967742,13370.774194,17103.677419,17984.16129,13810.870968
std,0.297068,0.286649,0.295135,0.33433,12.122138,12.65285,11.232059,9.103822,9.103822,4.951376,24073.481871,24472.203404,21774.736968,25564.191374,26911.703415,22937.67913
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.337892,0.428992,0.408093,0.178731,7.5,7.5,3.0,2.0,2.0,1.5,762.0,750.5,374.5,802.0,802.0,410.0
50%,0.559129,0.599966,0.589826,0.472939,12.0,10.0,6.0,6.0,6.0,4.0,3234.0,3160.0,2309.0,3847.0,4138.0,1944.0
75%,0.834007,0.861489,0.875078,0.778451,25.5,24.5,14.0,13.5,13.5,7.0,29258.0,29396.0,27112.5,26305.0,27036.5,23520.5
max,0.947723,0.952594,0.955062,0.935512,50.0,53.0,52.0,43.0,43.0,21.0,89827.0,89912.0,80350.0,102770.0,105729.0,92221.0
