In [1]:
import os

In [6]:
os.path.join("la","casa", "", "bonita")

'la/casa/bonita'

In [1]:
%config Completer.use_jedi = False

In [2]:
import os
from tqdm.notebook import tqdm
from skimage import io
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import albumentations
import torch

from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

In [3]:
import sys
sys.path.append('../')

from utils.general import *
import utils.dataload as d
from models import model_selector
from utils.data_augmentation import data_augmentation_selector
from medpy.metric.binary import hd, dc, jc, assd

from utils.neural import *
from utils.datasets import *
from utils.metrics import *

In [4]:
def find_path(directory, filename):
    for path in Path(directory).rglob(filename):
        return path

In [5]:
def save_pred(image, mask, pred_mask, case, metric_value, descriptor):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(17, 10))
    fig.tight_layout(pad=3)  # Set spacing between plots

    ax1.imshow(image, cmap="gray")
    ax1.axis("off")
    ax1.set_title("Input Image")

    masked_lv = np.ma.masked_where(mask == 0, mask)
    ax2.imshow(image, cmap="gray")
    ax2.imshow(masked_lv, 'hsv', interpolation='bilinear', alpha=0.33)
    ax2.axis("off")
    ax2.set_title("Ground-truth")

    masked_lv = np.ma.masked_where(pred_mask == 0, pred_mask)
    ax3.imshow(image, cmap="gray")
    ax3.imshow(masked_lv, 'hsv', interpolation='bilinear', alpha=0.33)
    ax3.axis("off")
    ax3.set_title("Automatic Segmentation")

    fig.suptitle(f"{case} - Jaccard {metric_value:.4f}", y=0.9)
    parent_dir = os.path.join("ALLMMS2ACDC3D_preds_overlays", descriptor)
    os.makedirs(parent_dir, exist_ok=True)
    plt.savefig(os.path.join(parent_dir, f"{case}.jpg"), dpi=300)
    plt.close()

In [6]:
value_ranges = [0, 0.25, 0.5, 0.75, 1]
values_desc = ["awful", "average", "good", "excellent"]

for i in range(4):
    print(f"{value_ranges[i]} - {value_ranges[i+1]}: {values_desc[i]}")

0 - 0.25: awful
0.25 - 0.5: average
0.5 - 0.75: good
0.75 - 1: excellent


In [8]:
model = model_selector(
    "segmentation", "resnet18_unet_scratch_scse_hypercols", num_classes=4, from_swa=True,
    in_channels=3, devices=[0], checkpoint="../checks/ALLMMSACDC/n1_swa.pt"
)

_, _, val_aug = data_augmentation_selector(
    "lvsc2d", 224, 224, "padd"
)

Model total number of parameters: 25641328
Loaded model from checkpoint: ../checks/ALLMMSACDC/n1_swa.pt
Using LVSC 2D Segmentation Data Augmentation Combinations
Padding masks!
Padding masks!


In [9]:
batch_size = 1
add_depth = True
normalization = "standardize"

train_dataset = ACDC173Dataset(
    mode="full_train", transform=val_aug, img_transform=[],
    add_depth=add_depth, normalization=normalization, relative_path="../"
)

acdc_loader =  DataLoader(
    train_dataset, batch_size=batch_size, shuffle=False, pin_memory=True,
    drop_last=False, collate_fn=train_dataset.custom_collate
)

In [10]:
def map_mask_classes(mask, classes_map):
    """

    Args:
        mask: (np.array) Mask Array to map (height, width)
        classes_map: (dict) Mapping between classes. E.g.  {0:0, 1:3, 2:2, 3:1 ,4:4}

    Returns: (np.array) Mapped mask array

    """
    res = np.zeros_like(mask).astype(mask.dtype)
    for value in np.unique(mask):
        if value not in classes_map:
            assert False, f"Please specify all class maps. {value} not in {classes_map}"
        res += np.where(mask == value, classes_map[value], 0).astype(mask.dtype)
    return res

In [11]:
model.eval()

metrics = {
    'img_id':[], 'phase':[],
    'iou_RV':[], 'dice_RV':[], 'hd_RV':[], 'assd_RV': [], 
    'iou_MYO':[], 'dice_MYO':[], 'hd_MYO':[], 'assd_MYO': [],
    'iou_LV':[], 'dice_LV':[], 'hd_LV':[], 'assd_LV': [],
}
plot_per_range = [50,50,50,50] # ["awful", "average", "good", "excellent"]
preds_dir = "ALLMMS2ACDCD3D_TESTA"
os.makedirs(preds_dir, exist_ok=True)

# MMS -> class_to_cat = {1: "LV", 2: "MYO", 3: "RV"}
# Original ACDC -> class_to_cat = {1: "RV", 2: "MYO", 3: "LV"}
map_classes = {0: 0, 1: 3, 2: 2, 3: 1}
class_to_cat = {1: "RV", 2: "MYO", 3: "LV"}
mask_reshape_method = "padd"
include_background = False

with torch.no_grad():
    for batch_indx, batch in enumerate(tqdm(acdc_loader)):
        img_id = batch["img_id"][0]
        img_phase = batch["phase"][0]
        image = batch["image"].squeeze().cuda()
        prob_preds = model(image)
    
        original_masks = batch["original_mask"][0]
        original_img = batch["original_img"][0]
        
        mask_affine = batch["mask_affine"][0]
        mask_header = batch["mask_header"][0]
        
        pred_mask = convert_multiclass_mask(prob_preds).data.cpu().numpy()
        pred_mask = pred_mask.astype(np.uint8)
        pred_mask = map_mask_classes(pred_mask, map_classes)
        pred_mask = reshape_volume(pred_mask, original_masks.shape[-2:], mask_reshape_method)
        
        for current_class in range(len(map_classes)):

            if not include_background and current_class == 0:
                continue

            y_true = np.where(original_masks == current_class, 1, 0).astype(np.int32)
            y_pred = np.where(pred_mask == current_class, 1, 0).astype(np.int32)
            class_str = class_to_cat[current_class]

            jc_score = jaccard_coef(y_true, y_pred)
            dc_score = dice_coef(y_true, y_pred)
            hd_score = secure_hd(y_true, y_pred)
            assd_score = secure_assd(y_true, y_pred)

            metrics[f'iou_{class_str}'].append(jc_score)
            metrics[f'dice_{class_str}'].append(dc_score)
            metrics[f'hd_{class_str}'].append(hd_score)
            metrics[f'assd_{class_str}'].append(assd_score)

        metrics[f'img_id'].append(img_id)
        metrics[f'phase'].append(img_phase)
        
    
        pred_name = img_id.split("_")[0]+f"_{img_phase}.nii.gz"
        patient = img_id.split("_")[0]
        os.makedirs(os.path.join(preds_dir, patient), exist_ok=True)
        d.save_nii(os.path.join(preds_dir, patient, pred_name), pred_mask, mask_affine, mask_header)

  0%|          | 0/200 [00:00<?, ?it/s]

In [12]:
df = pd.DataFrame(metrics)
df.head()

Unnamed: 0,img_id,phase,iou_RV,dice_RV,hd_RV,assd_RV,iou_MYO,dice_MYO,hd_MYO,assd_MYO,iou_LV,dice_LV,hd_LV,assd_LV
0,patient051_frame01,ED,0.842785,0.914686,13.601471,0.509872,0.71983,0.837095,3.605551,0.302142,0.934202,0.965982,1.414214,0.16082
1,patient018_frame01,ED,0.863565,0.926788,18.708287,0.338179,0.712289,0.831973,5.09902,0.300584,0.945392,0.97193,2.236068,0.123663
2,patient015_frame01,ED,0.915189,0.955717,9.433981,0.133869,0.767924,0.86873,4.898979,0.257387,0.951124,0.97495,1.732051,0.122097
3,patient073_frame01,ED,0.881644,0.9371,15.297059,0.215369,0.78235,0.877886,2.236068,0.244498,0.927044,0.962141,1.732051,0.145375
4,patient042_frame01,ED,0.906044,0.950706,7.874008,0.160014,0.772626,0.87173,2.828427,0.273555,0.945771,0.97213,2.44949,0.136378


## Get metrics by replacing infinite distance values with max value

In [15]:
print(f"Mean IOU RV: {df['iou_RV'].mean()}")
print(f"Mean IOU LV: {df['iou_LV'].mean()}")
print(f"Mean IOU MYO: {df['iou_MYO'].mean()}")

print("--------------")

print(f"Mean DICE RV: {df['dice_RV'].mean()}")
print(f"Mean DICE LV: {df['dice_LV'].mean()}")
print(f"Mean DICE MYO: {df['dice_MYO'].mean()}")

print("--------------")

print(f"Mean Hausdorff RV: {df['hd_RV'].mean()}")
print(f"Mean Hausdorff LV: {df['hd_LV'].mean()}")
print(f"Mean Hausdorff MYO: {df['hd_MYO'].mean()}")

print("--------------")

print(f"Mean ASSD RV: {df['assd_RV'].mean()}")
print(f"Mean ASSD LV: {df['assd_LV'].mean()}")
print(f"Mean ASSD MYO: {df['assd_MYO'].mean()}")

Mean IOU RV: 0.7968828491968787
Mean IOU LV: 0.8902449435843482
Mean IOU MYO: 0.7822543512953747
--------------
Mean DICE RV: 0.8804279094285203
Mean DICE LV: 0.9403192728835166
Mean DICE MYO: 0.8765771757802159
--------------
Mean Hausdorff RV: 10.664647535448967
Mean Hausdorff LV: 3.618678185082103
Mean Hausdorff MYO: 6.543688610479315
--------------
Mean ASSD RV: 0.3873731101328225
Mean ASSD LV: 0.23301156610986581
Mean ASSD MYO: 0.31038336348727474


In [16]:
min_hausdorff_lv = df["hd_LV"].min()
print(f"min_hausdorff_lv: {min_hausdorff_lv}")
min_hausdorff_rv = df["hd_RV"].min()
print(f"min_hausdorff_rv: {min_hausdorff_rv}")
min_hausdorff_myo = df["hd_MYO"].min()
print(f"min_hausdorff_myo: {min_hausdorff_myo}")

min_assd_lv = df["assd_LV"].min()
print(f"min_assd_lv: {min_assd_lv}")
min_assd_rv = df["assd_RV"].min()
print(f"min_assd_rv: {min_assd_rv}")
min_assd_myo = df["assd_MYO"].min()
print(f"min_assd_myo: {min_assd_myo}")

min_hausdorff_lv: 1.0
min_hausdorff_rv: 2.0
min_hausdorff_myo: 1.4142135623730951
min_assd_lv: 0.05008315546084634
min_assd_rv: 0.08211581768275747
min_assd_myo: 0.11815573697853594


In [17]:
max_hausdorff_lv = df["hd_LV"].max()
max_hausdorff_rv = df["hd_RV"].max()
max_hausdorff_myo = df["hd_MYO"].max()

max_assd_lv = df["assd_LV"].max()
max_assd_rv = df["assd_RV"].max()
max_assd_myo = df["assd_MYO"].max()

In [18]:
print(f"Mean IOU RV: {df['iou_RV'].mean()}")
print(f"Mean IOU LV: {df['iou_LV'].mean()}")
print(f"Mean IOU MYO: {df['iou_MYO'].mean()}")

print("--------------")

print(f"Mean DICE RV: {df['dice_RV'].mean()}")
print(f"Mean DICE LV: {df['dice_LV'].mean()}")
print(f"Mean DICE MYO: {df['dice_MYO'].mean()}")

print("--------------")

print(f"Mean Hausdorff RV: {df['hd_RV'].mean()}")
print(f"Mean Hausdorff LV: {df['hd_LV'].mean()}")
print(f"Mean Hausdorff MYO: {df['hd_MYO'].mean()}")

print("--------------")

print(f"Mean ASSD RV: {df['assd_RV'].mean()}")
print(f"Mean ASSD LV: {df['assd_LV'].mean()}")
print(f"Mean ASSD MYO: {df['assd_MYO'].mean()}")

Mean IOU RV: 0.7968828491968787
Mean IOU LV: 0.8902449435843482
Mean IOU MYO: 0.7822543512953747
--------------
Mean DICE RV: 0.8804279094285203
Mean DICE LV: 0.9403192728835166
Mean DICE MYO: 0.8765771757802159
--------------
Mean Hausdorff RV: 10.664647535448967
Mean Hausdorff LV: 3.618678185082103
Mean Hausdorff MYO: 6.543688610479315
--------------
Mean ASSD RV: 0.3873731101328225
Mean ASSD LV: 0.23301156610986581
Mean ASSD MYO: 0.31038336348727474


In [19]:
df.groupby("phase").mean()

Unnamed: 0_level_0,iou_RV,dice_RV,hd_RV,assd_RV,iou_MYO,dice_MYO,hd_MYO,assd_MYO,iou_LV,dice_LV,hd_LV,assd_LV
phase,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
ED,0.872809,0.930961,10.127142,0.255008,0.770981,0.86939,6.020326,0.285724,0.923994,0.960221,3.281441,0.185644
ES,0.720957,0.829895,11.202153,0.519738,0.793528,0.883765,7.067051,0.335043,0.856496,0.920417,3.955915,0.280379


In [20]:
df.to_csv(os.path.join(preds_dir, "results.csv"), index=False)