In [1]:
import pandas as pd
import numpy as np

import utils.print as print_f

from utils.engine import xami_evaluate, get_iou_types
from utils.plot import plot_losses, plot_ap_ars

from models.setup import ModelSetup
from models.build import create_multimodal_rcnn_model
from models.train import TrainingInfo
from utils.save import check_best, end_train, get_data_from_metric_logger
from data.load import get_datasets, get_dataloaders
from IPython.display import clear_output
from utils.eval import get_ap_ar, get_ap_ar_for_train_val
from utils.train import get_optimiser, get_lr_scheduler, print_params_setup, get_coco_eval_params, get_dynamic_loss, get_params
from utils.init import reproducibility, clean_memory_get_device
from data.constants import DEFAULT_REFLACX_LABEL_COLS
from data.paths import MIMIC_EYE_PATH
from datetime import datetime



## Suppress the assignement warning from pandas.r
pd.options.mode.chained_assignment = None  # default='warn'

## Supress user warning
import warnings
warnings.filterwarnings("ignore")

%matplotlib inline

In [2]:
device = clean_memory_get_device()
reproducibility()


This notebook will running on device: [CPU]


In [3]:
from data.transforms import get_tensorise_h_flip_transform
from data.constants import DEFAULT_REFLACX_LABEL_COLS
from data.paths import MIMIC_EYE_PATH
from data.datasets import ReflacxObjectDetectionDataset
from torch.utils.data import DataLoader
from data.datasets import collate_fn
from data.load import seed_worker, get_dataloader_g

dataset_params_dict = {
    "MIMIC_EYE_PATH": MIMIC_EYE_PATH,
    # "with_clinical": model_setup.use_clinical,
    "bbox_to_mask": True,
    "labels_cols": DEFAULT_REFLACX_LABEL_COLS,
}

train_dataset = ReflacxObjectDetectionDataset(
        **dataset_params_dict, split_str="train", transforms=get_tensorise_h_flip_transform(train=False), 
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collate_fn,
    worker_init_fn=seed_worker,
    generator=get_dataloader_g(0),
)

In [4]:
from models.components.feature_extractors import ImageFeatureExtractor
from models.components.fusors import NoActionFusor
from models.components.task_performers import ObjectDetectionWithMaskParameters, ObjectDetectionWithMaskPerformer
from models.frameworks import ExtractFusePerform
from models.backbones import get_normal_backbone
from models.setup import ModelSetup
from data.constants import DEFAULT_REFLACX_LABEL_COLS



In [5]:
from utils.init import reproducibility, clean_memory_get_device

device = clean_memory_get_device()
reproducibility()

This notebook will running on device: [CPU]


In [6]:
setup = ModelSetup()
backbone = get_normal_backbone(setup)
image_extractor = ImageFeatureExtractor(backbone)
fusor = NoActionFusor()
params = ObjectDetectionWithMaskParameters()
performer = ObjectDetectionWithMaskPerformer(
    params,
    image_extractor.backbone.out_channels,
    len(DEFAULT_REFLACX_LABEL_COLS) + 1
)
# get the backbone
model = ExtractFusePerform(
    feature_extractors={"image": image_extractor},
    fusor=fusor,
    task_performers={"object-detection": performer},
)


Using pretrained backbone. mobilenet_v3


In [7]:
setup = ModelSetup()
backbone = get_normal_backbone(setup)
image_extractor = ImageFeatureExtractor(backbone)
fusor = NoActionFusor()
params = ObjectDetectionWithMaskParameters()
performer = ObjectDetectionWithMaskPerformer(
    params,
    image_extractor.backbone.out_channels,
    len(DEFAULT_REFLACX_LABEL_COLS) + 1
)
# get the backbone
model = ExtractFusePerform(
    feature_extractors={"image": image_extractor},
    fusor=fusor,
    task_performers={"object-detection": performer},
)


Using pretrained backbone. mobilenet_v3


In [8]:
data = next(iter(train_dataloader))
data = train_dataset.prepare_input_from_data(data, device)


In [9]:
print_f.print_title("Preparing for the training.")
train_info = TrainingInfo(setup)



In [10]:
model.to(device)

ExtractFusePerform(
  (fusor): NoActionFusor()
)

In [11]:
from data.load import get_datasets, get_dataloaders
from utils.coco_utils import get_cocos
from utils.coco_eval import get_eval_params_dict


detect_eval_dataset, train_dataset, val_dataset, test_dataset = get_datasets(
        dataset_params_dict=dataset_params_dict,
    )

train_dataloader, val_dataloader, test_dataloader = get_dataloaders(
    train_dataset, val_dataset, test_dataset, batch_size=setup.batch_size,
)

train_coco, val_coco, test_coco = get_cocos(
    train_dataloader, val_dataloader, test_dataloader
)

creating index...
index created!
creating index...
index created!
creating index...
index created!


In [12]:
iou_thrs = np.array([0.5])
use_iobb = True

eval_params_dict = get_eval_params_dict(
    detect_eval_dataset, iou_thrs=iou_thrs, use_iobb=use_iobb,
)

creating index...
index created!


In [13]:
from models.dynamic_loss import DynamicWeightedLoss

# loss_keys = [
#     "loss_classifier",
#     "loss_box_reg",
#     "loss_objectness",
#     "loss_rpn_box_reg"
# ]

loss_keys = [
    "object-detection_loss_box_reg",
    "object-detection_loss_classifier",
    "object-detection_loss_mask",
    "object-detection_loss_objectness",
    "object-detection_loss_rpn_box_reg",
]

dynamic_loss_weight = DynamicWeightedLoss(
    keys=loss_keys + ["loss_mask"] if setup.use_mask else loss_keys
)
dynamic_loss_weight.to(device)

DynamicWeightedLoss(
  (params): ParameterDict(
      (loss_mask): Parameter containing: [torch.FloatTensor of size 1]
      (object-detection_loss_box_reg): Parameter containing: [torch.FloatTensor of size 1]
      (object-detection_loss_classifier): Parameter containing: [torch.FloatTensor of size 1]
      (object-detection_loss_mask): Parameter containing: [torch.FloatTensor of size 1]
      (object-detection_loss_objectness): Parameter containing: [torch.FloatTensor of size 1]
      (object-detection_loss_rpn_box_reg): Parameter containing: [torch.FloatTensor of size 1]
  )
)

In [14]:
print_params_setup(model)

params = [p for p in model.parameters() if p.requires_grad]
if dynamic_loss_weight:
    params += [p for p in dynamic_loss_weight.parameters() if p.requires_grad]

iou_types = get_iou_types(model, setup)
optimizer = get_optimiser(params, setup)
lr_scheduler = get_lr_scheduler(optimizer, setup)

current_time = datetime.now()

print_f.print_title(
    f"Start training. Preparing Took [{ (current_time - train_info.start_t).seconds}] sec"
)

train_info.start_t = datetime.now()

val_loss = None

[model]: 0
Using SGD as optimizer with lr=0.0005


In [15]:
# model.train()


# outputs = model({"image": data[0]}, targets={
#                            "object-detection": data[1]})

In [16]:
# # get all loses values from it.

# outputs

In [17]:
train_dataloader.dataset

<data.datasets.ReflacxObjectDetectionDataset at 0x1201edba0>

In [18]:
from utils.engine import train_one_epoch

train_info.last_train_evaluator, train_loger = train_one_epoch(
    setup=setup,
    model=model,
    optimizer=optimizer,
    data_loader=train_dataloader,
    device=device,
    epoch=train_info.epoch,
    print_freq=10,
    iou_types=iou_types,
    coco=train_coco,
    score_thres=None,
    evaluate_on_run=True,
    params_dict=eval_params_dict,
    dynamic_loss_weight=dynamic_loss_weight,
)

Epoch: [0]  [  0/531]  eta: 1:27:52  lr: 0.000500  loss: 3.2896 (3.2896)  object-detection_loss_classifier: 1.8488 (1.8488)  object-detection_loss_box_reg: 0.0343 (0.0343)  object-detection_loss_mask: 0.7119 (0.7119)  object-detection_loss_objectness: 0.6929 (0.6929)  object-detection_loss_rpn_box_reg: 0.0018 (0.0018)  time: 9.9292  data: 0.2835
Epoch: [0]  [ 10/531]  eta: 1:21:45  lr: 0.000500  loss: 3.1899 (3.1952)  object-detection_loss_classifier: 1.8457 (1.8468)  object-detection_loss_box_reg: 0.0343 (0.0340)  object-detection_loss_mask: 0.6128 (0.6170)  object-detection_loss_objectness: 0.6929 (0.6930)  object-detection_loss_rpn_box_reg: 0.0037 (0.0044)  time: 9.4155  data: 0.3040
Epoch: [0]  [ 20/531]  eta: 1:18:22  lr: 0.000500  loss: 3.1474 (3.1945)  object-detection_loss_classifier: 1.8454 (1.8463)  object-detection_loss_box_reg: 0.0222 (0.0258)  object-detection_loss_mask: 0.5936 (0.6255)  object-detection_loss_objectness: 0.6929 (0.6929)  object-detection_loss_rpn_box_reg: 

KeyboardInterrupt: 