## Imports are necessary

In [1]:
import datetime

import torch
import torch.utils.data

from torch.optim.lr_scheduler import StepLR
from torch.optim import SGD

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor, MaskRCNN_ResNet50_FPN_Weights

from _engine import train_one_epoch, evaluate
import _transforms as T
import _utils as utils

## Root Folder Path (Change according to your need)

In [2]:
root_path = '/home/aghosh57/Kerner-Lab/maskrcnn-rpn/dataset/'

## Import Dataset Structure

In [3]:
from _dataset import ParcelDataset

## Some utility Functions

In [4]:
def get_transform(train):
    transforms = []
    # converts the image, a PIL image, into a PyTorch Tensor
    transforms.append(T.ToTensor())
    if train:
        # during training, randomly flip the training images
        # and ground-truth for data augmentation
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

##  Build the Model (Function)

In [5]:
def get_instance_segmentation_model(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT)

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

    return model

## Hyperparameters

In [6]:
TRAIN_BATCH_SIZE = 16
TEST_BATCH_SIZE = 2

NUM_WORKERS = 4

LEARNING_RATE = 0.005
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0005

STEP_SIZE = 3
GAMMA = 0.1

NUM_EPOCHS = 1

NUM_CLASSES = 2 #Background and Foreground

## Set the device & Clear Cache

In [7]:
torch.cuda.empty_cache()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

## Build the dataset and generate the dataloaders

In [8]:
# use our dataset and defined transformations
dataset = ParcelDataset(root_path, get_transform(train=True))
dataset_test = ParcelDataset(root_path, get_transform(train=False))

# split the dataset in train and test set
torch.manual_seed(1)
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-50])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS,
    collate_fn=utils.collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=TEST_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS,
    collate_fn=utils.collate_fn)

## Model, Optimizer & LR Schedular

In [9]:
# get the model using our helper function
rpn_model = get_instance_segmentation_model(NUM_CLASSES)
# move model to the right device
rpn_model.to(device)

# construct an optimizer
params = [p for p in rpn_model.parameters() if p.requires_grad]
optimizer = SGD(params, lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

# and a learning rate scheduler which decreases the learning rate by
# 10x every 3 epochs
lr_scheduler = StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)

## Training & Evaluation

In [10]:
for epoch in range(NUM_EPOCHS):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(rpn_model, optimizer, data_loader, device, epoch, print_freq=10)
    
    # update the learning rate
    lr_scheduler.step()
    
    # evaluate on the test dataset
    evaluate(rpn_model, data_loader_test, device=device)

    torch.cuda.empty_cache()

# save model with timestamp as name
torch.save(rpn_model.state_dict(), "rpn_model_" + str(datetime.datetime.now()) + ".pth")    

Epoch: [0]  [  0/401]  eta: 0:16:34  lr: 0.000017  loss: 8.1094 (8.1094)  loss_classifier: 0.6765 (0.6765)  loss_box_reg: 0.8597 (0.8597)  loss_mask: 2.5982 (2.5982)  loss_objectness: 3.7684 (3.7684)  loss_rpn_box_reg: 0.2067 (0.2067)  time: 2.4812  data: 0.6632  max mem: 14604
Epoch: [0]  [ 10/401]  eta: 0:04:39  lr: 0.000142  loss: 6.2139 (5.9833)  loss_classifier: 0.6529 (0.6440)  loss_box_reg: 0.7336 (0.7294)  loss_mask: 1.6778 (1.8192)  loss_objectness: 2.4908 (2.5611)  loss_rpn_box_reg: 0.2113 (0.2296)  time: 0.7141  data: 0.0885  max mem: 14604
Epoch: [0]  [ 20/401]  eta: 0:03:59  lr: 0.000267  loss: 3.6852 (4.5511)  loss_classifier: 0.5755 (0.5843)  loss_box_reg: 0.7234 (0.7252)  loss_mask: 1.2092 (1.3835)  loss_objectness: 0.9519 (1.6359)  loss_rpn_box_reg: 0.2020 (0.2222)  time: 0.5366  data: 0.0303  max mem: 14604
Epoch: [0]  [ 30/401]  eta: 0:03:42  lr: 0.000392  loss: 2.5032 (3.8464)  loss_classifier: 0.4961 (0.5534)  loss_box_reg: 0.7206 (0.7111)  loss_mask: 0.6042 (1.121