In [None]:
import os, json, io
import numpy as np
import pandas as pd
# from models.load import TrainedModels

from utils.engine import get_iou_types, evaluate
from models.load import get_trained_model
from utils.print import print_title
from utils.init import reproducibility, clean_memory_get_device
from data.load import get_datasets, get_dataloaders
from data.paths import MIMIC_EYE_PATH
from tqdm import tqdm
from utils.train import  get_coco_eval_params
from utils.coco_eval import get_eval_params_dict
from data.strs import SourceStrs, TaskStrs
from IPython.display import clear_output
from coco_froc_analysis.froc.froc_curve import get_froc_curve, get_interpolate_froc
from pathlib import Path
from coco_froc_analysis.froc.froc_curve import TempRecord
## Suppress the assignement warning from pandas.
pd.options.mode.chained_assignment = None  # default='warn

## Supress user warning
import warnings
warnings.filterwarnings("ignore")

%matplotlib inline

In [None]:
from enum import Enum

class TrainedModels(Enum):
    mobilenet_baseline = "val_lesion-detection_ap_0_1655_test_lesion-detection_ap_0_1648_epoch50_03-15-2023 16-43-54_lesion_dsetection_baseline_mobilenet"  # mobilenet baseline
    mobilenet_with_fix = "val_lesion-detection_ap_0_1918_test_lesion-detection_ap_0_1903_epoch16_03-16-2023 11-34-10_lesion_dsetection_with_fixation_mobilenet"
    resnet18_baseline = "val_lesion-detection_ap_0_1973_test_lesion-detection_ap_0_2010_epoch22_03-16-2023 19-44-55_lesion_dsetection_baseline_resnet"
    resnet18_with_fix = "val_lesion-detection_ap_0_1951_test_lesion-detection_ap_0_2195_epoch12_03-17-2023 00-31-54_lesion_dsetection_with_fixation_resnet"
    densenet161_baseline = "val_lesion-detection_ap_0_1990_test_lesion-detection_ap_0_2085_epoch5_03-17-2023 08-53-33_lesion_dsetection_baseline_densenet161"
    densenet161_with_fix = "val_lesion-detection_ap_0_2120_test_lesion-detection_ap_0_2104_epoch12_03-17-2023 18-36-01_lesion_dsetection_with_fixation_densenet161"
    efficientnet_b5_baseline = "val_lesion-detection_ap_0_1898_test_lesion-detection_ap_0_2055_epoch5_03-17-2023 23-30-57_lesion_dsetection_baseline_efficientnet_b5"
    efficientnet_b5_with_fix = "val_lesion-detection_ap_0_2117_test_lesion-detection_ap_0_2190_epoch8_03-18-2023 12-29-20_lesion_dsetection_with_fixation_efficientnet_b5"
    efficientnet_b0_baseline = "val_lesion-detection_ap_0_1934_test_lesion-detection_ap_0_1858_epoch10_03-18-2023 23-50-47_lesion_dsetection_baseline_efficientnet_b0"
    efficientnet_b0_with_fix = "val_lesion-detection_ap_0_2191_test_lesion-detection_ap_0_2162_epoch10_03-18-2023 19-38-11_lesion_dsetection_with_fixation_efficientnet_b0"
    convnext_base_with_fix = "val_lesion-detection_ap_0_2472_test_lesion-detection_ap_0_2637_epoch15_03-23-2023 20-20-07_lesion_dsetection_with_fixation_convnext_base_silent_report"
    # convnext_base_with_fix = "val_lesion-detection_ap_0_2610_test_lesion-detection_ap_0_2548_epoch22_03-22-2023 02-55-37_lesion_dsetection_with_fixation_convnext_base"
    convnext_base_baseline = "val_lesion-detection_ap_0_2426_test_lesion-detection_ap_0_2325_epoch20_03-22-2023 11-53-53_lesion_dsetection_baseline_convnext_base"
    # convnext_base_with_fix_silent = "val_lesion-detection_ap_0_2405_test_lesion-detection_ap_0_2543_epoch19_03-24-2023 13-57-29_lesion_dsetection_with_fixation_convnext_base_silent_report"
    # convnext_base_with_fix_full = "val_lesion-detection_ap_0_2602_test_lesion-detection_ap_0_2499_epoch22_03-24-2023 04-42-21_lesion_dsetection_with_fixation_convnext_base_full_report"
    vgg16_with_fix = "val_lesion-detection_ap_0_2301_test_lesion-detection_ap_0_2186_epoch22_03-20-2023 19-26-02_lesion_dsetection_with_fixation_vgg16"
    vgg16_baseline = "val_lesion-detection_ap_0_2113_test_lesion-detection_ap_0_2068_epoch12_03-21-2023 00-45-24_lesion_dsetection_baseline_vgg16"
    regnet_y_8gf_with_fix = "val_lesion-detection_ap_0_2267_test_lesion-detection_ap_0_2029_epoch12_03-21-2023 11-28-48_lesion_dsetection_with_fixation_regnet_y_8gf"
    regnet_y_8gf_baseline = "val_lesion-detection_ap_0_1883_test_lesion-detection_ap_0_1658_epoch13_03-21-2023 15-22-32_lesion_dsetection_baseline_regnet_y_8gf"

In [None]:
naming_map = {
    TrainedModels.mobilenet_baseline: "mobilenet_baseline (ap)",
    TrainedModels.mobilenet_with_fix: "mobilenet_with_fix (ap)",
    TrainedModels.resnet18_baseline: "resnet18_baseline (ap)",
    TrainedModels.resnet18_with_fix: "resnet18_with_fix (ap)",
    TrainedModels.densenet161_baseline: "densenet161_baseline (ap)",
    TrainedModels.densenet161_with_fix: "densenet161_with_fix (ap)",
    TrainedModels.efficientnet_b5_baseline: "efficientnet_b5_baseline (ap)",
    TrainedModels.efficientnet_b5_with_fix: "efficientnet_b5_with_fix (ap)",
    TrainedModels.efficientnet_b0_baseline: "efficientnet_b0_baseline (ap)",
    TrainedModels.efficientnet_b0_with_fix: "efficientnet_b0_with_fix (ap)",
    TrainedModels.convnext_base_with_fix: "convnext_base_with_fix (ap)",
    TrainedModels.convnext_base_baseline: "convnext_base_baseline (ap)",
    # TrainedModels.convnext_base_with_fix_silent: "convnext_base_with_fix_silent (ap)",
    # TrainedModels.convnext_base_with_fix_full: "convnext_base_with_fix_full (ap)",
    TrainedModels.vgg16_with_fix: "vgg16_with_fix (ap)",
    TrainedModels.vgg16_baseline: "vgg16_baseline (ap)",
    TrainedModels.regnet_y_8gf_with_fix: "regnet_y_8gf_with_fix (ap)",
    TrainedModels.regnet_y_8gf_baseline: "regnet_y_8gf_baseline (ap)",
}


In [None]:
normal_iou_thrs = np.array([0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95])
score_thresholds = [0.05]

In [None]:
from models.build import create_model_from_setup
import torch
def get_trained_model_wiith_setup(
    model_select, setup, device
):
    model = create_model_from_setup(setup)
    model.to(device)

    cp = torch.load(
        os.path.join("trained_models", model_select.value), map_location=device
    )

    model.load_state_dict(cp["model_state_dict"])
    model.to(device)

    return model
    # return model, train_info, None, None

In [None]:
# fix backbones
from models.load import CLPretrainedLoadParams
from data.strs import TaskStrs, SourceStrs, FusionStrs

common_args = {
    "decoder_channels": [128, 64, 32, 16, 8],
    "optimiser": "sgd",
    "lr": 1e-3, #3e-3,  # 1e-5,  # 1e-3 is a trainable value.
    "sgb_momentum": 0.9,
    "weight_decay": 1e-3,  # 1e-5s
    # "pretrained": True,
    # "image_size": 128,
    "record_training_performance": False,
    "warmup_epochs": 10,
    "lr_scheduler": "ReduceLROnPlateau",  # ReduceLROnPlateau, MultiStepLR
    "reduceLROnPlateau_factor": (1 - 1e-10),
    "reduceLROnPlateau_patience": 9999999,  # only used for tracking customised early stop.
    "real_stop_patience": 20,
    "reduceLROnPlateau_full_stop": True,  # set to false to continue run.
    "multiStepLR_milestones": [20, 40, 60, 80, 100],  # list(range(1,100)),
    "multiStepLR_gamma": 0.5,
    "gt_in_train_till": 0,
    "box_head_dropout_rate": 0,
    "model_warmup_epochs": 0,  # stop fixing the weights in the backbone.
    "loss_warmup_epochs": 0,  # should be larger than model warm up.
    "measure_test": True,
    "use_dynamic_weight": True,
    "iou_thrs": np.array([0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]),
    "maxDets": [1, 5, 10, 30],
    "normalise_clinical_num": True,
    "use_clinical_df": False,  # for nature?
    # for discretise all numerical features.
}

batch_size_4_args = {
    "batch_size": 4,
}


image_512_args = {
    "image_size": 512,
}


no_bb_to_mask_args = {
    "lesion_detection_use_mask": False,
}

full_report_args = {"fiaxtions_mode_input": "normal"}


## using froc instead of the AP.
lesion_detection_best_args = {
    "performance_standards": [
        {
            "task": TaskStrs.LESION_DETECTION,
            "metric": "ap",
        },
        {
            "task": TaskStrs.LESION_DETECTION,
            "metric": "froc",
        },
        {
            "task": TaskStrs.LESION_DETECTION,
            "metric": "ar",
        },
    ]
}


lesion_detection_with_fix_args = {
    "sources": [SourceStrs.XRAYS, SourceStrs.FIXATIONS],
    "tasks": [
        TaskStrs.LESION_DETECTION,
    ],
}

lesion_detection_baseline_args = {
    "sources": [SourceStrs.XRAYS],
    "tasks": [
        TaskStrs.LESION_DETECTION,
    ],
}

element_wise_sum_fusor_args = {
    "fusor": FusionStrs.ElEMENTWISE_SUM,
}

small_model_args = {
    "mask_hidden_layers": 64,
    "fuse_conv_channels": 64,
    "representation_size": 64,  # 32
    # "clinical_input_channels": 64,
    # "clinical_conv_channels": 64,
    # "clinical_expand_conv_channels": 64,
    "backbone_out_channels": 64,
}

mobilenet_args = {
    "backbone": "mobilenet_v3",
    "using_fpn": False,
}

resnet18_args = {
    "backbone": "resnet18",
    "using_fpn": False,
}

densenet_args = {
    "backbone": "densenet161",
    "using_fpn": False,
}

efficientnet_b0_args = {
    "backbone": "efficientnet_b0",
    "using_fpn": False,
}

efficientnet_b5_args = {
    "backbone": "efficientnet_b5",
    "using_fpn": False,
}

vgg16_args = {
    "backbone": "vgg16",
    "using_fpn": False,
}

regnet_y_8gf_args = {
    "backbone": "regnet_y_8gf",
    "using_fpn": False,
}

convnext_base_args = {
    "backbone": "convnext_base",
    "using_fpn": False,
}

resnet50_args = {
    "backbone": "resnet50",
    "using_fpn": False,
}

In [None]:
from models.setup import ModelSetup


setup_map = {
    # mobilenet
    TrainedModels.mobilenet_baseline: ModelSetup(
        name="mobilenet_baseline",
        **mobilenet_args,
        **lesion_detection_baseline_args,
        **common_args,
        **batch_size_4_args,
        **image_512_args,
        **no_bb_to_mask_args,
        **full_report_args,
        **element_wise_sum_fusor_args,
        **small_model_args,
        **lesion_detection_best_args,
    ),
    TrainedModels.mobilenet_with_fix: ModelSetup(
        name="mobilenet_with_fix",
        **mobilenet_args,
        **lesion_detection_with_fix_args,
        **common_args,
        **batch_size_4_args,
        **image_512_args,
        **no_bb_to_mask_args,
        **full_report_args,
        **element_wise_sum_fusor_args,
        **small_model_args,
        **lesion_detection_best_args,
    ),
    # resnet18
    TrainedModels.resnet18_baseline: ModelSetup(
        name="resnet18_baseline",
        **resnet18_args,
        **lesion_detection_baseline_args,
        **common_args,
        **batch_size_4_args,
        **image_512_args,
        **no_bb_to_mask_args,
        **full_report_args,
        **element_wise_sum_fusor_args,
        **small_model_args,
        **lesion_detection_best_args,
    ),
    TrainedModels.resnet18_with_fix: ModelSetup(
        name="resnet18_with_fix",
        **resnet18_args,
        **lesion_detection_with_fix_args,
        **common_args,
        **batch_size_4_args,
        **image_512_args,
        **no_bb_to_mask_args,
        **full_report_args,
        **element_wise_sum_fusor_args,
        **small_model_args,
        **lesion_detection_best_args,
    ),

    # densenet161
    TrainedModels.densenet161_baseline: ModelSetup(
        name="densenet161_baseline",
        **densenet_args,
        **lesion_detection_baseline_args,
        **common_args,
        **batch_size_4_args,
        **image_512_args,
        **no_bb_to_mask_args,
        **full_report_args,
        **element_wise_sum_fusor_args,
        **small_model_args,
        **lesion_detection_best_args,
    ),
    TrainedModels.densenet161_with_fix: ModelSetup(
        name="densenet161_with_fix",
        **densenet_args,
        **lesion_detection_with_fix_args,
        **common_args,
        **batch_size_4_args,
        **image_512_args,
        **no_bb_to_mask_args,
        **full_report_args,
        **element_wise_sum_fusor_args,
        **small_model_args,
        **lesion_detection_best_args,
    ),

    # efficientnet_b5
    TrainedModels.efficientnet_b5_baseline: ModelSetup(
        name="efficientnet_b5_baseline",
        **efficientnet_b5_args,
        **lesion_detection_baseline_args,
        **common_args,
        **batch_size_4_args,
        **image_512_args,
        **no_bb_to_mask_args,
        **full_report_args,
        **element_wise_sum_fusor_args,
        **small_model_args,
        **lesion_detection_best_args,
    ),
    TrainedModels.efficientnet_b5_with_fix: ModelSetup(
        name="efficientnet_b5_with_fix",
        **efficientnet_b5_args,
        **lesion_detection_with_fix_args,
        **common_args,
        **batch_size_4_args,
        **image_512_args,
        **no_bb_to_mask_args,
        **full_report_args,
        **element_wise_sum_fusor_args,
        **small_model_args,
        **lesion_detection_best_args,
    ),

    # efficientnet_b0
    TrainedModels.efficientnet_b0_baseline: ModelSetup(
        name="efficientnet_b0_baseline",
        **efficientnet_b0_args,
        **lesion_detection_baseline_args,
        **common_args,
        **batch_size_4_args,
        **image_512_args,
        **no_bb_to_mask_args,
        **full_report_args,
        **element_wise_sum_fusor_args,
        **small_model_args,
        **lesion_detection_best_args,
    ),
    TrainedModels.efficientnet_b0_with_fix: ModelSetup(
        name="efficientnet_b0_with_fix",
        **efficientnet_b0_args,
        **lesion_detection_with_fix_args,
        **common_args,
        **batch_size_4_args,
        **image_512_args,
        **no_bb_to_mask_args,
        **full_report_args,
        **element_wise_sum_fusor_args,
        **small_model_args,
        **lesion_detection_best_args,
    ),


    # convnext_base
    TrainedModels.convnext_base_baseline: ModelSetup(
        name="convnext_base_baseline",
        **convnext_base_args,
        **lesion_detection_baseline_args,
        **common_args,
        **batch_size_4_args,
        **image_512_args,
        **no_bb_to_mask_args,
        **full_report_args,
        **element_wise_sum_fusor_args,
        **small_model_args,
        **lesion_detection_best_args,
    ),
    TrainedModels.convnext_base_with_fix: ModelSetup(
        name="convnext_base_with_fix",
        **convnext_base_args,
        **lesion_detection_with_fix_args,
        **common_args,
        **batch_size_4_args,
        **image_512_args,
        **no_bb_to_mask_args,
        **full_report_args,
        **element_wise_sum_fusor_args,
        **small_model_args,
        **lesion_detection_best_args,
    ),
    # 4564 3200 1345 9344
    # vgg16
    TrainedModels.vgg16_baseline: ModelSetup(
        name="vgg16_baseline",
        **vgg16_args,
        **lesion_detection_baseline_args,
        **common_args,
        **batch_size_4_args,
        **image_512_args,
        **no_bb_to_mask_args,
        **full_report_args,
        **element_wise_sum_fusor_args,
        **small_model_args,
        **lesion_detection_best_args,
    ),
    TrainedModels.vgg16_with_fix: ModelSetup(
        name="vgg16_with_fix",
        **vgg16_args,
        **lesion_detection_with_fix_args,
        **common_args,
        **batch_size_4_args,
        **image_512_args,
        **no_bb_to_mask_args,
        **full_report_args,
        **element_wise_sum_fusor_args,
        **small_model_args,
        **lesion_detection_best_args,
    ),
    # regnet_y_8gf
    TrainedModels.regnet_y_8gf_baseline: ModelSetup(
        name="regnet_y_8gf_baseline",
        **regnet_y_8gf_args,
        **lesion_detection_baseline_args,
        **common_args,
        **batch_size_4_args,
        **image_512_args,
        **no_bb_to_mask_args,
        **full_report_args,
        **element_wise_sum_fusor_args,
        **small_model_args,
        **lesion_detection_best_args,
    ),
    TrainedModels.regnet_y_8gf_with_fix: ModelSetup(
        name="regnet_y_8gf_with_fix",
        **regnet_y_8gf_args,
        **lesion_detection_with_fix_args,
        **common_args,
        **batch_size_4_args,
        **image_512_args,
        **no_bb_to_mask_args,
        **full_report_args,
        **element_wise_sum_fusor_args,
        **small_model_args,
        **lesion_detection_best_args,
    ),
}


In [None]:
# just generate for the test set.
for select_model in tqdm(TrainedModels):
    clear_output()
    for score_thrs in score_thresholds:
        device = clean_memory_get_device()
        reproducibility()

        # get setup first
        setup = setup_map[select_model]

        # model, train_info, _, _ = get_trained_model(
        #     select_model,
        #     device,
        #     with_record = False,
        # )

        model = get_trained_model_wiith_setup(
           select_model, setup, device
        )
        model = model.to(device)
        model.eval()

        # setup = train_info.model_setup
        iou_types = get_iou_types(model, setup)

        dataset_params_dict = {
            "MIMIC_EYE_PATH": MIMIC_EYE_PATH,
            "labels_cols": setup.lesion_label_cols,
            "with_xrays_input": SourceStrs.XRAYS in setup.sources,
            "with_clincal_input": SourceStrs.CLINICAL in setup.sources,
            "with_fixations_input": SourceStrs.FIXATIONS in setup.sources,
            "fixations_mode_input": setup.fiaxtions_mode_input,
            "with_bboxes_label": TaskStrs.LESION_DETECTION in setup.tasks,
            "with_fixations_label": TaskStrs.FIXATION_GENERATION in setup.tasks,
            "fixations_mode_label": setup.fiaxtions_mode_label,
            "with_chexpert_label": TaskStrs.CHEXPERT_CLASSIFICATION in setup.tasks,
            "with_negbio_label": TaskStrs.NEGBIO_CLASSIFICATION in setup.tasks,
            "clinical_numerical_cols": setup.clinical_num,
            "clinical_categorical_cols": setup.clinical_cat,
            "image_size": setup.image_size,
            "image_mean": setup.image_mean,
            "image_std": setup.image_std,
            "with_clinical_label": setup.with_clinical_label,
            "normalise_clinical_num": setup.normalise_clinical_num,
            "bbox_to_mask": setup.lesion_detection_use_mask,
            "use_clinical_df": setup.use_clinical_df,
        }

        detect_eval_dataset, train_dataset, val_dataset, test_dataset = get_datasets(
            dataset_params_dict=dataset_params_dict,
        )

        train_dataloader, val_dataloader, test_dataloader = get_dataloaders(
            train_dataset,
            val_dataset,
            test_dataset,
            batch_size=setup.batch_size,
        )

        train_coco = None
        train_coco, val_coco, test_coco, _ = get_coco_eval_params(
            source_name=SourceStrs.XRAYS,
            task_name=TaskStrs.LESION_DETECTION,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            test_dataloader=test_dataloader,
            detect_eval_dataset=detect_eval_dataset,
            # iou_thrs=np.array([0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]),
            iou_thrs=np.array([0.5]),
            use_iobb=True,
            maxDets=[1, 5, 10, 30],
        )

        normal_range_eval_params_dict = get_eval_params_dict(
            detect_eval_dataset,
            iou_thrs=normal_iou_thrs,
        )

        model.task_performers["lesion-detection"].roi_heads.score_thresh = 0
        test_evaluator, _ = evaluate(
            setup=setup,
            model=model,
            data_loader=test_dataloader,
            device=device,
            params_dict=normal_range_eval_params_dict,
            coco=test_coco,
            iou_types=iou_types,
            return_dt_gt=True,
        )
        model.task_performers["lesion-detection"].roi_heads.score_thresh = 0.05

        stats, lls_accuracy, nlls_per_image = get_froc_curve(
            dataset=test_dataset,
            dts=test_evaluator["lesion-detection"].all_dts,
            all_gts=test_evaluator["lesion-detection"].all_gts,
            plot_title=naming_map[select_model],
            use_iou=True,
            n_sample_points=200,
            froc_save_folder="./froc_figures",
        )

        all_cat_ids = [None] + [
            detect_eval_dataset.disease_to_idx(d)
            for d in detect_eval_dataset.labels_cols
        ]

        for cat_id in all_cat_ids:
            cat_ids = (
                [
                    detect_eval_dataset.disease_to_idx(d)
                    for d in detect_eval_dataset.labels_cols
                ]
                if cat_id is None
                else [cat_id]
            )

            if not (cat_ids is None):
                normal_range_eval_params_dict["bbox"].catIds = cat_ids

            test_evaluator, _ = evaluate(
                setup=setup,
                model=model,
                data_loader=test_dataloader,
                device=device,
                params_dict=normal_range_eval_params_dict,
                coco=test_coco,
                iou_types=iou_types,
            )

            if cat_id is None:
                disease_str = "all"
            else:
                disease_str = detect_eval_dataset.label_idx_to_disease(cat_id)

            froc_v = get_interpolate_froc(
                stats=stats,
                lls_accuracy=lls_accuracy,
                nlls_per_image=nlls_per_image,
                cat_id=cat_id,
                fps_per_img=[0.5, 1, 2, 4],
                weight=True,
            )

            x_axis_values = np.logspace(-3, 2, 200).tolist()

            model_froc_curve = get_interpolate_froc(
                stats=stats,
                lls_accuracy=lls_accuracy,
                nlls_per_image=nlls_per_image,
                cat_id=cat_id,
                fps_per_img=x_axis_values,
                weight=True,
            )

            # raise StopIteration()

            df = pd.DataFrame(
                [
                    {
                        "num_fps": test_evaluator["lesion-detection"]
                        .coco_eval["bbox"]
                        .eval["num_fps"],  # @all_range, maxDet= p.maxDets[-1]
                        "num_fns": test_evaluator["lesion-detection"]
                        .coco_eval["bbox"]
                        .eval["num_fns"],
                        "num_tps": test_evaluator["lesion-detection"]
                        .coco_eval["bbox"]
                        .eval["num_tps"],
                        "coco_states": json.dumps(
                            test_evaluator["lesion-detection"].coco_eval["bbox"].stats
                        ),
                        "Sensitivity@ [avgFP=0.5]": froc_v[0],
                        "Sensitivity@ [avgFP=1]": froc_v[1],
                        "Sensitivity@ [avgFP=2]": froc_v[2],
                        "Sensitivity@ [avgFP=4]": froc_v[3],
                        "mFROC@[0.5,1,2,4]": froc_v.mean(),
                        "froc_curve": json.dumps(model_froc_curve.tolist()),
                        "x-axis": json.dumps(x_axis_values),
                    },
                ]
            )

            os.makedirs("./eval_results", exist_ok=True)
            df.to_csv(
                Path(
                    os.path.join(
                        "./eval_results",
                        f"{select_model.value}_{disease_str}_{score_thrs}.csv",
                    )
                )
            )

            print_title(f"{select_model.value}-{disease_str}")
            print(df)
