# Ensemble Techniques

## Environment Setup

### Imports

In [1]:
import os
os.sys.path.append("utils")

import torch
import numpy as np
import pandas as pd
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.transforms import AsDiscrete

### Config

In [2]:
seed = 33
pd.set_option("display.max_columns", None)
pd.set_option('display.max_colwidth', None)

## Load Train Data

In [3]:
subject_ids = pd.read_csv('./data/TRAIN.csv')['SubjectID'].values

ah_segs, segresnet_segs, untr_segs, gt_segs = [], [], [], []
for sid in subject_ids:
    ah_channels, segresnet_channels, untr_channels, gt_channels = [], [], [], []
    for channel in ['TC', 'WT', 'ET']:
        ah_channels.append(f'./outputs/AHNet/pred_segs/train_pred_segs/pred_{sid}_{channel}.npz')
        segresnet_channels.append(f'./outputs/SegResNet/pred_segs/train_pred_segs/pred_{sid}_{channel}.npz')
        untr_channels.append(f'./outputs/UNETR/pred_segs/train_pred_segs/pred_{sid}_{channel}.npz')
        gt_channels.append(f'./outputs/gt_segs/train_gt_segs/gt_{sid}_{channel}.npz')
    
    ah_segs.append(ah_channels)
    segresnet_segs.append(segresnet_channels)
    untr_segs.append(untr_channels)
    gt_segs.append(gt_channels)

# Dataframe
train_df = pd.DataFrame()
train_df['SubjectID'] = subject_ids
train_df['AHNet'] = ah_segs
train_df['SegResNet'] = segresnet_segs
train_df['UNETR'] = untr_segs
train_df['GT'] = gt_segs

train_df.head()

Unnamed: 0,SubjectID,AHNet,SegResNet,UNETR,GT
0,100381A,"[./outputs/AHNet/pred_segs/train_pred_segs/pred_100381A_TC.npz, ./outputs/AHNet/pred_segs/train_pred_segs/pred_100381A_WT.npz, ./outputs/AHNet/pred_segs/train_pred_segs/pred_100381A_ET.npz]","[./outputs/SegResNet/pred_segs/train_pred_segs/pred_100381A_TC.npz, ./outputs/SegResNet/pred_segs/train_pred_segs/pred_100381A_WT.npz, ./outputs/SegResNet/pred_segs/train_pred_segs/pred_100381A_ET.npz]","[./outputs/UNETR/pred_segs/train_pred_segs/pred_100381A_TC.npz, ./outputs/UNETR/pred_segs/train_pred_segs/pred_100381A_WT.npz, ./outputs/UNETR/pred_segs/train_pred_segs/pred_100381A_ET.npz]","[./outputs/gt_segs/train_gt_segs/gt_100381A_TC.npz, ./outputs/gt_segs/train_gt_segs/gt_100381A_WT.npz, ./outputs/gt_segs/train_gt_segs/gt_100381A_ET.npz]"
1,100414B,"[./outputs/AHNet/pred_segs/train_pred_segs/pred_100414B_TC.npz, ./outputs/AHNet/pred_segs/train_pred_segs/pred_100414B_WT.npz, ./outputs/AHNet/pred_segs/train_pred_segs/pred_100414B_ET.npz]","[./outputs/SegResNet/pred_segs/train_pred_segs/pred_100414B_TC.npz, ./outputs/SegResNet/pred_segs/train_pred_segs/pred_100414B_WT.npz, ./outputs/SegResNet/pred_segs/train_pred_segs/pred_100414B_ET.npz]","[./outputs/UNETR/pred_segs/train_pred_segs/pred_100414B_TC.npz, ./outputs/UNETR/pred_segs/train_pred_segs/pred_100414B_WT.npz, ./outputs/UNETR/pred_segs/train_pred_segs/pred_100414B_ET.npz]","[./outputs/gt_segs/train_gt_segs/gt_100414B_TC.npz, ./outputs/gt_segs/train_gt_segs/gt_100414B_WT.npz, ./outputs/gt_segs/train_gt_segs/gt_100414B_ET.npz]"
2,100132B,"[./outputs/AHNet/pred_segs/train_pred_segs/pred_100132B_TC.npz, ./outputs/AHNet/pred_segs/train_pred_segs/pred_100132B_WT.npz, ./outputs/AHNet/pred_segs/train_pred_segs/pred_100132B_ET.npz]","[./outputs/SegResNet/pred_segs/train_pred_segs/pred_100132B_TC.npz, ./outputs/SegResNet/pred_segs/train_pred_segs/pred_100132B_WT.npz, ./outputs/SegResNet/pred_segs/train_pred_segs/pred_100132B_ET.npz]","[./outputs/UNETR/pred_segs/train_pred_segs/pred_100132B_TC.npz, ./outputs/UNETR/pred_segs/train_pred_segs/pred_100132B_WT.npz, ./outputs/UNETR/pred_segs/train_pred_segs/pred_100132B_ET.npz]","[./outputs/gt_segs/train_gt_segs/gt_100132B_TC.npz, ./outputs/gt_segs/train_gt_segs/gt_100132B_WT.npz, ./outputs/gt_segs/train_gt_segs/gt_100132B_ET.npz]"
3,100212A,"[./outputs/AHNet/pred_segs/train_pred_segs/pred_100212A_TC.npz, ./outputs/AHNet/pred_segs/train_pred_segs/pred_100212A_WT.npz, ./outputs/AHNet/pred_segs/train_pred_segs/pred_100212A_ET.npz]","[./outputs/SegResNet/pred_segs/train_pred_segs/pred_100212A_TC.npz, ./outputs/SegResNet/pred_segs/train_pred_segs/pred_100212A_WT.npz, ./outputs/SegResNet/pred_segs/train_pred_segs/pred_100212A_ET.npz]","[./outputs/UNETR/pred_segs/train_pred_segs/pred_100212A_TC.npz, ./outputs/UNETR/pred_segs/train_pred_segs/pred_100212A_WT.npz, ./outputs/UNETR/pred_segs/train_pred_segs/pred_100212A_ET.npz]","[./outputs/gt_segs/train_gt_segs/gt_100212A_TC.npz, ./outputs/gt_segs/train_gt_segs/gt_100212A_WT.npz, ./outputs/gt_segs/train_gt_segs/gt_100212A_ET.npz]"
4,100243B,"[./outputs/AHNet/pred_segs/train_pred_segs/pred_100243B_TC.npz, ./outputs/AHNet/pred_segs/train_pred_segs/pred_100243B_WT.npz, ./outputs/AHNet/pred_segs/train_pred_segs/pred_100243B_ET.npz]","[./outputs/SegResNet/pred_segs/train_pred_segs/pred_100243B_TC.npz, ./outputs/SegResNet/pred_segs/train_pred_segs/pred_100243B_WT.npz, ./outputs/SegResNet/pred_segs/train_pred_segs/pred_100243B_ET.npz]","[./outputs/UNETR/pred_segs/train_pred_segs/pred_100243B_TC.npz, ./outputs/UNETR/pred_segs/train_pred_segs/pred_100243B_WT.npz, ./outputs/UNETR/pred_segs/train_pred_segs/pred_100243B_ET.npz]","[./outputs/gt_segs/train_gt_segs/gt_100243B_TC.npz, ./outputs/gt_segs/train_gt_segs/gt_100243B_WT.npz, ./outputs/gt_segs/train_gt_segs/gt_100243B_ET.npz]"


## Load Val Data

In [4]:
subject_ids = pd.read_csv('./data/VAL.csv')['SubjectID'].values

ah_segs, segresnet_segs, untr_segs, gt_segs = [], [], [], []
for sid in subject_ids:
    ah_channels, segresnet_channels, untr_channels, gt_channels = [], [], [], []
    for channel in ['TC', 'WT', 'ET']:
        ah_channels.append(f'./outputs/AHNet/pred_segs/val_pred_segs/pred_{sid}_{channel}.npz')
        segresnet_channels.append(f'./outputs/SegResNet/pred_segs/val_pred_segs/pred_{sid}_{channel}.npz')
        untr_channels.append(f'./outputs/UNETR/pred_segs/val_pred_segs/pred_{sid}_{channel}.npz')
        gt_channels.append(f'./outputs/gt_segs/val_gt_segs/gt_{sid}_{channel}.npz')
    
    ah_segs.append(ah_channels)
    segresnet_segs.append(segresnet_channels)
    untr_segs.append(untr_channels)
    gt_segs.append(gt_channels)

# Dataframe
val_df = pd.DataFrame()
val_df['SubjectID'] = subject_ids
val_df['AHNet'] = ah_segs
val_df['SegResNet'] = segresnet_segs
val_df['UNETR'] = untr_segs
val_df['GT'] = gt_segs

val_df.head()

Unnamed: 0,SubjectID,AHNet,SegResNet,UNETR,GT
0,100237A,"[./outputs/AHNet/pred_segs/val_pred_segs/pred_100237A_TC.npz, ./outputs/AHNet/pred_segs/val_pred_segs/pred_100237A_WT.npz, ./outputs/AHNet/pred_segs/val_pred_segs/pred_100237A_ET.npz]","[./outputs/SegResNet/pred_segs/val_pred_segs/pred_100237A_TC.npz, ./outputs/SegResNet/pred_segs/val_pred_segs/pred_100237A_WT.npz, ./outputs/SegResNet/pred_segs/val_pred_segs/pred_100237A_ET.npz]","[./outputs/UNETR/pred_segs/val_pred_segs/pred_100237A_TC.npz, ./outputs/UNETR/pred_segs/val_pred_segs/pred_100237A_WT.npz, ./outputs/UNETR/pred_segs/val_pred_segs/pred_100237A_ET.npz]","[./outputs/gt_segs/val_gt_segs/gt_100237A_TC.npz, ./outputs/gt_segs/val_gt_segs/gt_100237A_WT.npz, ./outputs/gt_segs/val_gt_segs/gt_100237A_ET.npz]"
1,100219A,"[./outputs/AHNet/pred_segs/val_pred_segs/pred_100219A_TC.npz, ./outputs/AHNet/pred_segs/val_pred_segs/pred_100219A_WT.npz, ./outputs/AHNet/pred_segs/val_pred_segs/pred_100219A_ET.npz]","[./outputs/SegResNet/pred_segs/val_pred_segs/pred_100219A_TC.npz, ./outputs/SegResNet/pred_segs/val_pred_segs/pred_100219A_WT.npz, ./outputs/SegResNet/pred_segs/val_pred_segs/pred_100219A_ET.npz]","[./outputs/UNETR/pred_segs/val_pred_segs/pred_100219A_TC.npz, ./outputs/UNETR/pred_segs/val_pred_segs/pred_100219A_WT.npz, ./outputs/UNETR/pred_segs/val_pred_segs/pred_100219A_ET.npz]","[./outputs/gt_segs/val_gt_segs/gt_100219A_TC.npz, ./outputs/gt_segs/val_gt_segs/gt_100219A_WT.npz, ./outputs/gt_segs/val_gt_segs/gt_100219A_ET.npz]"
2,100363A,"[./outputs/AHNet/pred_segs/val_pred_segs/pred_100363A_TC.npz, ./outputs/AHNet/pred_segs/val_pred_segs/pred_100363A_WT.npz, ./outputs/AHNet/pred_segs/val_pred_segs/pred_100363A_ET.npz]","[./outputs/SegResNet/pred_segs/val_pred_segs/pred_100363A_TC.npz, ./outputs/SegResNet/pred_segs/val_pred_segs/pred_100363A_WT.npz, ./outputs/SegResNet/pred_segs/val_pred_segs/pred_100363A_ET.npz]","[./outputs/UNETR/pred_segs/val_pred_segs/pred_100363A_TC.npz, ./outputs/UNETR/pred_segs/val_pred_segs/pred_100363A_WT.npz, ./outputs/UNETR/pred_segs/val_pred_segs/pred_100363A_ET.npz]","[./outputs/gt_segs/val_gt_segs/gt_100363A_TC.npz, ./outputs/gt_segs/val_gt_segs/gt_100363A_WT.npz, ./outputs/gt_segs/val_gt_segs/gt_100363A_ET.npz]"
3,100354A,"[./outputs/AHNet/pred_segs/val_pred_segs/pred_100354A_TC.npz, ./outputs/AHNet/pred_segs/val_pred_segs/pred_100354A_WT.npz, ./outputs/AHNet/pred_segs/val_pred_segs/pred_100354A_ET.npz]","[./outputs/SegResNet/pred_segs/val_pred_segs/pred_100354A_TC.npz, ./outputs/SegResNet/pred_segs/val_pred_segs/pred_100354A_WT.npz, ./outputs/SegResNet/pred_segs/val_pred_segs/pred_100354A_ET.npz]","[./outputs/UNETR/pred_segs/val_pred_segs/pred_100354A_TC.npz, ./outputs/UNETR/pred_segs/val_pred_segs/pred_100354A_WT.npz, ./outputs/UNETR/pred_segs/val_pred_segs/pred_100354A_ET.npz]","[./outputs/gt_segs/val_gt_segs/gt_100354A_TC.npz, ./outputs/gt_segs/val_gt_segs/gt_100354A_WT.npz, ./outputs/gt_segs/val_gt_segs/gt_100354A_ET.npz]"
4,100303A,"[./outputs/AHNet/pred_segs/val_pred_segs/pred_100303A_TC.npz, ./outputs/AHNet/pred_segs/val_pred_segs/pred_100303A_WT.npz, ./outputs/AHNet/pred_segs/val_pred_segs/pred_100303A_ET.npz]","[./outputs/SegResNet/pred_segs/val_pred_segs/pred_100303A_TC.npz, ./outputs/SegResNet/pred_segs/val_pred_segs/pred_100303A_WT.npz, ./outputs/SegResNet/pred_segs/val_pred_segs/pred_100303A_ET.npz]","[./outputs/UNETR/pred_segs/val_pred_segs/pred_100303A_TC.npz, ./outputs/UNETR/pred_segs/val_pred_segs/pred_100303A_WT.npz, ./outputs/UNETR/pred_segs/val_pred_segs/pred_100303A_ET.npz]","[./outputs/gt_segs/val_gt_segs/gt_100303A_TC.npz, ./outputs/gt_segs/val_gt_segs/gt_100303A_WT.npz, ./outputs/gt_segs/val_gt_segs/gt_100303A_ET.npz]"


## Load Test Data

In [5]:
subject_ids = pd.read_csv('./data/TEST.csv')['SubjectID'].values

ah_segs, segresnet_segs, untr_segs, gt_segs = [], [], [], []
for sid in subject_ids:
    ah_channels, segresnet_channels, untr_channels, gt_channels = [], [], [], []
    for channel in ['TC', 'WT', 'ET']:
        ah_channels.append(f'./outputs/AHNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.npz')
        segresnet_channels.append(f'./outputs/SegResNet/pred_segs/test_pred_segs/pred_{sid}_{channel}.npz')
        untr_channels.append(f'./outputs/UNETR/pred_segs/test_pred_segs/pred_{sid}_{channel}.npz')
        gt_channels.append(f'./outputs/gt_segs/test_gt_segs/gt_{sid}_{channel}.npz')
    
    ah_segs.append(ah_channels)
    segresnet_segs.append(segresnet_channels)
    untr_segs.append(untr_channels)
    gt_segs.append(gt_channels)

# Dataframe
test_df = pd.DataFrame()
test_df['SubjectID'] = subject_ids
test_df['AHNet'] = ah_segs
test_df['SegResNet'] = segresnet_segs
test_df['UNETR'] = untr_segs
test_df['GT'] = gt_segs

test_df.head()

Unnamed: 0,SubjectID,AHNet,SegResNet,UNETR,GT
0,100214B,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100214B_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100214B_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100214B_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100214B_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100214B_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100214B_ET.npz]"
1,100340A,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100340A_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100340A_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100340A_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100340A_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100340A_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100340A_ET.npz]"
2,100391A,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100391A_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100391A_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100391A_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100391A_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100391A_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100391A_ET.npz]"
3,100190B,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100190B_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100190B_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100190B_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100190B_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100190B_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100190B_ET.npz]"
4,100142A,"[./outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_TC.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_WT.npz, ./outputs/AHNet/pred_segs/test_pred_segs/pred_100142A_ET.npz]","[./outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_TC.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_WT.npz, ./outputs/SegResNet/pred_segs/test_pred_segs/pred_100142A_ET.npz]","[./outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_TC.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_WT.npz, ./outputs/UNETR/pred_segs/test_pred_segs/pred_100142A_ET.npz]","[./outputs/gt_segs/test_gt_segs/gt_100142A_TC.npz, ./outputs/gt_segs/test_gt_segs/gt_100142A_WT.npz, ./outputs/gt_segs/test_gt_segs/gt_100142A_ET.npz]"


## Training & Validation

In [6]:
# Get highest size of the inputs
#max_size =  (0,0,0)
#for i in range(len(train_df)):
#    max_size = tuple(map(max, zip(max_size, np.load(train_df['AHNet'][i][0])['arr_0'].shape)))
#for i in range(len(val_df)):
#    max_size = tuple(map(max, zip(max_size, np.load(val_df['AHNet'][i][0])['arr_0'].shape)))
#for i in range(len(test_df)):
#    max_size = tuple(map(max, zip(max_size, np.load(test_df['AHNet'][i][0])['arr_0'].shape)))
#print(max_size)

In [7]:
max_size = (315, 315, 308)

In [8]:
import torch.nn as nn

# Params
val_interval = 1
threshold=0.9
lr=0.001
wd=0.0001
max_epochs = 5

# Model
class LogisticRegression(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        return torch.sigmoid(self.linear(x))
    
input_dim = 9
output_dim = 3
model = LogisticRegression(input_dim, output_dim)

# Post Transforms, Optimizer, Loss, Evaluation Metric
trans = AsDiscrete(threshold=threshold)
loss_function = DiceLoss(smooth_nr=1e-5, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)
dice_metric = DiceMetric(include_background=True, reduction="mean")
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

In [9]:
# TRAIN

for epoch in range(max_epochs):

    print(f"Epoch {epoch+1}/{max_epochs}")

    print('TRAIN')
    for i in range(len(train_df)):

        model.train()

        # Image & Label
        load_ahnet = train_df['AHNet'][i]
        load_segresnet = train_df['SegResNet'][i]
        load_unetr = train_df['UNETR'][i]
        load_label = train_df['GT'][i]
            
        # Load Images and Labels
        ahnet_image = [np.load(x)['arr_0'] for x in load_ahnet]
        segresnet_image = [np.load(x)['arr_0'] for x in load_segresnet]
        unetr_image = [np.load(x)['arr_0'] for x in load_unetr]
        img_label = [np.load(x)['arr_0'] for x in load_label] 

        # Pad till max size
        w, h, d = ahnet_image[0].shape
        ahnet_image = [np.pad(x, ((0, max_size[0] - x.shape[0]), (0, max_size[1] - x.shape[1]), (0, max_size[2] - x.shape[2])), 'constant', constant_values=0) for x in ahnet_image]
        segresnet_image = [np.pad(x, ((0, max_size[0] - x.shape[0]), (0, max_size[1] - x.shape[1]), (0, max_size[2] - x.shape[2])), 'constant', constant_values=0) for x in segresnet_image]
        unetr_image = [np.pad(x, ((0, max_size[0] - x.shape[0]), (0, max_size[1] - x.shape[1]), (0, max_size[2] - x.shape[2])), 'constant', constant_values=0) for x in unetr_image]
        img_label = [np.pad(x, ((0, max_size[0] - x.shape[0]), (0, max_size[1] - x.shape[1]), (0, max_size[2] - x.shape[2])), 'constant', constant_values=0) for x in img_label]

        # To Tensor
        ahnet_image = [torch.from_numpy(x) for x in ahnet_image]
        segresnet_image = [torch.from_numpy(x) for x in segresnet_image]
        unetr_image = [torch.from_numpy(x) for x in unetr_image]
        img_label = [torch.from_numpy(x) for x in img_label]
            
        # Stack Images and Label
        ahnet_image = torch.stack(ahnet_image, dim = 0).unsqueeze(0)
        segresnet_image = torch.stack(segresnet_image, dim = 0).unsqueeze(0)
        unetr_image = torch.stack(unetr_image, dim = 0).unsqueeze(0)
        img_label = torch.stack(img_label, dim = 0).unsqueeze(0)

        predictions = [ahnet_image, segresnet_image, unetr_image]

        # Flatten
        flattened_predictions = [pred.view(3, -1).t() for pred in predictions]
        flattened_predictions = torch.cat(flattened_predictions, dim=1)
        flattened_targets = img_label.view(3, -1).t()

        # Logistic Regression
        outputs = model(flattened_predictions)

        # Loss
        loss = loss_function(outputs, flattened_targets)
        loss.backward()
        optimizer.step()
    
    lr_scheduler.step()

    # VAL
    if (i+1) % val_interval == 0:
        model.eval()
        print('---------------------------------------')
        print('VAL')
        dice_values = []
        for j in range(len(val_df)):

            # Image & Label
            load_ahnet = val_df['AHNet'][j]
            load_segresnet = val_df['SegResNet'][j]
            load_unetr = val_df['UNETR'][j]
            load_label = val_df['GT'][j]

            # Load Images and Labels
            ahnet_image = [np.load(x)['arr_0'] for x in load_ahnet]
            segresnet_image = [np.load(x)['arr_0'] for x in load_segresnet]
            unetr_image = [np.load(x)['arr_0'] for x in load_unetr]
            img_label = [np.load(x)['arr_0'] for x in load_label]

            # Pad till max size
            w, h, d = ahnet_image[0].shape
            ahnet_image = [np.pad(x, ((0, max_size[0] - x.shape[0]), (0, max_size[1] - x.shape[1]), (0, max_size[2] - x.shape[2])), 'constant', constant_values=0) for x in ahnet_image]
            segresnet_image = [np.pad(x, ((0, max_size[0] - x.shape[0]), (0, max_size[1] - x.shape[1]), (0, max_size[2] - x.shape[2])), 'constant', constant_values=0) for x in segresnet_image]
            unetr_image = [np.pad(x, ((0, max_size[0] - x.shape[0]), (0, max_size[1] - x.shape[1]), (0, max_size[2] - x.shape[2])), 'constant', constant_values=0) for x in unetr_image]
            img_label = [np.pad(x, ((0, max_size[0] - x.shape[0]), (0, max_size[1] - x.shape[1]), (0, max_size[2] - x.shape[2])), 'constant', constant_values=0) for x in img_label]

            # To Tensor
            ahnet_image = [torch.from_numpy(x) for x in ahnet_image]
            segresnet_image = [torch.from_numpy(x) for x in segresnet_image]
            unetr_image = [torch.from_numpy(x) for x in unetr_image]
            img_label = [torch.from_numpy(x) for x in img_label]
                
            # Stack Images and Label
            ahnet_image = torch.stack(ahnet_image, dim = 0).unsqueeze(0)
            segresnet_image = torch.stack(segresnet_image, dim = 0).unsqueeze(0)
            unetr_image = torch.stack(unetr_image, dim = 0).unsqueeze(0)
            img_label = torch.stack(img_label, dim = 0).unsqueeze(0)

            predictions = [ahnet_image, segresnet_image, unetr_image]

            # Flatten
            flattened_predictions = [pred.view(3, -1).t() for pred in predictions]
            flattened_predictions = torch.cat(flattened_predictions, dim=1)

            # Predict & Eval
            with torch.no_grad():
                # Forward pass
                output = model(flattened_predictions)
            img = trans(output)

            # Dice Metric
            img = img.T.reshape(1, 3, max_size[0], max_size[1], max_size[2])
            dice_metric(y_pred=img, y=img_label)
            dice_score = dice_metric.aggregate()
            dice_values.append(dice_score.item())
            dice_metric.reset()

        # Results
        print('Dice Scores:', np.mean(dice_values))
        print('---------------------------------------')

Epoch 1/5
TRAIN
---------------------------------------
VAL
Dice Scores: 0.5260393355161913
---------------------------------------
Epoch 2/5
TRAIN
---------------------------------------
VAL
Dice Scores: 0.601629643070121
---------------------------------------
Epoch 3/5
TRAIN
---------------------------------------
VAL
Dice Scores: 0.6123274225260942
---------------------------------------
Epoch 4/5
TRAIN
---------------------------------------
VAL
Dice Scores: 0.6125137841389063
---------------------------------------
Epoch 5/5
TRAIN
---------------------------------------
VAL
Dice Scores: 0.6123114915624741
---------------------------------------


In [10]:
# Save model
torch.save(model.state_dict(), './outputs/EnsembleNU/LogisticRegression_nu.pth')

In [12]:
# Test set predictions
model.eval()
print('---------------------------------------')
print('TEST')
dice_values = []
for j in range(len(test_df)):

    # Image & Label
    load_ahnet = test_df['AHNet'][j]
    load_segresnet = test_df['SegResNet'][j]
    load_unetr = test_df['UNETR'][j]
    load_label = test_df['GT'][j]

    # Load Images and Labels
    ahnet_image = [np.load(x)['arr_0'] for x in load_ahnet]
    segresnet_image = [np.load(x)['arr_0'] for x in load_segresnet]
    unetr_image = [np.load(x)['arr_0'] for x in load_unetr]
    img_label = [np.load(x)['arr_0'] for x in load_label]

    # Pad till max size
    w, h, d = ahnet_image[0].shape
    ahnet_image = [np.pad(x, ((0, max_size[0] - x.shape[0]), (0, max_size[1] - x.shape[1]), (0, max_size[2] - x.shape[2])), 'constant', constant_values=0) for x in ahnet_image]
    segresnet_image = [np.pad(x, ((0, max_size[0] - x.shape[0]), (0, max_size[1] - x.shape[1]), (0, max_size[2] - x.shape[2])), 'constant', constant_values=0) for x in segresnet_image]
    unetr_image = [np.pad(x, ((0, max_size[0] - x.shape[0]), (0, max_size[1] - x.shape[1]), (0, max_size[2] - x.shape[2])), 'constant', constant_values=0) for x in unetr_image]
    img_label = [np.pad(x, ((0, max_size[0] - x.shape[0]), (0, max_size[1] - x.shape[1]), (0, max_size[2] - x.shape[2])), 'constant', constant_values=0) for x in img_label]

    # To Tensor
    ahnet_image = [torch.from_numpy(x) for x in ahnet_image]
    segresnet_image = [torch.from_numpy(x) for x in segresnet_image]
    unetr_image = [torch.from_numpy(x) for x in unetr_image]
    img_label = [torch.from_numpy(x) for x in img_label]

    # Stack Images and Label
    ahnet_image = torch.stack(ahnet_image, dim = 0).unsqueeze(0)
    segresnet_image = torch.stack(segresnet_image, dim = 0).unsqueeze(0)
    unetr_image = torch.stack(unetr_image, dim = 0).unsqueeze(0)
    img_label = torch.stack(img_label, dim = 0).unsqueeze(0)

    predictions = [ahnet_image, segresnet_image, unetr_image]

    # Flatten
    flattened_predictions = [pred.view(3, -1).t() for pred in predictions]
    flattened_predictions = torch.cat(flattened_predictions, dim=1)

    # Predict & Eval
    with torch.no_grad():
        # Forward pass
        output = model(flattened_predictions)

    img = AsDiscrete(threshold=0.9)(output)

    # Dice Metric
    img = img.T.reshape(1, 3, max_size[0], max_size[1], max_size[2])
    dice_metric(y_pred=img, y=img_label)
    dice_score = dice_metric.aggregate()
    dice_values.append(dice_score.item())
    print('Dice Scores:', dice_score.item())
    dice_metric.reset()

# Results
print('Mean Dice Scores:', np.mean(dice_values))
print('---------------------------------------')

---------------------------------------
TEST
Dice Scores: 0.4661494791507721
Dice Scores: 0.6523823142051697
Dice Scores: 0.3122517764568329
Dice Scores: 0.6976320743560791
Dice Scores: 0.5357033610343933
Dice Scores: 0.0
Dice Scores: 0.7682256698608398
Dice Scores: 0.7723116278648376
Dice Scores: 0.9012286067008972
Dice Scores: 0.8204750418663025
Dice Scores: 0.3039762079715729
Dice Scores: 0.7393490672111511
Dice Scores: 0.8881688714027405
Dice Scores: 0.6099405288696289
Dice Scores: 0.9053166508674622
Dice Scores: 0.6416249871253967
Dice Scores: 0.05784592404961586
Dice Scores: 0.8747085928916931
Dice Scores: 0.4156076908111572
Dice Scores: 0.8256669044494629
Dice Scores: 0.3720388412475586
Dice Scores: 0.8126528263092041
Dice Scores: 0.9251006245613098
Dice Scores: 0.36115559935569763
Dice Scores: 0.548288881778717
Dice Scores: 0.4893968999385834
Dice Scores: 0.4627707302570343
Dice Scores: 0.7486152648925781
Dice Scores: 0.8436587452888489
Dice Scores: 0.7477752566337585
Dice Scor