In [3]:
import os, glob, datetime
from typing import Optional

# PyTorch
import torch
torch.multiprocessing.set_sharing_strategy('file_system')
from torch.backends import cudnn
cudnn.benchmark = True

# Monai
from monai.data.dataset import Dataset
from monai.metrics import DiceMetric, HausdorffDistanceMetric, SurfaceDistanceMetric, CumulativeIterationMetric
from monai.data.utils import pad_list_data_collate
from monai.data.dataloader import DataLoader
from monai.inferers.utils import sliding_window_inference
import monai.transforms as transform
from monai.transforms import AsDiscrete

# Other
import numpy as np

# Local
from MAGIC import grow_forest

device = torch.device('cuda:1')

In [2]:
src = "MAGIC_experiments/PaperModel/Best_Val_dice"

forest = grow_forest.load_forest(src, pickle_name_forTraining=None).to(device)

ogrps = forest.config['decoder_module'][1]['out_groups']
grps = []
for ngrp in ogrps:
    s = sum([len(l) for l in grps])
    grp = [s + i for i in range(ngrp)]
    grps.append(grp)
print(f'{grps=}')

idx_grps = [[g - i - 1 for g in _grp[1:]] for i, _grp in enumerate(grps)]
print(f"{idx_grps=}")
group_idxs = [i for grp in idx_grps for i in grp]
# grps = [[g + ng + 1 for g in [grp[0]-1] + grp] for ng, grp in enumerate(idx_grps)]
out_groups = [len(grp) for grp in grps]
num_classes_set = [sum(out_groups)]

Loading model info...
4 5 False
4 4
[2, 10, 9, 3]
[2, 10, 9, 3]
[2, 10, 9, 3]
[2, 10, 9, 3]
[2, 10, 9, 3]
Loading the models...
Done.
grps=[[0, 1], [2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20], [21, 22, 23]]
idx_grps=[[0], [1, 2, 3, 4, 5, 6, 7, 8, 9], [10, 11, 12, 13, 14, 15, 16, 17], [18, 19]]


In [4]:
master_src = "/path/to/data/location"

image_keys = ['image']
real_keys = image_keys + ['label']
all_keys = image_keys + [f'g{i}' for i in range(4)]

class split_groups:
    def __init__(self, target_key: str, group_idxs: list[list[int]], out_names: list[str], stacking_order: Optional[list[list[int]]] = None):
        self.tk = target_key
        self.group_idxs = group_idxs
        self.out_names = out_names
        self.stacking_order = stacking_order if stacking_order is not None else [[i for i in range(len(g))] for g in group_idxs]

    def __call__(self, d: dict):
        d = dict(d)
        src_arr = d[self.tk]
        for gidx, gname, gorder in zip(self.group_idxs, self.out_names, self.stacking_order):
            base = np.zeros(src_arr.shape[1:]) if isinstance(src_arr, np.ndarray) else torch.zeros(src_arr.shape[1:], dtype=src_arr.dtype)
            glabel = src_arr[gidx]
            for i in gorder:
                base[glabel[i] == 1] = i + 1
            d[gname] = base[None]
        return d
    
class channel_to_stacked_binary:
    def __init__(self, target_key: str, out_key: str):
        self.tk = target_key
        self.ok = out_key
    
    def __call__(self, d: dict):
        d = dict(d)
        base = d[self.tk]
        d[self.ok] = base.sum(0)[None] > 1
        return d

preprocessing_transforms = [
    transform.LoadImaged(real_keys),
    split_groups(
        target_key='label',
        group_idxs=idx_grps,
        out_names=[f'g{i}' for i in range(len(grps))],
        stacking_order=None,
        ),
    channel_to_stacked_binary('label', 'blabel'),
    transform.NormalizeIntensityd(image_keys, nonzero=True), #z-score normalization that helps consistancy with patient to patient and brings mean to zero to help with deep learning
]

# ----------------------------------------------------------------
# For ViewRay
# ----------------------------------------------------------------
main_path = os.path.join(master_src, "VR")

testing_data = []
HF_testing_pids = [1, 18, 29, 35, 36]
testing_pids = [f'HF_VR_{pid:02d}' for pid in HF_testing_pids]
UW_testing_pids = [1, 2, 3, 4, 5]#, 6, 7, 8, 9, 10]
testing_pids += [f'UW_VR_{pid:02d}' for pid in UW_testing_pids]

for pid in testing_pids:
    image_paths = sorted(glob.glob(os.path.join(main_path, '*', f"{pid}_SIM.IMAGE.nii.gz")))
    label_paths = sorted(glob.glob(os.path.join(main_path, '*', f"{pid}_SIM.LABEL.nii.gz")))
    for i in range (len(image_paths)):
        testing_data.append({'image': image_paths[i], 'label': label_paths[i]})

testing_dataset = Dataset(testing_data, transform.Compose(preprocessing_transforms))
VR_testing_dataloader = DataLoader(testing_dataset,  num_workers = 1, batch_size = 1, shuffle = False, collate_fn = pad_list_data_collate, pin_memory = True)

# ----------------------------------------------------------------
# For simCT
# ----------------------------------------------------------------
main_path = os.path.join(master_src, "simCT")
HF_testing_pids = [13, 22, 28, 29, 34]
testing_pids = [f'HF_simCT_{pid:02d}' for pid in HF_testing_pids]
UW_testing_pids = [16, 18, 22, 28, 32]
testing_pids += [f'UW_simCT_{pid:02d}' for pid in UW_testing_pids]

testing_data = []
for pid in testing_pids:
    image_paths = sorted(glob.glob(os.path.join(main_path, '*', f"{pid}.IMAGE.nii.gz")))
    label_paths = sorted(glob.glob(os.path.join(main_path, '*', f"{pid}.LABEL.nii.gz")))
    for i in range (len(image_paths)):
        testing_data.append({'image': image_paths[i], 'label': label_paths[i]})


testing_dataset = Dataset(testing_data, transform.Compose(preprocessing_transforms))
simCT_testing_dataloader = DataLoader(testing_dataset,  num_workers = 1, batch_size = 1, shuffle = False, collate_fn = pad_list_data_collate, pin_memory = True)

# ----------------------------------------------------------------
# For CCTA
# ----------------------------------------------------------------
main_path = os.path.join(master_src, "CCTA")

UW_testing_pids = [112, 115, 127, 131, 145, 151, 153, 158, 162, 163]

testing_pids = [f'UW_CCTA_{pid:03d}' for pid in UW_testing_pids]
testing_data = []

for pid in testing_pids:
    image_paths = sorted(glob.glob(os.path.join(main_path, '*', f"{pid}.IMAGE.nii.gz")))
    label_paths = sorted(glob.glob(os.path.join(main_path, '*', f"{pid}.LABEL.nii.gz")))
    for i in range (len(image_paths)):
        testing_data.append({'image': image_paths[i], 'label': label_paths[i]})

    testing_dataset = Dataset(testing_data, transform.Compose(preprocessing_transforms))
    CCTA_testing_dataloader = DataLoader(testing_dataset,  num_workers = 1, batch_size = 1, shuffle = False, collate_fn = pad_list_data_collate, pin_memory = True)

testing_sets = [
    ('simCT', simCT_testing_dataloader),
    ('VR', VR_testing_dataloader),
    ('CCTA', CCTA_testing_dataloader),
]

In [5]:
label_OneHot_fns = {i: AsDiscrete(to_onehot=len(grp)) for i, grp in enumerate(grps)}


dice_fn = DiceMetric(include_background = True, reduction = 'none', get_not_nans = False)
hd95 = HausdorffDistanceMetric(include_background=True, percentile=95, reduction='none')
hd = HausdorffDistanceMetric(include_background=True, reduction='none')
msd = SurfaceDistanceMetric(include_background=True, reduction='none')
metric_fns: dict[str, CumulativeIterationMetric] = {
    'dice': dice_fn,
    'hd95': hd95,
    'msd': msd,
}

In [6]:
class GroupSeperate:
    def __init__(self, idxs_groups: list[list[int]]): 
        self.idxs_groups = idxs_groups

    def __call__(self, arr: torch.Tensor) -> tuple[torch.Tensor]:
        return [arr[:, igrp] for igrp in self.idxs_groups]

grp_seperator = GroupSeperate(grps)

In [7]:
times = []
with torch.no_grad():
    for modality, testing_dataloader in testing_sets:
        net_key = f"{modality}_0"
        print(f'Testing {modality=}')
        for k in metric_fns.keys(): metric_fns[k].reset()
            
        for ds in testing_dataloader:
            image = ds['image'].to(device)
            master_labels: list[torch.Tensor] = [label_OneHot_fns[i](ds[f'g{i}'][0])[1:].unsqueeze(0).to(device) for i in range(len(grps))]
            
            start = datetime.datetime.now()
            full_prediction = sliding_window_inference(inputs = image, roi_size=forest.config['roi_size'], sw_batch_size=1, predictor=forest.trees[net_key], overlap=0.9)
            end = datetime.datetime.now()
            times.append(end - start)

            grp_predictions = grp_seperator(full_prediction)

            _predictions = []
            for ii in range(len(grps)):
                _out = label_OneHot_fns[ii](torch.argmax(grp_predictions[ii], dim=1))[1:].unsqueeze(0)
                _predictions.append(_out)
            _predictions = torch.concatenate(_predictions, dim=1)
            _master_labels = torch.concatenate(master_labels, dim=1)

            for k in metric_fns.keys(): 
                metric_fns[k](_predictions.cpu(), _master_labels.cpu())

        aggregates = {k: metric_fns[k].aggregate()[:10] for k in metric_fns.keys()}
        _hd95 = aggregates['hd95'].detach().cpu().numpy() * 1.5
        _msd = aggregates['msd'].detach().cpu().numpy() * 1.5
        _dice = aggregates['dice'].detach().cpu().numpy()
        name_map = {0: 'WH', 1: 'RA', 2: 'LA', 3: 'RV', 4: 'LV', 5: 'AA', 6: 'SVC', 7: 'IVC', 8: 'PA', 9: 'PVs', 10: 'LMCA', 11: 'LADA', 12: 'RCA', 13: 'LCx', 14: 'V-AV', 15: 'V-PV', 16: 'V-TV', 17: 'V-MV', 18: 'N-SA', 19: 'N-AV'}
        line_report = "{:>5} || {:0.2f} ± {:0.2f} || {:4.1f} ± {:4.1f} mm || {:4.1f} ± {:4.1f} mm"
        top_line = f'{"Str":>5} || {"Dice":^11} || {"HD95":^14} || {"MSD":^14}'
        print("{:^{}}".format(f'{modality} Inputs', len(top_line)))
        print(top_line)
        print('='*len(top_line))
        for i in range(20): print(line_report.format(name_map[i], np.nanmean(_dice[:, i]), np.nanstd(_dice[:, i]), np.nanmean(_hd95[:, i]), np.nanstd(_hd95[:, i]), np.nanmean(_msd[:, i]), np.nanstd(_msd[:, i])))
        print()            


Testing modality='simCT'
                      simCT Inputs                      
  Str ||    Dice     ||      HD95      ||      MSD      
   WH || 0.96 ± 0.01 ||  4.3 ±  1.4 mm ||  1.3 ±  0.4 mm
   RA || 0.87 ± 0.06 ||  5.7 ±  2.5 mm ||  1.8 ±  0.9 mm
   LA || 0.88 ± 0.03 ||  5.2 ±  1.4 mm ||  1.6 ±  0.4 mm
   RV || 0.88 ± 0.05 ||  7.0 ±  4.2 mm ||  1.5 ±  0.4 mm
   LV || 0.93 ± 0.02 ||  5.1 ±  2.4 mm ||  1.4 ±  0.4 mm
   AA || 0.88 ± 0.06 ||  4.8 ±  2.8 mm ||  1.4 ±  0.6 mm
  SVC || 0.80 ± 0.08 ||  4.8 ±  3.1 mm ||  1.6 ±  0.8 mm
  IVC || 0.75 ± 0.11 ||  7.3 ±  3.2 mm ||  2.0 ±  1.0 mm
   PA || 0.86 ± 0.04 ||  4.2 ±  1.9 mm ||  1.3 ±  0.5 mm
  PVs || 0.76 ± 0.04 ||  5.1 ±  1.4 mm ||  1.3 ±  0.4 mm
 LMCA || 0.64 ± 0.10 ||  3.7 ±  0.9 mm ||  1.4 ±  0.4 mm
 LADA || 0.59 ± 0.13 ||  8.5 ±  4.1 mm ||  1.8 ±  0.7 mm
  RCA || 0.57 ± 0.16 ||  6.4 ±  2.7 mm ||  1.7 ±  0.8 mm
  LCx || 0.58 ± 0.10 ||  5.2 ±  1.8 mm ||  1.7 ±  0.5 mm
 V-AV || 0.71 ± 0.16 ||  3.8 ±  1.8 mm ||  2.0 ±  1.2 mm
 V-PV 