In [1]:
import os, gc, torch, PIL

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


from utils.eval import save_iou_results
from utils.engine import xami_evaluate
from models.load import get_trained_model
from utils.coco_eval import get_eval_params_dict
from data.dataset import ReflacxDataset, OurRadiologsitsDataset
from data.dataset import collate_fn
from utils.transforms import get_transform
from our_radiologist.load import get_anns
from utils.coco_eval import get_ar_ap
from utils.print import print_title

## Suppress the assignement warning from pandas.
pd.options.mode.chained_assignment = None  # default='warn

## Supress user warning
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

%matplotlib inline

In [2]:
gc.collect()
# torch.cuda.memory_summary(device=None, abbreviated=False)

use_gpu = torch.cuda.is_available()
device = 'cuda' if use_gpu else 'cpu'
print(f"This notebook will running on device: [{device}]")

if use_gpu:
    torch.cuda.empty_cache()

This notebook will running on device: [cuda]


In [3]:
XAMI_MIMIC_PATH = "D:\XAMI-MIMIC"

use_clinical = False
use_custom_modal = False
use_early_stop_model=True

# Setup evaluation parameters

In [4]:
from dataclasses import dataclass

# score_thres = {
#     "Enlarged cardiac silhouette": 0.3,
#     "Atelectasis": 0.2,
#     "Pleural abnormality": 0.1,
#     "Consolidation": 0.05,
#     "Pulmonary edema": 0.1,
# }

@dataclass
class ModelSetup:
    use_clinical: bool
    use_custom_modal: bool   
    use_early_stop_model: bool
    name: str

all_model_setups = [
    # ModelSetup(
    #     name="original",
    #     use_clinical= False,
    #     use_custom_modal=False,
    #     use_early_stop_model= True,
    # ),
    #     ModelSetup(
    #     name="custom_without_clinical",
    #     use_clinical= False,
    #     use_custom_modal=True,
    #     use_early_stop_model= True,
    # ),
        ModelSetup(
        name="custom_with_clinical",
        use_clinical= True,
        use_custom_modal=True,
        use_early_stop_model= True,
    )
]



# Run evaluation.

In [5]:
for model_setup in all_model_setups:

    dataset_params_dict = {
        "XAMI_MIMIC_PATH": XAMI_MIMIC_PATH,
        "with_clinical": model_setup.use_clinical,
        "using_full_reflacx": False,
        "bbox_to_mask": True,
        "labels_cols": [
            "Enlarged cardiac silhouette",
            "Atelectasis",
            "Pleural abnormality",
            "Consolidation",
            "Pulmonary edema",
            #  'Groundglass opacity', # 6th disease.
        ],
    }

    detect_eval_dataset = ReflacxDataset(
        **{**dataset_params_dict, "using_full_reflacx": False,},
        transforms=get_transform(train=False),
    )

    test_dataset = ReflacxDataset(
        **dataset_params_dict, split_str="test", transforms=get_transform(train=False),
    )

    val_dataset = ReflacxDataset(
        **dataset_params_dict, split_str="val", transforms=get_transform(train=False),
    )

    val_dataloader = torch.utils.data.DataLoader(
        val_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn,
    )
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn,
    )

    radiologists_ann = get_anns("radiologists_annotated", detect_eval_dataset)

    radiologist_dataset = OurRadiologsitsDataset(detect_eval_dataset, radiologists_ann)

    radiologist_dataloader = torch.utils.data.DataLoader(
        radiologist_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn,
    )

    model, model_path = get_trained_model(
        detect_eval_dataset,
        use_early_stop_model=model_setup.use_early_stop_model,
        use_custom_modal=model_setup.use_custom_modal,
        use_clinical=model_setup.use_clinical,
        device=device,
    )

    model.eval()

    normal_iou_thrs = np.array([0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95])
    normal_eval_params_dict = get_eval_params_dict(detect_eval_dataset, iou_thrs=normal_iou_thrs)

    all_range_iou_thrs = np.array([0.00, 0.05, 0.1 , 0.15, 0.2 , 0.25, 0.3 , 0.35, 0.4 , 0.45, 0.5 , 0.55,
        0.6 , 0.65, 0.7 , 0.75, 0.8 , 0.85, 0.9 , 0.95, 1.])
    all_range_eval_params_dict = get_eval_params_dict(detect_eval_dataset, iou_thrs=all_range_iou_thrs)

    model.roi_heads.score_thresh = 0.05

    all_cat_ids = [None] + [
        detect_eval_dataset.disease_to_idx(d) for d in detect_eval_dataset.labels_cols
    ]

    for cat_id in all_cat_ids:
        cat_ids = (
            [
                detect_eval_dataset.disease_to_idx(d)
                for d in detect_eval_dataset.labels_cols
            ]
            if cat_id is None
            else [cat_id]
        )
        if not (cat_ids is None):
            normal_eval_params_dict["bbox"].catIds = cat_ids
            normal_eval_params_dict["segm"].catIds = cat_ids

        test_evaluator = xami_evaluate(
            model,
            test_dataloader,
            device=device,
            params_dict=normal_eval_params_dict,
            # score_thres=score_thres,
        )

        val_evaluator = xami_evaluate(
            model,
            val_dataloader,
            device=device,
            params_dict=normal_eval_params_dict,
            # score_thres=score_thres,
        )

        radiologist_evaluator = xami_evaluate(
            model,
            radiologist_dataloader,
            device=device,
            params_dict=normal_eval_params_dict,
        )

        if cat_id is None:
            disease_str = "all"
        else:
            disease_str = detect_eval_dataset.label_idx_to_disease(cat_id)

        test_ar, test_ap = get_ar_ap(
            test_evaluator, areaRng="all", maxDets=10, iouThr=None
        )
        val_ar, val_ap = get_ar_ap(
            val_evaluator, areaRng="all", maxDets=10, iouThr=None
        )
        our_ar, our_ap = get_ar_ap(
            radiologist_evaluator, areaRng="all", maxDets=10, iouThr=None
        )

        df = pd.DataFrame(
            [
                {
                    "dataset": "test",
                    "AP@[IoU = 0.50:0.95]": test_ap,
                    "AR@[IoU = 0.50:0.95]": test_ar,
                },
                {
                    "dataset": "val",
                    "AP@[IoU = 0.50:0.95]": val_ap,
                    "AR@[IoU = 0.50:0.95]": val_ar,
                },
                {
                    "dataset": "our",
                    "AP@[IoU = 0.50:0.95]": our_ap,
                    "AR@[IoU = 0.50:0.95]": our_ar,
                },
            ]
        )

        df.to_csv(os.path.join("eval_results", f"{model_path}_{disease_str}.csv"))
        print_title(disease_str)
        print(df)

    for cat_id in all_cat_ids:
        cat_ids = (
            [
                detect_eval_dataset.disease_to_idx(d)
                for d in detect_eval_dataset.labels_cols
            ]
            if cat_id is None
            else [cat_id]
        )
        if not (cat_ids is None):
            all_range_eval_params_dict["bbox"].catIds = cat_ids
            all_range_eval_params_dict["segm"].catIds = cat_ids

        test_evaluator = xami_evaluate(
            model,
            test_dataloader,
            device=device,
            params_dict=all_range_eval_params_dict,
            # score_thres=score_thres,
        )

        val_evaluator = xami_evaluate(
            model,
            val_dataloader,
            device=device,
            params_dict=all_range_eval_params_dict,
            # score_thres=score_thres,
        )

        radiologist_evaluator = xami_evaluate(
            model,
            radiologist_dataloader,
            device=device,
            params_dict=all_range_eval_params_dict,
        )

        if cat_id is None:
            disease_str = "all"
        else:
            disease_str = detect_eval_dataset.label_idx_to_disease(cat_id)

        save_iou_results(test_evaluator, f"test_{disease_str}", model_path)
        save_iou_results(val_evaluator, f"val_{disease_str}", model_path)
        save_iou_results(radiologist_evaluator, f"our_{disease_str}", model_path)



{'rpn_nms_thresh': 0.3, 'box_detections_per_img': 6, 'box_nms_thresh': 0.2, 'rpn_score_thresh': 0.0, 'box_score_thresh': 0.05}
c1
None
creating index...
index created!
creating index...
index created!
creating index...
index created!
Test:  [ 0/26]  eta: 0:02:13  model_time: 4.0611 (4.0611)  evaluator_time: 0.2599 (0.2599)  time: 5.1309  data: 0.7089  max mem: 1199
Test:  [25/26]  eta: 0:00:01  model_time: 0.2731 (0.3998)  evaluator_time: 0.4540 (0.4224)  time: 1.4054  data: 0.6233  max mem: 1668
Test: Total time: 0:00:40 (1.5555 s / it)
Averaged stats: model_time: 0.2731 (0.3998)  evaluator_time: 0.4540 (0.4224)
Accumulating evaluation results...
DONE (t=0.03s).
Accumulating evaluation results...
DONE (t=0.03s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.069
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets= 10 ] = 0.154
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets= 10 ] = 0.043
 Average Precision  