## Install required libraries

In [None]:
!pip install hub
!pip install matplotlib
!pip install torch

## Imports

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
from torch.utils.data import ConcatDataset
import torch
from torch.utils.data import random_split
import hub

In [None]:
!hub login

## Load the dataset

In [None]:
ds = hub.load("activeloop/activeloop/coco_train")

## Define a model

In [None]:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torchvision


def get_model_instance_segmentation(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
num_classes = 92
num_epochs = 10
model = get_model_instance_segmentation(num_classes)
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)


def train(data_loader):
    len_dataloader = len(data_loader)

    for epoch in range(num_epochs):
        model.train()
        i = 0    
        print("Start training")
        for imgs, annotations in data_loader:
            i += 1
            imgs = list(img.to(device) for img in imgs if img is not None)
            annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations if t is not None]
            loss_dict = model(imgs, annotations)
            losses = sum(loss for loss in loss_dict.values())

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            print(f'Iteration: {i}/{len_dataloader}, Loss: {losses}')


## Convert to PyTorch dataset

In [None]:
def get_transform():
    custom_transforms = []
    custom_transforms.append(torchvision.transforms.ToTensor())
    return torchvision.transforms.Compose(custom_transforms)

In [None]:
import os
import torch
import torch.utils.data
import torchvision

class CocoDataset(torch.utils.data.Dataset):
    def __init__(self, ds, transforms=None):
        self.transforms = transforms
        self.ds = ds
        
    def __getitem__(self, index):
        img = self.ds[index]['image'].compute()
        objs = self.ds[index]['objects'].compute()
        num_objs = len(objs['bbox'])
        boxes = []
        for i in range(num_objs):
            xmin = objs['bbox'][i][0]
            ymin = objs['bbox'][i][1]
            xmax = xmin + objs['bbox'][i][2]
            ymax = ymin + objs['bbox'][i][3]
            boxes.append([xmin, ymin, xmax, ymax])
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        if boxes.shape == torch.Size([0]):
          return None, None
        labels = torch.tensor(objs['label'], dtype=torch.int64)
        areas = []
        for i in range(num_objs):
            areas.append(objs['area'][i])
        areas = torch.as_tensor(areas, dtype=torch.float32)
        iscrowd = torch.tensor(objs['is_crowd'], dtype=torch.int64)

        my_annotation = {}
        my_annotation["boxes"] = boxes
        my_annotation["labels"] = labels
        my_annotation["area"] = areas
        my_annotation["iscrowd"] = iscrowd

        if self.transforms is not None:
            img = self.transforms(img)

        return img, my_annotation

    def __len__(self):
        return len(self.ds)

## Train


In [None]:
torch_ds_train = CocoDataset(ds, transforms=get_transform())

def collate_fn(batch):
    len_batch = len(batch) 
    batch = list(filter (lambda x:x is not None, batch))
    if len_batch > len(batch):
        diff = len_batch - len(batch)
        for i in range(diff):
            batch = batch + batch[:diff]
    return tuple(zip(*batch))
    
train_dataloader = torch.utils.data.DataLoader(
        torch_ds_train,
        batch_size=4,
        shuffle=True,
        collate_fn=collate_fn
    )
train(train_dataloader)
torch.save(model, "/tmp/model_coco.pth")