# Experiment 0:

- Our base sketch transformer architecture
- Single class
- AbsolutePenPositionTokenizer

In [None]:
import os

os.chdir("..")

In [7]:
import torch
from torch.utils.data import DataLoader, random_split
from dataset import QuickDrawDataset, SketchDataset
from tokenizers import AbsolutePenPositionTokenizer
from models import SketchTransformer
from runner import train_model, sample_sequence_feat

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

seed = 42
torch.manual_seed(seed)
if device == "cuda":
    torch.cuda.manual_seed_all(seed)

Using device: cuda


In [None]:
training_data = QuickDrawDataset(label_names=["cat"], download=True)
tokenizer = AbsolutePenPositionTokenizer(bins=32)
dataset = SketchDataset(training_data, tokenizer, max_len=200)

Downloading QuickDrawDataset files: 100%|██████████| 1/1 [00:00<00:00, 5329.48it/s]
Loading QuickDrawDataset: 1it [00:00, 484.00it/s]
Tokenizing dataset: 100%|██████████| 1/1 [00:00<00:00, 1288.97it/s]


In [None]:
splits = (0.8, 0.1, 0.1)
train_size = int(splits[0] * len(dataset))
val_size = int(splits[1] * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size]
)

batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

model = SketchTransformer(
    vocab_size=len(tokenizer.vocab), d_model=256, nhead=8, num_layers=6, max_len=200
)

hparams = {
    "model_class": model.__class__.__name__,
    "num_layers": model.num_layers,
    "d_model": model.d_model,
    "nhead": model.nhead,
    "max_len": model.max_len,
    "tokenizer_class": tokenizer.__class__.__name__,
    "tokenizer_bins": tokenizer.bins,
}

train_model(
    model,
    train_loader,
    val_loader,
    vocab_size=len(tokenizer.vocab),
    epochs=40,
    lr=1e-4,
    device=device,
    hparams=hparams,
    log_dir="logs/sketch_experiments_0",
)

Epoch 1/15 [train]: 100%|██████████| 322/322 [01:07<00:00,  4.74it/s]
Epoch 1/15 [val]: 100%|██████████| 41/41 [00:02<00:00, 14.27it/s]


Epoch 1 | Train Loss: 2.2210 | Val Loss: 2.0867


Epoch 2/15 [train]: 100%|██████████| 322/322 [01:07<00:00,  4.74it/s]
Epoch 2/15 [val]: 100%|██████████| 41/41 [00:02<00:00, 14.25it/s]


Epoch 2 | Train Loss: 1.9830 | Val Loss: 1.7907


Epoch 3/15 [train]: 100%|██████████| 322/322 [01:08<00:00,  4.70it/s]
Epoch 3/15 [val]: 100%|██████████| 41/41 [00:02<00:00, 14.03it/s]


Epoch 3 | Train Loss: 1.6681 | Val Loss: 1.5312


Epoch 4/15 [train]: 100%|██████████| 322/322 [01:09<00:00,  4.67it/s]
Epoch 4/15 [val]: 100%|██████████| 41/41 [00:02<00:00, 13.90it/s]


Epoch 4 | Train Loss: 1.4989 | Val Loss: 1.4109


Epoch 5/15 [train]: 100%|██████████| 322/322 [01:07<00:00,  4.74it/s]
Epoch 5/15 [val]: 100%|██████████| 41/41 [00:02<00:00, 14.33it/s]


Epoch 5 | Train Loss: 1.4140 | Val Loss: 1.3463


Epoch 6/15 [train]: 100%|██████████| 322/322 [01:07<00:00,  4.74it/s]
Epoch 6/15 [val]: 100%|██████████| 41/41 [00:02<00:00, 14.18it/s]


Epoch 6 | Train Loss: 1.3629 | Val Loss: 1.3055


Epoch 7/15 [train]: 100%|██████████| 322/322 [01:08<00:00,  4.69it/s]
Epoch 7/15 [val]: 100%|██████████| 41/41 [00:02<00:00, 14.49it/s]


Epoch 7 | Train Loss: 1.3279 | Val Loss: 1.2754


Epoch 8/15 [train]: 100%|██████████| 322/322 [01:08<00:00,  4.67it/s]
Epoch 8/15 [val]: 100%|██████████| 41/41 [00:02<00:00, 14.20it/s]


Epoch 8 | Train Loss: 1.3017 | Val Loss: 1.2527


Epoch 9/15 [train]: 100%|██████████| 322/322 [01:08<00:00,  4.69it/s]
Epoch 9/15 [val]: 100%|██████████| 41/41 [00:02<00:00, 13.85it/s]


Epoch 9 | Train Loss: 1.2813 | Val Loss: 1.2346


Epoch 10/15 [train]: 100%|██████████| 322/322 [01:08<00:00,  4.70it/s]
Epoch 10/15 [val]: 100%|██████████| 41/41 [00:02<00:00, 14.24it/s]


Epoch 10 | Train Loss: 1.2642 | Val Loss: 1.2203


Epoch 11/15 [train]: 100%|██████████| 322/322 [01:08<00:00,  4.70it/s]
Epoch 11/15 [val]: 100%|██████████| 41/41 [00:02<00:00, 14.15it/s]


Epoch 11 | Train Loss: 1.2497 | Val Loss: 1.2069


Epoch 12/15 [train]: 100%|██████████| 322/322 [01:08<00:00,  4.70it/s]
Epoch 12/15 [val]: 100%|██████████| 41/41 [00:02<00:00, 14.05it/s]


Epoch 12 | Train Loss: 1.2372 | Val Loss: 1.1971


Epoch 13/15 [train]: 100%|██████████| 322/322 [01:08<00:00,  4.68it/s]
Epoch 13/15 [val]: 100%|██████████| 41/41 [00:02<00:00, 13.92it/s]


Epoch 13 | Train Loss: 1.2264 | Val Loss: 1.1878


Epoch 14/15 [train]: 100%|██████████| 322/322 [01:08<00:00,  4.73it/s]
Epoch 14/15 [val]: 100%|██████████| 41/41 [00:02<00:00, 14.01it/s]


Epoch 14 | Train Loss: 1.2168 | Val Loss: 1.1795


Epoch 15/15 [train]: 100%|██████████| 322/322 [01:08<00:00,  4.69it/s]
Epoch 15/15 [val]: 100%|██████████| 41/41 [00:02<00:00, 14.32it/s]


Epoch 15 | Train Loss: 1.2084 | Val Loss: 1.1723


In [None]:
# model = torch.load("sketch_transformer_cat_checkpoint1.pth", map_location=device, weights_only=False)

In [24]:
generations_inline = ""

for i in range(5):
    generated = sample_sequence_feat(
        model,
        start_tokens=[tokenizer.vocab["START"]],
        max_len=200,
        temperature=1.0,
        greedy=False,
        eos_id=tokenizer.vocab["END"],
        device=device,
    )
    decoded_sketch = tokenizer.decode(generated, stroke_width=0.3)
    generations_inline += f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Generated</b><br>{decoded_sketch}</div>'


from IPython.display import HTML, display

display(HTML(generations_inline))

In [13]:
# Example generation

from IPython.display import HTML, display

# Generated token sequence: [4098, 4096, 1420, 1420, 1485, 1807, 1997, 2250, 2565, 2817, 2944, 3012, 3210, 3407, 3540, 3547, 3552, 3173, 2858, 2413, 1710, 1454, 1131, 873, 610, 412, 217, 275, 461, 768, 897, 1228, 1421, 4096, 1816, 1751, 4096, 2264, 2200, 4096, 2142, 1952, 2082, 2081, 2078, 4096, 2016, 2085, 1958, 1702, 1507, 1569, 4096, 2016, 2276, 2406, 2596, 4096, 3033, 4051, 4096, 3100, 3420, 3934, 4096, 3106, 3815, 4096, 857, 24, 4096, 926, 609, 4099]
svg_inline = """<svg viewBox="0 0 64 64"><g stroke-width="0.8">
<path d="M 22 12 L 22 12 L 23 13 L 28 15 L 31 13 L 35 10 L 40 5 L 44 1 L 46 0 L 47 4 L 50 10 L 53 15 L 55 20 L 55 27 L 55 32 L 49 37 L 44 42 L 37 45 L 26 46 L 22 46 L 17 43 L 13 41 L 9 34 L 6 28 L 3 25 L 4 19 L 7 13 L 12 0 L 14 1 L 19 12 L 22 13" stroke="black" fill="none"/>
<path d="M 28 24 L 27 23" stroke="black" fill="none"/>
<path d="M 35 24 L 34 24" stroke="black" fill="none"/>
<path d="M 33 30 L 30 32 L 32 34 L 32 33 L 32 30" stroke="black" fill="none"/>
<path d="M 31 32 L 32 37 L 30 38 L 26 38 L 23 35 L 24 33" stroke="black" fill="none"/>
<path d="M 31 32 L 35 36 L 37 38 L 40 36" stroke="black" fill="none"/>
<path d="M 47 25 L 63 19" stroke="black" fill="none"/>
<path d="M 48 28 L 53 28 L 61 30" stroke="black" fill="none"/>
<path d="M 48 34 L 59 39" stroke="black" fill="none"/>
<path d="M 13 25 L 0 24" stroke="black" fill="none"/>
<path d="M 14 30 L 9 33" stroke="black" fill="none"/>
</g></svg>"""
display(
    HTML(
        f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Input</b><br>{svg_inline}</div>'
    )
)