In [1]:
NEW = '../models/some2.pth'
OLD = '../models/some1.pth'

In [2]:
# General Purpose
import os
import math
from tqdm import tqdm
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
matplotlib.rcParams['text.color'] = 'blue'

# Torch specifics
import torch
from torch import nn
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.optimizer import Optimizer

# Task specific
import utils
import transforms as T
from engine import train_one_epoch, evaluate
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

In [None]:
## check if new model already exists and if old model exists or not!
if os.path.exists(OLD) is False:
    raise Exception('no such model ' + NEW.split('/')[-1] + ' exixts !')
if os.path.exists(NEW):
    raise Exception('model ' + NEW.split('/')[-1] + ' already exixts !')

In [3]:
## check device: cpu or gpu(cuda)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [4]:
## train examples
train_df = pd.read_csv("../csv/annotations_train.csv")
train_df.head()

Unnamed: 0,image,xmin,ymix,xmax,ymax,class
0,frontFar/BLR-2018-03-22_17-39-26_2_frontFar/00...,641,431,658,446,0
1,frontFar/BLR-2018-03-22_17-39-26_2_frontFar/00...,661,397,711,452,1
2,frontFar/BLR-2018-03-22_17-39-26_2_frontFar/00...,593,405,639,466,2
3,frontFar/BLR-2018-03-22_17-39-26_2_frontFar/00...,525,405,560,449,3
4,frontFar/BLR-2018-03-22_17-39-26_2_frontFar/00...,555,341,680,483,4


In [5]:
## validation examples
val_df = pd.read_csv("../csv/annotations_val.csv")
val_df.head()

Unnamed: 0,image,xmin,ymix,xmax,ymax,class
0,frontFar/BLR-2018-04-16_16-14-27_frontFar/0006...,401,392,619,488,4
1,frontFar/BLR-2018-04-16_16-14-27_frontFar/0006...,638,381,835,497,4
2,frontFar/BLR-2018-04-16_16-14-27_frontFar/0006...,733,379,928,493,4
3,frontFar/BLR-2018-04-16_16-14-27_frontFar/0006...,1093,486,1172,707,7
4,frontFar/BLR-2018-04-16_16-14-27_frontFar/0006...,1101,426,1144,463,5


In [6]:
## class labels
labels_idx = pd.read_csv('../csv/labels.csv')
labels_idx.head()

Unnamed: 0,class,idx
0,car,0
1,bus,1
2,autorickshaw,2
3,vehicle fallback,3
4,truck,4


In [7]:
## labels dict
labels_dict = {labels_idx.iloc[i,1] : labels_idx.iloc[i,0] for i in range(len(labels_idx))}
labels_dict

{0: 'car',
 1: 'bus',
 2: 'autorickshaw',
 3: 'vehicle fallback',
 4: 'truck',
 5: 'motorcycle',
 6: 'rider',
 7: 'person',
 8: 'bicycle',
 9: 'animal',
 10: 'traffic sign',
 11: 'train',
 12: 'trailer',
 13: 'traffic light',
 14: 'caravan'}

In [8]:
## transform
def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [None]:
## custom dataset
class Data(Dataset):
    """IDD dataset."""

    def __init__(self, csv_file, root_dir, transforms=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.frame = pd.read_csv(csv_file)
        self.group = self.frame.groupby('image')
        self.unique_frame = self.frame.image.unique()
        self.root_dir = root_dir
        self.transforms = transforms

    def __len__(self):
        return len(self.unique_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,
                                self.unique_frame[idx])
        image = Image.open(img_name, mode='r').convert('RGB')
        boxes = torch.Tensor(self.group.get_group(self.unique_frame[idx]).iloc[:,1:-1].values)
        labels = self.group.get_group(self.unique_frame[idx]).iloc[:,-1].values
        target = {"boxes": boxes, "labels":labels}

        if self.transforms is not None:
            image, target = self.transforms(image, target)
        
        return image, target

In [None]:
## dataloader
train = Data("../csv/annotations_train.csv", 'data/JPEGImages', get_transform(train=True))
train_dl =  torch.utils.data.DataLoader(train, batch_size=2, shuffle=True, num_workers=4, collate_fn=utils.collate_fn)

In [None]:
## model build
def build(num_labels):
    model = fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_labels)
    return model.cuda()

In [None]:
## plot results
def plot(image, target, idx, save=False):
    plt.imshow(image.permute(1, 2, 0).cpu())
    for i in range(target[idx]['boxes'].shape[0]):
        x1, y1, x2, y2 = target[idx]['boxes'][i][0].cpu(), target[idx]['boxes'][i][1].cpu(), target[idx]['boxes'][i][2].cpu(), target[idx]['boxes'][i][3].cpu()
        plt.gca().add_patch(patches.Rectangle((x1, y1), x2-x1, y2-y1, label=labels_dict[target[idx]['labels'][i].cpu().item()], edgecolor='r', linewidth=1, facecolor='none'))
        if save is True:
            plt.savefig('../test/new.png', bbox_inches='tight')

In [None]:
## predict
def predict(model, img):
    img = img.cuda()
    model.eval()
    with torch.no_grad():
        preds = model([img])
    return preds

In [None]:
## start model training
model = build(len(labels_idx))
model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

In [None]:
## load previous model
checkpoint = torch.load(OLD)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
## train example test
img, target = next(iter(train))
target

{'boxes': tensor([[622., 431., 639., 446.],
         [569., 397., 619., 452.],
         [641., 405., 687., 466.]]), 'labels': array([0, 1, 2])}

In [None]:
## training begins
num_epochs = 10
for epoch in tqdm(range(num_epochs)):
    train_one_epoch(model, optimizer, train_dl, device, epoch, print_freq=50)
    lr_scheduler.step()

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: [0]  [    0/15785]  eta: 4:02:47  lr: 0.000010  loss: 0.1640 (0.1640)  loss_classifier: 0.0795 (0.0795)  loss_box_reg: 0.0396 (0.0396)  loss_objectness: 0.0082 (0.0082)  loss_rpn_box_reg: 0.0367 (0.0367)  time: 0.9229  data: 0.3291  max mem: 4033
Epoch: [0]  [   50/15785]  eta: 1:36:22  lr: 0.000260  loss: 0.3915 (0.4990)  loss_classifier: 0.1957 (0.2075)  loss_box_reg: 0.1043 (0.1251)  loss_objectness: 0.0207 (0.0324)  loss_rpn_box_reg: 0.0958 (0.1340)  time: 0.3599  data: 0.0047  max mem: 4930
Epoch: [0]  [  100/15785]  eta: 1:35:19  lr: 0.000509  loss: 0.5827 (0.5316)  loss_classifier: 0.2093 (0.2173)  loss_box_reg: 0.1488 (0.1322)  loss_objectness: 0.0231 (0.0343)  loss_rpn_box_reg: 0.1445 (0.1478)  time: 0.3578  data: 0.0047  max mem: 5040
Epoch: [0]  [  150/15785]  eta: 1:33:30  lr: 0.000759  loss: 0.4875 (0.5257)  loss_classifier: 0.2180 (0.2144)  loss_box_reg: 0.1257 (0.1291)  loss_objectness: 0.0262 (0.0347)  loss_rpn_box_reg: 0.1067 (0.1475)  time: 0.3479  data: 0.0046

Epoch: [0]  [ 1650/15785]  eta: 1:22:26  lr: 0.005000  loss: 0.5643 (0.5414)  loss_classifier: 0.2038 (0.2121)  loss_box_reg: 0.1216 (0.1271)  loss_objectness: 0.0289 (0.0394)  loss_rpn_box_reg: 0.1337 (0.1628)  time: 0.3493  data: 0.0045  max mem: 5353
Epoch: [0]  [ 1700/15785]  eta: 1:22:08  lr: 0.005000  loss: 0.3973 (0.5426)  loss_classifier: 0.1510 (0.2125)  loss_box_reg: 0.0963 (0.1274)  loss_objectness: 0.0301 (0.0399)  loss_rpn_box_reg: 0.0593 (0.1629)  time: 0.3526  data: 0.0046  max mem: 5353
Epoch: [0]  [ 1750/15785]  eta: 1:21:51  lr: 0.005000  loss: 0.4874 (0.5434)  loss_classifier: 0.1955 (0.2126)  loss_box_reg: 0.1138 (0.1276)  loss_objectness: 0.0266 (0.0398)  loss_rpn_box_reg: 0.1453 (0.1633)  time: 0.3521  data: 0.0045  max mem: 5353
Epoch: [0]  [ 1800/15785]  eta: 1:21:37  lr: 0.005000  loss: 0.4661 (0.5419)  loss_classifier: 0.2027 (0.2122)  loss_box_reg: 0.1131 (0.1274)  loss_objectness: 0.0389 (0.0400)  loss_rpn_box_reg: 0.1131 (0.1623)  time: 0.3484  data: 0.0048

Epoch: [0]  [ 3300/15785]  eta: 1:13:44  lr: 0.005000  loss: 0.6206 (0.5535)  loss_classifier: 0.2538 (0.2158)  loss_box_reg: 0.1466 (0.1288)  loss_objectness: 0.0371 (0.0421)  loss_rpn_box_reg: 0.1857 (0.1668)  time: 0.3502  data: 0.0046  max mem: 5353
Epoch: [0]  [ 3350/15785]  eta: 1:13:26  lr: 0.005000  loss: 0.4270 (0.5532)  loss_classifier: 0.1482 (0.2157)  loss_box_reg: 0.1190 (0.1287)  loss_objectness: 0.0421 (0.0422)  loss_rpn_box_reg: 0.0785 (0.1666)  time: 0.3540  data: 0.0045  max mem: 5353
Epoch: [0]  [ 3400/15785]  eta: 1:13:08  lr: 0.005000  loss: 0.5413 (0.5541)  loss_classifier: 0.1955 (0.2160)  loss_box_reg: 0.1244 (0.1288)  loss_objectness: 0.0303 (0.0421)  loss_rpn_box_reg: 0.1601 (0.1672)  time: 0.3483  data: 0.0047  max mem: 5353
Epoch: [0]  [ 3450/15785]  eta: 1:12:50  lr: 0.005000  loss: 0.5979 (0.5543)  loss_classifier: 0.2320 (0.2161)  loss_box_reg: 0.1352 (0.1288)  loss_objectness: 0.0287 (0.0420)  loss_rpn_box_reg: 0.2166 (0.1673)  time: 0.3527  data: 0.0047

Epoch: [0]  [ 4950/15785]  eta: 1:04:32  lr: 0.005000  loss: 0.4847 (0.5546)  loss_classifier: 0.1727 (0.2163)  loss_box_reg: 0.1008 (0.1291)  loss_objectness: 0.0323 (0.0425)  loss_rpn_box_reg: 0.0647 (0.1667)  time: 0.3487  data: 0.0046  max mem: 5372
Epoch: [0]  [ 5000/15785]  eta: 1:04:13  lr: 0.005000  loss: 0.6031 (0.5549)  loss_classifier: 0.1807 (0.2163)  loss_box_reg: 0.1149 (0.1291)  loss_objectness: 0.0365 (0.0425)  loss_rpn_box_reg: 0.2400 (0.1670)  time: 0.3486  data: 0.0047  max mem: 5372
Epoch: [0]  [ 5050/15785]  eta: 1:03:55  lr: 0.005000  loss: 0.4565 (0.5549)  loss_classifier: 0.1951 (0.2163)  loss_box_reg: 0.1193 (0.1291)  loss_objectness: 0.0313 (0.0425)  loss_rpn_box_reg: 0.1198 (0.1670)  time: 0.3434  data: 0.0045  max mem: 5372
Epoch: [0]  [ 5100/15785]  eta: 1:03:35  lr: 0.005000  loss: 0.6184 (0.5553)  loss_classifier: 0.2213 (0.2165)  loss_box_reg: 0.1498 (0.1292)  loss_objectness: 0.0302 (0.0425)  loss_rpn_box_reg: 0.1535 (0.1671)  time: 0.3410  data: 0.0046

Epoch: [0]  [ 6600/15785]  eta: 0:54:19  lr: 0.005000  loss: 0.3894 (0.5551)  loss_classifier: 0.1845 (0.2166)  loss_box_reg: 0.1218 (0.1290)  loss_objectness: 0.0290 (0.0426)  loss_rpn_box_reg: 0.0988 (0.1668)  time: 0.3455  data: 0.0047  max mem: 5372
Epoch: [0]  [ 6650/15785]  eta: 0:54:01  lr: 0.005000  loss: 0.5505 (0.5550)  loss_classifier: 0.1802 (0.2166)  loss_box_reg: 0.1192 (0.1291)  loss_objectness: 0.0354 (0.0427)  loss_rpn_box_reg: 0.1762 (0.1667)  time: 0.3490  data: 0.0046  max mem: 5372
Epoch: [0]  [ 6700/15785]  eta: 0:53:44  lr: 0.005000  loss: 0.4792 (0.5551)  loss_classifier: 0.2126 (0.2166)  loss_box_reg: 0.1185 (0.1291)  loss_objectness: 0.0304 (0.0426)  loss_rpn_box_reg: 0.1110 (0.1667)  time: 0.3694  data: 0.0048  max mem: 5372
Epoch: [0]  [ 6750/15785]  eta: 0:53:26  lr: 0.005000  loss: 0.5768 (0.5553)  loss_classifier: 0.2240 (0.2167)  loss_box_reg: 0.1207 (0.1291)  loss_objectness: 0.0265 (0.0427)  loss_rpn_box_reg: 0.1253 (0.1669)  time: 0.3548  data: 0.0046

Epoch: [0]  [ 8250/15785]  eta: 0:44:34  lr: 0.005000  loss: 0.5175 (0.5546)  loss_classifier: 0.1712 (0.2164)  loss_box_reg: 0.1130 (0.1289)  loss_objectness: 0.0366 (0.0432)  loss_rpn_box_reg: 0.1219 (0.1660)  time: 0.3539  data: 0.0047  max mem: 5372
Epoch: [0]  [ 8300/15785]  eta: 0:44:17  lr: 0.005000  loss: 0.4285 (0.5548)  loss_classifier: 0.2025 (0.2165)  loss_box_reg: 0.1002 (0.1290)  loss_objectness: 0.0282 (0.0433)  loss_rpn_box_reg: 0.0808 (0.1660)  time: 0.3563  data: 0.0045  max mem: 5372
Epoch: [0]  [ 8350/15785]  eta: 0:43:59  lr: 0.005000  loss: 0.5785 (0.5551)  loss_classifier: 0.2394 (0.2167)  loss_box_reg: 0.1348 (0.1291)  loss_objectness: 0.0337 (0.0433)  loss_rpn_box_reg: 0.1219 (0.1660)  time: 0.3526  data: 0.0042  max mem: 5372
Epoch: [0]  [ 8400/15785]  eta: 0:43:41  lr: 0.005000  loss: 0.5175 (0.5551)  loss_classifier: 0.1821 (0.2168)  loss_box_reg: 0.1269 (0.1291)  loss_objectness: 0.0330 (0.0432)  loss_rpn_box_reg: 0.1088 (0.1660)  time: 0.3507  data: 0.0045

In [None]:
## torch model save
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, NEW)

In [None]:
## torch load new model
checkpoint = torch.load(NEW)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']

In [None]:
## validation example test
val = Data("../csv/annotations_val.csv", 'data/JPEGImages', get_transform(train=False))
val_dl =  torch.utils.data.DataLoader(train, batch_size=2, shuffle=True, num_workers=4, collate_fn=utils.collate_fn)
img, target = next(iter(train))
print(target)

In [None]:
## predictions and plot
preds = predict(model, img)
plot(img, preds, 0, True)