# Experiment 0:

- Our base sketch transformer architecture
- Single class
- AbsolutePenPositionTokenizer

In [None]:
from experiment_dir import set_cwd_project_root

set_cwd_project_root()

In [None]:
from dataset import QuickDrawDataset
from sketch_tokenizers import AbsolutePenPositionTokenizer
from models import SketchTransformer
from runner import SketchTrainer, sample

dataset = QuickDrawDataset(label_names=["cat"], download=True)
tokenizer = AbsolutePenPositionTokenizer(bins=32)
model = SketchTransformer(
    vocab_size=len(tokenizer.vocab), d_model=384, nhead=8, num_layers=8, max_len=200
)

training_config = {
    "batch_size": 128,
    "num_epochs": 20,
    "learning_rate": 1e-4,
    "log_dir": "logs/sketch_transformer_experiment_0",
    "splits": [0.85, 0.1, 0.05],
    # "use_padding_mask": True,
}

# use padding mask = True
# Attention layers: Ignore padded tokens when computing attention weights
# Loss function: Ignore padded positions when computing loss
# It is more computationally expensive, but can have better results for variable-length sequences (like sketches)

trainer = SketchTrainer(model, dataset, tokenizer, training_config)

Downloading QuickDrawDataset files: 100%|██████████| 1/1 [00:00<00:00, 5475.59it/s]
Loading QuickDrawDataset: 1it [00:00, 481.83it/s]
Tokenizing dataset: 100%|██████████| 1/1 [00:00<00:00, 1453.33it/s]


No checkpoints found, starting from scratch.


Initial Eval: 100%|██████████| 81/81 [00:00<00:00, 197.07it/s]


In [None]:
# trainer.train()

trainer.train_mixed(training_config["num_epochs"])  # mixed precision training

Epoch 1/20 [train]: 100%|██████████| 685/685 [01:13<00:00,  9.37it/s]
Epoch 1/20 [val]: 100%|██████████| 81/81 [00:02<00:00, 32.76it/s]


Epoch 1 | Train Loss: 1.9635 | Val Loss: 1.5408


Epoch 2/20 [train]: 100%|██████████| 685/685 [01:13<00:00,  9.35it/s]
Epoch 2/20 [val]: 100%|██████████| 81/81 [00:02<00:00, 32.47it/s]


Epoch 2 | Train Loss: 1.4311 | Val Loss: 1.3162


Epoch 3/20 [train]: 100%|██████████| 685/685 [01:11<00:00,  9.56it/s]
Epoch 3/20 [val]: 100%|██████████| 81/81 [00:02<00:00, 33.53it/s]


Epoch 3 | Train Loss: 1.3073 | Val Loss: 1.2447


Epoch 4/20 [train]: 100%|██████████| 685/685 [01:11<00:00,  9.52it/s]
Epoch 4/20 [val]: 100%|██████████| 81/81 [00:02<00:00, 33.27it/s]


Epoch 4 | Train Loss: 1.2525 | Val Loss: 1.2019


Epoch 5/20 [train]: 100%|██████████| 685/685 [01:10<00:00,  9.67it/s]
Epoch 5/20 [val]: 100%|██████████| 81/81 [00:02<00:00, 34.11it/s]


Epoch 5 | Train Loss: 1.2183 | Val Loss: 1.1760


Epoch 6/20 [train]: 100%|██████████| 685/685 [01:11<00:00,  9.53it/s]
Epoch 6/20 [val]: 100%|██████████| 81/81 [00:02<00:00, 33.52it/s]


Epoch 6 | Train Loss: 1.1943 | Val Loss: 1.1586


Epoch 7/20 [train]: 100%|██████████| 685/685 [01:11<00:00,  9.65it/s]
Epoch 7/20 [val]: 100%|██████████| 81/81 [00:02<00:00, 34.02it/s]


Epoch 7 | Train Loss: 1.1759 | Val Loss: 1.1472


Epoch 8/20 [train]: 100%|██████████| 685/685 [01:12<00:00,  9.50it/s]
Epoch 8/20 [val]: 100%|██████████| 81/81 [00:02<00:00, 32.17it/s]


Epoch 8 | Train Loss: 1.1614 | Val Loss: 1.1338


Epoch 9/20 [train]: 100%|██████████| 685/685 [01:13<00:00,  9.31it/s]
Epoch 9/20 [val]: 100%|██████████| 81/81 [00:02<00:00, 32.81it/s]


Epoch 9 | Train Loss: 1.1495 | Val Loss: 1.1266


Epoch 10/20 [train]: 100%|██████████| 685/685 [01:13<00:00,  9.30it/s]
Epoch 10/20 [val]: 100%|██████████| 81/81 [00:02<00:00, 32.51it/s]


Epoch 10 | Train Loss: 1.1393 | Val Loss: 1.1181


Epoch 11/20 [train]: 100%|██████████| 685/685 [01:13<00:00,  9.31it/s]
Epoch 11/20 [val]: 100%|██████████| 81/81 [00:02<00:00, 32.71it/s]


Epoch 11 | Train Loss: 1.1301 | Val Loss: 1.1111


Epoch 12/20 [train]: 100%|██████████| 685/685 [01:13<00:00,  9.31it/s]
Epoch 12/20 [val]: 100%|██████████| 81/81 [00:02<00:00, 32.49it/s]


Epoch 12 | Train Loss: 1.1221 | Val Loss: 1.1069


Epoch 13/20 [train]: 100%|██████████| 685/685 [01:13<00:00,  9.32it/s]
Epoch 13/20 [val]: 100%|██████████| 81/81 [00:02<00:00, 32.77it/s]


Epoch 13 | Train Loss: 1.1150 | Val Loss: 1.1017


Epoch 14/20 [train]: 100%|██████████| 685/685 [01:13<00:00,  9.36it/s]
Epoch 14/20 [val]: 100%|██████████| 81/81 [00:02<00:00, 32.00it/s]


Epoch 14 | Train Loss: 1.1087 | Val Loss: 1.0985


Epoch 15/20 [train]: 100%|██████████| 685/685 [01:13<00:00,  9.29it/s]
Epoch 15/20 [val]: 100%|██████████| 81/81 [00:02<00:00, 32.45it/s]


Epoch 15 | Train Loss: 1.1026 | Val Loss: 1.0941


Epoch 16/20 [train]: 100%|██████████| 685/685 [01:13<00:00,  9.33it/s]
Epoch 16/20 [val]: 100%|██████████| 81/81 [00:02<00:00, 32.68it/s]


Epoch 16 | Train Loss: 1.0972 | Val Loss: 1.0925


Epoch 17/20 [train]:  12%|█▏        | 84/685 [00:09<01:04,  9.32it/s]


KeyboardInterrupt: 

In [None]:
generations_inline = ""

for i in range(5):
    generated = sample(
        model=trainer.model,
        start_tokens=[trainer.tokenizer.vocab["START"]],
        temperature=1.0,
        greedy=False,
        eos_id=trainer.tokenizer.vocab["END"],
    )
    decoded_sketch = tokenizer.decode(generated, stroke_width=0.3)
    generations_inline += f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Generated</b><br>{decoded_sketch}</div>'


from IPython.display import HTML, display

display(HTML(generations_inline))

In [3]:
# Example generation
from IPython.display import HTML, display

# Generated token sequence: [4098, 4096, 1420, 1420, 1485, 1807, 1997, 2250, 2565, 2817, 2944, 3012, 3210, 3407, 3540, 3547, 3552, 3173, 2858, 2413, 1710, 1454, 1131, 873, 610, 412, 217, 275, 461, 768, 897, 1228, 1421, 4096, 1816, 1751, 4096, 2264, 2200, 4096, 2142, 1952, 2082, 2081, 2078, 4096, 2016, 2085, 1958, 1702, 1507, 1569, 4096, 2016, 2276, 2406, 2596, 4096, 3033, 4051, 4096, 3100, 3420, 3934, 4096, 3106, 3815, 4096, 857, 24, 4096, 926, 609, 4099]
svg_inline = """<svg viewBox="0 0 64 64"><g stroke-width="0.8">
<path d="M 22 12 L 22 12 L 23 13 L 28 15 L 31 13 L 35 10 L 40 5 L 44 1 L 46 0 L 47 4 L 50 10 L 53 15 L 55 20 L 55 27 L 55 32 L 49 37 L 44 42 L 37 45 L 26 46 L 22 46 L 17 43 L 13 41 L 9 34 L 6 28 L 3 25 L 4 19 L 7 13 L 12 0 L 14 1 L 19 12 L 22 13" stroke="black" fill="none"/>
<path d="M 28 24 L 27 23" stroke="black" fill="none"/>
<path d="M 35 24 L 34 24" stroke="black" fill="none"/>
<path d="M 33 30 L 30 32 L 32 34 L 32 33 L 32 30" stroke="black" fill="none"/>
<path d="M 31 32 L 32 37 L 30 38 L 26 38 L 23 35 L 24 33" stroke="black" fill="none"/>
<path d="M 31 32 L 35 36 L 37 38 L 40 36" stroke="black" fill="none"/>
<path d="M 47 25 L 63 19" stroke="black" fill="none"/>
<path d="M 48 28 L 53 28 L 61 30" stroke="black" fill="none"/>
<path d="M 48 34 L 59 39" stroke="black" fill="none"/>
<path d="M 13 25 L 0 24" stroke="black" fill="none"/>
<path d="M 14 30 L 9 33" stroke="black" fill="none"/>
</g></svg>"""
display(
    HTML(
        f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><br>{svg_inline}</div>'
    )
)