In [41]:
import os
import torch
import pickle


import pandas as pd
import utils.print as print_f

from datetime import datetime
from data.dataset import REFLACXWithClinicalAndBoundingBoxDataset
from collections import OrderedDict

from utils.transforms import get_transform
from utils.engine import train_one_epoch, evaluate
from utils.save import  get_train_data

## Suppress the assignement warning from pandas.
pd.options.mode.chained_assignment = None  # default='warn'

### Define your MIMIC folde path here.

In [2]:
XAMI_MIMIC_PATH = "D:\XAMI-MIMIC"

## Checking if GPU is available.

If **GPU (CUDA)** is available, then it will be used. Otherwise, CPU will be applied.

In [3]:
use_gpu = torch.cuda.is_available()
device = 'cuda' if use_gpu else 'cpu'
print(f"Will be using {device}")

Will be using cuda


# Initiate datasets and dataloaders
The batch size is also defined in this section. Since this project will run on single 16GB GTX 3080 (laptop ver.), we can't applied batch size larger than 16.

In [4]:
## Prepare data
batch_size = 16

train_dataset = REFLACXWithClinicalAndBoundingBoxDataset(
    XAMI_MIMIC_PATH=XAMI_MIMIC_PATH,
    split_str="train",
    transforms=get_transform(train=True),
)

val_dataset = REFLACXWithClinicalAndBoundingBoxDataset(
    XAMI_MIMIC_PATH=XAMI_MIMIC_PATH,
    split_str="val",
    transforms=get_transform(train=False),
)

test_dataset = REFLACXWithClinicalAndBoundingBoxDataset(
    XAMI_MIMIC_PATH=XAMI_MIMIC_PATH,
    split_str="test",
    transforms=get_transform(train=False),
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, collate_fn=REFLACXWithClinicalAndBoundingBoxDataset.collate_fn
)


val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=True, collate_fn=REFLACXWithClinicalAndBoundingBoxDataset.collate_fn
)

test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=True, collate_fn=REFLACXWithClinicalAndBoundingBoxDataset.collate_fn
)



## Define Model.

We define he models here. Two backbone examples are in the below code section. The MobileNet is a light weight network, and ResNet is heavier, but usually perform better. In our case, the calculation is not the most important factor; therefore, we chose ResNet with feature pyramid networks (FPN) backbone. 

To use the pretrained model properly, we aslo fix the first 2 layer of the ResNet backbone.

** During instantiating the model, we also have to define whether we want to use the clinical data, so the model can adjust for it **

In [5]:
import torchvision
from models.rcnn import MultimodalFasterRCNN

## Weights of the first 2 layer of the ResNet backbone are fixed.
trainable_backbone_layers = torchvision.models.detection.backbone_utils._validate_trainable_layers(
    True, None, 5, 3
)
backbone = torchvision.models.detection.backbone_utils.resnet_fpn_backbone(
    "resnet50", pretrained=True, trainable_layers=trainable_backbone_layers
)
backbone.out_channels = 256


######################## For MobileNet backbone ########################
# from torchvision.models.detection.faster_rcnn import AnchorGenerator
# backbone = torchvision.models.mobilenet_v2(pretrained=True).features
# backbone.out_channels = 1280
# anchor_generator = AnchorGenerator(
#     sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),)
# )
# roi_pooler = torchvision.ops.MultiScaleRoIAlign(
#     featmap_names=["0"], output_size=7, sampling_ratio=2
# )
########################################################################

model = MultimodalFasterRCNN(
    backbone,
    num_classes=len(train_dataset.labels_cols) + 1,
    rpn_anchor_generator=None,
    box_roi_pool=None,
    use_clinical=True, # Define whether we use clinical data in our model.
)

# move model to the right device
model.to(device)

MultimodalFasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=1e-05)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=1e-05)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=1e-05)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=1e-05)
     

# Check how large is this model.

In [6]:
print(f"Model size: {sum([param.nelement()  for param in model.parameters()]):,}")

Model size: 54,694,171


## Define the parameters for training.

We define what are those parameters should be trained. And, also define the optimiser and learning rate scheduler here.

In [7]:
params = [p for p in model.parameters() if p.requires_grad]

# construct an optimizer
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)
# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=3,
                                                gamma=0.1)

# Train.

We run training here. Number of epochs is defined here. In each epoch, the evaluation of validation dataset will run. At the end, the evaluation will run on test set.



In [8]:
num_epochs = 200

train_logers = []
val_evaluators = []

start_t = datetime.now()

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_loger = train_one_epoch(model, optimizer, train_dataloader, device, epoch, print_freq=10)
    train_logers.append(train_loger)
    
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    val_evaluator = evaluate(model, val_dataloader, device=device)
    val_evaluators.append(val_evaluator)

end_t = datetime.now()

sec_took = (end_t - start_t).seconds

print_f.print_title(f"| Training Done, start testing! | Training time: [{sec_took}] seconds, Avg time / Epoch: [{sec_took/num_epochs}] seconds")

test_evaluator = evaluate(model, test_dataloader, device=device)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Epoch: [0]  [ 0/30]  eta: 0:07:14  lr: 0.000177  loss: 2.5088 (2.5088)  loss_classifier: 1.7768 (1.7768)  loss_box_reg: 0.0323 (0.0323)  loss_objectness: 0.6949 (0.6949)  loss_rpn_box_reg: 0.0049 (0.0049)  time: 14.4821  data: 2.4319  max mem: 7983
Epoch: [0]  [10/30]  eta: 0:01:45  lr: 0.001900  loss: 2.4572 (2.3755)  loss_classifier: 1.7351 (1.6559)  loss_box_reg: 0.0189 (0.0204)  loss_objectness: 0.6941 (0.6931)  loss_rpn_box_reg: 0.0049 (0.0061)  time: 5.2762  data: 2.8528  max mem: 8332
Epoch: [0]  [20/30]  eta: 0:00:48  lr: 0.003622  loss: 1.4088 (1.6134)  loss_classifier: 0.9676 (1.0540)  loss_box_reg: 0.0236 (0.0372)  loss_objectness: 0.6819 (0.5150)  loss_rpn_box_reg: 0.0057 (0.0073)  time: 4.3318  data: 2.8577  max mem: 8332
Epoch: [0]  [29/30]  eta: 0:00:04  lr: 0.005000  loss: 0.4526 (1.2065)  loss_classifier: 0.1822 (0.7761)  loss_box_reg: 0.0325 (0.0374)  loss_objectness: 0.1071 (0.3861)  loss_rpn_box_reg: 0.0061 (0.0069)  time: 4.0265  data: 2.6290  max mem: 8332
Epoch: 

In [9]:
test_evaluator.summarize()

IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.000


## Plot training Result.

In [65]:
training_record = OrderedDict({
    "train_data": [ get_train_data(loger) for loger in train_logers],
    "val_evaluators" : val_evaluators,
    "test_evaluator": test_evaluator,
})

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

train_data_keys = training_record['train_data'][0].keys()

fig, subplots  = plt.subplots(
    len(train_data_keys), figsize=(10, 5* len(train_data_keys)), dpi=80, sharex=True)


fig.suptitle(f"Training Losses")

for i, k in enumerate(train_data_keys):
    subplots[i].set_title(k)
    subplots[i].plot([data[k] for data in training_record["train_data"]], marker='o', label=k, color='steelblue')
    # subplots[i].legend(loc="upper left")

subplots[-1].set_xlabel('Epoch')
plt.plot()
plt.pause(0.01)

## Save mode and training data.

In [42]:
clinial_cond = "With" if model.use_clinical else "Without"
current_time_string = datetime.now().strftime("%m-%d-%Y %H-%M-%S")
final_model_path =  f"epoch{epoch}_{clinial_cond}Clincal_{current_time_string}".replace(":", "_").replace(".", "_")

In [None]:
torch.save(model.state_dict(), os.path.join(
    os.path.join('trained_models', final_model_path)
))

print(f"Model has been saved: {final_model_path}")

In [37]:
with open(
    os.path.join("training_records", f"{final_model_path}.pkl"), "wb",
) as training_record_f:
    pickle.dump(training_record, training_record_f)

In [38]:
# # load testing.
# import pickle
# final_model_path = "epoch199_WithClincal_03-07-2022 18-04-34"
# with open(os.path.join("training_records", f"{final_model_path}.pkl"), 'rb') as f:
#     training_record = pickle.load(f)