In [1]:
from data.transforms import get_tensorise_h_flip_transform
from data.constants import DEFAULT_REFLACX_LABEL_COLS
from data.paths import MIMIC_EYE_PATH
from data.datasets import ReflacxDataset
from data.strs import TaskStrs, SourceStrs, FusionStrs
from models.setup import ModelSetup
import pandas as pd

pd.set_option('mode.chained_assignment', None)


common_args = {
    # "sources": [SourceStrs.XRAYS, SourceStrs.CLINICAL,],
    # "tasks": [
    #     # TaskStrs.LESION_DETECTION,
    #     TaskStrs.FIXATION_GENERATION,
    #     TaskStrs.CHEXPERT_CLASSIFICATION,
    #     TaskStrs.NEGBIO_CLASSIFICATION,
    # ],
    # "performance_standard_task": TaskStrs.CHEXPERT_CLASSIFICATION,
    # "performance_standard_metric": "auc",
    # "performance_standard_task": TaskStrs.LESION_DETECTION,
    # "performance_standard_metric": "ap",
    "decoder_channels": [128, 64, 32, 16, 8],
    "optimiser": "sgd",
    "lr": 1e-3,  # should multiply by 8 or just use 1e-2, to test if it will explode.
    "batch_size": 4,
    "weight_decay": 1e-9,  # 1E-5
    # "pretrained": True,
    "image_size": 512,
    "record_training_performance": False,
    "warmup_epochs": 0,
    "lr_scheduler": "ReduceLROnPlateau",
    "reduceLROnPlateau_factor": 0.1,
    "reduceLROnPlateau_patience": 999,
    "reduceLROnPlateau_full_stop": True,
    "multiStepLR_milestones": 100,
    "multiStepLR_gamma": 0.1,
    "use_mask": False,
    "gt_in_train_till": 0,
    "box_head_dropout_rate": 0,
    "model_warmup_epochs": 20,
    "loss_warmup_epochs": 10,
    "measure_test": True,
}

chexpert_best_args = {
    "performance_standard_task": TaskStrs.CHEXPERT_CLASSIFICATION,
    "performance_standard_metric": "auc",
}

lesion_detection_best_args = {
    "performance_standard_task": TaskStrs.LESION_DETECTION,
    "performance_standard_metric": "ap",
}

with_fix_args = {
    "sources": [SourceStrs.XRAYS, SourceStrs.FIXATIONS],
    # "sources": [SourceStrs.XRAYS, SourceStrs.CLINICAL,],
    "tasks": [
        TaskStrs.LESION_DETECTION,
        # TaskStrs.FIXATION_GENERATION,
        # TaskStrs.CHEXPERT_CLASSIFICATION,
        # TaskStrs.NEGBIO_CLASSIFICATION,
    ],
    "fusor": FusionStrs.ElEMENTWISE_SUM,
}

without_fix_args = {
    "sources": [SourceStrs.XRAYS],
    # "sources": [SourceStrs.XRAYS, SourceStrs.CLINICAL,],
    "tasks": [
        TaskStrs.LESION_DETECTION,
        # TaskStrs.FIXATION_GENERATION,
        # TaskStrs.CHEXPERT_CLASSIFICATION,
        # TaskStrs.NEGBIO_CLASSIFICATION,
    ],
    "fusor": FusionStrs.ElEMENTWISE_SUM,
}

small_model_args = {
    "mask_hidden_layers": 64,
    "fuse_conv_channels": 64,
    "representation_size": 64,  # 32
    # "clinical_input_channels": 64,
    # "clinical_conv_channels": 64,
    # "clinical_expand_conv_channels": 64,
    # "backbone_out_channels": None,
    "backbone_out_channels": 64,
}

mobilenet_args = {
    "backbone": "resnet18",
    # "backbone": "mobilenet_v3",
    "using_fpn": False,
}

setup = ModelSetup(
    name="chexpert_with_fix",
    **lesion_detection_best_args,
    **mobilenet_args,
    **small_model_args,
    **common_args,
    **with_fix_args,
)


dataset_params_dict = {
    "MIMIC_EYE_PATH": MIMIC_EYE_PATH,
    "labels_cols": setup.lesion_label_cols,
    "with_xrays_input": SourceStrs.XRAYS in setup.sources,
    "with_clincal_input": SourceStrs.CLINICAL in setup.sources,
    "with_fixations_input": SourceStrs.FIXATIONS in setup.sources,
    "fiaxtions_mode_input": setup.fiaxtions_mode_input,
    "with_bboxes_label": TaskStrs.LESION_DETECTION in setup.tasks,
    "with_fixations_label": TaskStrs.FIXATION_GENERATION in setup.tasks,
    "fiaxtions_mode_label": setup.fiaxtions_mode_label,
    "with_chexpert_label": TaskStrs.CHEXPERT_CLASSIFICATION in setup.tasks,
    "with_negbio_label": TaskStrs.NEGBIO_CLASSIFICATION in setup.tasks,
    "clinical_numerical_cols": setup.clinical_num,
    "clinical_categorical_cols": setup.clinical_cat,
    "image_size": setup.image_size,
    "image_mean": setup.image_mean,
    "image_std": setup.image_std,
}

test_dataset = ReflacxDataset(
    **dataset_params_dict,
    split_str="train",
    random_flip=True,
)


In [2]:
input, targets = test_dataset[0]

In [3]:
targets

OrderedDict([('lesion-detection',
              {'boxes': tensor([[ 734., 1204., 2211., 2175.]], dtype=torch.float64),
               'labels': tensor([2]),
               'image_id': tensor([0]),
               'area': tensor([1434167.], dtype=torch.float64),
               'iscrowd': tensor([0]),
               'dicom_id': '34cedb74-d0996b40-6d218312-a9174bea-d48dc033',
               'image_path': 'C:\\Users\\mike8\\mimic-eye\\patient_18111516\\CXR-JPG\\s55032240\\34cedb74-d0996b40-6d218312-a9174bea-d48dc033.jpg',
               'original_image_size': torch.Size([3056, 2544])})])

In [4]:
input['xrays']['images'].shape


torch.Size([3, 512, 512])

In [5]:
input['fixations']['images'].shape


torch.Size([3, 512, 512])