In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
from ProjectRoot import change_wd_to_project_root
change_wd_to_project_root()

In [None]:
import shutil
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import SimpleITK as sitk
import torch
from tqdm.notebook import tqdm
from monai.inferers import sliding_window_inference
from src.utils import pts_from_mps, mps_from_pts
from src.models.multihead_swinunetr import MultiHeadSwinUNETR
from src.models.nnunet import nnunet_configuration, MultiHeadnnUNet
from tavi_predictor import TAVIPredictor

In [None]:
src = Path('/mnt/hdd/data/FLOTO_ImageCas_Subset')

In [None]:
EXPERIMENTS = {
    'Federated': {
        'UNet': './checkpoints/federated_hps_conditioned_on_heart.pt',
        'SWIN-UNETR': 'checkpoints/hinge_points/swin_unetr_heidelberg_muenster_goettingen_conditioned_on_heart_seg.pt'
    },
    'KD': {
        'UNet': 'checkpoints/kd/nnunet_conditioned_on_heart_seg_heidelberg.pt',
        'SWIN-UNETR': 'checkpoints/kd/swin_unetr_conditioned_on_heart_seg_heidelberg_munich_goettingen_hps_finetuned.pt'
    },
    'Heidelberg': {
        'UNet': 'checkpoints/hinge_points/nnunet_heidelberg_conditioned_on_heart_seg.pt',
        'SWIN-UNETR': 'checkpoints/hinge_points/swin_unetr_heidelberg_conditioned_on_heart_seg.pt'
    },
    'Muenster': {
        'UNet': 'checkpoints/hinge_points/nnunet_muenster_conditioned_on_heart_seg2.pt',
        'SWIN-UNETR': 'checkpoints/hinge_points/swin_unetr_muenster_conditioned_on_heart_seg2.pt'
    },
    'Munich': {
        'UNet': 'checkpoints/hinge_points/nnunet_munich_conditioned_on_heart_seg2.pt',
        'SWIN-UNETR': 'checkpoints/hinge_points/swin_unetr_munich_conditioned_on_heart_seg2.pt'
    },
    'Goettingen': {
        'UNet': 'checkpoints/hinge_points/nnunet_goettingen_conditioned_on_heart_seg2.pt',
        'SWIN-UNETR': 'checkpoints/hinge_points/swin_unetr_goettingen_conditioned_on_heart_seg2.pt'
    },
    'Hamburg': {
        'UNet': 'checkpoints/hinge_points/nnunet_hamburg_conditioned_on_heart_seg.pt',
        'SWIN-UNETR': 'checkpoints/hinge_points/swin_unetr_hamburg_conditioned_on_heart_seg2.pt'
    },
    'Frankfurt': {
        'UNet': 'checkpoints/hinge_points/nnunet_frankfurt_conditioned_on_heart_seg.pt',
        'SWIN-UNETR': 'checkpoints/hinge_points/swin_unetr_frankfurt_conditioned_on_heart_seg2.pt'
    },
    'Greifswald': {
        'UNet': 'checkpoints/hinge_points/nnunet_greifswald_conditioned_on_heart_seg.pt',
        'SWIN-UNETR': 'checkpoints/hinge_points/swin_unetr_greifswald_conditioned_on_heart_seg2.pt'
    },
    'Berlin': {
        'UNet': 'checkpoints/hinge_points/nnunet_berlin_conditioned_on_heart_seg.pt',
        'SWIN-UNETR': 'checkpoints/hinge_points/swin_unetr_berlin_conditioned_on_heart_seg2.pt'
    },
}

In [None]:
@torch.no_grad()
def eval_model(
        model,
        dst,
        condition_on_heart_seg=True,
        sliding_window=False,
        device='cuda'
):
    dst.mkdir(exist_ok=True)
    model.to(device)
    
    tp = TAVIPredictor(fname='.', tmp_dir='./tmp',)
    for img_path in tqdm((src / 'images').iterdir()):
        subject = img_path.name.replace('.nii.gz', '')
        heart_roi_fname = src / 'Totalsegmentator' / f'{subject}_heart_roi.nii.gz'
        heart_seg_fname = src / 'Totalsegmentator' / f'{subject}_heart_seg.nii.gz'
        heart_roi = sitk.ReadImage(heart_roi_fname)
        heart_seg = sitk.ReadImage(heart_seg_fname)

        x_torch, heart_roi_res, heart_seg_res = tp.transform(heart_roi, heart_seg)

        if not condition_on_heart_seg:
            x_torch = x_torch[:,:1]
        
        if sliding_window:
            with torch.cuda.amp.autocast():
                logits = sliding_window_inference(x_torch, (96,96,96), 4, model)
        else:
            logits = model(x_torch.to(device))
            
        if type(logits) == tuple:
            logits = logits[0]
        if type(logits) == dict:
            logits = logits['hps']
        if type(logits) == list:
            logits = logits[0]
        
        pred = torch.argmax(logits, dim=1)[0].cpu()
        pts_idx_pred, hps = tp.pts_from_pred(pred, heart_roi_res)
        # print(hps.shape)
        if len(hps)==0:
            continue
        
        mps_from_pts(hps, dst / f'{subject}.mps')

In [None]:
for loc, ckpt_paths in EXPERIMENTS.items():
    for model_name, ckpt_path in ckpt_paths.items():
        if not len(ckpt_path):
            continue
        dst = src / f'{loc}_{model_name}'
        if dst.exists():
            continue
        ckpt = torch.load(ckpt_path)
        if 'model' in ckpt:
            ckpt = ckpt['model']
        if model_name == 'UNet':
            if 'seg_layers.hps.0.weight' in ckpt:
                model = MultiHeadnnUNet(num_input_channels=2, num_segmentation_heads={'hps': 6})
            else:
                model = nnunet_configuration(num_segmentation_heads=6, num_input_channels=2)
            ckpt = {k.replace('unet.', ''): v for k, v in ckpt.items() if k != 'fed_task'}
        elif model_name == 'SWIN-UNETR':
            model = MultiHeadSwinUNETR(
                out_channels={'hps': 6},
                img_size=(96,96,96), 
                in_channels=2, 
                intermediate_out_channels=1,
                ckpt_path=None
            ).cuda()

        ckpt = {
            k: v for k, v in ckpt.items() 
            if not any([k.startswith(s) for s in ['outs.ms', 'outs.calc', 'outs.heart', 'seg_layers.ms', 'seg_layers.heart', 'seg_layers.calc']])
        }
        for k in ['fed_task', 'patches', 'deep_supervision', 'condition_on_seg', 'output_seg']:
            if k in ckpt:
                del ckpt[k]
        print(loc)
        model.load_state_dict(ckpt)
        
        eval_model(
            model=model,
            dst=dst,
            condition_on_heart_seg=True,
            sliding_window=model_name == 'SWIN-UNETR',
            device='cuda'
        )

In [None]:
locs = ['Frankfurt', 'Hamburg', 'Muenster', 'Greifswald', 'Berlin', 'Heidelberg']

In [None]:
pts_labels = ['RCC', 'LCC', 'NCC', 'RCA', 'LCA']
data, data2 = [], {}
# for pts_path in src.glob('*/*.mps'):
for loc in locs:
    for pts_path in (src /'Human' / loc).glob('*.mps'):
    # loc = pts_path.parent.name
        if loc not in locs:
            continue
        if loc not in data2:
            data2[loc] = {}
        subject = pts_path.stem
        pts = pts_from_mps(pts_path)
        if loc == 'Greifswald':
            pts = pts[[1,0,2,4,3]]
        if pts.shape[0] < 5:
            continue
        data2[loc][subject] = pts
        pts_data = {f'{l}_{c}': pt_c for l, pt in zip(pts_labels, pts) for c, pt_c in zip('xyz', pt)}
        pts_data['location'] = loc
        pts_data['subject'] = pts_path.stem
        pts_data['model'] = 'human'
        data.append(pts_data)

In [None]:
model_names = ['UNet', 'SWIN-UNETR']

In [None]:
for model_name in model_names:
    for pts_path in src.glob(f'*_{model_name}/*.mps'):
        # loc = pts_path.parent.name.split('_')[0]
        model = pts_path.parent.name
        if model not in data2:
            data2[model] = {}
        subject = pts_path.stem
        pts = pts_from_mps(pts_path)
        # if loc == 'Greifswald':
        #     pts = pts[[1,0,2,4,3]]
        # if pts.shape[0] < 5:
        #     continue

        pts_gt = data2['Heidelberg'][subject]
        pts_proc = {l: np.array([np.nan,np.nan,np.nan])[None] for l in pts_labels}
        for pt in pts:
            pt_distances = [np.linalg.norm(pt - pt_gt) for pt_gt in pts_gt]
            pt_label = pts_labels[np.argmin(pt_distances)]
            pts_proc[pt_label] = pt[None]

        pts = np.concatenate(list(pts_proc.values()), axis=0)        
        data2[model][subject] = pts
        pts_data = {f'{l}_{c}': pt_c for l, pt in zip(pts_labels, pts) for c, pt_c in zip('xyz', pt)}
        pts_data['location'] = model.split('_')[0]
        pts_data['subject'] = pts_path.stem
        pts_data['model'] = model_name
        data.append(pts_data)

In [None]:
df = pd.DataFrame(data)
df

In [None]:
subjects = df[df.location == 'Berlin'].subject.unique()

In [None]:
pts_mean, pts_std = {}, {}
for subject in subjects:
    subject_pts = [pts[subject][:5][None] for loc, pts in data2.items() if loc in locs] # not in ['Model2']]
    subject_pts = np.concatenate(subject_pts, axis=0)
    pts_mean[subject] = subject_pts.mean(axis=0)
    pts_std[subject] = subject_pts.std(axis=0)

In [None]:
distances = {}
for subject in subjects:
    subject_pts_mean = pts_mean[subject]
    for loc in df.location.unique():
        k = f'{loc}_human'
        if loc in data2:
            if k not in distances:
                distances[k] = []
            subject_pts_loc = data2[loc][subject]
            loc_distance = np.linalg.norm(subject_pts_mean - subject_pts_loc[:5], axis=1)
            distances[k].append(loc_distance[None])
        for m in df.model.unique():
            if m == 'human': 
                continue
            k = f'{loc}_{m}'
            if k not in data2 or subject not in data2[k]:
                continue
            if k not in distances:
                distances[k] = []
            subject_pts_m = data2[k][subject]
            m_distance = np.linalg.norm(subject_pts_mean - subject_pts_m[:5], axis=1)
            distances[k].append(m_distance[None])

In [None]:
df_distances = []
for loc, loc_distances in distances.items():
    for d, subject in zip(loc_distances, subjects):
        d_data = {'location': loc, 'subject': subject}
        for l, dd in zip(pts_labels, d[0]):
            d_data[l] = dd
        df_distances.append(d_data)
df_distances = pd.DataFrame(df_distances)

In [None]:
df_distances

In [None]:
def filter_fn(x):
    if 'human' in x:
        return 'H'
    elif 'Federated' in x:
        if 'UNet' in x:
            return 'FU'
        else:
            return 'FT'
    elif 'KD' in x:
        if 'UNet' in x:
            return 'KDU'
        else:
            return 'KDT'
    else:
        # return 2
        if 'UNet' in x:
            return 'LU'
        else:
            return 'LT'
    
df_distances['location2'] = df_distances.location.apply(filter_fn) # .isin(locs)
fig, axs = plt.subplots(1, 5, figsize=(15,3))
for i, (l, ax) in enumerate(zip(pts_labels, axs)):
    sns.boxplot(data=df_distances[df_distances.location2!='KDU'], x='location2', y=l, ax = ax) # hue='location')
    
    ax.set_title(l)
    # ax.set_ylim([-0.2,10.2])
    ax.set_yscale('log')
    xlabels = ax.get_xticklabels()
    # [xl.set_text(t) for xl, t in zip(xlabels, ['H', 'LU', 'LT', 'FU', 'FT', 'KDU', 'KDT'])]
    ax.set_xticklabels(xlabels)
    ax.set_xlabel(None)
    if i:
        ax.set_ylabel(None)
        ax.set_yticklabels([])
    else:
        ax.set_ylabel('d [mm]')
fig.tight_layout()
fig.savefig('./images/public_comparison_interobserver_variability.png', bbox_inches='tight')

In [None]:
means = df_distances[df_distances.location2!='KDU'].groupby('location2')[pts_labels].mean()
means.mean(axis=1), means.std(axis=1)

In [None]:
flattened_data = []
for i, row in df_distances[df_distances.location2!='KDU'].iterrows():
    for j, l in enumerate(pts_labels):
        k = 'distance_hps' if j < 3 else 'distance_ca'
        flattened_data.append({
            # 'train_location': row.train_location, 
            # 'test_location': row.test_location, 
            'location': row.location2,
            # 'model': row.model,
            # 'method': row.method,
            k: row[l]
        })
flattened_df = pd.DataFrame(flattened_data)
fig,axs = plt.subplots(1, 2, figsize=(12,5))
order = ['H', 'LU', 'LT', 'FU', 'FT', 'KDT']
sns.boxplot(data=flattened_df, x='location', y='distance_hps', ax=axs[0], order=order)
sns.boxplot(data=flattened_df, x='location', y='distance_ca', ax=axs[1], order=order)
for ax in axs:
    ax.set_yscale('log')
    # ax.set_ylim([0,10])
    ax.set_xlabel('Method')
    ax.set_ylabel('d [mm]')
axs[0].set_title('Hinge Points')
axs[1].set_title('Coronary Arteries')
axs[-1].set_ylabel(None)
axs[-1].set_yticklabels([])
fig.tight_layout()
fig.savefig('images/performance_methods_public_dataset.png', bbox_inches='tight')

In [None]:
flattened_data = []
for i, row in df_distances[df_distances.location2!='KDU'].iterrows():
    for j, l in enumerate(pts_labels):
        k = 'distance' # _hps' if j < 3 else 'distance_ca'
        flattened_data.append({
            # 'train_location': row.train_location, 
            # 'test_location': row.test_location, 
            'location': row.location2,
            # 'model': row.model,
            # 'method': row.method,
            k: row[l]
        })
flattened_df = pd.DataFrame(flattened_data)
fig,ax = plt.subplots(1, 1, figsize=(7,4))
order = ['H', 'LU', 'LT', 'FU', 'FT', 'KDT']
sns.boxplot(data=flattened_df, x='location', y='distance', ax=ax, order=order)
# sns.boxplot(data=flattened_df, x='location', y='distance_ca', ax=axs[1], order=order)
# for ax in axs:
ax.set_yscale('log')
# ax.set_ylim([0,10])
ax.set_xlabel('Method')
ax.set_ylabel('d [mm]')
ax.set_title(r'Hinge Points and Coronary Ostia (HPs and COs) $\downarrow$', fontsize=16)
ax.set_xticklabels(['Human', 'UNet', 'SWIN-UNETR', 'FedUNet', 'FedSWIN-\nUNETR', 'FedKDSWIN-\nUNETR'])
ax.set_xlabel(None)
# axs[1].set_title('Coronary Arteries')
# axs[-1].set_ylabel(None)
# axs[-1].set_yticklabels([])
fig.tight_layout()
fig.savefig('notebooks/images2/performance_methods_public_dataset.png', bbox_inches='tight')

In [None]:
medians = flattened_df.groupby('location')['distance'].median()
q25 = flattened_df.groupby('location')['distance'].quantile(.25)
q75 = flattened_df.groupby('location')['distance'].quantile(.75)

In [None]:
methods = ['H', 'LU', 'LT', 'FU', 'FT', 'KDT']

In [None]:
for method in methods:
    m = np.round(medians[method],2)
    mq25 = np.round(q25[method],2)
    mq75 = np.round(q75[method],2)
    iqr = mq75 - mq25
    print(f"{method}  ${m}\,(\mathrm{{IQR:}}{iqr:.2f})\,\mathrm{{mm}}$")