In [1]:
import torch
import pandas as pd


import utils.print as print_f
from data.dataset import REFLACXWithBoundingBoxesDataset, ReflacxAllCXRDataset
from utils.transforms import get_transform
from collections import OrderedDict

## Suppress the assignement warning from pandas.
pd.options.mode.chained_assignment = None  # default='warn'

from utils.engine import official_evaluate, official_train_one_epoch
import matplotlib.pyplot as plt
from IPython.display import clear_output
from utils.save import  get_train_data
from datetime import datetime
import PIL
from matplotlib.patches import Rectangle
import os 

import pickle
from utils.engine import map_target_to_device

%matplotlib inline

In [2]:
import gc

gc.collect()
torch.cuda.empty_cache()

In [3]:
torch.cuda.memory_summary(device=None, abbreviated=False)



In [4]:
XAMI_MIMIC_PATH = "D:\XAMI-MIMIC"

In [5]:
use_gpu = torch.cuda.is_available()
device = 'cuda' if use_gpu else 'cpu'
print(f"Will be using {device}")

Will be using cuda


In [6]:
labels_cols = [
    "Enlarged cardiac silhouette",
    "Atelectasis",
    "Pleural abnormality",
    "Consolidation",
    "Pulmonary edema",
    #  'Groundglass opacity', # 6th disease.
]

train_dataset = REFLACXWithBoundingBoxesDataset(
    XAMI_MIMIC_PATH=XAMI_MIMIC_PATH,
    labels_cols=labels_cols,
    split_str="train",
    transforms=get_transform(train=True),
)

val_dataset = REFLACXWithBoundingBoxesDataset(
    XAMI_MIMIC_PATH=XAMI_MIMIC_PATH,
    labels_cols=labels_cols,
    split_str="val",
    transforms=get_transform(train=False),
)

test_dataset = REFLACXWithBoundingBoxesDataset(
    XAMI_MIMIC_PATH=XAMI_MIMIC_PATH,
    labels_cols=labels_cols,
    split_str="test",
    transforms=get_transform(train=False),
)

batch_size = 4

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=REFLACXWithBoundingBoxesDataset.collate_fn,
)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=REFLACXWithBoundingBoxesDataset.collate_fn,
)

test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=REFLACXWithBoundingBoxesDataset.collate_fn,
)


In [7]:
train_dataset[0]

(tensor([[[0.5608, 0.5608, 0.5569,  ..., 0.0000, 0.0000, 0.0000],
          [0.5569, 0.5647, 0.5647,  ..., 0.0000, 0.0000, 0.0000],
          [0.5490, 0.5569, 0.5608,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.7686, 0.7686, 0.7647,  ..., 1.0000, 1.0000, 1.0000],
          [0.7608, 0.7647, 0.7608,  ..., 1.0000, 1.0000, 1.0000],
          [0.7529, 0.7569, 0.7569,  ..., 1.0000, 1.0000, 1.0000]],
 
         [[0.5608, 0.5608, 0.5569,  ..., 0.0000, 0.0000, 0.0000],
          [0.5569, 0.5647, 0.5647,  ..., 0.0000, 0.0000, 0.0000],
          [0.5490, 0.5569, 0.5608,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.7686, 0.7686, 0.7647,  ..., 1.0000, 1.0000, 1.0000],
          [0.7608, 0.7647, 0.7608,  ..., 1.0000, 1.0000, 1.0000],
          [0.7529, 0.7569, 0.7569,  ..., 1.0000, 1.0000, 1.0000]],
 
         [[0.5608, 0.5608, 0.5569,  ..., 0.0000, 0.0000, 0.0000],
          [0.5569, 0.5647, 0.5647,  ..., 0.0000, 0.0000, 0.0000],
          [0.5490, 0.5569, 0.5608,  ...,

In [8]:
# from data.dataset import PennFudanDataset

# train_dataset = PennFudanDataset('PennFudanPed', get_transform(train=True))
# test_dataset = PennFudanDataset('PennFudanPed', get_transform(train=False))

# indices = torch.randperm(len(train_dataset)).tolist()
# dataset = torch.utils.data.Subset(train_dataset, indices[:-50])
# dataset_test = torch.utils.data.Subset(test_dataset, indices[-50:])


# train_dataloader = torch.utils.data.DataLoader(
#     dataset, batch_size=4, shuffle=True,
#     collate_fn=REFLACXWithBoundingBoxesDataset.collate_fn)

# test_dataloader = torch.utils.data.DataLoader(
#     dataset_test, batch_size=4, shuffle=False,
#     collate_fn=REFLACXWithBoundingBoxesDataset.collate_fn)

# # val_dataloader = test_dataloader

In [9]:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor


def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(
        pretrained=True,
        rpn_nms_thresh=0.5,
        box_detections_per_img=10,
        box_nms_thresh=0.2,
    )

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask, hidden_layer, num_classes
    )

    return model


In [10]:
import torchvision
from torchvision.models.detection import FasterRCNN

# trainable_backbone_layers = torchvision.models.detection.backbone_utils._validate_trainable_layers(
#     True, None, 5, 3
# )
# backbone = torchvision.models.detection.backbone_utils.resnet_fpn_backbone(
#     "resnet50", pretrained=True, trainable_layers=trainable_backbone_layers
# )
# backbone.out_channels = 256


######################## For MobileNet backbone ########################
# from torchvision.models.detection.faster_rcnn import AnchorGenerator
# backbone = torchvision.models.mobilenet_v2(pretrained=True).features
# backbone.out_channels = 1280

# model = FasterRCNN(
#     backbone,
#     num_classes=len(train_dataset.labels_cols) + 1,
#     rpn_anchor_generator=None,
#     box_roi_pool=None,
# )

# anchor_generator = AnchorGenerator(
#     sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),)
# )
# roi_pooler = torchvision.ops.MultiScaleRoIAlign(
#     featmap_names=["0"], output_size=7, sampling_ratio=2
# )
########################################################################


# ResNet (Smaller model.)
# model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
#     pretrained_backbone=True, num_classes=len(train_dataset.labels_cols) + 1,
# )

# model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
#     pretrained_backbone=True, num_classes=len(train_dataset.labels_cols) + 1,
# )

# testing with PennFudanDataset

# Trainable fasterrcnn
# model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
#     # pretrained_backbone=True, num_classes=2,
#     pretrained=True,
#     # num_classes=2,
#     rpn_nms_thresh=0.5,
#     box_detections_per_img=10,
#     box_nms_thresh=0.2,
# )

# model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
# model = get_model_instance_segmentation(len(train_dataset.labels_cols) + 1)

model = torchvision.models.detection.maskrcnn_resnet50_fpn(
        pretrained=True,
        rpn_nms_thresh=0.5,
        box_detections_per_img=10,
        box_nms_thresh=0.2,
    )
    
model.to(device)


MaskRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(in

In [11]:
print(f"Model size: {sum([param.nelement()  for param in model.parameters()]):,}")

Model size: 44,401,393


In [12]:
params = [p for p in model.parameters() if p.requires_grad]
lr_scheduler = None

# construct an optimizer

# optimizer = torch.optim.Adam(params, lr=0.05 ,weight_decay=0.0005)
# lr_scheduler = None

## Original Setting (Kinda work)
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)

# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=3,
                                                gamma=0.5)

In [13]:
def plot_loss(train_logers):

    clear_output()
    
    train_data = [ get_train_data(loger) for loger in train_logers]

    train_data_keys = train_data[0].keys()

    fig, subplots  = plt.subplots(
        len(train_data_keys), figsize=(10, 5* len(train_data_keys)), dpi=80, sharex=True)


    fig.suptitle(f"Training Losses")

    for i, k in enumerate(train_data_keys):
        subplots[i].set_title(k)
        subplots[i].plot([data[k] for data in train_data], marker='o', label=k, color='steelblue')
        # subplots[i].legend(loc="upper left")

    subplots[-1].set_xlabel('Epoch')
    plt.plot()
    plt.pause(0.01)


In [14]:
import numpy as np

def transparent_cmap(cmap, N=255):
    "Copy colormap and set alpha values"

    t_cmap = cmap
    t_cmap._init()
    t_cmap._lut[:,-1] = np.linspace(0, 0.8, N+4)
    return t_cmap

In [15]:
disease_cmap = {
    "transparent": {
        "Enlarged cardiac silhouette": transparent_cmap(plt.cm.autumn),
        "Atelectasis": transparent_cmap(plt.cm.Reds),
        "Pleural abnormality": transparent_cmap(plt.cm.Oranges),
        "Consolidation": transparent_cmap(plt.cm.Greens),
        "Pulmonary edema": transparent_cmap(plt.cm.Blues),
    },
    "solid": {
        {
            "Enlarged cardiac silhouette": "yellow",
            "Atelectasis": "red",
            "Pleural abnormality": "orange",
            "Consolidation": "lightgreen",
            "Pulmonary edema": "dodgerblue",
        }
    },
}

transparent_disease_color_code_map = {
    "Enlarged cardiac silhouette": transparent_cmap(plt.cm.autumn),
    "Atelectasis": transparent_cmap(plt.cm.Reds),
    "Pleural abnormality": transparent_cmap(plt.cm.Oranges),
    "Consolidation": transparent_cmap(plt.cm.Greens),
    "Pulmonary edema": transparent_cmap(plt.cm.Blues),
}


In [16]:
disease_color_code_map = {
    'Enlarged cardiac silhouette': 'yellow',
    'Atelectasis': 'red',
    'Pleural abnormality': 'orange',
    'Consolidation': 'lightgreen',
    'Pulmonary edema': 'dodgerblue',
}

In [18]:
def plot_seg(target, pred, legend_elements, transparent_disease_color_code_map, seg_thres = 0):

    fig, (gt_ax, pred_ax) = plt.subplots(1, 2, figsize=(20, 10), dpi=80, sharex=True)

    fig.suptitle(target["image_path"])

    img = PIL.Image.open(target["image_path"]).convert("RGB")

    gt_ax.imshow(img)
    gt_ax.set_title("Ground Truth")
    pred_ax.imshow(img)
    pred_ax.set_title("Predictions")

    fig.legend(handles=legend_elements, loc="upper right")

    for label, m in zip(
        target["labels"].detach().cpu().numpy(), target["masks"].detach().cpu().numpy(),
    ):
        disease = train_dataset.label_index_to_disease(label)
        mask_img = PIL.Image.fromarray(m * 255)
        gt_ax.imshow(
            mask_img,
            transparent_disease_color_code_map[disease],
            interpolation="none",
            alpha=0.7,
        )

    for label, m in zip(
        pred[0]["labels"].detach().cpu().numpy(),
        pred[0]["masks"].detach().cpu().numpy(),
    ):
        disease = train_dataset.label_index_to_disease(label)
        mask = (m.squeeze() > seg_thres).astype(np.uint8)
        mask_img = PIL.Image.fromarray(mask * 255)

        pred_ax.imshow(
            mask_img,
            transparent_disease_color_code_map[disease],
            interpolation="none",
            alpha=0.7
        )


In [19]:
def plot_bbox(target, pred, legend_elements, disease_color_code_map):

    fig, (gt_ax, pred_ax) = plt.subplots(1, 2, figsize=(20, 10), dpi=80, sharex=True)

    fig.suptitle(target["image_path"])

    fig.legend(handles=legend_elements, loc="upper right")


    img = PIL.Image.open(target["image_path"]).convert("RGB")

    gt_ax.imshow(img)
    gt_ax.set_title(f"Ground Truth ({len(target['boxes'].detach().cpu().numpy())})")
    pred_ax.imshow(img)
    pred_ax.set_title(f"Predictions ({len(pred[0]['boxes'].detach().cpu().numpy())})")

    # load image
    gt_recs = []
    pred_recs = []
    
    for label, bbox in zip(pred[0]['labels'].detach().cpu().numpy(), pred[0]['boxes'].detach().cpu().numpy()):
        disease = train_dataset.label_index_to_disease(label)
        c = disease_color_code_map[disease]
        pred_recs.append(Rectangle((bbox[0],bbox[1]), bbox[2] - bbox[0], bbox[3]-bbox[1], fill=False, color=c, linewidth=2))
        pred_ax.text(bbox[0],bbox[1], disease,color="black", backgroundcolor=c)

    for rec in pred_recs:
        pred_ax.add_patch(rec)

    for label, bbox in zip(target['labels'].detach().cpu().numpy(),target['boxes'].detach().cpu().numpy()):
        disease = train_dataset.label_index_to_disease(label)
        c = disease_color_code_map[disease]
        gt_recs.append(Rectangle((bbox[0],bbox[1]), bbox[2] - bbox[0], bbox[3]-bbox[1], fill=False, color=c, linewidth=2))
        gt_ax.text(bbox[0],bbox[1], disease, color="black",  backgroundcolor=c)

    for rec in gt_recs:
        gt_ax.add_patch(rec)
    
    plt.plot()
    plt.pause(0.01)



In [31]:
[1,2, 3][:-1]

[1, 2]

In [20]:
def plot_result(
    model,
    dataset,
    idx,
    legend_elements,
    disease_color_code_map,
    transparent_disease_color_code_map,
    seg_thres=0.5,
):
    image, target = dataset[idx]
    target = map_target_to_device(target, device)
    image = image.to(device)
    pred = model([image])
    plot_bbox(target, pred, legend_elements, disease_color_code_map)
    # plot_seg(
    #     target,
    #     pred,
    #     legend_elements,
    #     transparent_disease_color_code_map,
    #     seg_thres=seg_thres,
    # )


In [21]:
num_epochs = 10

train_logers = []
val_evaluators = []

start_t = datetime.now()

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    model.train()
    
    train_loger = official_train_one_epoch(
        model, optimizer, train_dataloader, device, epoch, print_freq=10
    )
    train_logers.append(train_loger)

    model.eval()

    plot_loss(train_logers)
    plot_result(model, train_dataset, 0, legend_elements, disease_color_code_map, transparent_disease_color_code_map, seg_thres=0.5)
    plot_result(model, train_dataset, 45, legend_elements, disease_color_code_map, transparent_disease_color_code_map, seg_thres=0.5)
    plot_result(model, train_dataset, 88, legend_elements, disease_color_code_map, transparent_disease_color_code_map, seg_thres=0.5)

    # update the learning rate
    if not lr_scheduler is None:
        lr_scheduler.step()
    # evaluate on the test dataset
    val_evaluator = official_evaluate(model, val_dataloader, device=device)
    val_evaluators.append(val_evaluator)

end_t = datetime.now()

sec_took = (end_t - start_t).seconds

print_f.print_title(
    f"| Training Done, start testing! | Training time: [{sec_took}] seconds, Avg time / Epoch: [{sec_took/num_epochs}] seconds"
)

test_evaluator = official_evaluate(model, test_dataloader, device=device)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


AssertionError: 

In [25]:
def three_string_ha(a, b,c="Okay"):
    return f"{a}dsadasd{b}dasd{c}"

In [26]:
three_string_ha(*("hahaha", "hello", "word")[:-1])

'hahahadsadasdhellodasdOkay'

In [None]:
test_evaluator.summarize()

IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.112
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.249
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.092
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.112
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.195
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.246
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.246
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.246
IoU metric: segm
 Ave

In [None]:
training_record = OrderedDict({
    "train_data": [ get_train_data(loger) for loger in train_logers],
    "val_evaluators" : val_evaluators,
    "test_evaluator": test_evaluator,
})

In [None]:
clinial_cond = "Without"
current_time_string = datetime.now().strftime("%m-%d-%Y %H-%M-%S")
final_model_path =  f"epoch{epoch}_{clinial_cond}Clincal_{current_time_string}".replace(":", "_").replace(".", "_")

In [None]:
torch.save(model.state_dict(), os.path.join(
    os.path.join('trained_models', final_model_path)
))

print(f"Model has been saved: {final_model_path}")

Model has been saved: epoch9_WithoutClincal_03-10-2022 19-00-48


In [None]:
with open(
    os.path.join("training_records", f"{final_model_path}.pkl"), "wb",
) as training_record_f:
    pickle.dump(training_record, training_record_f)

In [None]:
# epoch9_WithoutClincal_03-10-2022 05-35-31 => all relfacx