## Imports are necessary

In [1]:
import datetime

import torch
import torch.utils.data

from torch.optim.lr_scheduler import StepLR
from torch.optim import SGD

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor, MaskRCNN_ResNet50_FPN_Weights

from _engine import train_one_epoch, evaluate, test_one_epoch
import _transforms as T
import _utils as utils

## Root Folder Path (Change according to your need)

In [2]:
root_path = '/home/aghosh57/Kerner-Lab/maskrcnn-rpn/dataset/'

## Import Dataset Structure

In [3]:
from _dataset import ParcelDataset

## Some utility Functions

In [4]:
def get_transform(train):
    transforms = []
    # converts the image, a PIL image, into a PyTorch Tensor
    transforms.append(T.ToTensor())
    if train:
        # during training, randomly flip the training images
        # and ground-truth for data augmentation
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

##  Build the Model (Function)

In [5]:
def get_instance_segmentation_model(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT)

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

    return model

## Hyperparameters

In [6]:
TRAIN_BATCH_SIZE = 8
TEST_BATCH_SIZE = 2

NUM_WORKERS = 4

LEARNING_RATE = 0.005
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0005

STEP_SIZE = 3
GAMMA = 0.1

NUM_EPOCHS = 33

NUM_CLASSES = 2 #Background and Foreground

## Set the device & Clear Cache

In [7]:
torch.cuda.empty_cache()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

## Build the dataset and generate the dataloaders

In [8]:
# use our dataset and defined transformations
dataset = ParcelDataset(root_path, get_transform(train=True))
dataset_test = ParcelDataset(root_path, get_transform(train=False))

# split the dataset in train and test set
torch.manual_seed(1)
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-50])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS,
    collate_fn=utils.collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=TEST_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS,
    collate_fn=utils.collate_fn)

## Model, Optimizer & LR Schedular

In [9]:
# get the model using our helper function
rpn_model = get_instance_segmentation_model(NUM_CLASSES)

rpn_model.load_state_dict(torch.load('model_chkpts/rpn_model_1.2.pth'))

# move model to the right device
rpn_model.to(device)

# construct an optimizer
params = [p for p in rpn_model.parameters() if p.requires_grad]
optimizer = SGD(params, lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

# and a learning rate scheduler which decreases the learning rate by
# 10x every 3 epochs
lr_scheduler = StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)

## Training & Evaluation

In [10]:
best_epoch_loss = 100000

for epoch in range(NUM_EPOCHS):
    # train for one epoch, printing every 10 iterations
    train_metric = train_one_epoch(rpn_model, optimizer, data_loade r, device, epoch, print_freq=10)
    
    #Print the train metric loss
    print('Train Loss: ' + str(train_metric.meters['loss'].avg))
    
    if train_metric.meters['loss'].avg < best_epoch_loss:
        best_epoch_loss = train_metric.meters['loss'].avg
        torch.save(rpn_model.state_dict(), 'model_chkpts/rpn_model_' + str(round(best_epoch_loss,2)) + '.pth')
    
    # update the learning rate
    lr_scheduler.step()
    
    evaluate(rpn_model, data_loader_test, device=device)

    torch.cuda.empty_cache()

Epoch: [0]  [  0/801]  eta: 2:41:39  lr: 0.000011  loss: 1.2703 (1.2703)  loss_classifier: 0.3203 (0.3203)  loss_box_reg: 0.4163 (0.4163)  loss_mask: 0.3252 (0.3252)  loss_objectness: 0.1270 (0.1270)  loss_rpn_box_reg: 0.0814 (0.0814)  time: 12.1095  data: 0.7325  max mem: 7417
Epoch: [0]  [ 10/801]  eta: 0:17:50  lr: 0.000074  loss: 1.2703 (1.2769)  loss_classifier: 0.3211 (0.3105)  loss_box_reg: 0.3884 (0.3818)  loss_mask: 0.3285 (0.3271)  loss_objectness: 0.1460 (0.1442)  loss_rpn_box_reg: 0.0900 (0.1134)  time: 1.3530  data: 0.0824  max mem: 7613
Epoch: [0]  [ 20/801]  eta: 0:10:55  lr: 0.000136  loss: 1.2640 (1.2616)  loss_classifier: 0.2988 (0.3026)  loss_box_reg: 0.3602 (0.3705)  loss_mask: 0.3337 (0.3309)  loss_objectness: 0.1346 (0.1389)  loss_rpn_box_reg: 0.1069 (0.1187)  time: 0.2751  data: 0.0169  max mem: 7613
Epoch: [0]  [ 30/801]  eta: 0:08:27  lr: 0.000199  loss: 1.2636 (1.2597)  loss_classifier: 0.2929 (0.2992)  loss_box_reg: 0.3735 (0.3715)  loss_mask: 0.3176 (0.3254)