In [1]:
import os, gc, torch, PIL, pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


from utils.plot import disease_cmap
from utils.coco_eval import get_ar_ap, get_eval_params_dict
from our_radiologist.load import get_anns

from data.dataset import collate_fn
from utils.transforms import get_transform
from utils.pred import pred_thrs_check
from utils.engine import xami_evaluate
from models.load import get_trained_model

from models.load import TrainedModels
from collections import OrderedDict
from matplotlib.patches import Rectangle
from data.dataset import ReflacxDataset
from data.dataset import OurRadiologsitsDataset

## Suppress the assignement warning from pandas.
pd.options.mode.chained_assignment = None  # default='warn

## Supress user warning
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

%matplotlib inline

In [2]:
gc.collect()
# torch.cuda.memory_summary(device=None, abbreviated=False)

use_gpu = torch.cuda.is_available()
device = 'cuda' if use_gpu else 'cpu'
print(f"This notebook will running on device: [{device}]")

if use_gpu:
    torch.cuda.empty_cache()

This notebook will running on device: [cuda]


In [3]:
# 0.23623015873015873 0.11900175731858897 (0.01, thrs)
# 0.35710317460317464 0.1440243872224295 (0.01, None)
# 0.22956349206349205 0.11206579987412377 (0.05, thrs)
# 0.35710317460317464 0.14248711487254806 (0.05, None)
# 0.21103174603174607 0.1446907131542274 (0.3, None)

In [4]:
for select_model in  TrainedModels:
    print(select_model.value)

val_ar_0_5230_ap_0_2576_test_ar_0_5678_ap_0_2546_epoch28_WithoutClincal_03-28-2022 06-56-13_original
val_ar_0_4575_ap_0_2689_test_ar_0_4953_ap_0_2561_epoch40_WithoutClincal_03-28-2022 09-15-40_custom_without_clinical
val_ar_0_5363_ap_0_2963_test_ar_0_5893_ap_0_2305_epoch36_WithClincal_03-28-2022 20-06-43_custom_with_clinical
val_ar_0_5126_ap_0_2498_test_ar_0_5607_ap_0_2538_epoch18_WithClincal_03-28-2022 10-18-55_custom_with_clinical
val_ar_0_3993_ap_0_2326_test_ar_0_4957_ap_0_2390_epoch50_WithClincal_03-28-2022 16-06-00_custom_with_clinical
val_ar_0_4955_ap_0_2942_test_ar_0_5449_ap_0_2566_epoch28_WithClincal_03-28-2022 17-25-34_custom_with_clinical


In [5]:
XAMI_MIMIC_PATH =  "D:\XAMI-MIMIC"

labels_cols = [
    "Enlarged cardiac silhouette",
    "Atelectasis",
    "Pleural abnormality",
    "Consolidation",
    "Pulmonary edema",
    #  'Groundglass opacity', # 6th disease.
]

In [None]:
def save_iou_results(evaluator, suffix, model_path):
    ap_ar_dict = OrderedDict({thrs: []  for thrs in evaluator.coco_eval['bbox'].params.iouThrs})

    for thrs in evaluator.coco_eval['bbox'].params.iouThrs:
        test_ar, test_ap = get_ar_ap(
            evaluator,
            areaRng='all',
            maxDets= 10,
            iouThr=thrs,
        )

        ap_ar_dict[thrs].append({
            'ar': test_ar,
            'ap': test_ap,
        })

        print(f"[{thrs:.4f}] | AR [{test_ar:.4f}] | AP [{test_ap:.4f}]")

    ## iouThr=0.3
    # 0.5699603174603174 0.23895678925202643 (custom, with clinical)
    # 0.5992460317460317 0.2928311315704941 (custom, without clinical)
    # 0.5244047619047619 0.22751787710826604 (original, without clinical)

    with open(
        os.path.join("eval_results", f"{model_path}_{suffix}.pkl"), "wb",
    ) as training_record_f:
        pickle.dump(ap_ar_dict, training_record_f)

### iouThr=0.3, With score thrs.
# 0.45178571428571423 0.20556006215252126 (custom, with clinical)
# 0.481547619047619 0.25642166456776067 (custom, without clinical)
# 0.36710317460317465 0.1927187078516919(original, without clinical)

### IoBB Result

In [7]:
for select_model in  TrainedModels:
    model, train_info = get_trained_model(select_model, labels_cols, device, include_train_info=True)
    model.eval()

    dataset_params_dict = {
        "XAMI_MIMIC_PATH": XAMI_MIMIC_PATH,
        "with_clinical": train_info.model_setup.use_clinical,
        "dataset_mode": "unified",
        "bbox_to_mask": True,
        "labels_cols": labels_cols,
    }

    detect_eval_dataset = ReflacxDataset(
        **{**dataset_params_dict, "dataset_mode": "unified",},
        transforms=get_transform(train=False),
    )

    radiologists_anns = get_anns("radiologists_annotated", detect_eval_dataset)


    iou_thrs = np.array([0.00, 0.05, 0.1 , 0.15, 0.2 , 0.25, 0.3 , 0.35, 0.4 , 0.45, 0.5 , 0.55,
        0.6 , 0.65, 0.7 , 0.75, 0.8 , 0.85, 0.9 , 0.95, 1.])
        
    eval_params_dict = get_eval_params_dict(detect_eval_dataset, iou_thrs=iou_thrs)

    test_dataset = ReflacxDataset(
        **dataset_params_dict, split_str="test", transforms=get_transform(train=False),
    )

    val_dataset = ReflacxDataset(
        **dataset_params_dict, split_str="val", transforms=get_transform(train=False),
    )

    val_dataloader = torch.utils.data.DataLoader(
        val_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn,
    )
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn,
    )

    radiologist_dataset = OurRadiologsitsDataset(detect_eval_dataset, radiologists_anns)

    radiologist_dataloader = torch.utils.data.DataLoader(
        radiologist_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn,
    )

    model.roi_heads.score_thresh = 0.05

    score_thres= {
        'Enlarged cardiac silhouette': 0.3,
        'Atelectasis': 0.2,
        'Pleural abnormality': 0.1,
        'Consolidation': 0.05,
        'Pulmonary edema': 0.1,
    }

    test_evaluator = xami_evaluate(
        model,
        test_dataloader,
        device=device,
        params_dict=eval_params_dict,
        # score_thres=score_thres,
    )  

    val_evaluator = xami_evaluate(
        model,
        val_dataloader,
        device=device,
        params_dict=eval_params_dict,
        # score_thres=score_thres,
    )

    radiologist_evaluator = xami_evaluate(
            model, radiologist_dataloader, device=device, params_dict=eval_params_dict
    )

    save_iou_results(test_evaluator, "test", select_model.value)
    save_iou_results(val_evaluator, "val", select_model.value)
    save_iou_results(radiologist_evaluator, "our", select_model.value)


creating index...
index created!
creating index...
index created!
Test:  [ 0/23]  eta: 0:01:25  model_time: 2.7627 (2.7627)  evaluator_time: 0.2422 (0.2422)  time: 3.7008  data: 0.5309  max mem: 2329
Test:  [22/23]  eta: 0:00:01  model_time: 0.3301 (0.4320)  evaluator_time: 0.2801 (0.3158)  time: 1.2495  data: 0.5313  max mem: 2459
Test: Total time: 0:00:31 (1.3693 s / it)
Averaged stats: model_time: 0.3301 (0.4320)  evaluator_time: 0.2801 (0.3158)
Accumulating evaluation results...
DONE (t=0.04s).
Accumulating evaluation results...
DONE (t=0.04s).
IoU metric: bbox
 Average Precision  (AP) @[ IoBB=0.00:1.00 | area=   all | maxDets=100 ] = 0.215
 Average Precision  (AP) @[ IoBB=0.50      | area=   all | maxDets= 10 ] = 0.259
 Average Precision  (AP) @[ IoBB=0.75      | area=   all | maxDets= 10 ] = 0.149
 Average Precision  (AP) @[ IoBB=0.00:1.00 | area= small | maxDets= 10 ] = -1.000
 Average Precision  (AP) @[ IoBB=0.00:1.00 | area=medium | maxDets= 10 ] = -1.000
 Average Precision  (

In [8]:
raise StopIteration()

StopIteration: 

In [None]:
# clinical_cond =  "With clincal" if use_clinical else "Without clinical"
# print(f"{clinical_cond} | Average Recall [{test_ar:.4f}] | Average Precision [{test_ap:.4f}]")

# Plotting

In [9]:
score_thres= {
    'Enlarged cardiac silhouette': 0.4,
    'Atelectasis': 0.2,
    'Pleural abnormality': 0.2,
    'Consolidation': 0.1,
    'Pulmonary edema': 0.2,
}

In [None]:
def plot_individual_bbox(
    ann, target, pred, label_idx_to_disease, legend_elements, disease_color_code_map,
):
    n_pred_boxes = len(pred["boxes"])

    fig, axes = plt.subplots(
        1, n_pred_boxes + 2, figsize=((n_pred_boxes + 2) * 10, 10), dpi=80, sharex=True
    )

    fig.legend(handles=legend_elements, loc="upper right")
    img = PIL.Image.open(ann["image_path"]).convert("RGB")

    for ax in axes:
        ax.imshow(img)

    axes[0].set_title(f"REFLACX ({len(target['boxes'].detach().cpu().numpy())})")
    axes[1].set_title(f"Our Radiologists ({len(ann['boxes'].detach().cpu().numpy())})")


    fig.suptitle(f"{ann['image_path']}_({ann['encoding']})")
    for idx, (label, bbox, score) in enumerate(
        zip(
            pred["labels"].detach().cpu().numpy(),
            pred["boxes"].detach().cpu().numpy(),
            pred["scores"].detach().cpu().numpy(),
        )
    ):
        disease = label_idx_to_disease(label)
        c = disease_color_code_map[disease]

        axes[idx + 2].add_patch(
            Rectangle(
                (bbox[0], bbox[1]),
                bbox[2] - bbox[0],
                bbox[3] - bbox[1],
                fill=False,
                color=c,
                linewidth=2,
            )
        )

        axes[idx + 2].set_title(f"Prediction_{disease}")

        axes[idx + 2].text(
            bbox[0],
            bbox[1],
            f"{disease} ({score:.2f})",
            color="black",
            backgroundcolor=c,
        )

    reflacx_recs = []
    for label, bbox in zip(
        target["labels"].detach().cpu().numpy(), target["boxes"].detach().cpu().numpy()
    ):
        disease = label_idx_to_disease(label)
        c = disease_color_code_map[disease]
        reflacx_recs.append(
            Rectangle(
                (bbox[0], bbox[1]),
                bbox[2] - bbox[0],
                bbox[3] - bbox[1],
                fill=False,
                color=c,
                linewidth=2,
            )
        )
        axes[0].text(bbox[0], bbox[1], disease, color="black", backgroundcolor=c)

    for rec in reflacx_recs:
        axes[0].add_patch(rec)

    ann_recs = []
    for label, bbox in zip(
        ann["labels"].detach().cpu().numpy(), ann["boxes"].detach().cpu().numpy()
    ):
        disease = label_idx_to_disease(label)
        c = disease_color_code_map[disease]
        ann_recs.append(
            Rectangle(
                (bbox[0], bbox[1]),
                bbox[2] - bbox[0],
                bbox[3] - bbox[1],
                fill=False,
                color=c,
                linewidth=2,
            )
        )
        axes[1].text(bbox[0], bbox[1], disease, color="black", backgroundcolor=c)

    for rec in ann_recs:
        axes[1].add_patch(rec)

    plt.plot()
    plt.pause(0.01)

    return fig

In [None]:
model.roi_heads.score_thresh = 0.01 

for ann in radiologists_anns:
    idx= detect_eval_dataset.get_idxs_from_dicom_id(ann['dicom_id'])[0]
    model.eval()
    data = collate_fn([detect_eval_dataset[idx]])
    data = detect_eval_dataset.prepare_input_from_data(data, device)
    target = data[-1]
    pred = model(*data[:-1])
    pred = pred[0]
    pred = pred_thrs_check(pred, detect_eval_dataset, score_thres, device)

    fig = plot_individual_bbox(
        ann,
        target[0],
        pred,
        detect_eval_dataset.label_idx_to_disease,
        legend_elements,
        disease_cmap["solid"],
    )

    save_dir = os.path.join("radiologist_ann_results_separate", "with_clinical"if use_clinical else "without_clinical", ann['encoding'])
    os.makedirs(save_dir, exist_ok=True)
    fig.savefig(os.path.join(save_dir, f"{ann['dicom_id']}.jpg"))


In [None]:
# def plot_three_bbox(
#     ann, target, pred, label_idx_to_disease, legend_elements, disease_color_code_map
# ):

#     fig, (gt_ax, pred_ax, ann_ax) = plt.subplots(1, 3, figsize=(30, 10), dpi=80, sharex=True)

#     fig.suptitle(f"{target['image_path']}_({ann['encoding']})")

#     fig.legend(handles=legend_elements, loc="upper right")

#     img = PIL.Image.open(target["image_path"]).convert("RGB")

#     gt_ax.imshow(img)
#     gt_ax.set_title(f"REFLACX ({len(target['boxes'].detach().cpu().numpy())})")
#     pred_ax.imshow(img)
#     pred_ax.set_title(f"Predictions ({len(pred['boxes'].detach().cpu().numpy())})")
#     ann_ax.imshow(img)
#     ann_ax.set_title(f"Our Radiologists ({len(ann['boxes'].detach().cpu().numpy())})")

#     # load image
#     gt_recs = []
#     pred_recs = []
#     ann_recs = []

#     for label, bbox, score in zip(
#         pred["labels"].detach().cpu().numpy(),
#         pred["boxes"].detach().cpu().numpy(),
#         pred["scores"].detach().cpu().numpy(),
#     ):
#         disease = label_idx_to_disease(label)
#         c = disease_color_code_map[disease]
#         pred_recs.append(
#             Rectangle(
#                 (bbox[0], bbox[1]),
#                 bbox[2] - bbox[0],
#                 bbox[3] - bbox[1],
#                 fill=False,
#                 color=c,
#                 linewidth=2,
#             )
#         )
#         pred_ax.text(
#             bbox[0],
#             bbox[1],
#             f"{disease} ({score:.2f})",
#             color="black",
#             backgroundcolor=c,
#         )

#     for rec in pred_recs:
#         pred_ax.add_patch(rec)

#     for label, bbox in zip(
#         target["labels"].detach().cpu().numpy(), target["boxes"].detach().cpu().numpy()
#     ):
#         disease = label_idx_to_disease(label)
#         c = disease_color_code_map[disease]
#         gt_recs.append(
#             Rectangle(
#                 (bbox[0], bbox[1]),
#                 bbox[2] - bbox[0],
#                 bbox[3] - bbox[1],
#                 fill=False,
#                 color=c,
#                 linewidth=2,
#             )
#         )
#         gt_ax.text(bbox[0], bbox[1], disease, color="black", backgroundcolor=c)

#     for rec in gt_recs:
#         gt_ax.add_patch(rec)


    
#     for label, bbox in zip(
#         ann["labels"].detach().cpu().numpy(), ann["boxes"].detach().cpu().numpy()
#     ):
#         disease = label_idx_to_disease(label)
#         c = disease_color_code_map[disease]
#         ann_recs.append(
#             Rectangle(
#                 (bbox[0], bbox[1]),
#                 bbox[2] - bbox[0],
#                 bbox[3] - bbox[1],
#                 fill=False,
#                 color=c,
#                 linewidth=2,
#             )
#         )
#         ann_ax.text(bbox[0], bbox[1], disease, color="black", backgroundcolor=c)

#     for rec in ann_recs:
#         ann_ax.add_patch(rec)

#     plt.plot()
#     plt.pause(0.01)

#     return fig


In [None]:
# model.roi_heads.score_thresh = 0.01 

# for ann in radiologists_anns:
#     idx= detect_eval_dataset.get_idxs_from_dicom_id(ann['dicom_id'])[0]
#     model.eval()
#     data = collate_fn([detect_eval_dataset[idx]])
#     data = detect_eval_dataset.prepare_input_from_data(data, device)
#     target = data[-1]
#     pred = model(*data[:-1])
#     pred = pred[0]
#     pred = pred_thrs_check(pred, detect_eval_dataset, score_thres, device)

#     fig = plot_three_bbox(
#         ann,
#         target[0],
#         pred,
#         detect_eval_dataset.label_idx_to_disease,
#         legend_elements,
#         disease_cmap["solid"],
#     )

#     save_dir = os.path.join("radiologist_ann_results", "with_clinical"if use_clinical else "without_clinical", ann['encoding'])
#     os.makedirs(save_dir, exist_ok=True)
#     fig.savefig(os.path.join(save_dir, f"{ann['dicom_id']}.jpg"))
