# Ensemble Techniques

## Environment Setup

### Imports

In [1]:
import os
os.sys.path.append("../utils")

import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from Transforms import Transforms
from Inference import inference
from monai.losses import DiceLoss
from Datasets import EnsembleDataset
from monai.metrics import DiceMetric
from Models import LogisticRegression
from monai.transforms import AsDiscrete
from torch.utils.data import DataLoader
from torch.utils.data import SequentialSampler   

### Config

In [2]:
seed = 33
pd.set_option("display.max_columns", None)
pd.set_option('display.max_colwidth', None)

## Load Train Data

In [3]:
subject_ids = pd.read_csv('../data/TRAIN.csv')['SubjectID'].values

ah_segs, segresnet_segs, unet_segs, untr_segs, gt_segs = [], [], [], [], []
for sid in subject_ids:
    ah_channels, unet_channels, segresnet_channels, untr_channels, gt_channels = [], [], [], [], []
    for channel in ['TC', 'WT', 'ET']:
        ah_channels.append(f'../outputs/AHNet/pred_segs/train_pred_segs/pred_{sid}_{channel}.npz')
        segresnet_channels.append(f'../outputs/SegResNet/pred_segs/train_pred_segs/pred_{sid}_{channel}.npz')
        unet_channels.append(f'../outputs/UNet/pred_segs/train_pred_segs/pred_{sid}_{channel}.npz')
        untr_channels.append(f'../outputs/UNETR/pred_segs/train_pred_segs/pred_{sid}_{channel}.npz')
        gt_channels.append(f'../outputs/gt_segs/train_gt_segs/gt_{sid}_{channel}.npz')
    
    ah_segs.append(ah_channels)
    segresnet_segs.append(segresnet_channels)
    unet_segs.append(unet_channels)
    untr_segs.append(untr_channels)
    gt_segs.append(gt_channels)

# Dataframe
train_df = pd.DataFrame()
train_df['SubjectID'] = subject_ids
train_df['AHNet'] = ah_segs
train_df['UNet'] = unet_segs
train_df['SegResNet'] = segresnet_segs
train_df['UNETR'] = untr_segs
train_df['GT'] = gt_segs

train_df.head()

Unnamed: 0,SubjectID,AHNet,UNet,SegResNet,UNETR,GT
0,100381A,"[../outputs/AHNet/pred_segs/train_pred_segs/pred_100381A_TC.npz, ../outputs/AHNet/pred_segs/train_pred_segs/pred_100381A_WT.npz, ../outputs/AHNet/pred_segs/train_pred_segs/pred_100381A_ET.npz]","[../outputs/UNet/pred_segs/train_pred_segs/pred_100381A_TC.npz, ../outputs/UNet/pred_segs/train_pred_segs/pred_100381A_WT.npz, ../outputs/UNet/pred_segs/train_pred_segs/pred_100381A_ET.npz]","[../outputs/SegResNet/pred_segs/train_pred_segs/pred_100381A_TC.npz, ../outputs/SegResNet/pred_segs/train_pred_segs/pred_100381A_WT.npz, ../outputs/SegResNet/pred_segs/train_pred_segs/pred_100381A_ET.npz]","[../outputs/UNETR/pred_segs/train_pred_segs/pred_100381A_TC.npz, ../outputs/UNETR/pred_segs/train_pred_segs/pred_100381A_WT.npz, ../outputs/UNETR/pred_segs/train_pred_segs/pred_100381A_ET.npz]","[../outputs/gt_segs/train_gt_segs/gt_100381A_TC.npz, ../outputs/gt_segs/train_gt_segs/gt_100381A_WT.npz, ../outputs/gt_segs/train_gt_segs/gt_100381A_ET.npz]"
1,100414B,"[../outputs/AHNet/pred_segs/train_pred_segs/pred_100414B_TC.npz, ../outputs/AHNet/pred_segs/train_pred_segs/pred_100414B_WT.npz, ../outputs/AHNet/pred_segs/train_pred_segs/pred_100414B_ET.npz]","[../outputs/UNet/pred_segs/train_pred_segs/pred_100414B_TC.npz, ../outputs/UNet/pred_segs/train_pred_segs/pred_100414B_WT.npz, ../outputs/UNet/pred_segs/train_pred_segs/pred_100414B_ET.npz]","[../outputs/SegResNet/pred_segs/train_pred_segs/pred_100414B_TC.npz, ../outputs/SegResNet/pred_segs/train_pred_segs/pred_100414B_WT.npz, ../outputs/SegResNet/pred_segs/train_pred_segs/pred_100414B_ET.npz]","[../outputs/UNETR/pred_segs/train_pred_segs/pred_100414B_TC.npz, ../outputs/UNETR/pred_segs/train_pred_segs/pred_100414B_WT.npz, ../outputs/UNETR/pred_segs/train_pred_segs/pred_100414B_ET.npz]","[../outputs/gt_segs/train_gt_segs/gt_100414B_TC.npz, ../outputs/gt_segs/train_gt_segs/gt_100414B_WT.npz, ../outputs/gt_segs/train_gt_segs/gt_100414B_ET.npz]"
2,100132B,"[../outputs/AHNet/pred_segs/train_pred_segs/pred_100132B_TC.npz, ../outputs/AHNet/pred_segs/train_pred_segs/pred_100132B_WT.npz, ../outputs/AHNet/pred_segs/train_pred_segs/pred_100132B_ET.npz]","[../outputs/UNet/pred_segs/train_pred_segs/pred_100132B_TC.npz, ../outputs/UNet/pred_segs/train_pred_segs/pred_100132B_WT.npz, ../outputs/UNet/pred_segs/train_pred_segs/pred_100132B_ET.npz]","[../outputs/SegResNet/pred_segs/train_pred_segs/pred_100132B_TC.npz, ../outputs/SegResNet/pred_segs/train_pred_segs/pred_100132B_WT.npz, ../outputs/SegResNet/pred_segs/train_pred_segs/pred_100132B_ET.npz]","[../outputs/UNETR/pred_segs/train_pred_segs/pred_100132B_TC.npz, ../outputs/UNETR/pred_segs/train_pred_segs/pred_100132B_WT.npz, ../outputs/UNETR/pred_segs/train_pred_segs/pred_100132B_ET.npz]","[../outputs/gt_segs/train_gt_segs/gt_100132B_TC.npz, ../outputs/gt_segs/train_gt_segs/gt_100132B_WT.npz, ../outputs/gt_segs/train_gt_segs/gt_100132B_ET.npz]"
3,100212A,"[../outputs/AHNet/pred_segs/train_pred_segs/pred_100212A_TC.npz, ../outputs/AHNet/pred_segs/train_pred_segs/pred_100212A_WT.npz, ../outputs/AHNet/pred_segs/train_pred_segs/pred_100212A_ET.npz]","[../outputs/UNet/pred_segs/train_pred_segs/pred_100212A_TC.npz, ../outputs/UNet/pred_segs/train_pred_segs/pred_100212A_WT.npz, ../outputs/UNet/pred_segs/train_pred_segs/pred_100212A_ET.npz]","[../outputs/SegResNet/pred_segs/train_pred_segs/pred_100212A_TC.npz, ../outputs/SegResNet/pred_segs/train_pred_segs/pred_100212A_WT.npz, ../outputs/SegResNet/pred_segs/train_pred_segs/pred_100212A_ET.npz]","[../outputs/UNETR/pred_segs/train_pred_segs/pred_100212A_TC.npz, ../outputs/UNETR/pred_segs/train_pred_segs/pred_100212A_WT.npz, ../outputs/UNETR/pred_segs/train_pred_segs/pred_100212A_ET.npz]","[../outputs/gt_segs/train_gt_segs/gt_100212A_TC.npz, ../outputs/gt_segs/train_gt_segs/gt_100212A_WT.npz, ../outputs/gt_segs/train_gt_segs/gt_100212A_ET.npz]"
4,100243B,"[../outputs/AHNet/pred_segs/train_pred_segs/pred_100243B_TC.npz, ../outputs/AHNet/pred_segs/train_pred_segs/pred_100243B_WT.npz, ../outputs/AHNet/pred_segs/train_pred_segs/pred_100243B_ET.npz]","[../outputs/UNet/pred_segs/train_pred_segs/pred_100243B_TC.npz, ../outputs/UNet/pred_segs/train_pred_segs/pred_100243B_WT.npz, ../outputs/UNet/pred_segs/train_pred_segs/pred_100243B_ET.npz]","[../outputs/SegResNet/pred_segs/train_pred_segs/pred_100243B_TC.npz, ../outputs/SegResNet/pred_segs/train_pred_segs/pred_100243B_WT.npz, ../outputs/SegResNet/pred_segs/train_pred_segs/pred_100243B_ET.npz]","[../outputs/UNETR/pred_segs/train_pred_segs/pred_100243B_TC.npz, ../outputs/UNETR/pred_segs/train_pred_segs/pred_100243B_WT.npz, ../outputs/UNETR/pred_segs/train_pred_segs/pred_100243B_ET.npz]","[../outputs/gt_segs/train_gt_segs/gt_100243B_TC.npz, ../outputs/gt_segs/train_gt_segs/gt_100243B_WT.npz, ../outputs/gt_segs/train_gt_segs/gt_100243B_ET.npz]"


## Load Val Data

In [4]:
subject_ids = pd.read_csv('../data/VAL.csv')['SubjectID'].values

ah_segs, unet_segs, segresnet_segs, untr_segs, gt_segs = [], [], [], [], []
for sid in subject_ids:
    ah_channels, unet_channels, segresnet_channels, untr_channels, gt_channels = [], [], [], [], []
    for channel in ['TC', 'WT', 'ET']:
        ah_channels.append(f'../outputs/AHNet/pred_segs/val_pred_segs/pred_{sid}_{channel}.npz')
        unet_channels.append(f'../outputs/UNet/pred_segs/val_pred_segs/pred_{sid}_{channel}.npz')
        segresnet_channels.append(f'../outputs/SegResNet/pred_segs/val_pred_segs/pred_{sid}_{channel}.npz')
        untr_channels.append(f'../outputs/UNETR/pred_segs/val_pred_segs/pred_{sid}_{channel}.npz')
        gt_channels.append(f'../outputs/gt_segs/val_gt_segs/gt_{sid}_{channel}.npz')
    
    ah_segs.append(ah_channels)
    unet_segs.append(unet_channels)
    segresnet_segs.append(segresnet_channels)
    untr_segs.append(untr_channels)
    gt_segs.append(gt_channels)

# Dataframe
val_df = pd.DataFrame()
val_df['SubjectID'] = subject_ids
val_df['AHNet'] = ah_segs
val_df['UNet'] = unet_segs
val_df['SegResNet'] = segresnet_segs
val_df['UNETR'] = untr_segs
val_df['GT'] = gt_segs

val_df.head()

Unnamed: 0,SubjectID,AHNet,UNet,SegResNet,UNETR,GT
0,100237A,"[../outputs/AHNet/pred_segs/val_pred_segs/pred_100237A_TC.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100237A_WT.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100237A_ET.npz]","[../outputs/UNet/pred_segs/val_pred_segs/pred_100237A_TC.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100237A_WT.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100237A_ET.npz]","[../outputs/SegResNet/pred_segs/val_pred_segs/pred_100237A_TC.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100237A_WT.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100237A_ET.npz]","[../outputs/UNETR/pred_segs/val_pred_segs/pred_100237A_TC.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100237A_WT.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100237A_ET.npz]","[../outputs/gt_segs/val_gt_segs/gt_100237A_TC.npz, ../outputs/gt_segs/val_gt_segs/gt_100237A_WT.npz, ../outputs/gt_segs/val_gt_segs/gt_100237A_ET.npz]"
1,100219A,"[../outputs/AHNet/pred_segs/val_pred_segs/pred_100219A_TC.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100219A_WT.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100219A_ET.npz]","[../outputs/UNet/pred_segs/val_pred_segs/pred_100219A_TC.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100219A_WT.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100219A_ET.npz]","[../outputs/SegResNet/pred_segs/val_pred_segs/pred_100219A_TC.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100219A_WT.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100219A_ET.npz]","[../outputs/UNETR/pred_segs/val_pred_segs/pred_100219A_TC.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100219A_WT.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100219A_ET.npz]","[../outputs/gt_segs/val_gt_segs/gt_100219A_TC.npz, ../outputs/gt_segs/val_gt_segs/gt_100219A_WT.npz, ../outputs/gt_segs/val_gt_segs/gt_100219A_ET.npz]"
2,100363A,"[../outputs/AHNet/pred_segs/val_pred_segs/pred_100363A_TC.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100363A_WT.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100363A_ET.npz]","[../outputs/UNet/pred_segs/val_pred_segs/pred_100363A_TC.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100363A_WT.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100363A_ET.npz]","[../outputs/SegResNet/pred_segs/val_pred_segs/pred_100363A_TC.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100363A_WT.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100363A_ET.npz]","[../outputs/UNETR/pred_segs/val_pred_segs/pred_100363A_TC.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100363A_WT.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100363A_ET.npz]","[../outputs/gt_segs/val_gt_segs/gt_100363A_TC.npz, ../outputs/gt_segs/val_gt_segs/gt_100363A_WT.npz, ../outputs/gt_segs/val_gt_segs/gt_100363A_ET.npz]"
3,100354A,"[../outputs/AHNet/pred_segs/val_pred_segs/pred_100354A_TC.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100354A_WT.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100354A_ET.npz]","[../outputs/UNet/pred_segs/val_pred_segs/pred_100354A_TC.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100354A_WT.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100354A_ET.npz]","[../outputs/SegResNet/pred_segs/val_pred_segs/pred_100354A_TC.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100354A_WT.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100354A_ET.npz]","[../outputs/UNETR/pred_segs/val_pred_segs/pred_100354A_TC.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100354A_WT.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100354A_ET.npz]","[../outputs/gt_segs/val_gt_segs/gt_100354A_TC.npz, ../outputs/gt_segs/val_gt_segs/gt_100354A_WT.npz, ../outputs/gt_segs/val_gt_segs/gt_100354A_ET.npz]"
4,100303A,"[../outputs/AHNet/pred_segs/val_pred_segs/pred_100303A_TC.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100303A_WT.npz, ../outputs/AHNet/pred_segs/val_pred_segs/pred_100303A_ET.npz]","[../outputs/UNet/pred_segs/val_pred_segs/pred_100303A_TC.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100303A_WT.npz, ../outputs/UNet/pred_segs/val_pred_segs/pred_100303A_ET.npz]","[../outputs/SegResNet/pred_segs/val_pred_segs/pred_100303A_TC.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100303A_WT.npz, ../outputs/SegResNet/pred_segs/val_pred_segs/pred_100303A_ET.npz]","[../outputs/UNETR/pred_segs/val_pred_segs/pred_100303A_TC.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100303A_WT.npz, ../outputs/UNETR/pred_segs/val_pred_segs/pred_100303A_ET.npz]","[../outputs/gt_segs/val_gt_segs/gt_100303A_TC.npz, ../outputs/gt_segs/val_gt_segs/gt_100303A_WT.npz, ../outputs/gt_segs/val_gt_segs/gt_100303A_ET.npz]"


## Datasets

In [5]:
# Transforms
transforms = Transforms(seed=33)
    
# Datasets
train_dataset = EnsembleDataset(train_df.drop(columns = ['SubjectID']), transform=transforms.train_ensemble((240,240,160)), size = None)
val_dataset = EnsembleDataset(val_df.drop(columns = ['SubjectID']), transform=transforms.val_ensemble(), size = None)

# Samplers
train_sampler = SequentialSampler(train_dataset)
val_sampler = SequentialSampler(val_dataset)

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size = 1, shuffle = False, sampler = train_sampler)
val_loader = DataLoader(val_dataset, batch_size = 1, shuffle = False, sampler = val_sampler)

image, label, og_shape = train_dataset[0]
print(image.shape, label.shape, og_shape)

torch.Size([9216000, 12]) torch.Size([9216000, 3]) torch.Size([3, 240, 240, 160])


## Training & Validation

In [6]:
# Params
val_interval = 1
threshold=0.9
lr=0.001
wd=0.0001
max_epochs = 10
    
input_dim = 12
output_dim = 3
model = LogisticRegression(input_dim, output_dim)

# Post Transforms, Optimizer, Loss, Evaluation Metric
trans = AsDiscrete(threshold=threshold)
loss_function = DiceLoss(smooth_nr=1e-5, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)
dice_metric = DiceMetric(include_background=True, reduction="mean")
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

In [7]:
# TRAIN
for epoch in range(max_epochs):

    model.train()

    print(f"Epoch {epoch+1}/{max_epochs}")
    print('TRAIN')

    epoch_loss = 0
    step = 0
    for i, batch in tqdm(enumerate(train_loader)):
        # Load
        step += 1
        images, targets, _ = batch   

        # Flatten
        images = [img.contiguous().view(3, -1).t() for img in images]
        images = torch.cat(images, dim=1)
        target = target.contiguous().view(3, -1).t()
        images, targets = images.squeeze(0), targets.squeeze(0)    

        # Logistic Regression
        outputs = model(images)

        # Loss
        loss = loss_function(outputs, targets)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    
    lr_scheduler.step()
    epoch_loss /= step
    print(f"Loss: {epoch_loss}")

    # VAL
    if (i+1) % val_interval == 0:
        model.eval()
        print('---------------------------------------')
        print('VAL')
        dice_values = []
        for batch in tqdm(val_loader):

            # Load
            images, target, og_shape = batch   
            images = [img.contiguous().view(3, -1).t() for img in images]
            images = torch.cat(images, dim=1)
            target = target.contiguous().view(3, -1).t()

            # Predict 
            outputs = inference(images, 12, model, VAL_AMP=False)
            img = trans(outputs).squeeze(0)
            target = target.squeeze(0) 

            # To OG Shape
            img = img.mT.reshape(1, og_shape[0], og_shape[1], og_shape[2], og_shape[3])
            target = target.mT.reshape(1, og_shape[0], og_shape[1], og_shape[2], og_shape[3])

            # Dice Metric
            dice_metric(y_pred=img, y=target)
            dice_score = dice_metric.aggregate()
            dice_values.append(dice_score.item())
            dice_metric.reset()

        # Results
        print('Dice Scores:', np.mean(dice_values))
        print('---------------------------------------')

    torch.save(model.state_dict(), f'../outputs/Ensemble/LogRegCheckpoints/LogisticRegression_{epoch+1}.pt')

Epoch 1/10
TRAIN


248it [13:43,  3.32s/it]


Loss: 0.9922427636000418
---------------------------------------
VAL


100%|██████████| 31/31 [01:39<00:00,  3.20s/it]


Dice Scores: 0.4762267150856074
---------------------------------------
Epoch 2/10
TRAIN


248it [14:12,  3.44s/it]


Loss: 0.9910916077994532
---------------------------------------
VAL


100%|██████████| 31/31 [01:50<00:00,  3.57s/it]


Dice Scores: 0.5583472551717874
---------------------------------------
Epoch 3/10
TRAIN


248it [13:48,  3.34s/it]


Loss: 0.9904430080806056
---------------------------------------
VAL


100%|██████████| 31/31 [01:50<00:00,  3.57s/it]


Dice Scores: 0.555258952922398
---------------------------------------
Epoch 4/10
TRAIN


248it [14:26,  3.49s/it]


Loss: 0.9899060622819008
---------------------------------------
VAL


100%|██████████| 31/31 [01:50<00:00,  3.56s/it]


Dice Scores: 0.5467652793252661
---------------------------------------
Epoch 5/10
TRAIN


248it [16:23,  3.97s/it]


Loss: 0.989539108449413
---------------------------------------
VAL


100%|██████████| 31/31 [02:30<00:00,  4.85s/it]


Dice Scores: 0.4964453305808767
---------------------------------------
Epoch 6/10
TRAIN


248it [16:14,  3.93s/it]


Loss: 0.9892987224363512
---------------------------------------
VAL


100%|██████████| 31/31 [02:09<00:00,  4.19s/it]


Dice Scores: 0.4934465692588879
---------------------------------------
Epoch 7/10
TRAIN


248it [15:50,  3.83s/it]


Loss: 0.9891094577408606
---------------------------------------
VAL


100%|██████████| 31/31 [02:08<00:00,  4.14s/it]


Dice Scores: 0.49063856136654654
---------------------------------------
Epoch 8/10
TRAIN


248it [15:48,  3.82s/it]


Loss: 0.9890163855687264
---------------------------------------
VAL


100%|██████████| 31/31 [02:07<00:00,  4.12s/it]


Dice Scores: 0.48800478848598655
---------------------------------------
Epoch 9/10
TRAIN


248it [16:00,  3.87s/it]


Loss: 0.9889847227642613
---------------------------------------
VAL


100%|██████████| 31/31 [02:06<00:00,  4.07s/it]


Dice Scores: 0.48570411478079134
---------------------------------------
Epoch 10/10
TRAIN


248it [15:42,  3.80s/it]


Loss: 0.9889660208455978
---------------------------------------
VAL


100%|██████████| 31/31 [02:04<00:00,  4.01s/it]

Dice Scores: 0.4850154718564403
---------------------------------------



