In [1]:
%config Completer.use_jedi = False

In [2]:
import os
from tqdm.notebook import tqdm
from skimage import io
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import albumentations
import torch

from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

In [3]:
import sys
sys.path.append('../')

from utils.general import *
import utils.dataload as d
from models import model_selector
from utils.data_augmentation import data_augmentation_selector
from medpy.metric.binary import hd, dc, jc, assd

from utils.neural import *
from utils.datasets import *
from utils.metrics import *

In [4]:
def find_path(directory, filename):
    for path in Path(directory).rglob(filename):
        return path

In [5]:
def save_pred(image, mask, pred_mask, case, metric_value, descriptor):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(17, 10))
    fig.tight_layout(pad=3)  # Set spacing between plots

    ax1.imshow(image, cmap="gray")
    ax1.axis("off")
    ax1.set_title("Input Image")

    masked_lv = np.ma.masked_where(mask == 0, mask)
    ax2.imshow(image, cmap="gray")
    ax2.imshow(masked_lv, 'jet', interpolation='bilinear', alpha=0.33, vmax=3)
    ax2.axis("off")
    ax2.set_title("Ground-truth")

    masked_lv = np.ma.masked_where(pred_mask == 0, pred_mask)
    ax3.imshow(image, cmap="gray")
    ax3.imshow(masked_lv, 'jet', interpolation='bilinear', alpha=0.33, vmax=3)
    ax3.axis("off")
    ax3.set_title("Automatic Segmentation")

    fig.suptitle(f"{case} - Jaccard {metric_value:.4f}", y=0.9)
    parent_dir = os.path.join("Plots/ACDC_PlotPreds", descriptor)
    os.makedirs(parent_dir, exist_ok=True)
    plt.savefig(os.path.join(parent_dir, f"{case}.jpg"), dpi=300)
    plt.close()

In [6]:
value_ranges = [0, 0.25, 0.5, 0.75, 1]
values_desc = ["awful", "average", "good", "excellent"]

for i in range(4):
    print(f"{value_ranges[i]} - {value_ranges[i+1]}: {values_desc[i]}")

0 - 0.25: awful
0.25 - 0.5: average
0.5 - 0.75: good
0.75 - 1: excellent


In [7]:
model = model_selector(
    "segmentation", "resnet34_unet_imagenet_encoder_scse_hypercols", num_classes=4, from_swa=True,
    in_channels=3, devices=[0], checkpoint="../checks/ACDC/n1_100_swa.pt"
)

_, _, val_aug = data_augmentation_selector(
    "none", 224, 224, "padd"
)


--- Frosted pretrained backbone! ---
Model total number of parameters: 35749488
Loaded model from checkpoint: ../checks/ACDC/n1_100_swa.pt
Using None Data Augmentation
Padding masks!
Padding masks!


In [8]:
batch_size = 1
add_depth = True
normalization = "standardize"

train_dataset = ALLMMsDataset3D(
    transform=val_aug, img_transform=[],
    add_depth=add_depth, normalization=normalization, data_relative_path="../../mnms_da",
    vendor=["A"]  # Solo queremos comparar con mismo fabricante -> Siemens == Vendor A
)

allmms_loader =  DataLoader(
    train_dataset, batch_size=batch_size, shuffle=False, pin_memory=True,
    drop_last=False, collate_fn=train_dataset.simple_collate
)

In [9]:
train_dataset.df

Unnamed: 0,External code,VendorName,Vendor,Centre,ED,ES,Partition,Labeled
0,A0S9V9,Siemens,A,1,0,9,Training,True
1,A1D9Z7,Siemens,A,1,22,11,Training,True
2,A1E9Q1,Siemens,A,1,0,9,Training,True
3,A2C0I1,Siemens,A,1,0,7,Training,True
4,A2L1N6,Siemens,A,1,0,12,Testing,False
...,...,...,...,...,...,...,...,...
90,Q3R9W7,Siemens,A,1,0,10,Training,True
91,R4Y1Z9,Siemens,A,1,24,7,Training,True
92,S1S3Z7,Siemens,A,1,0,9,Training,True
93,T2T9Z9,Siemens,A,1,0,13,Training,True


In [10]:
def map_mask_classes(mask, classes_map):
    """

    Args:
        mask: (np.array) Mask Array to map (height, width)
        classes_map: (dict) Mapping between classes. E.g.  {0:0, 1:3, 2:2, 3:1 ,4:4}

    Returns: (np.array) Mapped mask array

    """
    res = np.zeros_like(mask).astype(mask.dtype)
    for value in np.unique(mask):
        if value not in classes_map:
            assert False, f"Please specify all class maps. {value} not in {classes_map}"
        res += np.where(mask == value, classes_map[value], 0).astype(mask.dtype)
    return res

In [11]:
model.eval()

metrics = {
    'img_id':[], 'phase':[],
    'iou_RV':[], 'dice_RV':[], 'hd_RV':[], 'assd_RV': [], 
    'iou_MYO':[], 'dice_MYO':[], 'hd_MYO':[], 'assd_MYO': [],
    'iou_LV':[], 'dice_LV':[], 'hd_LV':[], 'assd_LV': [],
}
plot_per_range = [50,50,50,50] # ["awful", "average", "good", "excellent"]
preds_dir = "CrossDatabase_v2/ACDCD2ALLMMS3D"
os.makedirs(preds_dir, exist_ok=True)

# MMS -> class_to_cat = {1: "LV", 2: "MYO", 3: "RV"}
# Original ACDC -> class_to_cat = {1: "RV", 2: "MYO", 3: "LV"}
map_classes = {0: 0, 1: 3, 2: 2, 3: 1}
class_to_cat = {1: "RV", 2: "MYO", 3: "LV"}
mask_reshape_method = "padd"
include_background = False

pred_indx=-1

with torch.no_grad():
    for batch_indx, batch in enumerate(tqdm(allmms_loader)):
        img_id = batch["external_code"]
        partition = batch["partition"]
        init_shape = batch["initial_shape"]
        
        ed_vol = batch["ed_volume"].squeeze().cuda()
        es_vol = batch["es_volume"].squeeze().cuda()
        
        original_ed = batch["original_ed"]
        original_es = batch["original_es"]
        
        original_ed_mask = batch["original_ed_mask"]
        original_es_mask = batch["original_es_mask"]
        
        mask_affine = batch["affine"]
        mask_header = batch["header"]
        
        save_dir = os.path.join(preds_dir, partition, "Labeled" if partition=="Training" else "", img_id)
        os.makedirs(save_dir, exist_ok=True)
        
        # ---- ED CASE ----
        
        prob_preds = model(ed_vol)
        pred_mask = convert_multiclass_mask(prob_preds).data.cpu().numpy()
        pred_mask = pred_mask.astype(np.uint8)
        pred_mask = map_mask_classes(pred_mask, map_classes)
        pred_mask = reshape_volume(pred_mask, init_shape[:2], mask_reshape_method)
        pred_mask = pred_mask.transpose(1,2,0)
        
        for current_class in range(len(map_classes)):

            if not include_background and current_class == 0:
                continue

            for pred_indx in range(pred_mask.shape[2]):
                
                y_true = np.where(original_ed_mask == current_class, 1, 0).astype(np.int32)[...,pred_indx]
                y_pred = np.where(pred_mask == current_class, 1, 0).astype(np.int32)[...,pred_indx]
                class_str = class_to_cat[current_class]

                jc_score = jaccard_coef(y_true, y_pred)
                dc_score = dice_coef(y_true, y_pred)
                hd_score = secure_hd(y_true, y_pred)
                assd_score = secure_assd(y_true, y_pred)

                metrics[f'iou_{class_str}'].append(jc_score)
                metrics[f'dice_{class_str}'].append(dc_score)
                metrics[f'hd_{class_str}'].append(hd_score)
                metrics[f'assd_{class_str}'].append(assd_score)

                if jc_score <0.25:
                    if plot_per_range[0] > 0:
                        plot_per_range[0] = plot_per_range[0] - 1
                        save_pred(
                            original_ed[...,pred_indx], original_ed_mask[...,pred_indx], pred_mask[...,pred_indx], 
                            f"{img_id}_slice{pred_indx}", jc_score, values_desc[0]
                        )
                elif jc_score>0.25 and jc_score<0.5:
                    if plot_per_range[1] > 0:
                        plot_per_range[1] = plot_per_range[1] - 1
                        save_pred(
                            original_ed[...,pred_indx], original_ed_mask[...,pred_indx], pred_mask[...,pred_indx], 
                            f"{img_id}_slice{pred_indx}", jc_score, values_desc[1]
                        )
                elif jc_score>0.5 and jc_score<0.75:
                    if plot_per_range[2] > 0:
                        plot_per_range[2] = plot_per_range[2] - 1
                        save_pred(
                            original_ed[...,pred_indx], original_ed_mask[...,pred_indx], pred_mask[...,pred_indx], 
                            f"{img_id}_slice{pred_indx}", jc_score, values_desc[2]
                        )
                elif jc_score>0.75:
                    if plot_per_range[3] > 0:
                        plot_per_range[3] = plot_per_range[3] - 1
                        save_pred(
                            original_ed[...,pred_indx], original_ed_mask[...,pred_indx], pred_mask[...,pred_indx], 
                            f"{img_id}_slice{pred_indx}", jc_score, values_desc[3]
                        )
                        
            if np.array(plot_per_range).sum()==0:
                break
        if np.array(plot_per_range).sum()==0:
            break

  0%|          | 0/95 [00:00<?, ?it/s]

In [12]:
df = pd.DataFrame(metrics)
df.head()

ValueError: arrays must all be same length

## Get metrics by replacing infinite distance values with max value

In [None]:
print(f"Mean IOU RV: {df['iou_RV'].mean()}")
print(f"Mean IOU LV: {df['iou_LV'].mean()}")
print(f"Mean IOU MYO: {df['iou_MYO'].mean()}")

print("--------------")

print(f"Mean DICE RV: {df['dice_RV'].mean()}")
print(f"Mean DICE LV: {df['dice_LV'].mean()}")
print(f"Mean DICE MYO: {df['dice_MYO'].mean()}")

print("--------------")

print(f"Mean Hausdorff RV: {df['hd_RV'].mean()}")
print(f"Mean Hausdorff LV: {df['hd_LV'].mean()}")
print(f"Mean Hausdorff MYO: {df['hd_MYO'].mean()}")

print("--------------")

print(f"Mean ASSD RV: {df['assd_RV'].mean()}")
print(f"Mean ASSD LV: {df['assd_LV'].mean()}")
print(f"Mean ASSD MYO: {df['assd_MYO'].mean()}")

In [None]:
min_hausdorff_lv = df["hd_LV"].min()
print(f"min_hausdorff_lv: {min_hausdorff_lv}")
min_hausdorff_rv = df["hd_RV"].min()
print(f"min_hausdorff_rv: {min_hausdorff_rv}")
min_hausdorff_myo = df["hd_MYO"].min()
print(f"min_hausdorff_myo: {min_hausdorff_myo}")

min_assd_lv = df["assd_LV"].min()
print(f"min_assd_lv: {min_assd_lv}")
min_assd_rv = df["assd_RV"].min()
print(f"min_assd_rv: {min_assd_rv}")
min_assd_myo = df["assd_MYO"].min()
print(f"min_assd_myo: {min_assd_myo}")

In [None]:
max_hausdorff_lv = df["hd_LV"].max()
max_hausdorff_rv = df["hd_RV"].max()
max_hausdorff_myo = df["hd_MYO"].max()

max_assd_lv = df["assd_LV"].max()
max_assd_rv = df["assd_RV"].max()
max_assd_myo = df["assd_MYO"].max()

In [None]:
print(f"Mean IOU RV: {df['iou_RV'].mean()}")
print(f"Mean IOU LV: {df['iou_LV'].mean()}")
print(f"Mean IOU MYO: {df['iou_MYO'].mean()}")

print("--------------")

print(f"Mean DICE RV: {df['dice_RV'].mean()}")
print(f"Mean DICE LV: {df['dice_LV'].mean()}")
print(f"Mean DICE MYO: {df['dice_MYO'].mean()}")

print("--------------")

print(f"Mean Hausdorff RV: {df['hd_RV'].mean()}")
print(f"Mean Hausdorff LV: {df['hd_LV'].mean()}")
print(f"Mean Hausdorff MYO: {df['hd_MYO'].mean()}")

print("--------------")

print(f"Mean ASSD RV: {df['assd_RV'].mean()}")
print(f"Mean ASSD LV: {df['assd_LV'].mean()}")
print(f"Mean ASSD MYO: {df['assd_MYO'].mean()}")

In [None]:
df.groupby("phase").mean()

In [None]:
df.to_csv(os.path.join(preds_dir, "results.csv"), index=False)