In [1]:
import gc
import torch
import pandas as pd

from utils.plot import plot_result, plot_loss, get_legend_elements, DISEASE_CMAP
from models.load import TrainedModels, get_trained_model
from data.constants import XAMI_MIMIC_PATH, DEFAULT_REFLACX_LABEL_COLS

from utils.init import reproducibility, clean_memory_get_device
from data.datasets import ReflacxDataset, collate_fn
from data.transforms import get_transform

## Suppress the assignement warning from pandas.
pd.options.mode.chained_assignment = None  # default='warn

## Supress user warning
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

%matplotlib inline

In [2]:
from enum import Enum

class TrainedModels(Enum):
    custom_with_clinical_no_pretrained_ar = "val_ar_0_5220_ap_0_2513_test_ar_0_5590_ap_0_2442_epoch150_WithClincal_04-13-2022 20-13-47_custom_with_clinical_no_pretrained"
    custom_with_clinical_no_pretrained_ap = "val_ar_0_4554_ap_0_2582_test_ar_0_5254_ap_0_2405_epoch139_WithClincal_04-13-2022 19-05-08_custom_with_clinical_no_pretrained"
    custom_with_clinical_no_pretrained_final = "val_ar_0_4523_ap_0_2251_test_ar_0_5103_ap_0_2464_epoch200_WithClincal_04-14-2022 01-32-28_custom_with_clinical_no_pretrained"

    custom_without_clinical_no_pretrained_ar = "val_ar_0_5645_ap_0_2659_test_ar_0_6263_ap_0_2533_epoch145_WithoutClincal_04-13-2022 08-47-34_custom_without_clinical_no_pretrained"
    custom_without_clinical_no_pretrained_ap = "val_ar_0_5512_ap_0_2962_test_ar_0_5999_ap_0_2319_epoch93_WithoutClincal_04-12-2022 09-15-28_custom_without_clinical_no_pretrained"
    custom_without_clinical_no_pretrained_final = "val_ar_0_3757_ap_0_1699_test_ar_0_4421_ap_0_1819_epoch200_WithoutClincal_04-13-2022 14-58-52_custom_without_clinical_no_pretrained"

    custom_without_clinical_swim_ap = "val_ar_0_5307_ap_0_2054_test_ar_0_5321_ap_0_1726_epoch87_WithoutClincal_04-17-2022 06-51-10_custom_without_clinical_swim"
    custom_without_clinical_swim_ar = "val_ar_0_5313_ap_0_1540_test_ar_0_5906_ap_0_1486_epoch59_WithoutClincal_04-17-2022 05-48-30_custom_without_clinical_swim"
    custom_without_clinical_swim_final = "val_ar_0_2175_ap_0_1390_test_ar_0_2231_ap_0_0901_epoch200_WithoutClincal_04-17-2022 10-49-59_custom_without_clinical_swim"

    custom_with_clinical_swim_ap = "val_ar_0_5081_ap_0_2210_test_ar_0_5392_ap_0_1725_epoch95_WithClincal_04-17-2022 15-26-02_custom_with_clinical_swim"
    custom_with_clinical_swim_ar = "val_ar_0_5377_ap_0_1821_test_ar_0_4561_ap_0_1193_epoch67_WithClincal_04-17-2022 14-02-28_custom_with_clinical_swim"
    custom_with_clinical_swim_final = "val_ar_0_2752_ap_0_1293_test_ar_0_3391_ap_0_1097_epoch200_WithClincal_04-17-2022 20-31-24_custom_with_clinical_swim"


In [3]:
device = clean_memory_get_device()
reproducibility()

This notebook will running on device: [CUDA]


In [4]:
model, train_info, _ = get_trained_model(
    TrainedModels.custom_with_clinical_no_pretrained_ar,
    DEFAULT_REFLACX_LABEL_COLS,
    device,
    image_size=512,
    rpn_nms_thresh=0.3,
    box_detections_per_img=10,
    box_nms_thresh=0.2,
    rpn_score_thresh=0.0,
    box_score_thresh=0.05,
)



Load custom model
Using ResNet50 as backbone
Not using pretrained model.
Found optimizer for this model.
Model size: 62,334,753
Using SGD as optimizer with lr=0.0005


In [5]:
print(train_info)

ModelSetup(use_clinical=True, use_custom_model=True, use_early_stop_model=True, name='custom_with_clinical_no_pretrained', best_ar_val_model_path=None, best_ap_val_model_path=None, final_model_path=None, backbone='resnet50', optimiser='sgd', lr=0.0005, weight_decay=5e-05, pretrained=False, record_training_performance=True, dataset_mode='unified')

Best AP validation model has been saved to: [val_ar_0_4554_ap_0_2582_test_ar_0_5254_ap_0_2405_epoch139_WithClincal_04-13-2022 19-05-08_custom_with_clinical_no_pretrained]
Best AR validation model has been saved to: [val_ar_0_5056_ap_0_2360_test_ar_0_5891_ap_0_2176_epoch90_WithClincal_04-12-2022 18-38-40_custom_with_clinical_no_pretrained]
The final model has been saved to: [val_ar_0_5220_ap_0_2513_test_ar_0_5590_ap_0_2442_epoch150_WithClincal_04-13-2022 20-13-47_custom_with_clinical_no_pretrained]



In [6]:
model

MultimodalMaskRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=1e-05)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=1e-05)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=1e-05)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=1e-05)
       

In [7]:
train_info.best_ap_val_model_path

'val_ar_0_4554_ap_0_2582_test_ar_0_5254_ap_0_2405_epoch139_WithClincal_04-13-2022 19-05-08_custom_with_clinical_no_pretrained'

In [8]:
labels_cols = [
    "Enlarged cardiac silhouette",
    "Atelectasis",
    "Pleural abnormality",
    "Consolidation",
    "Pulmonary edema",
    #  'Groundglass opacity', # 6th disease.
]

dataset_params_dict = {
    "XAMI_MIMIC_PATH": XAMI_MIMIC_PATH,
    "with_clinical": train_info.model_setup.use_clinical,
    "dataset_mode": "normal",
    "bbox_to_mask": True,
    "labels_cols": labels_cols,
}

test_dataset = ReflacxDataset(
    **dataset_params_dict, split_str="test", transforms=get_transform(train=False),
)

test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn,
)

In [9]:
data = next(iter(test_dataloader))
data = test_dataset.prepare_input_from_data(data, device)

In [10]:
model.eval()
pred = model(*data[:-1])

In [11]:
pred[0]['boxes']

tensor([[2009.1520, 1412.4745, 2661.3843, 1635.9585],
        [ 383.4301, 1230.4435, 1129.4485, 1837.2769],
        [ 385.1847, 1184.6702, 1124.4093, 1837.5146],
        [2266.9734, 1610.6865, 2514.2390, 1799.1272],
        [ 389.0894, 1162.3492, 1122.1621, 1824.5599],
        [1941.1747, 1269.6035, 2569.3140, 1768.7753]], device='cuda:0',
       grad_fn=<StackBackward0>)