In [1]:
%config Completer.use_jedi = False

In [2]:
import os
from tqdm.notebook import tqdm
from skimage import io
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import albumentations
import torch

from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

In [3]:
import sys
sys.path.append('../')

from utils.general import *
import utils.dataload as d
from models import model_selector
from utils.data_augmentation import data_augmentation_selector
from medpy.metric.binary import hd, dc, jc, assd

from utils.neural import *
from utils.datasets import *
from utils.metrics import *

In [4]:
def find_path(directory, filename):
    for path in Path(directory).rglob(filename):
        return path

In [5]:
def save_pred(image, mask, pred_mask, case, metric_value, descriptor):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(17, 10))
    fig.tight_layout(pad=3)  # Set spacing between plots

    ax1.imshow(image, cmap="gray")
    ax1.axis("off")
    ax1.set_title("Input Image")

    masked_lv = np.ma.masked_where(mask == 0, mask)
    ax2.imshow(image, cmap="gray")
    ax2.imshow(masked_lv, 'jet', interpolation='bilinear', alpha=0.33, vmax=3)
    ax2.axis("off")
    ax2.set_title("Ground-truth")

    masked_lv = np.ma.masked_where(pred_mask == 0, pred_mask)
    ax3.imshow(image, cmap="gray")
    ax3.imshow(masked_lv, 'jet', interpolation='bilinear', alpha=0.33, vmax=3)
    ax3.axis("off")
    ax3.set_title("Automatic Segmentation")

    fig.suptitle(f"{case} - Jaccard {metric_value:.4f}", y=0.9)
    parent_dir = os.path.join("Plots/LVSC_PlotPreds", descriptor)
    os.makedirs(parent_dir, exist_ok=True)
    plt.savefig(os.path.join(parent_dir, f"{case}.jpg"), dpi=300)
    plt.close()

In [6]:
value_ranges = [0, 0.25, 0.5, 0.75, 1]
values_desc = ["awful", "average", "good", "excellent"]

for i in range(4):
    print(f"{value_ranges[i]} - {value_ranges[i+1]}: {values_desc[i]}")

0 - 0.25: awful
0.25 - 0.5: average
0.5 - 0.75: good
0.75 - 1: excellent


In [7]:
LV_INDEX = 2

In [8]:
lvsc_model = model_selector(
    "segmentation", "resnet34_unet_imagenet_encoder_scse_hypercols", 1, from_swa=True,
    in_channels=3, devices=[0], checkpoint="../checks/LVSC/n1_100_swa.pt"
)

lvsc_model.eval()
print("LVSC model loaded")


--- Frosted pretrained backbone! ---
Model total number of parameters: 35740845
Loaded model from checkpoint: ../checks/LVSC/n1_100_swa.pt
LVSC model loaded


In [9]:
_, _, val_aug = data_augmentation_selector(
    "none", 224, 224, "padd"
)

Using None Data Augmentation
Padding masks!
Padding masks!


In [10]:
batch_size = 16
add_depth = True
normalization = "standardize"

test_dataset = LVSC2Dataset(
    mode="test", transform=val_aug, img_transform=[],
    add_depth=add_depth, normalization=normalization, relative_path="../"
)

lvsc_loader =  DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True,
    drop_last=False, collate_fn=test_dataset.custom_collate
)

In [11]:
len(test_dataset)

18462

In [12]:
df_dict = {"patient":[], "slice":[], "phase":[]}
for path in test_dataset.data:
    path_slice = int(path[path.find("SA")+len("SA"):path.find("_ph")])
    path_phase = int(path[path.find("_ph")+len("_ph"):path.find(".dcm")])
    patient = path.split("/")[-2]
    
    df_dict["patient"].append(patient)
    df_dict["slice"].append(path_slice)
    df_dict["phase"].append(path_phase)

df_info = pd.DataFrame(df_dict)
df_info.head()

Unnamed: 0,patient,slice,phase
0,DET0009501,7,0
1,DET0009501,7,1
2,DET0009501,7,10
3,DET0009501,7,11
4,DET0009501,7,12


In [13]:
lvsc_model.eval()

metrics = {'img_id':[], 'iou':[], 'dice':[], 'hd':[], 'assd': []}
plot_per_range = [50,50,50,50] # ["awful", "average", "good", "excellent"]

with torch.no_grad():
    for batch_indx, batch in enumerate(tqdm(lvsc_loader)):
        img_id = batch["img_id"]
        image = batch["image"].cuda()
        prob_preds = lvsc_model(image)
    
        original_masks = batch["original_mask"]
        original_img = batch["original_img"]
        
        for pred_indx, single_pred in enumerate(prob_preds):
            
            path = img_id[pred_indx]
            path_slice = int(path[path.find("SA")+len("SA"):path.find("_ph")])
            path_phase = int(path[path.find("_ph")+len("_ph"):])
            path_patient = path.split("/")[0]
            middle_phase = df_info.loc[(df_info["patient"]==path_patient)]["phase"].max()//2
            current_phase = None
            
            if path_phase == 0:
                current_phase = "ED"
            elif path_phase == middle_phase:
                current_phase = "ES"
                            
            #if current_phase is None:
            #    continue
            
        
            if torch.is_tensor(original_masks[pred_indx]):
                original_mask = original_masks[pred_indx].data.cpu().numpy().astype(np.uint8).squeeze()
            else:  # numpy array
                original_mask = original_masks[pred_indx].astype(np.uint8)
            
            pred_mask = reshape_masks(
                torch.sigmoid(single_pred).squeeze(0).data.cpu().numpy(),
                original_mask.shape, "padd"
            )
            lv_pred_mask = np.where(pred_mask > 0.5, 1, 0).astype(np.int32)
        
            jc_score = jaccard_coef(lv_pred_mask, original_mask)
            dc_score = dice_coef(lv_pred_mask, original_mask)
            hd_score = secure_hd(lv_pred_mask, original_mask)
            assd_score = secure_assd(lv_pred_mask, original_mask)
            
            case = img_id[pred_indx]
            
            metrics['img_id'].append(case)
            metrics['iou'].append(jc_score)
            metrics['dice'].append(dc_score)
            metrics['hd'].append(hd_score)
            metrics['assd'].append(assd_score)
            
            """
            save_name = os.path.join(preds_dir, f"{case}.png")
            cv2.imwrite(save_name, lv_pred_mask)
            in_ = cv2.imread(save_name, cv2.IMREAD_GRAYSCALE)
            if not np.allclose(in_, lv_pred_mask):
                raise AssertionError('File read error: {:s}'.format(save_name))
            """    
            
            if jc_score <0.25:
                if plot_per_range[0] > 0:
                    plot_per_range[0] = plot_per_range[0] - 1
                    save_pred(
                        original_img[pred_indx], original_mask, lv_pred_mask, 
                        case, jc_score, values_desc[0]
                    )
            elif jc_score>0.25 and jc_score<0.5:
                if plot_per_range[1] > 0:
                    plot_per_range[1] = plot_per_range[1] - 1
                    save_pred(
                        original_img[pred_indx], original_mask, lv_pred_mask, 
                        case, jc_score, values_desc[1]
                    )
            elif jc_score>0.5 and jc_score<0.75:
                if plot_per_range[2] > 0:
                    plot_per_range[2] = plot_per_range[2] - 1
                    save_pred(
                        original_img[pred_indx], original_mask, lv_pred_mask, 
                        case, jc_score, values_desc[2]
                    )
            elif jc_score>0.75:
                if plot_per_range[3] > 0:
                    plot_per_range[3] = plot_per_range[3] - 1
                    save_pred(
                        original_img[pred_indx], original_mask, lv_pred_mask, 
                        case, jc_score, values_desc[3]
                    )
        if np.array(plot_per_range).sum()==0:
            break

  0%|          | 0/1154 [00:00<?, ?it/s]

In [19]:
df = pd.DataFrame(metrics)
df.head()

Unnamed: 0,img_id,iou,dice,hd,assd
0,DET0009501_SA7_ph0,7.54717e-14,7.54717e-14,-1,-1
1,DET0009501_SA7_ph1,7.326007e-14,7.326007e-14,-1,-1
2,DET0009501_SA7_ph10,6.6313e-14,6.6313e-14,-1,-1
3,DET0009501_SA7_ph11,6.811989e-14,6.811989e-14,-1,-1
4,DET0009501_SA7_ph12,7.012623e-14,7.012623e-14,-1,-1


## Get metrics by replacing infinite distance values with max value

In [14]:
print(f"Mean IOU: {df['iou'].mean()}")
print(f"Mean DICE: {df['dice'].mean()}")
print(f"Mean Hausdorff: {df['hd'].mean()}")
print(f"Mean ASSD: {df['assd'].mean()}")

Mean IOU: 0.6827378761524501
Mean DICE: 0.7934173178549807
Mean Hausdorff: 6.071080533710504
Mean ASSD: 1.6513260025628804


In [15]:
max_hausdorff = df["hd"].max()
max_assd = df["assd"].max()

In [16]:
df["hd"].replace(-1, max_hausdorff, inplace=True)
df["assd"].replace(-1, max_assd, inplace=True)

In [17]:
print(f"Mean IOU: {df['iou'].mean()}")
print(f"Mean DICE: {df['dice'].mean()}")
print(f"Mean Hausdorff: {df['hd'].mean()}")
print(f"Mean ASSD: {df['assd'].mean()}")

Mean IOU: 0.6827378761524501
Mean DICE: 0.7934173178549807
Mean Hausdorff: 7.619092224630712
Mean ASSD: 2.443492124753212


In [18]:
df.to_csv("results.csv", index=False)