# Example training setup

In [None]:
%pip install -r requirements.txt

In [1]:
from dataset import QuickDrawDataset
from sketch_tokenizers import DeltaPenPositionTokenizer
from models import SketchTransformerConditional
from runner import SketchTrainer, sample, device
import torch

if device == "cuda":
    # Empty the cache to free up GPU memory
    torch.cuda.empty_cache()

Using device: cuda


In [None]:
label_names = ["bird", "crab", "guitar"]
dataset = QuickDrawDataset(label_names=label_names)
tokenizer = DeltaPenPositionTokenizer(bins=32)

model = SketchTransformerConditional(
    vocab_size=len(tokenizer.vocab),
    d_model=512,
    nhead=8,
    num_layers=8,
    max_len=200,
    num_classes=len(label_names),
)

training_config = {
    "batch_size": 128,
    "num_epochs": 15,
    "learning_rate": 1e-4,
    "log_dir": "logs/sketch_transformer_example",
    "splits": [0.85, 0.1, 0.05],
}

In [None]:
trainer = SketchTrainer(model, dataset, tokenizer, training_config)
trainer.train_mixed(training_config["num_epochs"])

Sample from the trained model

In [None]:
from IPython.display import HTML, display
from prepare_data import stroke_to_bezier_single, clean_svg

generations_inline = ""
generations = []

for j, label_name in enumerate(label_names):
    for i in range(5):
        generated = sample(
            model=trainer.model,
            start_tokens=[trainer.tokenizer.vocab["START"]],
            temperature=0.8,
            top_k=20,
            top_p=0.7,
            greedy=False,
            eos_id=trainer.tokenizer.vocab["END"],
            class_label=j,
        )
        decoded_sketch = trainer.tokenizer.decode(generated, stroke_width=0.3)
        decoded_sketch = stroke_to_bezier_single(decoded_sketch)
        decoded_sketch = clean_svg(decoded_sketch)

        generations_inline += f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Generated {label_name}</b><br>{decoded_sketch}</div>'
        generations.append((generated, decoded_sketch))

display(HTML(generations_inline))

# Using pre-trained models

In [None]:
%pip install gdown

# small model uses different tokenizer
# !gdown 19HdVCrTVS2E7z5cR6BUK8DO-6qgghn1R  -O _site/model_checkpoint_small.pt

# medium model
!gdown 1AnsoHUmOKF1op5vemKKORdj2QVPYMnz5 -O _site/model_checkpoint_medium.pt

In [2]:
# Sample using an existing checkpoint
model = torch.load(
    "_site/model_checkpoint_medium.pt",
    map_location=device,
    weights_only=False,
)

In [3]:
from IPython.display import HTML, display
from sketch_tokenizers import DeltaPenPositionTokenizer

from prepare_data import stroke_to_bezier_single, clean_svg

tokenizer = DeltaPenPositionTokenizer(bins=32)

generations_inline = ""
generations = []

labels = ["cake", "butterfly", "flower", "mug", "sea turtle"]

for j, label_name in enumerate(labels):
    for i in range(4):
        generated = sample(
            model=model,
            start_tokens=[tokenizer.vocab["START"]],
            temperature=0.8,
            top_k=20,
            top_p=0.7,
            greedy=False,
            eos_id=tokenizer.vocab["END"],
            class_label=j,
        )
        decoded_sketch = tokenizer.decode(generated, stroke_width=0.3)
        decoded_sketch = stroke_to_bezier_single(decoded_sketch)
        decoded_sketch = clean_svg(decoded_sketch)

        generations_inline += f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Generated {label_name}</b><br>{decoded_sketch}</div>'
        generations.append((generated, decoded_sketch))

display(HTML(generations_inline))