In [1]:
%config Completer.use_jedi = False

In [2]:
import os
from tqdm.notebook import tqdm
from skimage import io
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import albumentations
import torch

from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

In [3]:
import sys
sys.path.append('../')

from utils.general import *
import utils.dataload as d
from models import model_selector
from utils.data_augmentation import data_augmentation_selector
from medpy.metric.binary import hd, dc, jc, assd

from utils.neural import *
from utils.datasets import *
from utils.metrics import *

In [4]:
def find_path(directory, filename):
    for path in Path(directory).rglob(filename):
        return path

In [5]:
def save_pred(image, mask, pred_mask, case, metric_value, descriptor):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(17, 10))
    fig.tight_layout(pad=3)  # Set spacing between plots

    ax1.imshow(image, cmap="gray")
    ax1.axis("off")
    ax1.set_title("Input Image")

    masked_lv = np.ma.masked_where(mask == 0, mask)
    ax2.imshow(image, cmap="gray")
    ax2.imshow(masked_lv, 'hsv', interpolation='bilinear', alpha=0.33)
    ax2.axis("off")
    ax2.set_title("Ground-truth")

    masked_lv = np.ma.masked_where(pred_mask == 0, pred_mask)
    ax3.imshow(image, cmap="gray")
    ax3.imshow(masked_lv, 'hsv', interpolation='bilinear', alpha=0.33)
    ax3.axis("off")
    ax3.set_title("Automatic Segmentation")

    fig.suptitle(f"{case} - Jaccard {metric_value:.4f}", y=0.9)
    parent_dir = os.path.join("lvsc2acdc_preds_overlays", descriptor)
    os.makedirs(parent_dir, exist_ok=True)
    plt.savefig(os.path.join(parent_dir, f"{case}.jpg"), dpi=300)
    plt.close()

In [6]:
value_ranges = [0, 0.25, 0.5, 0.75, 1]
values_desc = ["awful", "average", "good", "excellent"]

for i in range(4):
    print(f"{value_ranges[i]} - {value_ranges[i+1]}: {values_desc[i]}")

0 - 0.25: awful
0.25 - 0.5: average
0.5 - 0.75: good
0.75 - 1: excellent


In [7]:
LV_INDEX = 2

In [8]:
model = model_selector(
    "segmentation", "resnet34_unet_imagenet_encoder_scse_hypercols", num_classes=1, from_swa=False,
    in_channels=3, devices=[0], checkpoint="../checks/lvsc.pt"
)

_, _, val_aug = data_augmentation_selector(
    "lvsc2d", 224, 224, "padd"
)


--- Frosted pretrained backbone! ---
Model total number of parameters: 35740845
Loaded model from checkpoint: ../checks/lvsc.pt
Using LVSC 2D Segmentation Data Augmentation Combinations
Padding masks!
Padding masks!


In [9]:
batch_size = 1
add_depth = True
normalization = "standardize"

train_dataset = ACDC173Dataset(
    mode="full_train", transform=val_aug, img_transform=[],
    add_depth=add_depth, normalization=normalization, relative_path="../"
)

acdc_loader =  DataLoader(
    train_dataset, batch_size=batch_size, shuffle=False, pin_memory=True,
    drop_last=False, collate_fn=train_dataset.custom_collate
)

In [10]:
model.eval()

metrics = {'img_id':[], 'iou':[], 'dice':[], 'hd':[], 'assd': [], 'slice':[], 'phase':[]}
plot_per_range = [50,50,50,50] # ["awful", "average", "good", "excellent"]
preds_dir = "lvsc2acdc_preds"
os.makedirs(preds_dir, exist_ok=True)

with torch.no_grad():
    for batch_indx, batch in enumerate(tqdm(acdc_loader)):
        img_id = batch["img_id"][0]
        img_phase = batch["phase"][0]
        image = batch["image"].squeeze().cuda()
        prob_preds = model(image)
    
        original_masks = batch["original_mask"][0]
        original_img = batch["original_img"][0]
        
        mask_affine = batch["mask_affine"][0]
        mask_header = batch["mask_header"][0]
        
        res_mask = []
        
        for pred_indx, single_pred in enumerate(prob_preds):
        
            if torch.is_tensor(original_masks[pred_indx]):
                original_mask = original_masks[pred_indx].data.cpu().numpy().astype(np.uint8).squeeze()
            else:  # numpy array
                original_mask = original_masks[pred_indx].astype(np.uint8)
            
            original_mask = np.where(original_mask==LV_INDEX, 1, 0)
            
            pred_mask = reshape_masks(
                torch.sigmoid(single_pred).squeeze(0).data.cpu().numpy(),
                original_mask.shape, "padd"
            )
            lv_pred_mask = np.where(pred_mask > 0.5, 1, 0).astype(np.int32)
        
            jc_score = jaccard_coef(lv_pred_mask, original_mask)
            dc_score = dice_coef(lv_pred_mask, original_mask)
            hd_score = secure_hd(lv_pred_mask, original_mask)
            assd_score = secure_assd(lv_pred_mask, original_mask)
            
            metrics['iou'].append(jc_score)
            metrics['dice'].append(dc_score)
            metrics['hd'].append(hd_score)
            metrics['assd'].append(assd_score)
            
            metrics['img_id'].append(img_id)
            metrics['slice'].append(pred_indx)
            metrics['phase'].append(img_phase)
            
            
            res_mask.append(torch.tensor(lv_pred_mask))
                
            if jc_score <0.25:
                if plot_per_range[0] > 0:
                    plot_per_range[0] = plot_per_range[0] - 1
                    save_pred(
                        original_img[pred_indx], original_mask, lv_pred_mask, 
                        f"{img_id}_slice{pred_indx}_{img_phase}", jc_score, values_desc[0]
                    )
            elif jc_score>0.25 and jc_score<0.5:
                if plot_per_range[1] > 0:
                    plot_per_range[1] = plot_per_range[1] - 1
                    save_pred(
                        original_img[pred_indx], original_mask, lv_pred_mask, 
                        f"{img_id}_slice{pred_indx}_{img_phase}", jc_score, values_desc[1]
                    )
            elif jc_score>0.5 and jc_score<0.75:
                if plot_per_range[2] > 0:
                    plot_per_range[2] = plot_per_range[2] - 1
                    save_pred(
                        original_img[pred_indx], original_mask, lv_pred_mask, 
                        f"{img_id}_slice{pred_indx}_{img_phase}", jc_score, values_desc[2]
                    )
            elif jc_score>0.75:
                if plot_per_range[3] > 0:
                    plot_per_range[3] = plot_per_range[3] - 1
                    save_pred(
                        original_img[pred_indx], original_mask, lv_pred_mask, 
                        f"{img_id}_slice{pred_indx}_{img_phase}", jc_score, values_desc[3]
                    )
        
        res_mask = torch.stack(res_mask)
        pred_name = img_id.split("_")[0]+f"_{img_phase}.nii.gz"
        d.save_nii(os.path.join(preds_dir, pred_name), res_mask, mask_affine, mask_header)

  0%|          | 0/200 [00:00<?, ?it/s]

In [11]:
df = pd.DataFrame(metrics)
df.head()

Unnamed: 0,img_id,iou,dice,hd,assd,slice,phase
0,patient051_frame01,0.819539,0.90082,2.0,0.570074,0,ED
1,patient051_frame01,0.811881,0.896175,2.0,0.697133,1,ED
2,patient051_frame01,0.8,0.888889,2.0,0.751975,2,ED
3,patient051_frame01,0.740084,0.85063,2.236068,0.967716,3,ED
4,patient051_frame01,0.689941,0.816527,2.236068,1.022536,4,ED


## Get metrics by replacing infinite distance values with max value

In [12]:
print(f"Mean IOU: {df['iou'].mean()}")
print(f"Mean DICE: {df['dice'].mean()}")
print(f"Mean Hausdorff: {df['hd'].mean()}")
print(f"Mean ASSD: {df['assd'].mean()}")

Mean IOU: 0.7254411098415574
Mean DICE: 0.8217171434543292
Mean Hausdorff: 3.5935598961131356
Mean ASSD: 0.9812730155705777


In [13]:
max_hausdorff = df["hd"].max()
max_assd = df["assd"].max()

In [14]:
df["hd"].replace(-1, max_hausdorff, inplace=True)
df["assd"].replace(-1, max_assd, inplace=True)

In [15]:
print(f"Mean IOU: {df['iou'].mean()}")
print(f"Mean DICE: {df['dice'].mean()}")
print(f"Mean Hausdorff: {df['hd'].mean()}")
print(f"Mean ASSD: {df['assd'].mean()}")

Mean IOU: 0.7254411098415574
Mean DICE: 0.8217171434543292
Mean Hausdorff: 6.2111488195291376
Mean ASSD: 2.9107317539941318


In [18]:
df.groupby("phase").mean()

Unnamed: 0_level_0,iou,dice,hd,assd,slice
phase,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
ED,0.723024,0.827199,4.396858,1.858859,4.555205
ES,0.727859,0.816235,8.02544,3.962604,4.555205


In [19]:
df.to_csv("results.csv", index=False)