## Imports

In [None]:
# Imports - Inherent
import os, glob, datetime, math as m, logging, warnings
from pathlib import Path

# Imports - Base Modules
import torch, pandas as pd, numpy as np, matplotlib.pyplot as plt
from PIL import Image

from monai.metrics.meandice import DiceMetric
from monai.metrics import HausdorffDistanceMetric, SurfaceDistanceMetric
from monai.data.dataloader import DataLoader
from monai.data.utils import pad_list_data_collate

# Imports - Robert's Torch Manager
from torchmanager_monai import metrics, Manager as UntManager
from torchmanager_monai import Manager as UntManager

# Imports - Our Code
from stuff import data

# from SpatialHDMetric import SpatialHausdorffDistanceMetric
# from metrics.selective_input import simple_dice3D, grouped_metric


from monai.networks.nets.unet import UNet
base_model = UNet
class model_dict(base_model):
    def forward(self, x):
        y = super().forward(x)
        return {'out': y} if self.training else y

device = torch.device('cuda:3')

torch.multiprocessing.set_sharing_strategy('file_system')
warnings.filterwarnings("ignore")

from monai.inferers.utils import sliding_window_inference
save_dir = None

## Loading the trained Manager/Model

In [None]:
# ckpt_path = "experiments/Deeper_nnUNETwSD_Oct4_1.exp/best_dice.model"

# Best nnUNETwSD-Sim
ckpt_path = 'experiments/Deep_nnUNetwSD_Best_NoSims.exp/best_dice.model'


# Best nnUNETwSD-Sim+FX
# ckpt_path = 'experiments/Deeper_nnUNETwSD_Oct3_1.exp/best_dice.model'

# Best UNET
# ckpt_path = "experiments/UNET_onSortedInterpolatedData_Aug15.exp/best_dice.model"

# Loading the Manager
manager = UntManager.from_checkpoint(ckpt_path, map_location=torch.device('cpu'))
print(f'Running {ckpt_path}\nSaved Epoch: {manager.current_epoch}')

# # initialize the dice metric
# dice_fn = metrics.CumulativeIterationMetric(DiceMetric(include_background=False, reduction="none", get_not_nans=False))
# metric_fns: dict[str, metrics.Metric] = {"val_dice": dice_fn}
# manager.metric_fns = metric_fns


In [None]:
# 0.645 0.59 0.64 0.64 0.627 0.621

print(manager.notes)

In [None]:
dice_fn = metrics.CumulativeIterationMetric(DiceMetric(include_background=False, reduction="none", get_not_nans=False), target="out")
# shd_fn = metrics.CumulativeIterationMetric(SpatialHausdorffDistanceMetric(include_background=False, reduction="none", distance_metric='euclidean', percentile=95, sampling=(1.5,1.5,1.5)), target="out")

# dice_fn = metrics.CumulativeIterationMetric(DiceMetric(include_background=False, reduction="none", get_not_nans=False))
# shd_fn = metrics.CumulativeIterationMetric(SpatialHausdorffDistanceMetric(include_background=False, reduction="none", distance_metric='euclidean', percentile=95, sampling=(1.5,1.5,3)))

hd_fn = metrics.CumulativeIterationMetric(HausdorffDistanceMetric(include_background=False, reduction="none", distance_metric='euclidean', percentile=95), target="out")
msd_fn = metrics.CumulativeIterationMetric(SurfaceDistanceMetric(include_background=False, reduction="none", distance_metric='euclidean'), target="out")

# WH_dice = grouped_metric(simple_dice3D, indicies=[i for i in range(1, 13)], target='out')
# GC_dice = grouped_metric(simple_dice3D, indicies=[1, 2, 3, 4], target='out')
# GV_dice = grouped_metric(simple_dice3D, indicies=[5, 6, 7, 8, 9], target='out')
# CA_dice = grouped_metric(simple_dice3D, indicies=[10, 11, 12], target='out')

# WH = sel_DSC(indicies=[i for i in range(1, 13)])
# GC = sel_DSC(indicies=[1, 2, 3, 4])
# GV = sel_DSC(indicies=[5, 6, 7, 8, 9])
# CA = sel_DSC(indicies=[10, 11, 12])

metric_fns = {
    'val_dice': dice_fn,
    # 'val_shd': shd_fn,
    'val_hd': hd_fn,
    'val_msd': msd_fn
    # 'val_WH_dice': WH_dice,
    # 'val_GC_dice': GC_dice,
    # 'val_GV_dice': GV_dice,
    # 'val_CA_dice': CA_dice,
}
manager.metric_fns = metric_fns

## Selecting the data to be Tested

In [None]:
# data_json = '/mnt/data/Summerfield/Data/ViewRay.data/ViewRay.Resampled_1mm/cohort.json'
# src = '/mnt/data/Summerfield/Data/ViewRay.data/ViewRay.Resampled_1mm'
# num_workers = 10
# dataset_configuration = {
#     'src': src,
#     'data_json': data_json,
#     'fold': 1,
#     'roi_size': (200, 200, 200),
#     'img_size': (144, 144, 144),
#     'cached': True,
#     'cache_num': (10, 10),
#     'num_samples': 4,
#     'num_workers': num_workers,
#     'sim_only': True
#     }

# src = "/mnt/data/Summerfield/Data/CasePrediction/to_model"
# data_json = "/mnt/data/Summerfield/Data/CasePrediction/to_model/cohort.json"
src = '/mnt/data/Summerfield/Data/ViewRay.data/1.5mm_volumes.data'
data_json="/mnt/data/Summerfield/Data/ViewRay.data/1.5mm_volumes.data/cohort.json"
num_workers = 10
dataset_configuration = {
    'src': src,
    'data_json': data_json,
    'fold': 1,
    'roi_size': (128, 128, 128),
    'img_size': (96, 96, 96),
    'cached': True,
    'cache_num': (10, 10),
    'num_samples': 4,
    'num_workers': num_workers,
    'sim_only': True,
    }

# # UW external TEST
# src = "/mnt/data/Summerfield/Data/UWVR_Cardiac_MRsimOnly-Processed/for_model"
# data_json = "/mnt/data/Summerfield/Data/UWVR_Cardiac_MRsimOnly-Processed/for_model/cohort.json"
# num_workers = 10
# dataset_configuration = {
#     'src': src,
#     'data_json': data_json,
#     'fold': 1,
#     'roi_size': (128, 128, 128),
#     'img_size': (96, 96, 96),
#     'cached': True,
#     'cache_num': (10, 10),
#     'num_samples': 4,
#     'num_workers': num_workers,
#     'sim_only': True,
#     }

_, test_ds, paths = data.load_ViewRay(
        testing_or_training='testing',  
        logger=None,
        return_testing_paths=True,
        **dataset_configuration
)
testing_dataset = DataLoader(test_ds, batch_size=1, collate_fn=pad_list_data_collate, num_workers=num_workers, pin_memory=True)

## Running the model and recording results

In [None]:
# Getting the prediction masks
# predictions = manager.predict(testing_dataset, device=device, show_verbose=True)
# Getting the metrics
manager.test(testing_dataset, device=device, show_verbose=True)

In [None]:
preds, imgs = [], []

for pred, img in zip(predictions, testing_dataset):
    pred = pred.detach().cpu()
    img = img['image'].detach().cpu()[:, 0]

    one_hot_pred = torch.argmax(pred, dim=1)
    print(one_hot_pred.shape, img.shape)

    imgs.append(img)
    preds.append(one_hot_pred)


In [None]:
n = 0

img = imgs[n][0]
WL = 0.4
_max = img.max() * WL
img = np.where(img > _max, _max, img)

pred = preds[n][0]

mask = np.ma.masked_where(pred == 0, pred)
for i in range(128):
    fig, axs = plt.subplots(ncols=2, dpi=200)
    
    ax: plt.Axes = axs[0]
    ax.imshow(img[..., i], cmap='gray', vmin=img.min(), vmax=img.max())
    ax.set_title(f'slice {i}')
    ax.axis('off')
    
    ax: plt.Axes = axs[1]
    ax.imshow(img[..., i], cmap='gray', vmin=img.min(), vmax=img.max())
    ax.imshow(mask[..., i], cmap='Reds', vmin=mask.min(), vmax=mask.max())
    ax.set_title(f'Prediction')
    ax.axis('off')

    plt.show()



In [None]:
np_img = np.stack(imgs, axis=0)[0]
np_preds = np.stack(preds, axis=0)[0]
# print(np_preds.shape)

np.save('to_Carri-IMG.npy', arr=np_img, allow_pickle=True)
np.save('to_Carri-PRED.npy', arr=np_preds, allow_pickle=True)

# np.save('UWVR_Cardiac_FULL_nnUNETwSD_IMGS.npy', arr=np_img, allow_pickle=True)
# np.save('UWVR_Cardiac_FULL_nnUNETwSD_PREDS.npy', arr=np_preds, allow_pickle=True)

# np.save('UWVR_Cardiac_nnUNETwSD_Oct3_IMGS.npy', arr=np_img, allow_pickle=True)
# np.save('UWVR_Cardiac_nnUNETwSD_Oct3_PREDS.npy', arr=np_preds, allow_pickle=True)

In [None]:
print(np_img.shape, np_preds.shape)

In [None]:
dice = manager.metric_fns['val_dice'].results.type(torch.float).squeeze(1).cpu().numpy() # No background is calcualted
# hd = manager.metric_fns['val_shd'].results.type(torch.float).squeeze(1).cpu().numpy() # No background is calcualted
hd = manager.metric_fns['val_hd'].results.type(torch.float).squeeze(1).cpu().numpy() # No background is calcualted
mda = manager.metric_fns['val_msd'].results.type(torch.float).squeeze(1).cpu().numpy() # No background is calcualted


print(dice.shape, type(dice))

print(hd.shape, mda.shape, dice.shape)

model_str = 'UNETR_2'
np.save(f'ViewRay_Predictions/{model_str}_dice_metrics.npy', dice, allow_pickle=True)
np.save(f'ViewRay_Predictions/{model_str}_HD_metrics.npy', hd, allow_pickle=True)
np.save(f'ViewRay_Predictions/{model_str}_MDA_metrics.npy', mda, allow_pickle=True)

In [None]:
# Getting a list of the testing images / labels
imgs, labs = [], []
for _set in testing_dataset:
    img, lab = _set['image'][0, 0], _set['label'][0, 0]
    imgs.append(img)
    labs.append(lab)
    
preds = [pred[0] for pred in predictions]

In [None]:
preds = torch.stack(preds, dim=0).cpu().numpy()
imgs = torch.stack(imgs, dim=0).cpu().numpy()
labs = torch.stack(labs, dim=0).cpu().numpy()

In [None]:
print(preds.shape)
print(imgs.shape)
print(labs.shape)

In [None]:
n = 'nnUNET_WITH_SD'
np.save(f'{n}_imgs', imgs, allow_pickle=True)
np.save(f'{n}_labs', labs, allow_pickle=True)
np.save(f'{n}_preds', preds, allow_pickle=True)


## Presenting Results

#### Metrics - DSC values

In [None]:
cout = print
# cout = results_logger.info

structures = ['RA','LA','RV','LV','AA','SVC','IVC','PA','PV','LMCA','LADA','RCA']
cout(f'DSC for ViewRay patients (n={len(testing_dataset)})')

dice = manager.metric_fns['val_dice'].results.type(torch.float).squeeze(1).cpu().numpy() # No background is calcualted
# dice = manager.metric_fns['val_dice'].results.type(torch.float).squeeze(1).cpu()[:, 1:].numpy()
sub_std = np.std(dice, axis=0)

p_dice = np.nanmean(dice, axis=1)
# print(p_dice)
cout(f'Total average DSC without background = {dice.mean():0.4} ± {p_dice.std():0.4f}')
if save_dir: np.save(os.path.join(save_dir, 'DSC.metrics.npy'), dice)

subs = np.nanmean(dice, axis=0)
no_WH = subs[1:].mean()
cout('Breakdown by substructure: (mean ± std)')
[cout(f'\tSub {structures[i]}: \t{subs[i]:0.4f} ± {sub_std[i]:0.4f}') for i in range(len(subs))]

print()

# print(dice.shape)
avg_GCs = np.nanmean(dice[:, :4], axis=1)
avg_GVs = np.nanmean(dice[:, 4:9], axis=1)
avg_CAs = np.nanmean(dice[:, 9:], axis=1)
print(f'\tgroup GCs:\t{np.nanmean(avg_GCs):0.4f} ± {np.nanstd(avg_GCs):0.4f}')
print(f'\tgroup GVs:\t{np.nanmean(avg_GVs):0.4f} ± {np.nanstd(avg_GVs):0.4f}')
print(f'\tgroup CAs:\t{np.nanmean(avg_CAs):0.4f} ± {np.nanstd(avg_CAs):0.4f}')
import numpy as np
dice = np.nanmean(dice, axis=1)

best = np.where(dice==np.nanmax(dice))
worst = np.where(dice==np.nanmin(dice))

cout(f'\nBest: {best[0]} -> {dice[best][0]:0.4f}')
cout(f'Worst: {worst[0]} -> {dice[worst][0]:0.4f}')


_files = paths
cout(f'\nAverage DSC for each test volume:')
for i in range(dice.shape[0]):
    _file = os.path.split(_files[i]['image'])[-1]
    cout(f'[{i:02d}]: {_file} -> {dice[i]:0.4f}')
    # cout(f'[{i:02d}]: -> {dice[i]:0.4f}')
    # cout(f'Test Patient {i:02d} -> {dice[i]:0.4f}')

In [None]:
with open('EXAMPLE_nnUNET/nnUNET_bestworst.txt', 'w') as f:
    f.write(f'Best: {best[0][0]} -> {dice[best][0]:0.4f}\n')
    f.write(f'Worst: {worst[0][0]} -> {dice[worst][0]:0.4f}')

In [None]:
cout = print
# cout = results_logger.info

structures = ['RA','LA','RV','LV','AA','SVC','IVC','PA','PV','LMCA','LADA','RCA']
cout(f'HD95 for ViewRay patients (n={len(testing_dataset)})')

hd = manager.metric_fns['val_shd'].results.type(torch.float).squeeze(1).cpu().numpy() # No background is calcualted
# dice = manager.metric_fns['val_dice'].results.type(torch.float).squeeze(1).cpu()[:, 1:].numpy()
sub_std = np.nanstd(hd, axis=0)

p_hd = np.nanmean(hd, axis=1)
# print(p_dice)
cout(f'Total average HD95 without background = {np.nanmean(hd):0.4} ± {p_hd.std():0.4f}')
# if save_dir: np.save(os.path.join(save_dir, 'DSC.metrics.npy'), dice)

subs = np.nanmean(hd, axis=0)
no_WH = subs[1:].mean()
cout('Breakdown by substructure: (mean ± std)')
[cout(f'\tSub {structures[i]}: \t{subs[i]:0.4f} ± {sub_std[i]:0.4f}') for i in range(len(subs))]

print()

# print(dice.shape)
avg_GCs = np.nanmean(hd[:, :4], axis=1)
avg_GVs = np.nanmean(hd[:, 4:9], axis=1)
avg_CAs = np.nanmean(hd[:, 9:], axis=1)
print(f'\tgroup GCs:\t{np.nanmean(avg_GCs):0.4f} ± {np.nanstd(avg_GCs):0.4f}')
print(f'\tgroup GVs:\t{np.nanmean(avg_GVs):0.4f} ± {np.nanstd(avg_GVs):0.4f}')
print(f'\tgroup CAs:\t{np.nanmean(avg_CAs):0.4f} ± {np.nanstd(avg_CAs):0.4f}')
import numpy as np
hd = np.nanmean(hd, axis=1)

worst_hd = np.where(hd==np.nanmax(hd))
best_hd = np.where(hd==np.nanmin(hd))

cout(f'\nBest: {best_hd[0]} -> {hd[best_hd][0]:0.4f}')
cout(f'Worst: {worst_hd[0]} -> {hd[worst_hd][0]:0.4f}')

# cout(f'\nAverage DSC for each test volume:')
# for i in range(dice.shape[0]):
#     cout(f'Test Patient {i:02d} -> {dice[i]:0.4f}')

#### Metrics - DSC graph

In [None]:
dice = manager.metric_fns['val_dice'].results.type(torch.float).squeeze(1).cpu().numpy() # No background is calcualted


structures = ['Total','RA','LA','RV','LV','AA','SVC','IVC','PA','PV','LMCA','LADA','RCA']
averages = [np.nanmean(dice)] + list(np.nanmean(dice, axis=0))
stds = [np.nanstd(np.nanmean(dice, axis=1))] + list(np.nanstd(dice, axis=0))

pd.DataFrame({'Structures': structures, 'AVG': averages, 'STD': stds})


%matplotlib inline
fig, (ax, tab) = plt.subplots(dpi=150, figsize=(8, 4), ncols=2, width_ratios=[5,1])



x = [i for i in range(13)]
dice = manager.metric_fns['val_dice'].results.type(torch.float).squeeze(1).cpu().numpy()
p_dice = np.nanmean(dice, axis=1)

box_args = {
    'showmeans': True,
    'widths': 0.5,
    'showfliers': True,
    'patch_artist':True,
    'meanprops': {
        'marker': '.',
        'markerfacecolor': 'w',
        'markeredgecolor': 'k',
        'markersize': 10
        },
    'medianprops': {
        'color': 'k'
        },
    'flierprops': {
        'marker': '.',
        'markerfacecolor': 'k',
        }
    }
box_bkg = ax.boxplot(x=p_dice, positions=x[:1], **box_args)
box_GC = ax.boxplot(x=dice[:, 0:4], positions=x[1:5], **box_args)
box_GV = ax.boxplot(x=dice[:, 4:9], positions=x[5:10], **box_args)
box_CA = ax.boxplot(x=dice[:, 9:], positions=x[10:], **box_args)

for box, color in zip([box_bkg, box_GC, box_GV, box_CA], ['k', 'firebrick', 'royalblue', 'forestgreen']):
    for b in box['boxes']:
        b.set(edgecolor='k', linewidth=1)
        b.set(facecolor=color)

ax.plot([], [], marker='.', markerfacecolor='w', markeredgecolor='k', markersize=8, linestyle='none', label='Average')
ax.plot([], [], linestyle='-', linewidth=1, color='k', label='Median')


ax.legend(fontsize=10, ncols=1, loc = 'upper right', frameon=False, handlelength=1)

ax.plot([0.5, 0.5], [-0.5, 1], c='k', ls='--')
ax.plot([4.5, 4.5], [-0.5, 1], c='k', ls='--')
ax.plot([9.5, 9.5], [-0.5, 1], c='k', ls='--')

ax.set_xlim([-0.5, 12.5])
ax.set_ylim([0, 1])
# ax.set_ylim([0, 1.2])

structures = ['Total','RA','LA','RV','LV','AA','SVC','IVC','PA','PV','LMCA','LADA','RCA']
ax.set_xticklabels(structures, rotation=45)
ax.set_title(f'Dice Similarity Coefficient (DSC' + r'$\uparrow$' + f', n={dice.shape[0]}) [nnUNETwSD, no Fractions]', weight='bold')
# fig.text(0.5, 0.95, 'Dice Similarity Coefficient (DSC, n=12)', ha='center', va='center', weight='bold', fontsize=12)
ax.set_ylabel('DSC (AU)')
# ax.set_ylabel('Worse ' + r'$\leftarrow$' + ' DSC (AU) ' + r'$\rightarrow$' + ' Better')

text_args = {
    'ha': 'center',
    'va': 'center',
    'fontsize': 12,
    'weight': 'normal'
}
# ax.text(2.5, 0.075, 'Great\nChambers', color='firebrick', **text_args)
# ax.text(7, 0.075, 'Great\nVeins', color='royalblue', **text_args)
# ax.text(11, 0.075, 'Coronary\nArteries', color='forestgreen', **text_args)
ax.text(2.5, -0.25, 'Great\nChambers', color='firebrick', **text_args)
ax.text(7, -0.25, 'Great\nVeins', color='royalblue', **text_args)
ax.text(11, -0.25, 'Coronary\nArteries', color='forestgreen', **text_args)


tab.axis('off')
tab.set_yticks([])
tab.set_xticks([])

tab.set_ylim([-2.5, 11.5])
tab.set_xlim([-0.5, 3])
structures = ['Name'] + structures
averages = ['  AVG '] + averages
stds = [' STD '] + stds

k = 'k'
r = 'firebrick'
b = 'royalblue'
g = 'forestgreen'
colors = [k, k, r, r, r, r, b, b, b, b, b, g, g, g]
for i, (n, a, s, c) in enumerate(zip(structures, averages, stds, colors)):
    if i != 0:
        tab.text(0, 11-i, f'{n}:', ha='left', va='center', fontsize=10, color=c)
        tab.text(1.75, 11-i, f'{a:0.3f} ± {s:0.3f}', ha='left', va='center')
    else:
        # tab.text(0, 11-i, f'{n}', ha='left', va='center', fontsize=10)
        tab.text(0.5, 11-i, 'AVG ± STD (AU)', ha='left', va='center')
        # tab.text(1.75, 11-i, f'{a}   {s}', ha='left', va='center')
plt.subplots_adjust(wspace=0)
plt.show()

In [None]:
dice = manager.metric_fns['val_shd'].results.type(torch.float).squeeze(1).cpu().numpy() # No background is calcualted


structures = ['Total','RA','LA','RV','LV','AA','SVC','IVC','PA','PV','LMCA','LADA','RCA']
averages = [np.nanmean(dice)] + list(np.nanmean(dice, axis=0))
stds = [np.nanstd(np.nanmean(dice, axis=1))] + list(np.nanstd(dice, axis=0))

pd.DataFrame({'Structures': structures, 'AVG': averages, 'STD': stds})


%matplotlib inline
fig, (ax, tab) = plt.subplots(dpi=150, figsize=(8, 4), ncols=2, width_ratios=[5,1])



x = [i for i in range(13)]
dice = manager.metric_fns['val_shd'].results.type(torch.float).squeeze(1).cpu().numpy()
p_dice = np.nanmean(dice, axis=1)

box_args = {
    'showmeans': True,
    'widths': 0.5,
    'showfliers': True,
    'patch_artist':True,
    'meanprops': {
        'marker': '.',
        'markerfacecolor': 'w',
        'markeredgecolor': 'k',
        'markersize': 10
        },
    'medianprops': {
        'color': 'k'
        },
    'flierprops': {
        'marker': '.',
        'markerfacecolor': 'k',
        }
    }
box_bkg = ax.boxplot(x=p_dice, positions=x[:1], **box_args)
box_GC = ax.boxplot(x=dice[:, 0:4], positions=x[1:5], **box_args)
box_GV = ax.boxplot(x=dice[:, 4:9], positions=x[5:10], **box_args)

CA_boxes = []
for i in range(9, 12):
    ca = dice[:, i][~np.isnan(dice[:, i])]
    CA_boxes.append(ax.boxplot(x=ca, positions=[i+1], **box_args))

# CAs = dice[:, 9:]#[~np.isnan(dice[:,9:])]
# box_CA = ax.boxplot(x=CAs, positions=x[10:], **box_args)

for box, color in zip([box_bkg, box_GC, box_GV], ['k', 'firebrick', 'royalblue']):
    for b in box['boxes']:
        b.set(edgecolor='k', linewidth=1)
        b.set(facecolor=color)

box_color = 'forestgreen'
for box in CA_boxes:
    for b in box['boxes']:
        b.set(edgecolor='k', linewidth=1)
        b.set(facecolor=box_color)

ax.plot([], [], marker='.', markerfacecolor='w', markeredgecolor='k', markersize=8, linestyle='none', label='Average')
ax.plot([], [], linestyle='-', linewidth=1, color='k', label='Median')


ax.legend(fontsize=10, ncols=1, loc = 'upper right', frameon=False, handlelength=1)

_max_dice = np.nanmax(dice)
ax.plot([0.5, 0.5], [-1*_max_dice, 1.5*_max_dice], c='k', ls='--')
ax.plot([4.5, 4.5], [-1*_max_dice, 1.5*_max_dice], c='k', ls='--')
ax.plot([9.5, 9.5], [-1*_max_dice, 1.5*_max_dice], c='k', ls='--')

ax.set_xlim([-0.5, 12.5])
ax.set_ylim([0, 1.25*_max_dice])
# ax.set_ylim([0, 1.2])
# ax.set_yticks([5*i for i in range(7)])
print(int(_max_dice // 10))
ax.set_yticks([10*i for i in range(int(_max_dice // 10)+1)])
structures = ['Total','RA','LA','RV','LV','AA','SVC','IVC','PA','PV','LMCA','LADA','RCA']
ax.set_xticklabels(structures, rotation=45)
ax.set_title('95% Hausdorf Distance (HD95' + r'$\downarrow$' + f', n={dice.shape[0]}) [nnUNETwSD, no Fractions]', weight='bold')
# fig.text(0.5, 0.95, 'Dice Similarity Coefficient (DSC, n=12)', ha='center', va='center', weight='bold', fontsize=12)
ax.set_ylabel('HD95 (mm)')
# ax.set_ylabel('Worse ' + r'$\leftarrow$' + ' DSC (AU) ' + r'$\rightarrow$' + ' Better')

text_args = {
    'ha': 'center',
    'va': 'bottom',
    'fontsize': 12,
    'weight': 'normal'
}
# ax.text(2.5, -0.1*_max_dice, 'Great\nChambers', color='firebrick', **text_args)
# ax.text(7, -0.1*_max_dice, 'Great\nVeins', color='royalblue', **text_args)
# ax.text(11, -0.1*_max_dice, 'Coronary\nArteries', color='forestgreen', **text_args)
ax.text(2.5, -0.3*(1.25*_max_dice), 'Great\nChambers', color='firebrick', **text_args)
ax.text(7, -0.3*(1.25*_max_dice), 'Great\nVeins', color='royalblue', **text_args)
ax.text(11, -0.3*(1.25*_max_dice), 'Coronary\nArteries', color='forestgreen', **text_args)



tab.axis('off')
tab.set_yticks([])
tab.set_xticks([])

tab.set_ylim([-2.5, 11.5])
tab.set_xlim([-0.5, 3])
structures = ['Name'] + structures
averages = ['  AVG '] + averages
stds = [' STD '] + stds

k = 'k'
r = 'firebrick'
b = 'royalblue'
g = 'forestgreen'
colors = [k, k, r, r, r, r, b, b, b, b, b, g, g, g]
for i, (n, a, s, c) in enumerate(zip(structures, averages, stds, colors)):
    if i != 0:
        tab.text(0, 11-i, f'{n}:', ha='left', va='center', fontsize=10, color=c)
        tab.text(1.75, 11-i, f'{a:0.3f} ± {s:0.3f}', ha='left', va='center')
    else:
        # tab.text(0, 11-i, f'{n}', ha='left', va='center', fontsize=10)
        tab.text(0.5, 11-i, f'AVG ± STD (mm)', ha='left', va='center')
        # tab.text(1.75, 11-i, f'{a}   {s}', ha='left', va='center')
plt.subplots_adjust(wspace=0)
plt.show()

In [None]:
np.save('ViewRay_diceresults.npy', dice, allow_pickle=True)

## Images and Masks

In [None]:
# Getting a list of the testing images / labels
imgs, labs = [], []
for _set in testing_dataset:
    img, lab = _set['image'][0, 0], _set['label'][0, 0]
    imgs.append(img)
    labs.append(lab)

In [None]:
# Getting a list of the testing images / labels
imgs, labs = [], []
for _set in testing_dataset:
    img, lab = _set['image'][0, 0], _set['label'][0, 0]
    imgs.append(img)
    labs.append(lab)
    
preds = [pred[0] for pred in predictions]

In [None]:
preds = torch.stack(preds, dim=0)
print(preds.shape)

In [None]:
import nibabel as nib
#  Plotting the BEST masks
# best_preds = os.path.join(save_dir, 'Prediction.Best')
# os.makedirs(best_preds, exist_ok=True)

best_i = best[0][0]
# Getting the best img / label / mask
img = imgs[best_i]
lab = labs[best_i]

os.makedirs('EXAMPLE_nnUNET', exist_ok=True)
nib.save(img=nib.nifti1.Nifti1Image(img.numpy(), affine=np.eye(4)), filename='EXAMPLE_nnUNET/BestIMG.nii.gz')
nib.save(img=nib.nifti1.Nifti1Image(lab.numpy(), affine=np.eye(4)), filename='EXAMPLE_nnUNET/BestPRD.nii.gz')

print(img.dtype)

pred = predictions[best_i].cpu().detach()
pred = torch.argmax(pred, dim=1)[0].numpy()
pred = np.float32(pred)
print(pred.shape)
nib.save(img=nib.nifti1.Nifti1Image(pred, affine=np.eye(4)), filename='EXAMPLE_nnUNET/BestLAB.nii.gz')


pred = np.ma.masked_where(pred==0, pred)
lab = np.ma.masked_where(lab==0, lab)

for i in range(pred.shape[-1]):
    if np.nansum(pred[..., i]):
        fig, ax = plt.subplots(ncols=3, figsize=(18,6.5))
        ax[0].imshow(img[..., i], cmap='gray', vmin=img.min(), vmax=img.max())
        ax[0].axis('off')
        ax[0].set_title(f'Image, slice {i}', fontsize=24)

        ax[1].imshow(img[..., i], cmap='gray', vmin=img.min(), vmax=img.max())
        ax[1].imshow(pred[..., i], cmap='Reds', vmin=0, vmax=pred.max(), interpolation='none')
        ax[1].axis('off')
        ax[1].set_title(f'Prediction, slice {i}', fontsize=24)

        ax[2].imshow(img[..., i], cmap='gray', vmin=img.min(), vmax=img.max())
        ax[2].imshow(lab[..., i], cmap='Reds', vmin=0, vmax=pred.max(), interpolation='none')
        ax[2].axis('off')
        ax[2].set_title(f'GT, slice {i}', fontsize=24)

        plt.tight_layout()
        plt.subplots_adjust(top=0.9)
        # plt.savefig(os.path.join(best_preds, f'slice.{i:04d}.png'), dpi=150)
        plt.show()

# if save_dir:
#     frames = [Image.open(file) for file in sorted(glob.glob(os.path.join(best_preds, '*.png')))]
#     frames[0].save(os.path.join(save_dir, 'Prediction.Best.gif'), format="GIF", append_images=frames, save_all=True, duration=300, loop=0)

In [None]:
# Plotting the WORST masks
# worst_preds = os.path.join(save_dir, 'Prediction.Worst')
# os.makedirs(worst_preds, exist_ok=True)

worst_ = worst[0][0]
# Getting the best img / label / mask
img = imgs[worst_]
lab = labs[worst_]

os.makedirs('EXAMPLE_nnUNET', exist_ok=True)
nib.save(img=nib.nifti1.Nifti1Image(img.numpy(), affine=np.eye(4)), filename='EXAMPLE_nnUNET/worstIMG.nii.gz')
nib.save(img=nib.nifti1.Nifti1Image(lab.numpy(), affine=np.eye(4)), filename='EXAMPLE_nnUNET/worstPRD.nii.gz')

pred = predictions[worst_].cpu().detach()
pred = torch.argmax(pred, dim=1)[0].numpy()
pred = np.float32(pred)
nib.save(img=nib.nifti1.Nifti1Image(pred, affine=np.eye(4)), filename='EXAMPLE_nnUNET/worstLAB.nii.gz')

pred = np.ma.masked_where(pred==0, pred)
lab = np.ma.masked_where(lab==0, lab)

for i in range(pred.shape[-1]):
    if np.nansum(pred[..., i]):
        fig, ax = plt.subplots(ncols=3, figsize=(18,6.5))
        ax[0].imshow(img[..., i], cmap='gray', vmin=img.min(), vmax=img.max())
        ax[0].axis('off')
        ax[0].set_title(f'Image, slice {i}', fontsize=24)

        ax[1].imshow(img[..., i], cmap='gray', vmin=img.min(), vmax=img.max())
        ax[1].imshow(pred[..., i], cmap='Reds', vmin=0, vmax=pred.max(), interpolation='none')
        ax[1].axis('off')
        ax[1].set_title(f'Prediction, slice {i}', fontsize=24)

        ax[2].imshow(img[..., i], cmap='gray', vmin=img.min(), vmax=img.max())
        ax[2].imshow(lab[..., i], cmap='Reds', vmin=0, vmax=pred.max(), interpolation='none')
        ax[2].axis('off')
        ax[2].set_title(f'GT, slice {i}', fontsize=24)

        plt.tight_layout()
        plt.subplots_adjust(top=0.9)
        # plt.savefig(os.path.join(worst_preds, f'slice.{i:04d}.png'), dpi=150)
        plt.show()

# if save_dir:
#     frames = [Image.open(file) for file in sorted(glob.glob(os.path.join(worst_preds, '*.png')))]
#     frames[0].save(os.path.join(save_dir, 'Prediction.Worst.gif'), format="GIF", append_images=frames, save_all=True, duration=300, loop=0)