In [1]:
import sys
sys.path.append('./faster_RCNN/')

from cell_dataset import CellDataset
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

import os
import shutil
import numpy as np
import torch
import torch.utils.data
from PIL import Image
import pandas as pd
import cv2

from engine import train_one_epoch, evaluate
import utils
import transforms as T
import random

from sklearn.model_selection import train_test_split

In [2]:
def build_model(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model


def get_transform(train, augmentation_dict=None):
    transforms = []
    if train:
        for aug_key, factors in augmentation_dict.items():
            transforms.append(T.DataAugmentation(0.5, aug_key, factors[0], factors[1]))
        transforms.append(T.ToTensor())
#         transforms.append(T.RandomHorizontalFlip(0.5))
    else:
        transforms.append(T.ToTensor())
    return T.Compose(transforms)

In [3]:
ANN_PATH = "cellDetection_annotations.pkl"
IMG_DIR  = "../../download_from_drive/data/ProcessedO7"
AUGMENTATION_DICT = {
    'contrast': [0.5, 1.5],
    'brightness': [0.4, 1.5],
    'saturation': [0.4, 1.2],
    'gamma': [1, 1.25],
    'hue': [-0.1, 0.1]
}
BATCH_SIZE = 4
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

ann_df = pd.read_pickle('cellDetection_annotations.pkl')
ann_df["path"] = ann_df.filename.apply(lambda fn: os.path.join(IMG_DIR, fn.split("/")[-1].split(".json")[0]))
ann_df.drop(["filename", "value"], axis=1, inplace=True)

imgs = ann_df.path.unique()
train_imgs, valid_imgs = train_test_split(imgs, train_size=0.7, random_state=1234)

train_df = ann_df.query("path in @train_imgs")
valid_df = ann_df.query("path in @valid_imgs")

In [4]:
train_transforms = get_transform(train=True,  augmentation_dict=AUGMENTATION_DICT)
valid_transforms = get_transform(train=False, augmentation_dict=None)

train_ds = CellDataset(data_df=train_df, transforms=train_transforms)
valid_ds = CellDataset(data_df=valid_df, transforms=valid_transforms)

train_dl = torch.utils.data.DataLoader(
    dataset=train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=os.cpu_count(),
    collate_fn=utils.collate_fn
)

valid_dl = torch.utils.data.DataLoader(
    dataset=valid_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=os.cpu_count(),
    collate_fn=utils.collate_fn
)

In [5]:
checkpoints_dir = 'trained_models'
if os.path.isdir(checkpoints_dir):
    shutil.rmtree(checkpoints_dir)
os.mkdir(checkpoints_dir)

# our dataset has two classes only - background and CELL
num_classes = 2
model = build_model(num_classes)
model.to(DEVICE)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005
)

# and a learning rate scheduler which decreases the learning rate by
# 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=3,
    gamma=0.1
)

# number of epochs
num_epochs = 30
for epoch in range(num_epochs):
    
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, train_dl, DEVICE, epoch, print_freq=10)
    
    # update the learning rate
    lr_scheduler.step()
    
    # evaluate on the test dataset
    evaluate(model, valid_dl, device=DEVICE)
    
    # save model afeter the current epoch
    torch.save(model, os.path.join(checkpoints_dir, 'checkpoint-' + str(epoch).zfill(2)) + '.pt')
    print()
    print()

Epoch: [0]  [ 0/36]  eta: 0:01:31  lr: 0.000148  loss: 6.0536 (6.0536)  loss_classifier: 0.6177 (0.6177)  loss_box_reg: 0.1642 (0.1642)  loss_objectness: 4.9627 (4.9627)  loss_rpn_box_reg: 0.3089 (0.3089)  time: 2.5430  data: 1.2861  max mem: 4442
Epoch: [0]  [10/36]  eta: 0:00:32  lr: 0.001575  loss: 1.8091 (2.9530)  loss_classifier: 0.5734 (0.5661)  loss_box_reg: 0.6997 (0.5954)  loss_objectness: 0.3777 (1.6050)  loss_rpn_box_reg: 0.1788 (0.1865)  time: 1.2612  data: 0.1870  max mem: 4924
Epoch: [0]  [20/36]  eta: 0:00:18  lr: 0.003002  loss: 1.5307 (2.1155)  loss_classifier: 0.4758 (0.4834)  loss_box_reg: 0.5587 (0.5752)  loss_objectness: 0.2768 (0.9261)  loss_rpn_box_reg: 0.0777 (0.1308)  time: 1.1018  data: 0.0526  max mem: 4924
Epoch: [0]  [30/36]  eta: 0:00:06  lr: 0.004429  loss: 0.9257 (1.6843)  loss_classifier: 0.3025 (0.4156)  loss_box_reg: 0.4384 (0.5103)  loss_objectness: 0.1015 (0.6547)  loss_rpn_box_reg: 0.0549 (0.1037)  time: 1.0750  data: 0.0295  max mem: 4924
Epoch: [