# Test Predictions Generator

## Environment Set Up

### Imports

In [1]:
# Data / MONAI
import pandas as pd

# Utils
from Inference import calculate_metrics

### Config

In [2]:
seed = 33
pd.set_option("display.max_columns", None)
pd.set_option('display.max_colwidth', None)

## Load Val Data

In [3]:
subject_ids = pd.read_csv('../data/VAL.csv')['SubjectID'].values

ah_segs, unet_segs, segresnet_segs, untr_segs, gt_segs = [], [], [], [], []
for sid in subject_ids:
    ah_channels, unet_channels, segresnet_channels, untr_channels, gt_channels = [], [], [], [], []
    for channel in ['TC', 'WT', 'ET']:
        ah_channels.append(f'../outputs/AHNet/pred_segs/val_pred_segs/pred_{sid}_{channel}.npz')
        unet_channels.append(f'../outputs/UNet/pred_segs/val_pred_segs/pred_{sid}_{channel}.npz')
        segresnet_channels.append(f'../outputs/SegResNet/pred_segs/val_pred_segs/pred_{sid}_{channel}.npz')
        untr_channels.append(f'../outputs/UNETR/pred_segs/val_pred_segs/pred_{sid}_{channel}.npz')
        gt_channels.append(f'../outputs/gt_segs/val_gt_segs/gt_{sid}_{channel}.npz')
    
    ah_segs.append(ah_channels)
    unet_segs.append(unet_channels)
    segresnet_segs.append(segresnet_channels)
    untr_segs.append(untr_channels)
    gt_segs.append(gt_channels)

# Dataframe
val_df = pd.DataFrame()
val_df['SubjectID'] = subject_ids
val_df['AHNet'] = ah_segs
val_df['UNet'] = unet_segs
val_df['SegResNet'] = segresnet_segs
val_df['UNETR'] = untr_segs
val_df['GT'] = gt_segs

val_df.head()

Unnamed: 0,SubjectID,AHNet,UNet,SegResNet,UNETR,GT
0,100237A,"[../outputs/AHNet/pred_segs/val_pred_segs/pred_100237A_TC.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100237A_WT.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100237A_ET.npz]","[../outputs/UNet/pred_segs/val_pred_segs/pred_100237A_TC.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100237A_WT.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100237A_ET.npz]","[../outputs/SegResNet/pred_segs/val_pred_segs/pred_100237A_TC.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100237A_WT.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100237A_ET.npz]","[../outputs/UNETR/pred_segs/val_pred_segs/pred_100237A_TC.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100237A_WT.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100237A_ET.npz]","[../outputs/gt_segs/val_gt_segs/gt_100237A_TC.npz, ../outputs/gt_segs/val_gt_segs/gt_100237A_WT.npz, ../outputs/gt_segs/val_gt_segs/gt_100237A_ET.npz]"
1,100219A,"[../outputs/AHNet/pred_segs/val_pred_segs/pred_100219A_TC.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100219A_WT.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100219A_ET.npz]","[../outputs/UNet/pred_segs/val_pred_segs/pred_100219A_TC.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100219A_WT.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100219A_ET.npz]","[../outputs/SegResNet/pred_segs/val_pred_segs/pred_100219A_TC.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100219A_WT.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100219A_ET.npz]","[../outputs/UNETR/pred_segs/val_pred_segs/pred_100219A_TC.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100219A_WT.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100219A_ET.npz]","[../outputs/gt_segs/val_gt_segs/gt_100219A_TC.npz, ../outputs/gt_segs/val_gt_segs/gt_100219A_WT.npz, ../outputs/gt_segs/val_gt_segs/gt_100219A_ET.npz]"
2,100363A,"[../outputs/AHNet/pred_segs/val_pred_segs/pred_100363A_TC.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100363A_WT.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100363A_ET.npz]","[../outputs/UNet/pred_segs/val_pred_segs/pred_100363A_TC.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100363A_WT.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100363A_ET.npz]","[../outputs/SegResNet/pred_segs/val_pred_segs/pred_100363A_TC.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100363A_WT.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100363A_ET.npz]","[../outputs/UNETR/pred_segs/val_pred_segs/pred_100363A_TC.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100363A_WT.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100363A_ET.npz]","[../outputs/gt_segs/val_gt_segs/gt_100363A_TC.npz, ../outputs/gt_segs/val_gt_segs/gt_100363A_WT.npz, ../outputs/gt_segs/val_gt_segs/gt_100363A_ET.npz]"
3,100354A,"[../outputs/AHNet/pred_segs/val_pred_segs/pred_100354A_TC.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100354A_WT.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100354A_ET.npz]","[../outputs/UNet/pred_segs/val_pred_segs/pred_100354A_TC.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100354A_WT.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100354A_ET.npz]","[../outputs/SegResNet/pred_segs/val_pred_segs/pred_100354A_TC.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100354A_WT.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100354A_ET.npz]","[../outputs/UNETR/pred_segs/val_pred_segs/pred_100354A_TC.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100354A_WT.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100354A_ET.npz]","[../outputs/gt_segs/val_gt_segs/gt_100354A_TC.npz, ../outputs/gt_segs/val_gt_segs/gt_100354A_WT.npz, ../outputs/gt_segs/val_gt_segs/gt_100354A_ET.npz]"
4,100303A,"[../outputs/AHNet/pred_segs/val_pred_segs/pred_100303A_TC.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100303A_WT.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100303A_ET.npz]","[../outputs/UNet/pred_segs/val_pred_segs/pred_100303A_TC.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100303A_WT.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100303A_ET.npz]","[../outputs/SegResNet/pred_segs/val_pred_segs/pred_100303A_TC.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100303A_WT.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100303A_ET.npz]","[../outputs/UNETR/pred_segs/val_pred_segs/pred_100303A_TC.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100303A_WT.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100303A_ET.npz]","[../outputs/gt_segs/val_gt_segs/gt_100303A_TC.npz, ../outputs/gt_segs/val_gt_segs/gt_100303A_WT.npz, ../outputs/gt_segs/val_gt_segs/gt_100303A_ET.npz]"


## Load Test Data

In [4]:
subject_ids = pd.read_csv('../data/TEST.csv')['SubjectID'].values

ah_segs, unet_segs, segresnet_segs, untr_segs, gt_segs = [], [], [], [], []
for sid in subject_ids:
    ah_channels, unet_channels, segresnet_channels, untr_channels, gt_channels = [], [], [], [], []
    for channel in ['TC', 'WT', 'ET']:
        ah_channels.append(f'../outputs/AHNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.npz')
        unet_channels.append(f'../outputs/UNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.npz')
        segresnet_channels.append(f'../outputs/SegResNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.npz')
        untr_channels.append(f'../outputs/UNETR/pred_segs/test_pred_segs/pred_{sid}_{channel}.npz')
        gt_channels.append(f'../outputs/gt_segs/test_gt_segs/gt_{sid}_{channel}.npz')
    
    ah_segs.append(ah_channels)
    unet_segs.append(unet_channels)
    segresnet_segs.append(segresnet_channels)
    untr_segs.append(untr_channels)
    gt_segs.append(gt_channels)

# Dataframe
test_df = pd.DataFrame()
test_df['SubjectID'] = subject_ids
test_df['AHNet'] = ah_segs
test_df['UNet'] = unet_segs
test_df['SegResNet'] = segresnet_segs
test_df['UNETR'] = untr_segs
test_df['GT'] = gt_segs

test_df.head()

Unnamed: 0,SubjectID,AHNet,UNet,SegResNet,UNETR,GT
0,100214B,"[../outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_TC.npz, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_WT.npz, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_ET.npz]","[../outputs/UNet/pred_segs/test_pred_segs/pred_100214B_TC.npz, ../outputs/UNet/pred_segs/test_pred_segs/pred_100214B_WT.npz, ../outputs/UNet/pred_segs/test_pred_segs/pred_100214B_ET.npz]","[../outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_TC.npz, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_WT.npz, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_ET.npz]","[../outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_TC.npz, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_WT.npz, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_ET.npz]","[../outputs/gt_segs/test_gt_segs/gt_100214B_TC.npz, ../outputs/gt_segs/test_gt_segs/gt_100214B_WT.npz, ../outputs/gt_segs/test_gt_segs/gt_100214B_ET.npz]"
1,100340A,"[../outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_TC.npz, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_WT.npz, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_ET.npz]","[../outputs/UNet/pred_segs/test_pred_segs/pred_100340A_TC.npz, ../outputs/UNet/pred_segs/test_pred_segs/pred_100340A_WT.npz, ../outputs/UNet/pred_segs/test_pred_segs/pred_100340A_ET.npz]","[../outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_TC.npz, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_WT.npz, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_ET.npz]","[../outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_TC.npz, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_WT.npz, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_ET.npz]","[../outputs/gt_segs/test_gt_segs/gt_100340A_TC.npz, ../outputs/gt_segs/test_gt_segs/gt_100340A_WT.npz, ../outputs/gt_segs/test_gt_segs/gt_100340A_ET.npz]"
2,100391A,"[../outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_TC.npz, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_WT.npz, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_ET.npz]","[../outputs/UNet/pred_segs/test_pred_segs/pred_100391A_TC.npz, ../outputs/UNet/pred_segs/test_pred_segs/pred_100391A_WT.npz, ../outputs/UNet/pred_segs/test_pred_segs/pred_100391A_ET.npz]","[../outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_TC.npz, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_WT.npz, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_ET.npz]","[../outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_TC.npz, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_WT.npz, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_ET.npz]","[../outputs/gt_segs/test_gt_segs/gt_100391A_TC.npz, ../outputs/gt_segs/test_gt_segs/gt_100391A_WT.npz, ../outputs/gt_segs/test_gt_segs/gt_100391A_ET.npz]"
3,100190B,"[../outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_TC.npz, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_WT.npz, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_ET.npz]","[../outputs/UNet/pred_segs/test_pred_segs/pred_100190B_TC.npz, ../outputs/UNet/pred_segs/test_pred_segs/pred_100190B_WT.npz, ../outputs/UNet/pred_segs/test_pred_segs/pred_100190B_ET.npz]","[../outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_TC.npz, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_WT.npz, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_ET.npz]","[../outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_TC.npz, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_WT.npz, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_ET.npz]","[../outputs/gt_segs/test_gt_segs/gt_100190B_TC.npz, ../outputs/gt_segs/test_gt_segs/gt_100190B_WT.npz, ../outputs/gt_segs/test_gt_segs/gt_100190B_ET.npz]"
4,100142A,"[../outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_TC.npz, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_WT.npz, ../outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_ET.npz]","[../outputs/UNet/pred_segs/test_pred_segs/pred_100142A_TC.npz, ../outputs/UNet/pred_segs/test_pred_segs/pred_100142A_WT.npz, ../outputs/UNet/pred_segs/test_pred_segs/pred_100142A_ET.npz]","[../outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_TC.npz, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_WT.npz, ../outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_ET.npz]","[../outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_TC.npz, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_WT.npz, ../outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_ET.npz]","[../outputs/gt_segs/test_gt_segs/gt_100142A_TC.npz, ../outputs/gt_segs/test_gt_segs/gt_100142A_WT.npz, ../outputs/gt_segs/test_gt_segs/gt_100142A_ET.npz]"


## Generate Results CSV

### AHNet

In [5]:
best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    # VAL AHNet
    ahnet_df = calculate_metrics('AHNet', val_df, threshold)
    mean = ahnet_df['Dice'].mean()
    print(f'AHNet VAL Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = ahnet_df
print(f'BEST THRESHOLD: {best_threshold}')
df.to_csv('../outputs/AHNet/val_scores.csv', index=False)
df.describe()

AHNet VAL Mean Dice for threshold 0.1: 0.5492611799028612
AHNet VAL Mean Dice for threshold 0.2: 0.5604259324890952
AHNet VAL Mean Dice for threshold 0.3: 0.5605223681627503
AHNet VAL Mean Dice for threshold 0.4: 0.5556399243162765
AHNet VAL Mean Dice for threshold 0.5: 0.5462690406899539
AHNet VAL Mean Dice for threshold 0.6: 0.5356928398532252
AHNet VAL Mean Dice for threshold 0.7: 0.5207302911627677
AHNet VAL Mean Dice for threshold 0.8: 0.49963182919929106
AHNet VAL Mean Dice for threshold 0.9: 0.46508253750301176
BEST THRESHOLD: 0.3


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.560522,0.595008,0.599129,0.481108,12.774194,10.258065,109.290323,7.741935,7.741935,3.032258,21333.903226,21511.290323,19288.032258,19105.225806,19355.193548,15780.967742
std,0.291371,0.294258,0.299808,0.303951,12.41695,10.318164,241.63322,10.708152,10.708152,2.4696,41954.305231,42270.831368,38737.120564,37451.165077,37730.84962,32728.506884
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,69.0,69.0,0.0
25%,0.333367,0.361916,0.361883,0.273857,5.0,4.0,6.0,2.5,2.5,1.5,597.5,599.0,416.5,606.5,606.5,271.0
50%,0.639329,0.70902,0.712976,0.477851,8.0,7.0,19.0,5.0,5.0,2.0,1542.0,1532.0,1161.0,2111.0,2149.0,949.0
75%,0.837559,0.855369,0.865494,0.783561,18.0,14.0,61.0,8.0,8.0,4.0,21160.5,21425.0,18357.5,20289.0,20793.5,16180.5
max,0.928312,0.933032,0.939695,0.912208,57.0,42.0,1095.0,55.0,55.0,9.0,171800.0,173245.0,158599.0,153858.0,153901.0,133325.0


In [6]:
# TEST AHNet
ahnet_df = calculate_metrics('AHNet', test_df, best_threshold)
ahnet_df.to_csv('../outputs/AHNet/test_scores.csv', index=False)
ahnet_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.583364,0.617662,0.626915,0.501632,10.580645,7.870968,69.064516,8.709677,8.709677,5.129032,18751.741935,18937.580645,17134.516129,17105.387097,17985.870968,13810.806452
std,0.273949,0.272766,0.278902,0.288378,8.135413,6.652027,130.761089,9.103822,9.103822,4.951376,28004.605846,28278.026189,26452.299711,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,5.0,0.0,0.0,44.0,44.0,0.0
25%,0.41718,0.48934,0.488268,0.2725,5.5,3.0,7.0,2.0,2.0,1.5,911.5,914.5,815.5,802.0,802.0,410.0
50%,0.648561,0.659333,0.694444,0.561328,7.0,7.0,16.0,6.0,6.0,4.0,4825.0,4866.0,3967.0,3847.0,4138.0,1944.0
75%,0.821729,0.853907,0.872481,0.732228,15.0,10.0,60.5,13.5,13.5,7.0,33080.5,33513.5,32337.5,26305.0,27036.5,23520.5
max,0.920359,0.934362,0.933894,0.89282,34.0,26.0,588.0,43.0,43.0,21.0,107917.0,108916.0,101786.0,102770.0,105729.0,92221.0


### SegResNet

In [7]:
best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    # SegResNet
    segresnet_df = calculate_metrics('SegResNet', val_df, threshold)
    mean = segresnet_df['Dice'].mean()
    print(f'SegResNet VAL Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = segresnet_df
print(f'BEST THRESHOLD: {best_threshold}')
df.to_csv('../outputs/SegResNet/val_scores.csv', index=False)
df.describe()

SegResNet VAL Mean Dice for threshold 0.1: 0.5970835380196091
SegResNet VAL Mean Dice for threshold 0.2: 0.6291356047074641
SegResNet VAL Mean Dice for threshold 0.3: 0.642114634505443
SegResNet VAL Mean Dice for threshold 0.4: 0.646779386567012
SegResNet VAL Mean Dice for threshold 0.5: 0.6486533776887001
SegResNet VAL Mean Dice for threshold 0.6: 0.6453569404059841
SegResNet VAL Mean Dice for threshold 0.7: 0.635466123660726
SegResNet VAL Mean Dice for threshold 0.8: 0.6187912326065763
SegResNet VAL Mean Dice for threshold 0.9: 0.5839838963903247
BEST THRESHOLD: 0.5


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.648653,0.682256,0.689264,0.554394,7.129032,6.83871,4.129032,7.741935,7.741935,3.032258,18359.354839,18367.225806,15318.516129,19105.225806,19355.193548,15780.967742
std,0.237255,0.229664,0.233772,0.296713,5.439007,4.89283,3.344069,10.708152,10.708152,2.4696,34927.167462,35135.087426,30755.426912,37451.165077,37730.84962,32728.506884
min,0.017208,0.017117,0.017299,0.0,1.0,1.0,1.0,1.0,1.0,0.0,135.0,126.0,13.0,69.0,69.0,0.0
25%,0.474633,0.535691,0.5253,0.373331,4.0,4.0,2.0,2.5,2.5,1.5,659.0,682.5,243.5,606.5,606.5,271.0
50%,0.734192,0.772202,0.783077,0.57103,6.0,5.0,3.0,5.0,5.0,2.0,1190.0,1167.0,759.0,2111.0,2149.0,949.0
75%,0.835967,0.848231,0.875156,0.817324,9.0,8.5,5.5,8.0,8.0,4.0,16917.0,16844.0,13720.0,20289.0,20793.5,16180.5
max,0.949027,0.946284,0.955822,0.944976,29.0,25.0,17.0,55.0,55.0,9.0,135459.0,136617.0,121015.0,153858.0,153901.0,133325.0


In [8]:
# SegResNet
segresnet_df = calculate_metrics('SegResNet', test_df, best_threshold)
segresnet_df.to_csv('../outputs/SegResNet/test_scores.csv', index=False)
segresnet_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.646272,0.679585,0.68475,0.561635,8.451613,8.354839,5.451613,8.709677,8.709677,5.129032,14533.483871,14518.193548,11519.967742,17105.387097,17985.870968,13810.806452
std,0.237641,0.217268,0.228583,0.294801,6.850006,6.848593,5.271551,9.103822,9.103822,4.951376,21299.762596,21376.291349,18354.865167,25565.209783,26912.616906,22937.645993
min,0.08422,0.151394,0.101266,0.0,1.0,1.0,0.0,1.0,1.0,0.0,71.0,63.0,0.0,44.0,44.0,0.0
25%,0.484024,0.572342,0.567735,0.324266,4.5,3.5,1.0,2.0,2.0,1.5,1100.0,1081.0,636.0,802.0,802.0,410.0
50%,0.723051,0.763834,0.768185,0.638653,6.0,6.0,4.0,6.0,6.0,4.0,4401.0,4441.0,2781.0,3847.0,4138.0,1944.0
75%,0.825707,0.855405,0.86577,0.79676,10.5,10.5,8.0,13.5,13.5,7.0,21443.5,21289.5,18070.5,26305.0,27036.5,23520.5
max,0.945202,0.950976,0.95245,0.935746,26.0,27.0,26.0,43.0,43.0,21.0,78216.0,79071.0,68497.0,102770.0,105729.0,92221.0


### UNet

In [9]:
best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    # UNETR
    unet_df = calculate_metrics('UNet', val_df, threshold)
    mean = unet_df['Dice'].mean()
    print(f'UNet VAL Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = unet_df
print(f'BEST THRESHOLD: {best_threshold}')
df.to_csv('../outputs/UNet/val_scores.csv', index=False)
df.describe()

UNet VAL Mean Dice for threshold 0.1: 0.4282878165118276
UNet VAL Mean Dice for threshold 0.2: 0.42832267277164293
UNet VAL Mean Dice for threshold 0.3: 0.42831103770708245
UNet VAL Mean Dice for threshold 0.4: 0.4282265670537468
UNet VAL Mean Dice for threshold 0.5: 0.42820052421771954
UNet VAL Mean Dice for threshold 0.6: 0.42809437097410763
UNet VAL Mean Dice for threshold 0.7: 0.42796112505734085
UNet VAL Mean Dice for threshold 0.8: 0.4279389574583019
UNet VAL Mean Dice for threshold 0.9: 0.4277176525642825
BEST THRESHOLD: 0.2


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.428323,0.453467,0.461658,0.369501,11.419355,10.83871,11.451613,7.741935,7.741935,3.032258,21324.774194,21312.677419,21327.548387,19105.225806,19355.193548,15780.967742
std,0.344629,0.358677,0.366625,0.322365,34.860459,31.846608,34.501052,10.708152,10.708152,2.4696,34065.143266,34056.285313,34061.761377,37451.165077,37730.84962,32728.506884
min,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0,1315.0,1321.0,1318.0,69.0,69.0,0.0
25%,0.066284,0.078828,0.077547,0.042478,3.0,3.0,3.0,2.5,2.5,1.5,3118.5,3116.0,3121.5,606.5,606.5,271.0
50%,0.544761,0.582043,0.581235,0.341828,4.0,4.0,4.0,5.0,5.0,2.0,7019.0,7039.0,7020.0,2111.0,2149.0,949.0
75%,0.729432,0.769625,0.832015,0.661509,7.0,6.5,7.0,8.0,8.0,4.0,19793.0,19788.5,19789.0,20289.0,20793.5,16180.5
max,0.913806,0.928784,0.938529,0.874104,198.0,181.0,196.0,55.0,55.0,9.0,139620.0,139598.0,139578.0,153858.0,153901.0,133325.0


In [10]:
# UNet
unet_df = calculate_metrics('UNet', test_df, best_threshold)
unet_df.to_csv('../outputs/UNet/test_scores.csv', index=False)
unet_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.484883,0.517132,0.529417,0.4081,21.548387,19.16129,23.193548,8.709677,8.709677,5.129032,16574.387097,16540.774194,16578.83871,17105.387097,17985.870968,13810.806452
std,0.328195,0.339524,0.34653,0.307589,62.4232,53.717841,69.292818,9.103822,9.103822,4.951376,20918.777034,20937.924118,20923.296144,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0,1785.0,1787.0,1791.0,44.0,44.0,0.0
25%,0.116562,0.141904,0.141575,0.08697,3.0,2.5,3.0,2.0,2.0,1.5,3656.5,3662.5,3661.5,802.0,802.0,410.0
50%,0.582786,0.647238,0.647275,0.447944,5.0,5.0,5.0,6.0,6.0,4.0,5500.0,5513.0,5516.0,3847.0,4138.0,1944.0
75%,0.767916,0.834543,0.836938,0.671957,8.5,8.0,9.0,13.5,13.5,7.0,26130.0,26137.5,26131.5,26305.0,27036.5,23520.5
max,0.93002,0.946876,0.94914,0.898466,274.0,237.0,328.0,43.0,43.0,21.0,75747.0,75778.0,75750.0,102770.0,105729.0,92221.0


### UNETR

In [11]:
best_mean = 0
best_threshold = 0
df = None
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    # UNETR
    untr_df = calculate_metrics('UNETR', val_df, threshold)
    mean = untr_df['Dice'].mean()
    print(f'UNETR VAL Mean Dice for threshold {threshold}: {mean}')
    if mean > best_mean:
        best_mean = mean
        best_threshold = threshold
        df = untr_df
print(f'BEST THRESHOLD: {best_threshold}')
df.to_csv('../outputs/UNETR/val_scores.csv', index=False)
df.describe()

UNETR VAL Mean Dice for threshold 0.1: 0.5113578857553582
UNETR VAL Mean Dice for threshold 0.2: 0.5365295441701047
UNETR VAL Mean Dice for threshold 0.3: 0.5473637978936876
UNETR VAL Mean Dice for threshold 0.4: 0.5504476521765033
UNETR VAL Mean Dice for threshold 0.5: 0.547998258844018
UNETR VAL Mean Dice for threshold 0.6: 0.5435207280660829
UNETR VAL Mean Dice for threshold 0.7: 0.5328597827183623
UNETR VAL Mean Dice for threshold 0.8: 0.5149041368295589
UNETR VAL Mean Dice for threshold 0.9: 0.4807232413440943
BEST THRESHOLD: 0.4


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.550448,0.573631,0.581319,0.484506,16.677419,16.967742,10.419355,7.741935,7.741935,3.032258,19717.354839,19936.096774,16438.709677,19105.225806,19355.193548,15780.967742
std,0.300557,0.296484,0.300111,0.335131,20.989501,21.138407,10.626301,10.708152,10.708152,2.4696,36504.672766,37090.636853,32046.198094,37451.165077,37730.84962,32728.506884
min,0.012344,0.012241,0.012447,0.0,2.0,2.0,1.0,1.0,1.0,0.0,84.0,96.0,1.0,69.0,69.0,0.0
25%,0.308373,0.316527,0.348451,0.171904,7.5,8.0,4.5,2.5,2.5,1.5,744.0,762.0,246.0,606.5,606.5,271.0
50%,0.579267,0.662887,0.680541,0.486154,11.0,11.0,7.0,5.0,5.0,2.0,2739.0,2738.0,1661.0,2111.0,2149.0,949.0
75%,0.796567,0.813837,0.843381,0.775039,15.0,16.5,12.0,8.0,8.0,4.0,18423.0,18468.0,13786.0,20289.0,20793.5,16180.5
max,0.937153,0.932987,0.953532,0.92494,116.0,117.0,46.0,55.0,55.0,9.0,147550.0,148563.0,129968.0,153858.0,153901.0,133325.0


In [12]:
# UNETR
untr_df = calculate_metrics('UNETR', test_df, best_threshold)
untr_df.to_csv('../outputs/UNETR/test_scores.csv', index=False)
untr_df.describe()

Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.545455,0.579272,0.586039,0.466769,16.225806,15.806452,10.258065,8.709677,8.709677,5.129032,15686.032258,15907.258065,13347.096774,17105.387097,17985.870968,13810.806452
std,0.297384,0.287008,0.295619,0.334207,12.082245,12.626478,11.245644,9.103822,9.103822,4.951376,24049.223531,24448.639346,21750.937461,25565.209783,26912.616906,22937.645993
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.338497,0.428638,0.408833,0.1793,8.0,7.5,3.0,2.0,2.0,1.5,764.5,753.5,375.5,802.0,802.0,410.0
50%,0.558096,0.599966,0.589826,0.472939,11.0,10.0,6.0,6.0,6.0,4.0,3234.0,3160.0,2309.0,3847.0,4138.0,1944.0
75%,0.834007,0.861489,0.875078,0.778451,25.0,24.5,13.0,13.5,13.5,7.0,29258.0,29396.0,27112.5,26305.0,27036.5,23520.5
max,0.947723,0.952594,0.955062,0.935512,51.0,53.0,52.0,43.0,43.0,21.0,89827.0,89912.0,80350.0,102770.0,105729.0,92221.0
