In [None]:
import os

from skimage import io
import pylab as plt
import wandb
import numpy as np
import time

from ml_utils.model.faster_rcnn import load_model_for_training
from ml_utils.train.dataloader_bbox import get_data_loaders
from ml_utils.train.train_bbox import train
from ml_utils.utils.visualize import draw_bbox

### Specify parameters

In [None]:
input_dir = '../example_data/img'
bbox_fn = '../example_data/bboxes.csv'
model_dir = '../outputs/model'

project_name = 'test_project'

batch_size = 2
val_fraction = 0.2
num_workers = 2

config = dict(num_epochs=20, 
              lr=0.01, 
              momentum=0.9, 
              weight_decay=0.0005,
              step_size=3, 
              gamma=0.1,
              detection_thr=0.1,
              overlap_thr=0.1,
              dist_thr=10)

log_progress = True

### Show example data

In [None]:
tr_dl, val_dl = get_data_loaders(bbox_fn,
                                 input_dir=input_dir,
                                 val_fraction=val_fraction, 
                                 batch_size=batch_size, 
                                 num_workers=num_workers)

n = 0
for images, targets, image_ids in tr_dl:
    for i in range(len(images)):
        boxes = targets[i]['boxes'].cpu().numpy().astype(np.int32)
        sample = images[i].permute(1, 2, 0).cpu().numpy()

        for box in boxes:
            sample = draw_bbox(sample, [box[1], box[0], box[3], box[2]], color=(1, 0, 0))

        io.imshow(sample)
        plt.show()
        n += 1
        if n > 3:
            break

### Load data and model

In [None]:
tr_dl, val_dl = get_data_loaders(bbox_fn,
                                 input_dir=input_dir,
                                 val_fraction=val_fraction, 
                                 batch_size=batch_size, 
                                 num_workers=num_workers)
model = load_model_for_training()

### Train the model

In [None]:
%%time

if log_progress is False:
    os.environ['WANDB_MODE'] = 'offline'
    
wandb.init(project=project_name, config=config)

train(model, tr_dl, val_dl, config=config, log_progress=log_progress, model_dir=model_dir)

wandb.finish()