In [1]:
import os, gc, torch

import numpy as np
import pandas as pd
# from models.load import TrainedModels

from utils.eval import save_iou_results
from utils.engine import xami_evaluate, get_iou_types
from models.load import get_trained_model
from utils.coco_eval import get_eval_params_dict
from data.datasets import  OurRadiologsitsDataset, collate_fn
from our_radiologist.load import get_anns
from utils.coco_utils import get_cocos, get_coco_api_from_dataset
from utils.eval import get_ap_ar
from utils.print import print_title
from utils.init import reproducibility, clean_memory_get_device
from data.load import get_datasets, get_dataloaders
from data.constants import XAMI_MIMIC_PATH, DEFAULT_REFLACX_LABEL_COLS
from utils.constants import full_iou_thrs, iou_thrs_5to95
from data.load  import seed_worker, get_dataloader_g
from tqdm import tqdm

## Suppress the assignement warning from pandas.
pd.options.mode.chained_assignment = None  # default='warn

## Supress user warning
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

%matplotlib inline

In [2]:
device = clean_memory_get_device()
reproducibility()

This notebook will running on device: [CUDA]


In [3]:
from enum import Enum

# class TrainedModels(Enum):
#     CXR_Clinial_fusion1_fusion2="val_ar_0_5436_ap_0_1911_test_ar_0_5476_ap_0_3168_epoch49_WithClincal_05-23-2022 12-06-22_CXR_Clinical_roi_heads_spatialisation"
#     CXR_Clinical_fusion1 = "val_ar_0_5476_ap_0_1984_test_ar_0_6038_ap_0_2757_epoch41_WithClincal_05-30-2022 08-01-54_CXR_Clinical_fusion1"
#     CXR_Clinical_fusion2= "val_ar_0_4369_ap_0_2098_test_ar_0_4940_ap_0_2218_epoch58_WithClincal_05-30-2022 13-58-43_CXR_Clinical_fusion2"
#     CXR="val_ar_0_5659_ap_0_1741_test_ar_0_5390_ap_0_1961_epoch36_WithoutClincal_05-29-2022 12-29-51_CXR"

class TrainedModels(Enum):
    all = "val_ar_0_5436_ap_0_1911_test_ar_0_5476_ap_0_3168_epoch49_WithClincal_05-23-2022 12-06-22_CXR_Clinical_roi_heads_spatialisation"
    age_temperature_heartrate_resprate_o2sat = "val_ar_0_4686_ap_0_1828_test_ar_0_5400_ap_0_2989_epoch44_WithClincal_09-16-2022 10-24-59_MDF-Net (age+temp+heartrate+resprate+o2sat)"
    gender = "val_ar_0_4993_ap_0_2692_test_ar_0_4841_ap_0_2037_epoch44_WithClincal_09-16-2022 13-22-05_MDF-Net (gender)"
    gender_age = "val_ar_0_4841_ap_0_2159_test_ar_0_5575_ap_0_2395_epoch34_WithClincal_09-16-2022 18-15-26_MDF-Net (gender+age)"
    gender_temp = "val_ar_0_4781_ap_0_2228_test_ar_0_5787_ap_0_2335_epoch50_WithClincal_09-16-2022 22-16-49_MDF-Net (gender+temp)"
    gender_heartrate = "val_ar_0_4938_ap_0_2016_test_ar_0_5225_ap_0_2671_epoch33_WithClincal_09-17-2022 06-11-20_MDF-Net (gender+heartrate)"
    gender_resprate = "val_ar_0_3951_ap_0_1903_test_ar_0_5381_ap_0_2582_epoch36_WithClincal_09-17-2022 09-02-43_MDF-Net (gender+resprate)"
    gender_age_temp = "val_ar_0_4386_ap_0_2188_test_ar_0_5029_ap_0_2371_epoch30_WithClincal_09-17-2022 11-29-19_MDF-Net (gender+age+temp)"
    gender_age_heartrate = "val_ar_0_4868_ap_0_2027_test_ar_0_6048_ap_0_2641_epoch44_WithClincal_09-21-2022 05-45-16_MDF-Net (gender+age+heartrate)"
    gender_age_resprate = "val_ar_0_5659_ap_0_2302_test_ar_0_6194_ap_0_2798_epoch44_WithClincal_09-21-2022 08-38-02_MDF-Net (gender+age+resprate)"
    # gender_age_heartrate_final = "val_ar_0_4312_ap_0_1680_test_ar_0_5492_ap_0_2980_epoch50_WithClincal_09-18-2022 05-04-22_MDF-Net (gender+age+heartrate)"
    # age_temperature_heartrate_resprate_o2sat_sbp_dbp_pain_acuity_gender

In [4]:
# select_model = TrainedModels.gender_age_heartrate_final

# model, train_info, _, _ = get_trained_model(
#             select_model,
#             DEFAULT_REFLACX_LABEL_COLS,
#             device,
#             rpn_nms_thresh=0.3,
#             box_detections_per_img=10,
#             box_nms_thresh=0.2,
#             rpn_score_thresh=0.0,
#             box_score_thresh=0.05,
#         )

In [5]:
# train_info.best_ap_val_model_path

In [6]:
normal_iou_thrs = iou_thrs_5to95
all_range_iou_thrs = full_iou_thrs
score_thresholds = [0.05]

# Run evaluation.

In [7]:
from data.constants import DEFAULT_MIMIC_CLINICAL_CAT_COLS, DEFAULT_MIMIC_CLINICAL_NUM_COLS


for select_model in tqdm(
    [
        TrainedModels.all,
        TrainedModels.age_temperature_heartrate_resprate_o2sat,
        TrainedModels.gender,
        TrainedModels.gender_age,
        TrainedModels.gender_temp,
        TrainedModels.gender_heartrate,
        TrainedModels.gender_resprate,
        TrainedModels.gender_age_temp,
        TrainedModels.gender_age_heartrate,
        TrainedModels.gender_age_resprate,
        # TrainedModels.CXR_Clinical_fusion1,
        # TrainedModels.CXR_Clinical_fusion2,
        # TrainedModels.CXR_Clinial_fusion1_fusion2,
        # TrainedModels.CXR
    ]
):

    for score_thrs in score_thresholds:

        model, train_info, _, _ = get_trained_model(
            select_model,
            DEFAULT_REFLACX_LABEL_COLS,
            device,
            rpn_nms_thresh=0.3,
            box_detections_per_img=10,
            box_nms_thresh=0.2,
            rpn_score_thresh=0.0,
            box_score_thresh=score_thrs,
        )

        model.eval()

        iou_types = get_iou_types(model, train_info.model_setup)

        model_setup = train_info.model_setup
        dataset_params_dict = {
            "XAMI_MIMIC_PATH": XAMI_MIMIC_PATH,
            "with_clinical": model_setup.use_clinical,
            "dataset_mode": model_setup.dataset_mode,
            "bbox_to_mask": model_setup.use_mask,
            "normalise_clinical_num": model_setup.normalise_clinical_num,
            "labels_cols": DEFAULT_REFLACX_LABEL_COLS,
        }

        detect_eval_dataset, train_dataset, val_dataset, test_dataset = get_datasets(
            dataset_params_dict=dataset_params_dict,
        )

        train_dataloader, val_dataloader, test_dataloader = get_dataloaders(
            train_dataset, val_dataset, test_dataset, batch_size=4,
        )

        if model_setup.use_clinical:
            # set for every dataset 
            if not hasattr(model_setup, "including_clinical_num") or not hasattr(model_setup, "including_clinical_num"):
                # assign attributes for them
                model_setup.including_clinical_num = DEFAULT_MIMIC_CLINICAL_NUM_COLS
                model_setup.including_clinical_cat = DEFAULT_MIMIC_CLINICAL_CAT_COLS

            train_dataloader.dataset.set_clinical_features_used(model_setup.including_clinical_num, model_setup.including_clinical_cat)
            val_dataloader.dataset.set_clinical_features_used(model_setup.including_clinical_num, model_setup.including_clinical_cat)
            test_dataloader.dataset.set_clinical_features_used(model_setup.including_clinical_num, model_setup.including_clinical_cat)
            detect_eval_dataset.set_clinical_features_used(model_setup.including_clinical_num, model_setup.including_clinical_cat)


        train_coco, val_coco, test_coco = get_cocos(
            train_dataloader, val_dataloader, test_dataloader
        )

        radiologists_ann = get_anns("radiologists_annotated", detect_eval_dataset)

        radiologist_dataset = OurRadiologsitsDataset(
            detect_eval_dataset, radiologists_ann
        )
        radiologist_dataloader = torch.utils.data.DataLoader(
            radiologist_dataset,
            batch_size=4,
            shuffle=False,
            collate_fn=collate_fn,
            worker_init_fn=seed_worker,
            generator=get_dataloader_g(0),
        )

        radiologists_coco = get_coco_api_from_dataset(radiologist_dataloader.dataset)

        normal_eval_params_dict = get_eval_params_dict(
            detect_eval_dataset, iou_thrs=normal_iou_thrs,
        )

        all_range_eval_params_dict = get_eval_params_dict(
            detect_eval_dataset, iou_thrs=all_range_iou_thrs,
        )

        all_cat_ids = [None] + [
            detect_eval_dataset.disease_to_idx(d)
            for d in detect_eval_dataset.labels_cols
        ]

        for cat_id in all_cat_ids:
            cat_ids = (
                [
                    detect_eval_dataset.disease_to_idx(d)
                    for d in detect_eval_dataset.labels_cols
                ]
                if cat_id is None
                else [cat_id]
            )

            if not (cat_ids is None):
                all_range_eval_params_dict["bbox"].catIds = cat_ids
                all_range_eval_params_dict["segm"].catIds = cat_ids

            train_evaluator, _ = xami_evaluate(
                model_setup,
                model,
                train_dataloader,
                device=device,
                params_dict=all_range_eval_params_dict,
                coco=train_coco,
                iou_types=iou_types,
                # score_thres=score_thres,
            )

            test_evaluator, _ = xami_evaluate(
                model_setup,
                model,
                test_dataloader,
                device=device,
                params_dict=all_range_eval_params_dict,
                coco=test_coco,
                iou_types=iou_types,
                # score_thres=score_thres,
            )

            val_evaluator, _ = xami_evaluate(
                model_setup,
                model,
                val_dataloader,
                device=device,
                params_dict=all_range_eval_params_dict,
                coco=val_coco,
                iou_types=iou_types,
                # score_thres=score_thres,
            )

            # radiologist_evaluator, _ = xami_evaluate(
            #     model_setup,
            #     model,
            #     radiologist_dataloader,
            #     device=device,
            #     params_dict=all_range_eval_params_dict,
            #     coco=radiologists_coco,
            #     iou_types=iou_types,
            # )

            if cat_id is None:
                disease_str = "all"
            else:
                disease_str = detect_eval_dataset.label_idx_to_disease(cat_id)

            save_iou_results(
                train_evaluator,
                f"train_{disease_str}_score_thrs{score_thrs}",
                select_model.value,
            )
            save_iou_results(
                test_evaluator,
                f"test_{disease_str}_score_thrs{score_thrs}",
                select_model.value,
            )
            save_iou_results(
                val_evaluator,
                f"val_{disease_str}_score_thrs{score_thrs}",
                select_model.value,
            )
            # save_iou_results(
            #     radiologist_evaluator,
            #     f"our_{disease_str}_score_thrs{score_thrs}",
            #     select_model.value,
            # )

            train_ap_ar = get_ap_ar(
                train_evaluator, areaRng="all", maxDets=10, iouThr=0.5,
            )

            test_ap_ar = get_ap_ar(
                test_evaluator, areaRng="all", maxDets=10, iouThr=0.5,
            )
            val_ap_ar = get_ap_ar(val_evaluator, areaRng="all", maxDets=10, iouThr=0.5,)
            # our_ap_ar = get_ap_ar(
            #     radiologist_evaluator, areaRng="all", maxDets=10, iouThr=0.5,
            # )

            df = pd.DataFrame(
                [
                    {
                        "dataset": "train",
                        f"AP@[IoBB = 0.50]": train_ap_ar["ap"],
                        f"AR@[IoBB = 0.50]": train_ap_ar["ar"],
                    },
                    {
                        "dataset": "test",
                        f"AP@[IoBB = 0.50]": test_ap_ar["ap"],
                        f"AR@[IoBB = 0.50]": test_ap_ar["ar"],
                    },
                    {
                        "dataset": "val",
                        f"AP@[IoBB = 0.50]": val_ap_ar["ap"],
                        f"AR@[IoBB = 0.50]": val_ap_ar["ar"],
                    },
                    # {
                    #     "dataset": "our",
                    #     f"AP@[IoBB = 0.50]": our_ap_ar["ap"],
                    #     f"AR@[IoBB = 0.50]": our_ap_ar["ar"],
                    # },
                ]
            )

            df.to_csv(
                os.path.join(
                    "eval_results",
                    f"{select_model.value}_{disease_str}_score_thrs{score_thrs}.csv",
                )
            )
            print_title(disease_str)
            print(df)


  0%|          | 0/1 [00:00<?, ?it/s]

Load custom model
Using pretrained backbone. mobilenet_v3
Using pretrained backbone. mobilenet_v3
Mask Hidden Layers 256
Using SGD as optimizer with lr=0.001
creating index...
index created!
creating index...
index created!
creating index...
index created!
creating index...
index created!
creating index...
index created!
creating index...
index created!
Evaluation:  [  0/117]  eta: 0:12:56  loss: 0.3390 (0.3390)  loss_classifier: 0.0252 (0.0252)  loss_box_reg: 0.0332 (0.0332)  loss_mask: 0.2564 (0.2564)  loss_objectness: 0.0206 (0.0206)  loss_rpn_box_reg: 0.0036 (0.0036)  model_time: 5.1846 (5.1846)  evaluator_time: 0.5761 (0.5761)  time: 6.6360  data: 0.6812  max mem: 1265
Evaluation:  [100/117]  eta: 0:00:26  loss: 0.3836 (0.3951)  loss_classifier: 0.0146 (0.0183)  loss_box_reg: 0.0087 (0.0133)  loss_mask: 0.3390 (0.3479)  loss_objectness: 0.0102 (0.0139)  loss_rpn_box_reg: 0.0009 (0.0017)  model_time: 0.1900 (0.2601)  evaluator_time: 0.3741 (0.4689)  time: 1.4147  data: 0.7038  max 

100%|██████████| 1/1 [29:24<00:00, 1764.78s/it]

  dataset  AP@[IoBB = 0.50]  AR@[IoBB = 0.50]
0   train          0.233511          0.634409
1    test          0.105758          0.611111
2     val          0.251647          0.631579





In [8]:

# <TrainedModels.age_temperature_heartrate_resprate_o2sat: 'val_ar_0_4686_ap_0_1828_test_ar_0_5400_ap_0_2989_epoch44_WithClincal_09-16-2022 10-24-59_MDF-Net (age+temp+heartrate+resprate+o2sat)'>, it's this model having this problem.
select_model

<TrainedModels.gender_age_resprate: 'val_ar_0_5659_ap_0_2302_test_ar_0_6194_ap_0_2798_epoch44_WithClincal_09-21-2022 08-38-02_MDF-Net (gender+age+resprate)'>