# Ensemble Techniques

## Environment Setup

### Imports

In [None]:
import os
os.sys.path.append("../utils")

import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from Transforms import Transforms
from Inference import inference
from monai.losses import DiceLoss
from Datasets import EnsembleDataset
from monai.metrics import DiceMetric
from Models import LogisticRegression
from monai.transforms import AsDiscrete
from torch.utils.data import DataLoader
from torch.utils.data import SequentialSampler   

### Config

In [None]:
seed = 33
pd.set_option("display.max_columns", None)
pd.set_option('display.max_colwidth', None)

## Load Train Data

In [None]:
subject_ids = pd.read_csv('../data/TRAIN.csv')['SubjectID'].values

ahnet_dice = pd.read_csv('../outputs/AHNet/train_scores.csv')['Dice'].values
segresnet_dice = pd.read_csv('../outputs/SegResNet/train_scores.csv')['Dice'].values
unet_dice = pd.read_csv('../outputs/UNet/train_scores.csv')['Dice'].values
unetr_dice = pd.read_csv('../outputs/UNETR/train_scores.csv')['Dice'].values

ah_segs, segresnet_segs, unet_segs, untr_segs, gt_segs = [], [], [], [], []
for sid in subject_ids:
    ah_channels, unet_channels, segresnet_channels, untr_channels, gt_channels = [], [], [], [], []
    for channel in ['TC', 'WT', 'ET']:
        ah_channels.append(f'../outputs/AHNet/pred_segs/train_pred_segs/pred_{sid}_{channel}.npz')
        segresnet_channels.append(f'../outputs/SegResNet/pred_segs/train_pred_segs/pred_{sid}_{channel}.npz')
        unet_channels.append(f'../outputs/UNet/pred_segs/train_pred_segs/pred_{sid}_{channel}.npz')
        untr_channels.append(f'../outputs/UNETR/pred_segs/train_pred_segs/pred_{sid}_{channel}.npz')
        gt_channels.append(f'../outputs/gt_segs/train_gt_segs/gt_{sid}_{channel}.npz')
    
    ah_segs.append(ah_channels)
    segresnet_segs.append(segresnet_channels)
    unet_segs.append(unet_channels)
    untr_segs.append(untr_channels)
    gt_segs.append(gt_channels)

# Dataframe
train_df = pd.DataFrame()
train_df['SubjectID'] = subject_ids
train_df['AHNet'] = ah_segs
train_df['UNet'] = unet_segs
train_df['SegResNet'] = segresnet_segs
train_df['UNETR'] = untr_segs
train_df['GT'] = gt_segs

train_df.head()

## Load Val Data

In [None]:
subject_ids = pd.read_csv('../data/VAL.csv')['SubjectID'].values

ah_segs, unet_segs, segresnet_segs, untr_segs, gt_segs = [], [], [], [], []
for sid in subject_ids:
    ah_channels, unet_channels, segresnet_channels, untr_channels, gt_channels = [], [], [], [], []
    for channel in ['TC', 'WT', 'ET']:
        ah_channels.append(f'../outputs/AHNet/pred_segs/val_pred_segs/pred_{sid}_{channel}.npz')
        unet_channels.append(f'../outputs/UNet/pred_segs/val_pred_segs/pred_{sid}_{channel}.npz')
        segresnet_channels.append(f'../outputs/SegResNet/pred_segs/val_pred_segs/pred_{sid}_{channel}.npz')
        untr_channels.append(f'../outputs/UNETR/pred_segs/val_pred_segs/pred_{sid}_{channel}.npz')
        gt_channels.append(f'../outputs/gt_segs/val_gt_segs/gt_{sid}_{channel}.npz')
    
    ah_segs.append(ah_channels)
    unet_segs.append(unet_channels)
    segresnet_segs.append(segresnet_channels)
    untr_segs.append(untr_channels)
    gt_segs.append(gt_channels)

# Dataframe
val_df = pd.DataFrame()
val_df['SubjectID'] = subject_ids
val_df['AHNet'] = ah_segs
val_df['UNet'] = unet_segs
val_df['SegResNet'] = segresnet_segs
val_df['UNETR'] = untr_segs
val_df['GT'] = gt_segs

val_df.head()

## Datasets

In [None]:
# Transforms
transforms = Transforms(seed=33)
    
# Datasets
train_dataset = EnsembleDataset(train_df.drop(columns = ['SubjectID']), transform=transforms.train_ensemble((100,100,50)), size = None, include_unet = False)
val_dataset = EnsembleDataset(val_df.drop(columns = ['SubjectID']), transform=transforms.val_ensemble(), size = None, include_unet = False)

# Samplers
train_sampler = SequentialSampler(train_dataset)
val_sampler = SequentialSampler(val_dataset)

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size = 1, shuffle = False, sampler = train_sampler)
val_loader = DataLoader(val_dataset, batch_size = 1, shuffle = False, sampler = val_sampler)

image, label, og_shape = train_dataset[0]
print(image.shape, label.shape, og_shape)

## Training & Validation

In [None]:
# Params
val_interval = 1
threshold=0.9
lr=0.001
wd=0.0001
max_epochs = 10
    
input_dim = 9
output_dim = 3
model = LogisticRegression(input_dim, output_dim)

# Post Transforms, Optimizer, Loss, Evaluation Metric
trans = AsDiscrete(threshold=threshold)
loss_function = DiceLoss(smooth_nr=1e-5, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)
dice_metric = DiceMetric(include_background=True, reduction="mean")
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

In [None]:
# TRAIN
for epoch in range(max_epochs):

    model.train()

    print(f"Epoch {epoch+1}/{max_epochs}")
    print('TRAIN')

    epoch_loss = 0
    step = 0
    for i, batch in tqdm(enumerate(train_loader)):
        # Load
        step += 1
        images, targets, _ = batch   

        # Logistic Regression
        outputs = model(images)

        # Loss
        loss = loss_function(outputs, targets)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    
    lr_scheduler.step()
    epoch_loss /= step
    print(f"Loss: {epoch_loss}")

    # VAL
    if (i+1) % val_interval == 0:
        model.eval()
        print('---------------------------------------')
        print('VAL')
        dice_values = []
        for batch in tqdm(val_loader):

            # Load
            images, target, og_shape = batch   

            # Predict 
            outputs = inference(images, 9, model, VAL_AMP=False)
            img = trans(outputs).squeeze(0)
            target = target.squeeze(0) 

            # To OG Shape
            img = img.mT.reshape(1, og_shape[0], og_shape[1], og_shape[2], og_shape[3])
            target = target.mT.reshape(1, og_shape[0], og_shape[1], og_shape[2], og_shape[3])

            # Dice Metric
            dice_metric(y_pred=img, y=target)
            dice_score = dice_metric.aggregate()
            dice_values.append(dice_score.item())
            dice_metric.reset()

        # Results
        print('Dice Scores:', np.mean(dice_values))
        print('---------------------------------------')

    torch.save(model.state_dict(), f'../outputs/EnsembleNU/LogRegCheckpoints/LogisticRegression_{epoch+1}.pt')