# Predictions Generator

## Environment Set Up

### Imports

In [1]:
# System
import os

# Data / MONAI
import pandas as pd
from monai.data import DataLoader

# Torch
import torch
from torch.utils.data import SequentialSampler

# Utils
from Transforms import Transforms
from UCSF_Dataset import UCSF_Dataset
from Models import SEGRESNET, UNET, AHNET, UNTR
from Inference import test_model

### Config

In [2]:
# Check if CUDA is available
device = None
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Running on GPU")
else:
    device = torch.device("cpu")
    print("Running on CPU")

# Print the device
print(f"Device: {device}")

Running on GPU
Device: cuda


In [3]:
seed = 33
pd.set_option("display.max_columns", None)

## Load

### Load Models

In [4]:
models = {}
for model in os.listdir('../outputs'):
	models[model] = {}
	for files in os.listdir(f'../outputs/{model}'):
		if files.endswith('.pth'):
			models[model]['model'] = f'../outputs/{model}/{files}'

models

{'AHNet': {'model': '../outputs/AHNet/best_AHNet_92.pth'},
 'gt_segmentations': {},
 'SegResNet': {'model': '../outputs/SegResNet/best_SegResNet_96.pth'},
 'UNet': {'model': '../outputs/UNet/best_UNet_97.pth'},
 'UNETR': {'model': '../outputs/UNETR/best_UNETR_99.pth'}}

### Load Test Data

In [5]:
# Load Subjects Information
test_df = pd.read_csv('../data/TEST.csv')
test_df['BraTS-seg'] = test_df['BraTS-seg'].apply(lambda x: f'../{x}')
test_df['T1post'] = test_df['T1post'].apply(lambda x: f'../{x}')
test_df['T1pre'] = test_df['T1pre'].apply(lambda x: f'../{x}')
test_df['FLAIR'] = test_df['FLAIR'].apply(lambda x: f'../{x}')
test_df['T2Synth'] = test_df['T2Synth'].apply(lambda x: f'../{x}')

# Parameters
b_size = 1
t_size = None

# Transforms
transforms = Transforms(seed)

# Test Dataset
test_images = [test_df['T1pre'], test_df['FLAIR'], test_df['T1post'], test_df['T2Synth']]
test_labels = test_df['BraTS-seg']

test_dataset = UCSF_Dataset(test_images, test_labels, transforms.val(), t_size)

# Samplers
test_sampler = SequentialSampler(test_dataset)

# DataLoaders
test_loader = DataLoader(test_dataset, batch_size=b_size, shuffle=False, sampler=test_sampler)

test_df

Unnamed: 0,SubjectID,Sex,CancerType,ScannerType,In-plane voxel size (mm),Matrix size,Prior Craniotomy/Biopsy/Resection,Age,Scanner Strength (Tesla),Slice Thickness (mm),NumberMetastases,VolumeMetastases_mm3,S-NM,S-V,S-VMax,S-VMin,S-VMean,S-VStd,S-VDiff,T1pre,FLAIR,T1post,T2Synth,Seg,BraTS-seg
0,100214B,Female,Melanoma,Philips 1.5 T Achieva,0.69x0.69,320x320x104,No,72.0,1.5,1.5,17.0,438.861328,17.0,438.861328,133.998047,4.253906,25.815372,37.943862,0.0,../data/raw/UCSF_BrainMetastases_TRAIN/100214B...,../data/raw/UCSF_BrainMetastases_TRAIN/100214B...,../data/raw/UCSF_BrainMetastases_TRAIN/100214B...,../data/raw/UCSF_BrainMetastases_TRAIN/100214B...,data/raw/UCSF_BrainMetastases_TRAIN/100214B/10...,../data/raw/UCSF_BrainMetastases_TRAIN/100214B...
1,100340A,Female,Lung,GE 1.5 T Signa HDxt,1.17x1.17,256x256x98,No,64.0,1.5,1.5,9.0,1845.781983,9.0,1845.781952,716.888526,20.600245,205.086884,216.778603,3.1e-05,../data/raw/UCSF_BrainMetastases_TRAIN/100340A...,../data/raw/UCSF_BrainMetastases_TRAIN/100340A...,../data/raw/UCSF_BrainMetastases_TRAIN/100340A...,../data/raw/UCSF_BrainMetastases_TRAIN/100340A...,data/raw/UCSF_BrainMetastases_TRAIN/100340A/10...,../data/raw/UCSF_BrainMetastases_TRAIN/100340A...
2,100391A,Female,Lung,Philips 1.5 T Achieva,0.69x0.69,320x320x104,No,68.0,1.5,1.5,1.0,280.757812,1.0,280.757812,280.757812,280.757812,280.757812,0.0,0.0,../data/raw/UCSF_BrainMetastases_TRAIN/100391A...,../data/raw/UCSF_BrainMetastases_TRAIN/100391A...,../data/raw/UCSF_BrainMetastases_TRAIN/100391A...,../data/raw/UCSF_BrainMetastases_TRAIN/100391A...,data/raw/UCSF_BrainMetastases_TRAIN/100391A/10...,../data/raw/UCSF_BrainMetastases_TRAIN/100391A...
3,100190B,Female,Lung,GE 1.5 T Signa HDxt,0.86x0.86,256x256x104,No,42.0,1.5,1.5,3.0,163.962166,2.0,163.962164,132.942295,31.019869,81.981082,50.961213,2e-06,../data/raw/UCSF_BrainMetastases_TRAIN/100190B...,../data/raw/UCSF_BrainMetastases_TRAIN/100190B...,../data/raw/UCSF_BrainMetastases_TRAIN/100190B...,../data/raw/UCSF_BrainMetastases_TRAIN/100190B...,data/raw/UCSF_BrainMetastases_TRAIN/100190B/10...,../data/raw/UCSF_BrainMetastases_TRAIN/100190B...
4,100142A,Female,Neuroendocrine,GE 1.5 T Signa HDxt,0.86x0.86,256x256x106,No,64.0,1.5,1.5,2.0,1197.588524,2.0,1197.588508,1186.509984,11.078525,598.794254,587.715729,1.6e-05,../data/raw/UCSF_BrainMetastases_TRAIN/100142A...,../data/raw/UCSF_BrainMetastases_TRAIN/100142A...,../data/raw/UCSF_BrainMetastases_TRAIN/100142A...,../data/raw/UCSF_BrainMetastases_TRAIN/100142A...,data/raw/UCSF_BrainMetastases_TRAIN/100142A/10...,../data/raw/UCSF_BrainMetastases_TRAIN/100142A...
5,100162A,Female,Lung,GE 1.5 T Signa HDxt,1x1,256x256x136,No,86.0,1.5,1.5,1.0,54.0,1.0,54.0,54.0,54.0,54.0,0.0,0.0,../data/raw/UCSF_BrainMetastases_TRAIN/100162A...,../data/raw/UCSF_BrainMetastases_TRAIN/100162A...,../data/raw/UCSF_BrainMetastases_TRAIN/100162A...,../data/raw/UCSF_BrainMetastases_TRAIN/100162A...,data/raw/UCSF_BrainMetastases_TRAIN/100162A/10...,../data/raw/UCSF_BrainMetastases_TRAIN/100162A...
6,100357B,Female,Breast,GE 3.0 T Discovery MR750,0.47x0.47,512x512x56,No,50.0,3.0,3.0,1.0,6291.234716,1.0,6291.234822,6291.234822,6291.234822,6291.234822,0.0,0.000106,../data/raw/UCSF_BrainMetastases_TRAIN/100357B...,../data/raw/UCSF_BrainMetastases_TRAIN/100357B...,../data/raw/UCSF_BrainMetastases_TRAIN/100357B...,../data/raw/UCSF_BrainMetastases_TRAIN/100357B...,data/raw/UCSF_BrainMetastases_TRAIN/100357B/10...,../data/raw/UCSF_BrainMetastases_TRAIN/100357B...
7,100178A,Male,Melanoma,GE 1.5 T Signa HDxt,0.86x0.86,256x256x110,No,72.0,1.5,1.5,2.0,291.365201,2.0,291.365197,238.188279,53.176918,145.682598,92.50568,4e-06,../data/raw/UCSF_BrainMetastases_TRAIN/100178A...,../data/raw/UCSF_BrainMetastases_TRAIN/100178A...,../data/raw/UCSF_BrainMetastases_TRAIN/100178A...,../data/raw/UCSF_BrainMetastases_TRAIN/100178A...,data/raw/UCSF_BrainMetastases_TRAIN/100178A/10...,../data/raw/UCSF_BrainMetastases_TRAIN/100178A...
8,100290A,Female,Rectal,Philips 1.5 T Achieva,0.69x0.69,320x320x104,No,65.0,1.5,1.5,15.0,5007.556641,15.0,5007.556641,3931.318359,9.925781,333.837109,965.765836,0.0,../data/raw/UCSF_BrainMetastases_TRAIN/100290A...,../data/raw/UCSF_BrainMetastases_TRAIN/100290A...,../data/raw/UCSF_BrainMetastases_TRAIN/100290A...,../data/raw/UCSF_BrainMetastases_TRAIN/100290A...,data/raw/UCSF_BrainMetastases_TRAIN/100290A/10...,../data/raw/UCSF_BrainMetastases_TRAIN/100290A...
9,100412B,Male,Lung,GE 1.5 T Signa HDxt,0.86x0.86,256x256x126,No,64.0,1.5,1.5,8.0,14275.786978,8.0,14275.786786,13811.596606,13.29423,1784.473348,4546.121085,0.000192,../data/raw/UCSF_BrainMetastases_TRAIN/100412B...,../data/raw/UCSF_BrainMetastases_TRAIN/100412B...,../data/raw/UCSF_BrainMetastases_TRAIN/100412B...,../data/raw/UCSF_BrainMetastases_TRAIN/100412B...,data/raw/UCSF_BrainMetastases_TRAIN/100412B/10...,../data/raw/UCSF_BrainMetastases_TRAIN/100412B...


## Testing

In [6]:
spatial_size = (240, 240, 160)
ah_spatial_size = (256, 256, 160)

### Test Models

In [7]:
# AHNet
model = AHNET
model.to(device)
model.load_state_dict(torch.load(models['AHNet']['model']))
AHNet_scores = test_model(model, 'AHNet', test_loader, test_df, ah_spatial_size)
AHNet_scores.to_csv('../outputs/AHNet/test_scores.csv')
AHNet_scores.describe()

100%|██████████| 31/31 [04:50<00:00,  9.37s/it]


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.577865,0.61272,0.619734,0.499683,10.548387,7.387097,117.935484,8.709677,8.709677,5.129032,17486.193548,17435.677419,15412.129032,17103.677419,17984.16129,13810.870968
std,0.285203,0.281992,0.289538,0.298995,8.781946,6.396235,193.890164,9.103822,9.103822,4.951376,26467.969212,26476.04487,24446.999325,25564.191374,26911.703415,22937.67913
min,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,2.0,0.0,0.0,44.0,44.0,0.0
25%,0.392068,0.464731,0.453269,0.248504,5.0,3.0,11.0,2.0,2.0,1.5,737.5,710.5,546.5,802.0,802.0,410.0
50%,0.632319,0.635346,0.689713,0.544421,7.0,5.0,30.0,6.0,6.0,4.0,4106.0,4059.0,3191.0,3847.0,4138.0,1944.0
75%,0.81358,0.857451,0.876358,0.73842,16.0,9.5,142.0,13.5,13.5,7.0,30876.5,30896.5,29286.0,26305.0,27036.5,23520.5
max,0.930925,0.942043,0.943445,0.907288,31.0,25.0,787.0,43.0,43.0,21.0,101717.0,101717.0,93702.0,102770.0,105729.0,92221.0


In [8]:
# SegResNet
model = SEGRESNET
model.to(device)	
model.load_state_dict(torch.load(models['SegResNet']['model']))
SegResNet_scores = test_model(model, 'SegResNet', test_loader, test_df, spatial_size)
SegResNet_scores.to_csv('../outputs/SegResNet/test_scores.csv')
SegResNet_scores.describe()

100%|██████████| 31/31 [06:56<00:00, 13.42s/it]


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.646463,0.679766,0.684896,0.561792,8.483871,8.322581,5.419355,8.709677,8.709677,5.129032,14547.064516,14531.0,11533.322581,17103.677419,17984.16129,13810.870968
std,0.237764,0.217345,0.228654,0.295048,6.840424,6.842938,5.25848,9.103822,9.103822,4.951376,21314.629841,21391.868382,18371.986988,25564.191374,26911.703415,22937.67913
min,0.08422,0.151394,0.101266,0.0,1.0,1.0,0.0,1.0,1.0,0.0,71.0,63.0,0.0,44.0,44.0,0.0
25%,0.484024,0.571311,0.567735,0.324266,4.5,3.5,1.0,2.0,2.0,1.5,1101.5,1081.0,636.5,802.0,802.0,410.0
50%,0.723051,0.763834,0.768185,0.638653,6.0,6.0,4.0,6.0,6.0,4.0,4401.0,4441.0,2781.0,3847.0,4138.0,1944.0
75%,0.825656,0.855362,0.866187,0.796817,10.5,10.5,8.0,13.5,13.5,7.0,21443.5,21289.5,18183.0,26305.0,27036.5,23520.5
max,0.945202,0.950976,0.95245,0.935746,26.0,27.0,26.0,43.0,43.0,21.0,78216.0,79071.0,68497.0,102770.0,105729.0,92221.0


In [9]:
# UNet
model = UNET
model.to(device)
model.load_state_dict(torch.load(models['UNet']['model']))
UNet_scores = test_model(model, 'UNet', test_loader, test_df, spatial_size)
UNet_scores.to_csv('../outputs/UNet/test_scores.csv')
UNet_scores.describe()

100%|██████████| 31/31 [03:03<00:00,  5.91s/it]


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.484214,0.516452,0.528742,0.407448,18.870968,17.064516,20.419355,8.709677,8.709677,5.129032,16343.0,16311.709677,16331.193548,17103.677419,17984.16129,13810.870968
std,0.328109,0.339412,0.346458,0.307584,53.905932,47.892891,62.166858,9.103822,9.103822,4.951376,20781.41927,20792.337948,20771.678511,25564.191374,26911.703415,22937.67913
min,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0,1774.0,1772.0,1778.0,44.0,44.0,0.0
25%,0.117693,0.142674,0.142999,0.087825,2.5,2.5,2.5,2.0,2.0,1.5,3579.5,3461.0,3538.0,802.0,802.0,410.0
50%,0.585055,0.649518,0.649864,0.444811,5.0,4.0,5.0,6.0,6.0,4.0,5488.0,5496.0,5498.0,3847.0,4138.0,1944.0
75%,0.767356,0.832715,0.835276,0.671355,8.5,9.0,9.0,13.5,13.5,7.0,25991.0,25986.0,25990.5,26305.0,27036.5,23520.5
max,0.930025,0.946794,0.949309,0.898954,260.0,237.0,322.0,43.0,43.0,21.0,75085.0,75122.0,75042.0,102770.0,105729.0,92221.0


In [10]:
# UNETR
model = UNTR
model.to(device)
model.load_state_dict(torch.load(models['UNETR']['model']))
UNetR_scores = test_model(model, 'UNETR', test_loader, test_df, spatial_size)
UNetR_scores.to_csv('../outputs/UNETR/test_scores.csv')
UNetR_scores.describe()

100%|██████████| 31/31 [06:35<00:00, 12.76s/it]


Unnamed: 0,Dice,Dice TC,Dice WT,Dice ET,Pred NM TC,Pred NM WT,Pred NM ET,GT NM TC,GT NM WT,GT NM ET,Pred V TC,Pred V WT,Pred V ET,GT V TC,GT V WT,GT V ET
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.541837,0.577307,0.5821,0.459992,13.516129,12.774194,9.0,8.709677,8.709677,5.129032,15038.935484,15297.032258,12731.225806,17103.677419,17984.16129,13810.870968
std,0.299187,0.28761,0.297085,0.339434,10.026202,10.28173,9.640194,9.103822,9.103822,4.951376,23252.777197,23707.287224,20966.015148,25564.191374,26911.703415,22937.67913
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,44.0,44.0,0.0
25%,0.333243,0.426975,0.403104,0.150649,7.0,7.0,3.0,2.0,2.0,1.5,684.5,681.5,310.5,802.0,802.0,410.0
50%,0.560271,0.603463,0.598039,0.456625,10.0,9.0,6.0,6.0,6.0,4.0,3004.0,2955.0,2028.0,3847.0,4138.0,1944.0
75%,0.837698,0.861524,0.878676,0.773247,19.0,16.0,11.5,13.5,13.5,7.0,27791.5,28026.0,25568.0,26305.0,27036.5,23520.5
max,0.951585,0.956219,0.958249,0.940287,44.0,46.0,39.0,43.0,43.0,21.0,86733.0,87052.0,77353.0,102770.0,105729.0,92221.0
