In [None]:
import sys
sys.path.append('./faster_RCNN/')

from cell_dataset import CellDataset
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

import os
import shutil
import numpy as np
import torch
import torch.utils.data
from PIL import Image
import pandas as pd
import cv2

from engine import train_one_epoch, evaluate
import utils
import transforms as T
import random
import pickle

from sklearn.model_selection import train_test_split

In [None]:
def build_model(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model


def get_transform(train, augmentation_dict=None):
    transforms = []
    if train:
        for aug_key, factors in augmentation_dict.items():
            transforms.append(T.DataAugmentation(0.5, aug_key, factors[0], factors[1]))
        transforms.append(T.ToTensor())
#         transforms.append(T.RandomHorizontalFlip(0.5))
    else:
        transforms.append(T.ToTensor())
    return T.Compose(transforms)

In [None]:
ANN_PATH = "cellDetection_annotations.pkl"
IMG_DIR  = "../../download_from_drive/data/ProcessedO7"
AUGMENTATION_DICT = {
    'contrast': [0.5, 1.5],
    'brightness': [0.4, 1.5],
    'saturation': [0.4, 1.2],
    'gamma': [1, 1.25],
    'hue': [-0.1, 0.1]
}
BATCH_SIZE = 4
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

ann_df = pd.read_pickle('cellDetection_annotations.pkl')
ann_df["path"] = ann_df.filename.apply(lambda fn: os.path.join(IMG_DIR, fn.split("/")[-1].split(".json")[0]))
ann_df.drop(["filename", "value"], axis=1, inplace=True)

imgs = ann_df.path.unique()
train_imgs, valid_imgs = train_test_split(imgs, train_size=0.7, random_state=1234)

train_df = ann_df.query("path in @train_imgs")
valid_df = ann_df.query("path in @valid_imgs")

In [None]:
train_transforms = get_transform(train=True,  augmentation_dict=AUGMENTATION_DICT)
valid_transforms = get_transform(train=False, augmentation_dict=None)

train_ds = CellDataset(data_df=train_df, transforms=train_transforms)
valid_ds = CellDataset(data_df=valid_df, transforms=valid_transforms)

train_dl = torch.utils.data.DataLoader(
    dataset=train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=os.cpu_count(),
    collate_fn=utils.collate_fn
)

valid_dl = torch.utils.data.DataLoader(
    dataset=valid_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=os.cpu_count(),
    collate_fn=utils.collate_fn
)

In [None]:
checkpoints_dir = 'trained_models'
if os.path.isdir(checkpoints_dir):
    shutil.rmtree(checkpoints_dir)
os.mkdir(checkpoints_dir)

# our dataset has two classes only - background and CELL
num_classes = 2
model = build_model(num_classes)
model.to(DEVICE)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005
)

# and a learning rate scheduler which decreases the learning rate by
# 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=3,
    gamma=0.1
)

# number of epochs
num_epochs = 3
for epoch in range(num_epochs):
    
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, train_dl, DEVICE, epoch, print_freq=10)
    
    # update the learning rate
    lr_scheduler.step()
    
    # evaluate on the test dataset
    evaluate(model, valid_dl, device=DEVICE)
    
    # save model afeter the current epoch
    torch.save(model, os.path.join(checkpoints_dir, 'checkpoint-' + str(epoch).zfill(2)) + '.pt')
    print()
    print()

In [None]:
pickle.dump(valid_imgs, open("valid_imgs.pkl", "wb"))