In [1]:
from datasets import load_dataset
from huggingface_hub import login
import numpy as np
import cv2
import json
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torchvision.transforms.functional as F
from tqdm import tqdm
from IPython.display import display

from PIL import Image, ImageDraw

### Dataset
The HF dataset created using `shapesParser.ipynb` and the QuickDraw dataset is used to train the detector 

In [2]:
login(token = "")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [3]:
hf_dataset = load_dataset("whENbhAI/doodle_512")

Downloading readme:   0%|          | 0.00/417 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/434M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/434M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/434M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/20000 [00:00<?, ? examples/s]

In [4]:
class Loader:
    def __init__(self, dataset, batch_size, collator_fn, train_max = 18000, mode = "train"):
        self.dataset = dataset.shuffle()
        self.collator_fn = collator_fn
        self.len = len(self.dataset)
        self.batch_size = batch_size
        if mode == "train":
            self.index = 0
        else :
            self.index = train_max
        self.train_max = train_max
        self.mode = mode

    def hasNext(self):
        if self.mode == "train":
            return self.index + self.batch_size <= self.train_max
        else :
            return self.index + self.batch.size <= self.len
    
    def reset(self):
        if self.mode == "train":
            self.dataset = self.dataset.shuffle()
            self.index = 0
        else:
            self.index = self.train_max
        
    def __iter__(self):
        return self

    def __next__(self):
        if self.mode == "train":
            if self.index >= self.train_max:
                raise StopIteration
        else :
            if self.index >= self.len:
                raise StopIteration
                
        batch = self.dataset[self.index: self.index + self.batch_size]
        batch = self.collator_fn(batch)
        self.index += self.batch_size
        return batch
    
    def __len__(self):
        if self.mode == "train":
            return self.train_max
        return self.len - self.train_max
    
    def train(self):
        self.mode = "train"
        
    def validate(self):
        self.mode = "validation"

In [5]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


In [6]:
def transform(img):
    img = F.pil_to_tensor(img)
    retimg = img / 255.0
    return retimg

label_map = {
    "circle" : 1, "octagon" : 2, "hexagon" : 3, "star" : 4, "square" : 5, "triangle" : 6, "line" : 7
}
def collator(batch):
    images = [transform(img).to(device) for img in batch["image"]]
    targets = [{
        "boxes" : torch.tensor(shapes["boxes"], dtype = torch.float).to(device), 
        "labels" : torch.tensor([label_map[label] for label in shapes["labels"]], dtype = torch.int64).to(device)} 
        for shapes in batch["shapes"]]
    return (images, targets)

In [7]:
loader = Loader(hf_dataset['train'], 16, collator)

### Model
We are fine-tuning the FasterRCNN model on this 8 class detection problem

In [8]:
def save_model(model, name):
    torch.save(model.state_dict(), name)
    
def initDetector():
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained = True)
    num_classes = 8
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    model = model.to(device)
    return model

def load_model(name, backup = initDetector, frommem = True):
    model = backup()
    if frommem == False:
        print("Initializing from scratch.")
        return model
    try : 
        model.load_state_dict(torch.load(f"{name}"))
        print("Loaded model successfully.")
    except:
        print("Couldn't find model. Initializing from scratch.")
    return model

In [9]:
model = load_model("/kaggle/working/detector_v1.pth")

Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
100%|██████████| 160M/160M [00:02<00:00, 69.4MB/s] 


Loaded model successfully.


In [10]:
optimizer = torch.optim.SGD(model.parameters(), lr = 0.002, momentum = 0.9, weight_decay = 0.0005)
NUM_EPOCHS = 3

In [None]:
for epoch in range(NUM_EPOCHS):
    epoch_loss = 0
    loader.reset()
    with tqdm(total=len(loader), desc="Processing batches", dynamic_ncols=True) as pbar:
        for (images, targets) in loader:
            loss_dict = model(images, targets)
            loss = sum(v for v in loss_dict.values())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.cpu().detach().numpy()
            pbar.update(loader.batch_size)
    save_model(model, "detector_v1.pth")
    print(epoch_loss)

Processing batches: 100%|██████████| 18000/18000 [28:39<00:00, 10.47it/s]


935.2554804086685


Processing batches:  39%|███▉      | 7088/18000 [11:16<17:26, 10.43it/s]

### Evaluation
We train the model until the loss plateaus. The overall detection performance is gauged by eye-balling it

In [None]:
def visualizePerf():
    img = hf_dataset["train"][0]["image"]
    model.eval()
    output = model([transform(img).to(device)])[0]
    bboxes = [(bounds.cpu().detach().numpy(), label.cpu().detach().numpy(), score.cpu().detach().numpy()) for (bounds, label, score) in zip(output["boxes"], output["labels"], output["scores"]) if score >= 0.7]
    print(bboxes)
    draw = ImageDraw.Draw(img)
    for bbox in bboxes:
        draw.rectangle(bbox[0], outline = "red")
        draw.text((bbox[0][0], bbox[0][1]), str(bbox[1]),(0, 0, 0))
    display(img)
    
# visualizePerf()