In this tutorial we will train a simple object detector for one class. We'd like to have a system which works on low-end devices in real-time. The provided code is simplified to show key steps in model creation, training and exporting to Lens Studio. If you'd like to obtain better precision then the first steps you could try are training for longer time, using higher resolution, increasing the 'width_mult' parameter in config (the model will become slower and bigger in this case so be careful with it).

Current model configuration allows to achieve the lens execution time of ~27 msec on Iphone 6 with the model inference time of ~14-17 msec. The 2 times increase of input resolution for one side of the image increases the model inference time two times.

If you'd like to use your own training pipeline and custom architecture then make sure that you have outputs of the model in the same format as the provided model. The example of conversion to onnx can be found in the last paragraph.



---



---



# Install libraries

First, we need to prepare our work environment and install the necessary Python packages. If you're using Google Colab and get error message "ERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible." then simply ignore it.

We added strict version requirements for the packages for better reproducibility. Note that these versions of packages will replace already installed ones.

### *You need to restart the kernel to use the installed packages.* (in Colab: Runtime->Restart runtime...)

In [None]:
%pip install -q numpy opencv-python-headless \
    torch torchvision xarray-einstats \
    albumentations tqdm

# Imports

Google Colab already has these packages installed but you might need to download some of them for local execution

In [None]:
import cv2
import json
import itertools
import numpy as np
import os
import shutil
import urllib.request
import pandas as pd

from collections import defaultdict
from pathlib import Path
from tqdm.notebook import tqdm
from zipfile import ZipFile

from matplotlib import pyplot as plt
from PIL import Image

import albumentations as albu
import albumentations.augmentations.functional as AF

import torch
import torch.onnx as onnx
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader

import torchvision
from torchvision.models.mobilenet import ConvBNReLU
from torchvision.ops import box_iou, nms

# Set random seeds for the libraries to obtain reproducible results
RANDOM_SEED = 1337
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# Flip values for slower training speed, but more determenistic results.
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

It is recommened to train on GPU but you can run the model on CPU also

In [None]:
DEVICE = torch.device('cpu')
if torch.cuda.is_available():
    DEVICE = torch.device('cuda:0')
    torch.cuda.manual_seed(RANDOM_SEED)
DEVICE


# Global variables for training
This training notebook uses COCO dataset: http://cocodataset.org/ The annotations in this dataset belong to the COCO Consortium and are licensed under a Creative Commons Attribution 4.0 License. http://cocodataset.org/#termsofuse Images are part of flickr and have corresponding licenses. To check license for each image please refer to the contents of http://images.cocodataset.org/annotations/image_info_unlabeled2017.zip

# 1. Available COCO classes
You can check available data here http://cocodataset.org/#explore

COCO has 80 categories:

person, bicycle, car, motorcycle, airplane, bus, train, truck, boat, traffic light, fire hydrant, stop sign, parking meter, bench, bird, cat, dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella, handbag, tie, suitcase, frisbee, skis, snowboard, sports ball, kite, baseball bat, baseball glove, skateboard, surfboard, tennis racket, bottle, wine glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange, broccoli, carrot, hot dog, pizza, donut, cake, chair, couch, potted plant, bed, dining table, toilet, tv, laptop, mouse, remote, keyboard, cell phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy bear, hair drier, toothbrush

In [None]:
# Following classes will be united into single category
OBJECT_LABELS_UNION = ['car', 'truck', 'bus']
# You can create your own categories like these:
# OBJECT_LABELS_UNION = ['cat', 'dog']
# OBJECT_LABELS_UNION = ['bird']

In [None]:
DATASET_PATH = Path('.')  # Path to the dataset
DIR_TO_SAVE_RESULTS = Path('centernet_model')  # Model snapshots will be saved here
os.makedirs(DIR_TO_SAVE_RESULTS, exist_ok=True)

In [None]:
# This number can be different
# The bigger number of epochs, longer training time, but better model quality
# NUM_EPOCHS = 40 # => faster training 
# NUM_EPOCHS = 70

NUM_EPOCHS = 100 # => better quality 

Advanced constants

In [None]:
BATCH_SIZE = 32
LEARNING_RATE = 1e-3
NUM_WORKERS = 4  # Number of workers used in PyTorch dataloader

INPUT_SIZE = (128, 256) # width & height
FEATURE_MAP_SIZE_RATIO = 8 # Ratio between output&input size (e.g. 128 / 16 = 8)

# Larger model might yield better results but will be slower.
# width_mult is a coefficient which defines how many filters
# to use from original mobilenet:
MOBILENET_WIDTH_MULTIPLIER = 0.3

# 2. Download dataset (it might take some time)

http://cocodataset.org/#home

If you encounter message "Disk is almost full" on Google Colab then press "ignore" button. There is enough space to extract the archive and it will be deleted afterwards. 

In [None]:
def download_and_unpack_file(link, filename, unpack=True):
    """ Download and unpack dataset's annotation files """
    if (DATASET_PATH / filename).exists():
        print("{} already exists".format(filename))
        return
    archname = link.split('/')[-1]
    progress_bar = tqdm(desc=filename,
                        dynamic_ncols=True, leave=False,
                        mininterval=5, maxinterval=30,
                        unit='KiB', unit_scale=True,
                        unit_divisor=1024)
    def update_progress(count, block_size, total_size):
        if progress_bar.total is None:
            progress_bar.reset(total_size)
        progress_bar.update(count * block_size - progress_bar.n)
    urllib.request.urlretrieve(link, archname, reporthook=update_progress)
    urllib.request.urlcleanup()
    progress_bar.close()
    if unpack:
        print("Unpacking the archive...")
        shutil.unpack_archive(archname, DATASET_PATH)
        os.remove(archname)
        print("Successfuly downloaded and extracted archive")

In [None]:
os.listdir("./")

In [None]:
# Train set is too big for colab so we will extract only files used in training. You can unpack whole dataset locally if you have enough free space
download_and_unpack_file('http://images.cocodataset.org/zips/train2017.zip', 'train2017', unpack=False)

In [None]:
download_and_unpack_file('http://images.cocodataset.org/annotations/annotations_trainval2017.zip', 'annotations')

In [None]:
download_and_unpack_file('http://images.cocodataset.org/zips/val2017.zip', 'val2017')

# 3. Data class

It provides the interface for image and annotations loading. It should be compatible with PyTorch's Dataloader class. The annotations are loaded into memory from json files while images are loaded from disc each time __getitem__() is called.

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, phase, box_coder=None, transform=None):
        super().__init__()
        self.phase = phase
        self.class_names = OBJECT_LABELS_UNION

        self.dataset = torchvision.datasets.CocoDetection(
            DATASET_PATH / (phase + '2017'),
            annFile=DATASET_PATH / ('annotations/instances_'+phase+'2017.json'))
        self.transform = transform
        self.box_coder = box_coder

        categories = self.dataset.coco.getCatIds()
        categories = self.dataset.coco.loadCats(categories)
        self.all_possible_classes = [category['name'] for category in categories]
        self.negative_classes = set(self.all_possible_classes) - set(self.class_names)

        self.filter_dataset()

        if self.phase == "train":
            self.extract_images()

    def filter_dataset(self):
        """ Leave only classes specified in the OBJECT_LABELS_UNION """
        filter_categories = self.dataset.coco.getCatIds(catNms=self.class_names)
        
        min_area = 500.
        ann_ids = self.dataset.coco.getAnnIds(
            catIds=filter_categories, areaRng=[min_area, float('inf')],
            iscrowd=False)
        im_ids = {self.dataset.coco.anns[ann_idx]['image_id'] for ann_idx in ann_ids}

        if self.phase == "train":  # Use some part of remaining data to reduce the number of false positives
            num_false_positives = len(im_ids) // 20 # 5 percent 
            cat_ids = self.dataset.coco.getCatIds(self.negative_classes)
            for cls_id, neg_class in zip(cat_ids, self.negative_classes):
                neg_subset = set(self.dataset.coco.getAnnIds(catIds=[cls_id], areaRng=[min_area, float('inf')], iscrowd=False))
                neg_im_ids = {self.dataset.coco.anns[ann_idx]['image_id'] for ann_idx in neg_subset}
                neg_im_count = num_false_positives // len(self.negative_classes)
                if neg_class == 'person':
                    neg_im_count *= 10
                neg_im_count = min(neg_im_count, len(neg_im_ids - im_ids))
                neg_subset = list(neg_im_ids - im_ids)[:neg_im_count]
                im_ids.update(neg_subset)

        im_ids = list(im_ids)
        cat_ids = self.dataset.coco.getCatIds(catIds=filter_categories)

        self.dataset.ids = sorted(im_ids)
        self.dataset.coco.anns = {i: self.dataset.coco.anns[i] for i in ann_ids}
        self.dataset.coco.imgs = {i: self.dataset.coco.imgs[i] for i in im_ids}
        self.dataset.coco.cats = {i: self.dataset.coco.cats[i] for i in cat_ids}
        imgToAnns, catToImgs = defaultdict(list), defaultdict(list)

        for k, ann in self.dataset.coco.anns.items():
            imgToAnns[ann['image_id']].append(ann)
            catToImgs[ann['category_id']].append(ann['image_id'])
        self.dataset.coco.imgToAnns = imgToAnns
        self.dataset.coco.catToImgs = catToImgs

    def extract_images(self):
        """ Extract images which will be used """
        im_paths = []
        for im_id in self.dataset.ids:
            im_paths.append(self.dataset.coco.loadImgs(im_id)[0]['file_name'])

        if not os.path.isdir('./train2017'):
            os.mkdir('./train2017')

        with ZipFile('./train2017.zip', 'r') as archive:
            for image in tqdm(im_paths, dynamic_ncols=True, leave=False):
                archive.extract('train2017/' + image, './')
        return

    def parse_annotation(self, annotation):
        bboxes = []
        for anno in annotation:
            # Filter boxes with 0 area
            if (anno['bbox'][2] < 1) or (anno['bbox'][3] < 1) or (anno['iscrowd'] and self.phase != 'val'):
                continue
            # Boxes in form x_left, y_top, w, h
            bboxes.append(anno['bbox'])
        
        return {'bboxes': bboxes, 'labels': [1] * len(bboxes)}

    def __getitem__(self, index):
        image, annotation = self.dataset[index]
    
        annotations = self.parse_annotation(annotation)
        annotations['image'] = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
        if self.transform:
            annotations = self.transform(annotations)
       
        if self.box_coder is None:
            return annotations

        annotations['bboxes'] = torch.tensor(annotations['bboxes'])
        encoded_bboxes, encoded_labels = self.box_coder.encode(annotations['bboxes'])

        annotations['bboxes'] = encoded_bboxes
        annotations['labels'] = encoded_labels
        return annotations

    def __len__(self):
        return len(self.dataset)

# 4. Define data preprocessing and augmentations.

For training we apply different noisy transformations as augmentations. Bounding boxes are adjusted accordingly.

Lens Studio feeds the network camera input as images with values in range [0, 255]. So we will train the network with this input range without additional rescaling and normalization. You can add your own input normalization but make sure that you rescaled the weight of the first network's layer before exporting to onxx.

It might be useful to tune min_area parameter in albu.BboxParams() if you are working with your own dataset.


In [None]:
def train_transform(annotations):
    image = annotations['image']
    
    size = (INPUT_SIZE[1], INPUT_SIZE[0]) # height, width
    scale = min(size[0] / image.shape[0], size[1] / image.shape[1])
    intermediate_size = int(image.shape[0] * scale), int(image.shape[1] * scale)
    augmentation = albu.Compose(
        [
            albu.RandomSizedBBoxSafeCrop(*intermediate_size),
            albu.HorizontalFlip(p=0.5),
            albu.HueSaturationValue(p=0.5),
            albu.RGBShift(p=0.5),
            albu.RandomBrightnessContrast(p=0.5),
            albu.MotionBlur(p=0.5),
            albu.PadIfNeeded(*size)
        ],
        albu.BboxParams(format='coco', min_area=500.,
                        min_visibility=0.3, label_fields=['labels'])
    )

    augmented = augmentation(**annotations)
    augmented['image'] = augmented['image'].astype(
        np.float32).transpose(2, 0, 1)
    return augmented

def validation_transform(annotations, with_bboxes=True):
    bbox_params = None
    if with_bboxes:
        bbox_params = albu.BboxParams(format='coco', min_area=500.,
                        min_visibility=0.3, label_fields=['labels'])
    
    image = annotations['image']
    size = (INPUT_SIZE[1], INPUT_SIZE[0])
    scale = min(size[0] / image.shape[0], size[1] / image.shape[1])
    intermediate_size = [int(dim * scale) for dim in image.shape[:2]]
    
    augmentation = albu.Compose(
        [
            albu.Resize(*intermediate_size),
            albu.PadIfNeeded(*size)
        ],
        bbox_params
    )

    augmented = augmentation(**annotations)
    augmented['image'] = augmented['image'].astype(
        np.float32).transpose(2, 0, 1)
    augmented['scale'] = scale

    augmented['in_size'] = image.shape[:2]
    augmented['out_size'] = size
    augmented['intermediate_size'] = intermediate_size
    return augmented

# 5. Define box coder

Box coder transforms annotations to the format suitable for training and allows to decode the ouputs of trained model. The ground truth bounding boxes should be trasformed into heatmaps compatible with the network ouputs.

In [None]:
class BoxCoder:
    def __init__(self, image_size, ratio):
        self.image_size = image_size
        self.fw, self.fh = (i // ratio for i in image_size)
        self.iw, self.ih = self.image_size

    def encode(self, boxes):
        """Transforms the ground truth annotations into form of heatmaps
        suitable for training of the object detector"""
        cls_targets = torch.zeros((self.fh, self.fw))
        loc_targets = torch.zeros((self.fh, self.fw, 4))

        boxes_locations = []
        if boxes.numel() > 0:
            boxes[:, 2:] = boxes[:, :2] + boxes[:, 2:]
            boxes /= torch.tensor([self.iw, self.ih, self.iw, self.ih])
            box_center_xy = (boxes[:, :2] + boxes[:, 2:]) / 2
            box_wh = (boxes[:, 2:] - boxes[:, :2])

            mask = (box_center_xy[:, 0] >= 0) & (box_center_xy[:, 1] >= 0)
            mask &= (box_center_xy[:, 0] < 1) & (box_center_xy[:, 1] < 1)

            box_center_xy, box_wh = box_center_xy[mask], box_wh[mask]

            for i, (xy, wh) in enumerate(zip(box_center_xy, box_wh)):
                (x, y), (w, h) = xy, wh

                ix, iy = int(x * self.fw), int(y * self.fh)
                cx, cy = (ix + 0.5) / self.fw, (iy + 0.5) / self.fh

                cls_targets[iy, ix] = 1
                loc_targets[iy, ix] = torch.tensor([x - cx, y - cy, w, h])
                boxes_locations.append((iy, ix))
            # Activations are Gaussian-like curves
            for iy, ix in boxes_locations:
                for dx, dy in itertools.product(range(-1, 2), range(-1, 2)):
                    if dx == dy == 0:
                        continue
                    nx = ix + dx
                    ny = iy + dy
                    if not 0 <= ny < self.fh or not 0 <= nx < self.fw:
                        continue
                    if cls_targets[ny, nx] == 0:
                        cls_targets[ny, nx] = max(cls_targets[ny, nx],
                                                  np.exp(-(abs(dx) + abs(dy))))

        return loc_targets, cls_targets

    def decode(self, loc_preds, cls_preds, score_thresh=0.5, nms_thresh=0.45,
               normalized=False, max_detections=200):
        """
        Decode predicted loc/cls back to real box locations and class labels
        """
        boxes = []
        labels = []
        scores = []

        cls_preds_thresh = cls_preds > score_thresh

        for x, y in itertools.product(range(self.fw), range(self.fh)):
            if not cls_preds_thresh[y, x]:
                continue

            box_params = loc_preds[y, x]
            cx = (x + 0.5) / self.fw + box_params[0]
            cy = (y + 0.5) / self.fh + box_params[1]
            bw, bh = box_params[2:] * 0.5

            boxes.append([cx - bw, cy - bh, cx + bw, cy + bh])
            labels.append(1)
            scores.append(cls_preds[y, x])

        boxes = torch.tensor(boxes)
        labels = torch.tensor(labels)
        scores = torch.tensor(scores)
        if boxes.numel() > 0:
            if not normalized:
                boxes *= torch.tensor([self.iw, self.ih, self.iw, self.ih])
            
            keep = nms(boxes, scores, nms_thresh)[:max_detections]
            boxes = boxes[keep]
            labels = labels[keep]
            scores = scores[keep]
            
            boxes[:, 2:] = boxes[:, 2:] - boxes[:, :2]
        return boxes, labels, scores

# 6. Model definition
This model is based on Mobilenet V2 https://arxiv.org/abs/1801.04381. This is a good starting place for general object detection model, however there were a range of new architectures optimized for mobile inference, including Mobilenet V3 https://arxiv.org/abs/1905.02244.

This model runs on average ~17ms on iPhone 6.

Our model will use pretrained weights of Mobilenet V2, however these weights assume the input to be RGB in range [0, 1] and input should be normalized. We will disregard this and just use pretrained weights as good initialization for our network. You can play around and check if using the network from scratch without pretrained weights helps you achieve better quality.

# Modifications:
* Relu6 is replaced with Relu. Make sure that you do the same for your custom architectures because CoreML doesn't support Relu6.
* Last 2 blocks are removed.

The detector is based on CenterNet approach (Objects as Points, https://arxiv.org/pdf/1904.07850.pdf). It has two heads: classification (heatmap with sigmoid activation) and location.

# Important point regarding input and output ranges
Lens studio feeds the network camera input as RGB images with values in range [0, 255]. 


In [None]:
def convert_layers(model):
    """ Convert relu6 to relu for faster inference in libdnn """
    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            model._modules[name] = convert_layers(model=module,)
        if isinstance(module, nn.ReLU6):
            model._modules[name] = nn.ReLU()
    return model

def _xavier_init_(m: nn.Module):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight)

class Detector(nn.Module):
    def __init__(self, width_mult, box_transformer=None, test_transform=None):
        super().__init__()
        self.box_transformer = box_transformer
        self.test_transform = test_transform
        self.backbone = self.get_backbone(width_mult)

        neck_dim = 160
        self.neck = ConvBNReLU(self.backbone[-1].conv[-1].num_features,
                               neck_dim, kernel_size=1)

        self.smooth = nn.Conv2d(neck_dim, neck_dim, kernel_size=3, stride=1,
                                padding=1, groups=neck_dim, bias=True)

        self.cls_scores_out = nn.Sequential(
            nn.Conv2d(neck_dim, 1, kernel_size=3, padding=1, bias=True),
            nn.Sigmoid()
        )
        self.loc_out = nn.Conv2d(neck_dim, 4, kernel_size=3,
                                 padding=1, bias=True)
        self.to_convert = False

        self.neck.apply(_xavier_init_)
        self.cls_scores_out.apply(_xavier_init_)
        self.loc_out.apply(_xavier_init_)

        self = convert_layers(self)

    def get_backbone(self, width_mult):
        """
        Using mobilenet_v2 pretrained on imagenet from torchvision library
        https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenet.py
        """
        model = torchvision.models.mobilenet_v2(width_mult=width_mult)
        state_dict = torchvision.models.mobilenet_v2(pretrained=True).state_dict()

        if width_mult != 1:
            target_dict = model.state_dict()
            for k in target_dict.keys():
                if len(target_dict[k].size()) == 0:
                    continue
                state_dict[k] = state_dict[k][:target_dict[k].size(0)]
                if len(state_dict[k].size()) > 1:
                    state_dict[k] = state_dict[k][:, :target_dict[k].size(1)]

        model.load_state_dict(state_dict)

        return nn.Sequential(*(list(model.features.children())[:14]))

    def forward(self, x):
        x = self.backbone(x)
        x = self.neck(x)
        x = self.smooth(F.interpolate(x, scale_factor=2))

        cls_scores = self.cls_scores_out(x)
        if self.to_convert:
            return self.loc_out(x), cls_scores

        cls_scores = torch.clamp(cls_scores, min=1e-4, max=1-1e-4)

        return self.loc_out(x).permute((0, 2, 3, 1)), cls_scores.squeeze()

    def set_conversion_mode(self, to_convert=False):
        """ Set True for export to onnx in libdnn compatible format.
        The reshape operation in heads might work incorrectly in lens studio
        so it is ommited in the onnx graph.
        """
        self.to_convert = to_convert

    def load(self, model):
        self.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage))

    def save(self, model_path):
        torch.save(self.state_dict(), model_path)

    def rescale_boxes(self, boxes, out_size, intermediate_size, scale):
        """ Removes padding shift and rescales bounding boxes to original
        input image size """
        outh, outw = out_size
        rh, rw = intermediate_size

        if rh < outh:
            boxes[:, 1] -= ((outh - rh) / 2)
            boxes[:, 1] = torch.clamp(boxes[:, 1], min=0, max=outh-1)
        if rw < outw:
            boxes[:, 0] -= ((outw - rw) / 2)
            boxes[:, 0] = torch.clamp(boxes[:, 0], min=0, max=outw-1)

        boxes /= scale
        
        return boxes

    def predict(self, sample, score_threshold=0.5, nms_thresh=0.45):
        """ sample (dict) {"image": cv2 BGR image} """
        sample = self.test_transform(sample, with_bboxes=False)
        images = torch.tensor(sample["image"]).unsqueeze(0)

        with torch.no_grad():
            boxes, cls = self.forward(images)

        boxes, labels, probs = self.box_transformer.decode(boxes[0], cls, score_threshold, nms_thresh)
        if len(boxes) == 0:
            return torch.tensor([]), torch.tensor([]), torch.tensor([])

        boxes = self.rescale_boxes(boxes, sample['out_size'], sample['intermediate_size'], sample['scale'])

        return boxes, labels, probs

# 7. Loss function

It consists of two parts: localization (model estimates coordinates of objects' bounding boxes) and classification (predicts whether there is any object in the given region on heatmap)

In [None]:
class Loss(nn.Module):
    """ Focal loss is used for classification and L1-loss for regression. """
    def __init__(self):
        super(Loss, self).__init__()

    def cls_loss(self, pred, target, neg_weights=4, pos_weights=2):
        pos_mask = (target == 1).float()
        neg_mask = (target < 1).float()
        pos_loss = -torch.log(pred) * torch.pow(1 - pred, pos_weights) * pos_mask
        neg_loss = -torch.log(1 - pred) * torch.pow(pred, pos_weights) * torch.pow(1 - target, neg_weights) * neg_mask
        num_pos = max(pos_mask.float().sum(), 1)
        return (pos_loss.sum() + neg_loss.sum()) / num_pos

    def forward(self, loc_preds, loc_targets, cls_preds, cls_targets):
        batch_size = loc_preds.size(0)
        cls_targets = cls_targets.view(batch_size, -1)
        loc_targets = loc_targets.view(batch_size, -1, 4)
        cls_preds = cls_preds.view(batch_size, -1)
        loc_preds = loc_preds.view(batch_size, -1, 4)

        pos = cls_targets > 0.99  # 0.5 in target means negative anchor
        mask = pos.unsqueeze(2).expand_as(loc_preds)
        
        loc_loss = F.l1_loss(loc_preds[mask], loc_targets[mask], reduction='sum')
        loc_loss /= pos.sum().item() + 1e-5

        cls_loss = self.cls_loss(cls_preds, cls_targets)

        loss = loc_loss + cls_loss
        return loss, loc_loss, cls_loss

# 8. Create dataloaders

PyTorch's dataloader allows to load multiple images in parallel processes for faster training.

In [None]:
box_coder = BoxCoder(INPUT_SIZE, FEATURE_MAP_SIZE_RATIO)

In [None]:
val_dataset = Dataset('val', box_coder, validation_transform)
train_dataset = Dataset('train', box_coder, train_transform)

Create dataloaders for parallel loading of images

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    worker_init_fn=lambda _: np.random.seed(),
    drop_last=True
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE // 4,
    shuffle=False,
    num_workers=NUM_WORKERS,
    drop_last=False
)

# 9. Create and train the network

Model snapshots will be saved each epoch to the DIR_TO_SAVE_RESULTS directory

In [None]:
model = Detector(MOBILENET_WIDTH_MULTIPLIER, box_transformer=box_coder,
                 test_transform=validation_transform)

We'll also set up learning rate scheduler to drop learning rate if our network training platoes.

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE,
                             weight_decay=1e-4)
model.to(DEVICE)  # Move model to the device selected for training

criterion = Loss()
print(f"Learning rate: {LEARNING_RATE}")
print("Using ReduceLROnPlateau scheduler.")
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 'min', factor=0.5, patience=5, verbose=True)

In [None]:
print("Device used for training:", DEVICE)

Define train, validation functions and train the network

In [None]:
def train(loader, net, criterion, optimizer, device):
    net.train(True)
    total_loss = 0.
    total_regression_loss = 0.
    total_classification_loss = 0.

    progress_bar = tqdm(enumerate(loader), total=len(loader),
                        dynamic_ncols=True, leave=False)

    for i, data in progress_bar:
        images = data['image'].to(device)
        boxes = data['bboxes'].to(device)
        labels = data['labels'].to(device)

        optimizer.zero_grad()
        locs, cls = net(images)
        loss, regression_loss, classification_loss = criterion(locs, boxes,
                                                               cls, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_regression_loss += regression_loss.item()
        total_classification_loss += classification_loss.item()
        
        progress_bar.set_description("loss {:.4f},  regression loss {:.4f}, classification loss {:.4f}".format(
            total_loss / (i+1), total_regression_loss / (i+1), total_classification_loss / (i+1)))

def validate(loader, net, criterion, device):
    net.eval()
    total_loss = 0.
    total_regression_loss = 0.
    total_classification_loss = 0.
    loader_len = len(loader)
    for _, data in enumerate(loader):
        images = data['image'].to(device)
        boxes = data['bboxes'].to(device)
        labels = data['labels'].to(device)

        with torch.no_grad():
            locs, cls = net(images)
            loss, regression_loss, classification_loss = criterion(locs, boxes, cls, labels)

        total_loss += loss.item()
        total_regression_loss += regression_loss.item()
        total_classification_loss += classification_loss.item()
    return total_loss / loader_len, total_regression_loss / loader_len, total_classification_loss / loader_len
# Start training
for epoch in range(NUM_EPOCHS):
    train(train_loader, model, criterion, optimizer,
          device=DEVICE)
    
    val_loss, val_regression_loss, val_class_loss = validate(val_loader, model,
                                                             criterion, DEVICE)
    print("Epoch: {}, val loss {:.4f}, regression loss {:.4f}, classification loss {:.4f}".format(
          epoch, val_loss, val_regression_loss, val_class_loss))

    scheduler.step(val_loss)
    model_path = DIR_TO_SAVE_RESULTS / f"e-{epoch}-{val_loss:.3f}.pth"
    model.save(model_path)
    print(f"Saved model {model_path}")

# 10. Check out the predictions of trained detector

Set 'eval' mode for test inference

In [None]:
# model.load(DIR_TO_SAVE_RESULTS / 'e-5-1.583.pth')  # Use to load saved model snapshot
model.eval()
model = model.to('cpu')

Use test image 'car_test_image.png'. Make sure you've uploaded it to the google colab environment. Try changing the score_threshold in model.predict() to see how the model predicts less confident detections.

In [None]:
sample = {'image': cv2.imread('./car_test_image.png')}  # Load the image in BGR format

boxes, labels, probs = model.predict(sample, score_threshold=0.4)
for i in range(boxes.size(0)):
    box = boxes[i, :]
    label = f"{probs[i]:.2f}"
    cv2.rectangle(sample['image'], (box[0], box[1]),
                  (box[0] + box[2], box[1] + box[3]), (255, 255, 0), 4)
    cv2.putText(sample['image'], label, (box[0] + 20, box[1] + 40),
                cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 255),  2)
plt.figure(figsize=(10, 10))
plt.axis('off')
plt.imshow(cv2.cvtColor(sample['image'], cv2.COLOR_BGR2RGB))

# 11. Export the model to onnx for furher usage in Lens Studio
'det.onnx' file will be created in the DIR_TO_SAVE_RESULTS.


LensStudio sends RGB image to the network input with values in range [0, 255]. Our network is already trained for that, but if you trained the network with other input range you might need to adjust it when you import your ONNX file.

BatchNorm layers will be fused with convolution layers in studio so there is no need to do it in pytorch.

In [None]:
onnx_model_path = DIR_TO_SAVE_RESULTS / 'det.onnx'
dummy_input = torch.ones(1, 3, INPUT_SIZE[1], INPUT_SIZE[0],
                         dtype=torch.float32)

In [None]:
# model.load(DIR_TO_SAVE_RESULTS / 'e-75-1.091.pth')  # Use to load saved model snapshot
model.to('cpu')

model.set_conversion_mode(to_convert=True)
model = model.eval()

input_names = ['data']
output_names = ['loc', 'cls']

onnx.export(model, dummy_input, onnx_model_path, verbose=False,
            input_names=input_names, output_names=output_names,
            keep_initializers_as_inputs=True, opset_version=11)

model.set_conversion_mode(to_convert=False)
print("Successfully saved model as {}".format(onnx_model_path))