In [1]:
%%capture
! pip install transformers wandb

In [2]:
! wandb login

[34m[1mwandb[0m: Currently logged in as: [33mrayanren[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
import wandb

In [4]:
wandb.login()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrayanren[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
from typing import List, Optional
import urllib.request
from tqdm.auto import tqdm
from pathlib import Path
import requests
import torch
import math
import numpy as np
import os
import glob


def get_quickdraw_class_names():
    """
    TODO - Check performance w/ gsutil in colab. The following command downloads all files to ./data
    `gsutil cp gs://quickdraw_dataset/full/numpy_bitmap/* ./data`
    """
    url = "https://raw.githubusercontent.com/googlecreativelab/quickdraw-dataset/master/categories.txt"
    r = requests.get(url)
    classes = [x.replace(' ', '_') for x in r.text.splitlines()]
    return classes


def download_quickdraw_dataset(root="./data", limit: Optional[int] = None, class_names: List[str]=None):
    if class_names is None:
        class_names = get_quickdraw_class_names()

    root = Path(root)
    root.mkdir(exist_ok=True, parents=True)
    url = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'

    print("Downloading Quickdraw Dataset...")
    for class_name in tqdm(class_names[:limit]):
        fpath = root / f"{class_name}.npy"
        if not fpath.exists():
            urllib.request.urlretrieve(f"{url}{class_name.replace('_', '%20')}.npy", fpath)


def load_quickdraw_data(root="./data", max_items_per_class=5000):
    all_files = Path(root).glob('*.npy')

    x = np.empty([0, 784], dtype=np.uint8)
    y = np.empty([0], dtype=np.long)
    class_names = []

    print(f"Loading {max_items_per_class} examples for each class from the Quickdraw Dataset...")
    for idx, file in enumerate(tqdm(sorted(all_files))):
        data = np.load(file, mmap_mode='r')
        data = data[0: max_items_per_class, :]
        labels = np.full(data.shape[0], idx)
        x = np.concatenate((x, data), axis=0)
        y = np.append(y, labels)

        class_names.append(file.stem)

    return x, y, class_names


class QuickDrawDataset(torch.utils.data.Dataset):
    def __init__(self, root, max_items_per_class=5000, class_limit=None):
        super().__init__()
        self.root = root
        self.max_items_per_class = max_items_per_class
        self.class_limit = class_limit

        download_quickdraw_dataset(self.root, self.class_limit)
        self.X, self.Y, self.classes = load_quickdraw_data(self.root, self.max_items_per_class)

    def __getitem__(self, idx):
        x = (self.X[idx] / 255.).astype(np.float32).reshape(1, 28, 28)
        y = self.Y[idx]

        return torch.from_numpy(x), y.item()

    def __len__(self):
        return len(self.X)

    def collate_fn(self, batch):
        x = torch.stack([item[0] for item in batch])
        y = torch.LongTensor([item[1] for item in batch])
        return {'pixel_values': x, 'labels': y}
    
    def split(self, pct=0.1):
        num_classes = len(self.classes)
        indices = torch.randperm(len(self)).tolist()
        n_val = math.floor(len(indices) * pct)
        train_ds = torch.utils.data.Subset(self, indices[:-n_val])
        val_ds = torch.utils.data.Subset(self, indices[-n_val:])
        return train_ds, val_ds

In [6]:
import torch
from transformers import Trainer
from transformers.modeling_utils import ModelOutput


class QuickDrawTrainer(Trainer):

    def compute_loss(self, model, inputs, return_outputs=False):
        logits = model(inputs["pixel_values"])
        labels = inputs.get("labels")

        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)

        return (loss, ModelOutput(logits=logits, loss=loss)) if return_outputs else loss

# Taken from timm - https://github.com/rwightman/pytorch-image-models/blob/master/timm/utils/metrics.py
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    maxk = min(max(topk), output.size()[1])
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))
    return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]


def quickdraw_compute_metrics(p):
    acc1, acc5 = accuracy(
        torch.tensor(p.predictions),
        torch.tensor(p.label_ids), topk=(1, 5)
    )
    return {'acc1': acc1, 'acc5': acc5}

In [7]:
device = "cuda" if torch.cuda.is_available else "cpu"

In [8]:
import torch
from torch import nn
from transformers import TrainingArguments
from datetime import datetime

data_dir = './data'
max_examples_per_class = 20000
train_val_split_pct = .1

ds = QuickDrawDataset(data_dir, max_examples_per_class)
num_classes = len(ds.classes)
train_ds, val_ds = ds.split(train_val_split_pct)

Downloading Quickdraw Dataset...


  0%|          | 0/345 [00:00<?, ?it/s]

Loading 20000 examples for each class from the Quickdraw Dataset...


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  y = np.empty([0], dtype=np.long)


  0%|          | 0/345 [00:00<?, ?it/s]

ValueError: ignored

In [None]:
model = nn.Sequential(
    nn.Conv2d(1, 32, 3, 1),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(32, 64, 3, 1),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(64, 128, 3, 1),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(128, num_classes)
).to(device)

In [None]:
timestamp = datetime.now().strftime('%Y-%m-%d-%H%M%S')
training_args = TrainingArguments(
    output_dir=f'./outputs_20k_{timestamp}',
    evaluation_strategy='epoch',
    save_strategy='epoch',
    report_to=['wandb', 'tensorboard'],  # Update to just tensorboard if not using wandb
    logging_strategy='steps',
    logging_steps=100,
    per_device_train_batch_size=256,
    per_device_eval_batch_size=256,
    learning_rate=0.003,
    fp16=torch.cuda.is_available(),
    num_train_epochs=20,
    run_name=f"quickdraw-med-{timestamp}",  # Can remove if not using wandb
    warmup_steps=10000,
    save_total_limit=5,
)

trainer = QuickDrawTrainer(
    model,
    training_args,
    data_collator=ds.collate_fn,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=None,
    compute_metrics=quickdraw_compute_metrics,
)

# Training
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

# Evaluation
eval_results = trainer.evaluate()
trainer.log_metrics("eval", eval_results)
trainer.save_metrics("eval", eval_results)

In [None]:
def predict(img):
  img = torch.tensor(img).unsqueeze(0).unsqueeze(0) / 255
  with torch.no_grad():
    logit = model(img)
  probs = torch.softmax(logit[0], 0)
  values, indeces = torch.topk(probs, 5)
  confidences = {get_quickdraw_class_names()[i]: val for i, val in zip(indeces, values)}
  return confidences
#!pip install gradio
import gradio as gr
gr.Interface(fn=predict, inputs="sketchpad", outputs="label", live=True).launch()