In [1]:
import os
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
try:
    import xml.etree.cElementTree as ET
except ImportError:
    import xml.etree.ElementTree as ET

%matplotlib inline


In [2]:
classes = ['smoke']

root = "../datasets/pp_smoke"

num_classes = len(classes)+1  # n class + background

batch_size = 3

num_epochs = 5

In [3]:
import Dataset
import utils
dataset_train = Dataset.DatasetGen(root, Dataset.get_transform(horizontal_flip=True), train=True)
dataset_test = Dataset.DatasetGen(root, Dataset.get_transform(horizontal_flip=False), train=False)

data_loader_train = torch.utils.data.DataLoader(
        dataset_train, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0,
        collate_fn=utils.collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=batch_size, shuffle=False, num_workers=0,
    collate_fn=utils.collate_fn)

In [4]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#device = torch.device('cpu')
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
device

device(type='cuda')

In [5]:
anchor_generator = torchvision.models.detection.rpn. \
    AnchorGenerator(sizes=((32,), (24, ), (24, ), (16,), (8, )),
                                        aspect_ratios=([1.0, 1.0, 1.0, 1.0], 
                                                     [0.8, 1.0, 1.0, 1.0], 
                                                     [1.0, 0.8, 1.0, 1.0],
                                                     [1.0, 1.0, 1.0, 1.0],
                                                     [1.0, 1.0, 1.0, 1.0]))

In [6]:
model.rpn.anchor_generator = anchor_generator
model.rpn.head = torchvision.models.detection.faster_rcnn. \
    RPNHead(256, anchor_generator.num_anchors_per_location()[0])
# get the number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn. \
    FastRCNNPredictor(in_features, num_classes)


In [7]:
model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)
# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=3,
                                               gamma=0.1)

In [10]:
from engine import train_one_epoch,evaluate
if True:#make this 1 for training to happen
    for epoch in range(num_epochs):
        train_one_epoch(model, optimizer, data_loader_train, device, epoch, print_freq=1)
        lr_scheduler.step()
        evaluate(model, data_loader_test, device=device)

    print("Training complete!")
    
    torch.cuda.synchronize()
    # create directory for saving the model
    print("Saving model...")
    torch.save(model, "model.pth")
    print("Model saving complete!")


Epoch: [0]  [  0/258]  eta: 0:01:48  lr: 0.000024  loss: 0.3630 (0.3630)  loss_classifier: 0.0120 (0.0120)  loss_box_reg: 0.0304 (0.0304)  loss_objectness: 0.0601 (0.0601)  loss_rpn_box_reg: 0.2605 (0.2605)  time: 0.4189  data: 0.0186  max mem: 6359
Epoch: [0]  [  1/258]  eta: 0:01:39  lr: 0.000044  loss: 0.3630 (1.2315)  loss_classifier: 0.0022 (0.0071)  loss_box_reg: 0.0090 (0.0197)  loss_objectness: 0.0601 (0.1301)  loss_rpn_box_reg: 0.2605 (1.0747)  time: 0.3870  data: 0.0119  max mem: 6359
Epoch: [0]  [  2/258]  eta: 0:01:45  lr: 0.000063  loss: 0.3906 (0.9512)  loss_classifier: 0.0086 (0.0076)  loss_box_reg: 0.0162 (0.0185)  loss_objectness: 0.0627 (0.1076)  loss_rpn_box_reg: 0.3030 (0.8174)  time: 0.4110  data: 0.0126  max mem: 6359
Epoch: [0]  [  3/258]  eta: 0:01:39  lr: 0.000083  loss: 0.3906 (1.0173)  loss_classifier: 0.0086 (0.0082)  loss_box_reg: 0.0162 (0.0207)  loss_objectness: 0.0627 (0.1030)  loss_rpn_box_reg: 0.3030 (0.8854)  time: 0.3899  data: 0.0108  max mem: 6359
