In [None]:
'''
Author       : Aditya Jain
Date Started : 10th August, 2021
About        : This is the training file DL-based localization module
'''

#### Loading Experiment Manager

In [None]:
from comet_ml import Experiment

experiment = Experiment(
    api_key='epeaAhyRcHSkn92H4kusmbX8k',
    project_name='mothai',
    workspace='adityajain07'
)

experiment.add_tag('DL_Localiz_A1')

#### Loading Library

In [None]:
import torchvision.models as torchmodels
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torch
import utils
from torch.utils.data import random_split
from torch import nn
from torchsummary import summary
import json
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn.functional as F
import torch.optim as optim
import datetime
import time

from localizdataset import LocalizDataset

#### Loading Model

In [None]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

# load a model pre-trained pre-trained on COCO
model       = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 2  # 1 class (person) + background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

# print(model)
# print(summary(model, (3,224,224)))  # keras-type model summary

#### Loading Data

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))

In [None]:
root_dir    = '/scratch/Localization/'
BATCH_SIZE  = 32
TRAIN_PER   = 0.85   # percentage of training points in the data
NUM_EPOCHS  = 1
EARLY_STOP  = 4
DTSTR       = datetime.datetime.now()
DTSTR       = DTSTR.strftime("%Y-%m-%d-%H-%M")
SAVE_PATH   = '/home/mila/a/aditya.jain/logs/v1_localizmodel_' + DTSTR + '.pt'

In [None]:
transformer        = transforms.Compose([              
                        transforms.ToTensor()])
data               = LocalizDataset(root_dir, transformer)
train_size         = int(TRAIN_PER*len(data))
val_size           = len(data)-train_size

train_set, val_set = random_split(data, [train_size, val_size])
train_dataloader   = DataLoader(train_set,batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_dataloader     = DataLoader(val_set,batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

#### Loading Loss function and Optimizer

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

#### Model Training

In [None]:
model.train()
lowest_val_loss = 100000000
early_stp_count = 0

for epoch in range(num_epochs):
    train_loss = 0
    val_loss   = 0
    
    for image_batch, label_batch in train_dataloader:        
        output       = model(image_batch,label_batch)   
        total_loss   = sum(loss for loss in output.values())
        train_loss   += total_loss.item()
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
          
    for image_batch, label_batch in val_dataloader:        
        output       = model(image_batch,label_batch)   
        total_loss   = sum(loss for loss in output.values())
        val_loss     += total_loss.item()        
        
    if val_loss<lowest_val_loss:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss':val_loss}, 
            SAVE_PATH)                
        lowest_val_loss = val_loss
        early_stp_count = 0
    else:
        early_stp_count += 1 
        
    experiment.log_metric("loss_training", train_loss, epoch=epoch)
    experiment.log_metric("loss_validation", val_loss, epoch=epoch)
    
    if early_stp_count >= EARLY_STOP:
        break         


In [None]:
experiment.end()