### Importing packages

In [None]:
import os, glob
from typing import Optional

# PyTorch
import torch
torch.multiprocessing.set_sharing_strategy('file_system')
from torch.backends import cudnn
cudnn.benchmark = True

# Monai
import monai
from monai.data.dataset import Dataset
from monai.metrics import DiceMetric, HausdorffDistanceMetric, SurfaceDistanceMetric, CumulativeIterationMetric
from monai.data.utils import pad_list_data_collate
from monai.losses.dice import DiceCELoss
from monai.data.dataloader import DataLoader
from monai.inferers.utils import sliding_window_inference
import monai.transforms as transform
from monai.transforms import AsDiscrete

# Other
import numpy as np

# Local
from MAGIC import MAGIC_framework
from MagicianAssistant import argmax_with_multiOutput, split_groups_transform, channel_to_stacked_binary_transform

device = torch.device('cuda:0')

### Loading the data sources / preprocessing for testing

In [None]:
# Defining the modality to look at
mode = "VR"
mode_options = ["VR", "simCT", "CCTA"]

# Defining the data source
master_src = "path/to/data/location"

# Defining the image keys
image_keys = ['image']
real_keys = image_keys + ['label']
# 'g{i}' represent the sub-groups used in multi-task learning
all_keys = image_keys + [f'g{i}' for i in range(4)]

# Defining pre-processing steps
preprocessing_transforms = [
    transform.LoadImaged(real_keys),
    split_groups_transform(
        target_key='label',
        group_idxs=[[0], [1, 2, 3, 4, 5, 6, 7, 8, 9], [10, 11, 12, 13, 14, 15, 16, 17], [18, 19]],
        out_names=['g0', 'g1', 'g2', 'g3'],
        stacking_order=None,
        ),
    channel_to_stacked_binary_transform('label', 'blabel'),
    transform.NormalizeIntensityd(image_keys, nonzero=True), #z-score normalization that helps consistancy with patient to patient and brings mean to zero to help with deep learning
]

# ----------------------------------------------------------------
# For ViewRay
# ----------------------------------------------------------------
if mode.upper() == "VR":
    main_path = os.path.join(master_src, "VR")

    testing_data = []
    HF_testing_pids = [1, 18, 29, 35, 36]
    testing_pids = [f'HF_VR_{pid:02d}' for pid in HF_testing_pids]
    UW_testing_pids = [1, 2, 3, 4, 5] #, 6, 7, 8, 9, 10]
    testing_pids += [f'UW_VR_{pid:02d}' for pid in UW_testing_pids]

    for pid in testing_pids:
        image_paths = sorted(glob.glob(os.path.join(main_path, '*', f"{pid}_SIM.IMAGE.nii.gz")))
        label_paths = sorted(glob.glob(os.path.join(main_path, '*', f"{pid}_SIM.LABEL.nii.gz")))
        for i in range (len(image_paths)):
            testing_data.append({'image': image_paths[i], 'label': label_paths[i]})

    testing_dataset = Dataset(testing_data, transform.Compose(preprocessing_transforms))

    VR_testing_dataloader = DataLoader(testing_dataset,  num_workers = 1, batch_size = 1, shuffle = False, collate_fn = pad_list_data_collate, pin_memory = True)

    testing_sets = [
        ('VR', VR_testing_dataloader)
    ]


# ----------------------------------------------------------------
# For simCT
# ----------------------------------------------------------------
elif mode.upper() == 'SIMCT':

    main_path = os.path.join(master_src, "simCT")
    HF_testing_pids = [13, 22, 28, 29, 34]
    testing_pids = [f'HF_simCT_{pid:02d}' for pid in HF_testing_pids]
    UW_testing_pids = [16, 18, 22, 28, 32]
    testing_pids += [f'UW_simCT_{pid:02d}' for pid in UW_testing_pids]

    testing_data = []
    for pid in testing_pids:
        image_paths = sorted(glob.glob(os.path.join(main_path, '*', f"{pid}.IMAGE.nii.gz")))
        label_paths = sorted(glob.glob(os.path.join(main_path, '*', f"{pid}.LABEL.nii.gz")))
        for i in range (len(image_paths)):
            testing_data.append({'image': image_paths[i], 'label': label_paths[i]})


    testing_dataset = Dataset(testing_data, transform.Compose(preprocessing_transforms))
    simCT_testing_dataloader = DataLoader(testing_dataset,  num_workers = 1, batch_size = 1, shuffle = False, collate_fn = pad_list_data_collate, pin_memory = True)
    testing_sets = [
        ('simCT', simCT_testing_dataloader)
    ]

# ----------------------------------------------------------------
# For CCTA
# ----------------------------------------------------------------
elif mode.upper() == "CCTA":
    main_path = os.path.join(master_src, "CCTA")

    UW_testing_pids = [3] # PLACEHOLDER FOR REAL TESTING SET
    testing_pids = [f'UW_CCTA_{pid:03d}' for pid in UW_testing_pids]
    testing_data = []

    for pid in testing_pids:
        image_paths = sorted(glob.glob(os.path.join(main_path, '*', f"{pid}.IMAGE.nii.gz")))
        label_paths = sorted(glob.glob(os.path.join(main_path, '*', f"{pid}.LABEL.nii.gz")))
        for i in range (len(image_paths)):
            testing_data.append({'image': image_paths[i], 'label': label_paths[i]})

        testing_dataset = Dataset(testing_data, transform.Compose(preprocessing_transforms))
        CCTA_testing_dataloader = DataLoader(testing_dataset,  num_workers = 1, batch_size = 1, shuffle = False, collate_fn = pad_list_data_collate, pin_memory = True)
        testing_sets = [
            ('CCTA', CCTA_testing_dataloader)
        ]

else:
    raise ValueError(f'{mode=} not found, arguments include: {mode_options}')

# ----------------------------------------------------------------
print(f'Running {mode=}')

### Setting up basic post-proessing for the prediction output

In [None]:
# Of the combined output, what channels correspond to what group
# Note, each group has it's own set of backgrounds
grps = [
        [0, 1],
        [2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
        [12, 13, 14, 15, 16, 17, 18, 19, 20],
        [21, 22, 23]
    ]

# Defining how to convert from the predicted logits to binary classification
label_OneHot_fns = {i: AsDiscrete(to_onehot=len(grp)) for i, grp in enumerate(grps)}

# Metrics for comparison against the output
dice_fn = DiceMetric(include_background = True, reduction = 'none', get_not_nans = False)
hd95 = HausdorffDistanceMetric(include_background=True, percentile=95, reduction='none')
msd = SurfaceDistanceMetric(include_background=True, reduction='none')
metric_fns: dict[str, CumulativeIterationMetric] = {
    'dice': dice_fn,
    'hd95': hd95,
    'msd': msd
}

### Selecting the model to use

In [None]:
# Path to the model
model_path = 'Experiments/MAGIC_ForWorking/Best_Val_dice'

# Loading the model & send to cuda
magic = MAGIC_framework.load_magic(model_path).to(device)

### Running the model on a dataset

In [None]:
# Tracking the images, predictions, and labels
images, predictions, labels = [], [], []

with torch.no_grad():
    # Set up in way that matches the training loop
    for modality, testing_dataloader in testing_sets:
        print(f'Looking at {modality} inputs')
        for k in metric_fns.keys(): metric_fns[k].reset()

        # Testing loop
        for i, ds in enumerate(testing_dataloader):
            print(f'  Working on {i=}')

            # Pull the data from the dataloader
            image = ds['image'].to(device)
            # The label is stored as N differend sub-groups for the multi-task learning
            # Apply a label OneHot transform to each individual label and stack them into a combined label
            # Note: Background is NOT included
            master_label: list[torch.Tensor] = torch.concatenate([label_OneHot_fns[i](ds[f'g{i}'][0])[1:].unsqueeze(0) for i in range(len(grps))], dim=1).to(device)

            # Load the modality-specific model and run a sliding window inference
            full_prediction = sliding_window_inference(
                inputs=image,
                roi_size=magic.config['roi_size'],
                sw_batch_size=2,
                predictor=magic.tricks[modality],
                overlap=0.9
            )

            # Perform an argmax on each output group within the combined prediction
            # Remove the background from each group and stack back into a combined prediction
            full_prediction = argmax_with_multiOutput(full_prediction[0], [len(x) for x in grps])

            # Record the preictions / labels / images if needed
            images.append(image.cpu().numpy()[0, 0])
            labels.append(master_label.cpu().numpy())
            predictions.append(full_prediction.cpu().numpy())
            
            # Run the predictions
            for k in metric_fns.keys(): metric_fns[k](full_prediction.unsqueeze(0), master_label)

### Model Performance

In [None]:
aggregates = {k: metric_fns[k].aggregate()[:10] for k in metric_fns.keys()}
_hd95 = aggregates['hd95'].detach().cpu().numpy() * 1.5
_msd = aggregates['msd'].detach().cpu().numpy() * 1.5
_dice = aggregates['dice'].detach().cpu().numpy()
name_map = {0: 'WH', 1: 'RA', 2: 'LA', 3: 'RV', 4: 'LV', 5: 'AA', 6: 'SVC', 7: 'IVC', 8: 'PA', 9: 'PVs', 10: 'LMCA', 11: 'LADA', 12: 'RCA', 13: 'LCx', 14: 'V-AV', 15: 'V-PV', 16: 'V-TV', 17: 'V-MV', 18: 'N-SA', 19: 'N-AV'}
line_report = "{:>5} || {:0.3f} ± {:0.3f} || {:6.3f} ± {:6.3f} mm || {:6.3f} ± {:6.3f} mm"
top_line = f'{"Str":>5} || {"Dice":^13} || {"HD95":^18} || {"MSD":^18}'
print("{:^{}}".format('VR Inputs', len(top_line)))
print(top_line)
print('='*len(top_line))
for i in range(20): print(line_report.format(name_map[i], np.nanmean(_dice[:, i]), np.nanstd(_dice[:, i]), np.nanmean(_hd95[:, i]), np.nanstd(_hd95[:, i]), np.nanmean(_msd[:, i]), np.nanstd(_msd[:, i])))