## Try out our model here.

We test our mutli-modal Faster R-CNN with MIMIC dataset here.

In [1]:
import torch
import pandas as pd

from data.dataset import REFLACXWithClinicalAndBoundingBoxDataset
from utils.transforms import get_transform

## Suppress the assignement warning from pandas.
pd.options.mode.chained_assignment = None  # default='warn'

### Define your MIMIC folde path here.

In [2]:
XAMI_MIMIC_PATH = "D:\XAMI-MIMIC"

# Initiate datasets and dataloaders
The batch size is also defined in this section. For testing purpose, we only set it as 2.

In [3]:
train_dataset = REFLACXWithClinicalAndBoundingBoxDataset(
    XAMI_MIMIC_PATH=XAMI_MIMIC_PATH,
    split_str="train",
    transforms=get_transform(train=True),
)

val_dataset = REFLACXWithClinicalAndBoundingBoxDataset(
    XAMI_MIMIC_PATH=XAMI_MIMIC_PATH,
    split_str="val",
    transforms=get_transform(train=False),
)

test_dataset = REFLACXWithClinicalAndBoundingBoxDataset(
    XAMI_MIMIC_PATH=XAMI_MIMIC_PATH,
    split_str="test",
    transforms=get_transform(train=False),
)

batch_size = 2

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, collate_fn=REFLACXWithClinicalAndBoundingBoxDataset.collate_fn
)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=True, collate_fn=REFLACXWithClinicalAndBoundingBoxDataset.collate_fn
)

test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=True, collate_fn=REFLACXWithClinicalAndBoundingBoxDataset.collate_fn
)


## Example instance from dataset
We show what's inside a single instance. It will provide:

- Images
- Clinical data
- Targets (Dictionary)

And, inside the target, there're:

- boxes (bounding boxes of abnormality)
- lable (disease index (Note: the class **0** means the background))
- image_id (idx to get that image)
- area (the areas that bouding boxes contain)
- iscrowd (if it's a place with multiple bouding boxes, we assume all the the bouding boxes are not crowd.)

In [4]:
train_dataset[0]

(tensor([[[0.5608, 0.5608, 0.5569,  ..., 0.0000, 0.0000, 0.0000],
          [0.5569, 0.5647, 0.5647,  ..., 0.0000, 0.0000, 0.0000],
          [0.5490, 0.5569, 0.5608,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.7686, 0.7686, 0.7647,  ..., 1.0000, 1.0000, 1.0000],
          [0.7608, 0.7647, 0.7608,  ..., 1.0000, 1.0000, 1.0000],
          [0.7529, 0.7569, 0.7569,  ..., 1.0000, 1.0000, 1.0000]],
 
         [[0.5608, 0.5608, 0.5569,  ..., 0.0000, 0.0000, 0.0000],
          [0.5569, 0.5647, 0.5647,  ..., 0.0000, 0.0000, 0.0000],
          [0.5490, 0.5569, 0.5608,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.7686, 0.7686, 0.7647,  ..., 1.0000, 1.0000, 1.0000],
          [0.7608, 0.7647, 0.7608,  ..., 1.0000, 1.0000, 1.0000],
          [0.7529, 0.7569, 0.7569,  ..., 1.0000, 1.0000, 1.0000]],
 
         [[0.5608, 0.5608, 0.5569,  ..., 0.0000, 0.0000, 0.0000],
          [0.5569, 0.5647, 0.5647,  ..., 0.0000, 0.0000, 0.0000],
          [0.5490, 0.5569, 0.5608,  ...,

## Define Model.

We define he models here. Two backbone examples are in the below code section. The MobileNet is a light weight network, and ResNet is heavier, but usually perform better. In our case, the calculation is not the most important factor; therefore, we chose ResNet with feature pyramid networks (FPN) backbone.

In [5]:
import torchvision
from models.rcnn import MultimodalFasterRCNN

trainable_backbone_layers = torchvision.models.detection.backbone_utils._validate_trainable_layers(
    True, None, 5, 3
)
backbone = torchvision.models.detection.backbone_utils.resnet_fpn_backbone(
    "resnet50", pretrained=True, trainable_layers=trainable_backbone_layers
)
backbone.out_channels = 256


######################## For MobileNet backbone ########################
# from torchvision.models.detection.faster_rcnn import AnchorGenerator
# backbone = torchvision.models.mobilenet_v2(pretrained=True).features
# backbone.out_channels = 1280
# anchor_generator = AnchorGenerator(
#     sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),)
# )
# roi_pooler = torchvision.ops.MultiScaleRoIAlign(
#     featmap_names=["0"], output_size=7, sampling_ratio=2
# )
########################################################################

model = MultimodalFasterRCNN(
    backbone,
    num_classes=len(train_dataset.labels_cols) + 1,
    rpn_anchor_generator=None,
    box_roi_pool=None,
    use_clinical=True,
)


## Prepare data to feed

We prepare three main data to test the model:

- CXR image
- Clinical data
- Target

And, for each data, we adjust the format to what the model expect.

In [6]:
images, clinical_num, clinical_cat, targets = next(iter(train_dataloader))
images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]

# Test Feedforawrd (Training)

In [7]:
output = model(images, (clinical_num, clinical_cat), targets)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


## Results we get.
Four different losses are given in the result, we will use these losses to optimise the network while training. 

In [8]:
output

{'loss_classifier': tensor(1.7564, grad_fn=<NllLossBackward0>),
 'loss_box_reg': tensor(0., grad_fn=<DivBackward0>),
 'loss_objectness': tensor(0.6937, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
 'loss_rpn_box_reg': tensor(0., grad_fn=<DivBackward0>)}

# Test Feedforawrd (Evaluation)

In [9]:
model.eval()
pred = model(images, (clinical_num, clinical_cat))

## Results we get.
If we set the model to evaluation mode and don't pass the target to the forward function, the model will output prediction (detections). In the below sections, we show what's inside the detection of first instance (idx=0).

### Detection.

A detection contain *boxes*, *lables*, and *scores*.

- *boxes*: All the bounding boxes for this image. 
- *lables*: Labels corresponded to the bounding boxes.
- *score*: Score (Confidence) for each boudning box.

In [10]:
pred[0].keys()

dict_keys(['boxes', 'labels', 'scores'])

In [11]:
pred[0]['boxes']

tensor([[1.7698e+02, 1.2941e+03, 6.2344e+02, 2.3637e+03],
        [8.1306e+02, 1.5948e+03, 9.7100e+02, 1.7924e+03],
        [3.1414e+02, 1.5101e+03, 9.4729e+02, 2.2719e+03],
        [3.2322e+02, 8.4963e+02, 9.4594e+02, 1.6075e+03],
        [8.1714e+02, 3.7854e+00, 2.0837e+03, 1.0938e+03],
        [1.6042e+03, 1.0259e+03, 2.0531e+03, 2.0976e+03],
        [1.5851e+00, 4.8525e+02, 6.1034e+02, 2.6296e+03],
        [7.7397e+02, 1.6425e+03, 9.3297e+02, 1.8400e+03],
        [8.3674e+02, 1.2631e+03, 1.0656e+03, 1.8032e+03],
        [5.5690e+02, 3.7163e+02, 1.1832e+03, 1.1342e+03],
        [7.9909e+02, 1.4023e+03, 1.0283e+03, 1.9429e+03],
        [7.1944e+02, 1.4034e+03, 9.4915e+02, 1.9444e+03],
        [4.7966e+02, 2.9425e+00, 1.1110e+03, 6.6280e+02],
        [3.2187e+02, 8.4664e+02, 9.4967e+02, 1.6122e+03],
        [7.7631e+02, 1.5262e+03, 9.3498e+02, 1.7207e+03],
        [3.4360e+00, 1.3271e+03, 9.6685e+02, 2.8294e+03],
        [6.3905e+02, 8.6920e+01, 1.2667e+03, 8.4386e+02],
        [5.138

In [12]:
pred[0]['labels']

tensor([4, 1, 4, 4, 4, 4, 4, 1, 1, 4, 1, 1, 4, 1, 1, 4, 4, 4, 4, 1, 1, 4, 4, 4,
        1, 4, 1, 1, 5, 4, 4, 4, 4, 4, 4, 1, 4, 4, 1, 1, 1, 4, 1, 4, 5, 1, 4, 5,
        1, 5, 4, 4, 5, 4, 5, 4, 1, 1, 4, 4, 1, 4, 1, 4, 1, 1, 5, 1, 4, 1, 1, 1,
        1, 1, 4, 4, 1, 1, 1, 1, 4, 4, 4, 1, 1, 1, 1, 4, 5, 1, 4, 1, 1, 4, 1, 4,
        1, 1, 5, 4])

In [13]:
pred[0]['scores']

tensor([0.1763, 0.1745, 0.1740, 0.1740, 0.1736, 0.1734, 0.1733, 0.1732, 0.1732,
        0.1731, 0.1731, 0.1731, 0.1730, 0.1727, 0.1727, 0.1727, 0.1726, 0.1725,
        0.1725, 0.1723, 0.1723, 0.1722, 0.1722, 0.1721, 0.1721, 0.1720, 0.1720,
        0.1720, 0.1720, 0.1719, 0.1719, 0.1719, 0.1719, 0.1719, 0.1718, 0.1717,
        0.1717, 0.1716, 0.1716, 0.1715, 0.1715, 0.1715, 0.1714, 0.1714, 0.1714,
        0.1713, 0.1713, 0.1713, 0.1713, 0.1713, 0.1713, 0.1713, 0.1712, 0.1712,
        0.1712, 0.1712, 0.1712, 0.1711, 0.1711, 0.1711, 0.1711, 0.1710, 0.1710,
        0.1710, 0.1710, 0.1710, 0.1710, 0.1710, 0.1710, 0.1710, 0.1710, 0.1709,
        0.1709, 0.1709, 0.1709, 0.1709, 0.1709, 0.1708, 0.1708, 0.1708, 0.1708,
        0.1708, 0.1708, 0.1708, 0.1708, 0.1708, 0.1708, 0.1708, 0.1708, 0.1708,
        0.1707, 0.1707, 0.1707, 0.1707, 0.1707, 0.1707, 0.1707, 0.1706, 0.1706,
        0.1706], grad_fn=<IndexBackward0>)