# Experiment 4:

- Our base sketch transformer architecture + class embedding layer
- Multi-class Conditional Model
- DeltaPenPositionTokenizer

In [2]:
from experiment_dir import set_cwd_project_root

set_cwd_project_root()

In [None]:
from dataset import QuickDrawDataset
from sketch_tokenizers import DeltaPenPositionTokenizer
from models import SketchTransformerConditional
from runner import SketchTrainer, sample
from prepare_data import stroke_to_bezier_single, clean_svg

label_names = ["monkey", "peanut", "saxophone", "pizza"]
dataset = QuickDrawDataset(label_names=label_names, download=True)
tokenizer = DeltaPenPositionTokenizer(bins=32)

model = SketchTransformerConditional(
    vocab_size=len(tokenizer.vocab),
    d_model=512,
    nhead=8,
    num_layers=8,
    max_len=200,
    num_classes=len(label_names),
)

training_config = {
    "batch_size": 128,
    "num_epochs": 15,
    "learning_rate": 1e-4,
    "log_dir": "logs/sketch_transformer_experiment_4",
    "splits": [0.85, 0.075, 0.075],
    # "use_padding_mask": True,
}

trainer = SketchTrainer(model, dataset, tokenizer, training_config)

2025-11-12 12:49:53.196686: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-12 12:49:53.427054: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-11-12 12:49:54.438823: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


Using device: cuda


Downloading QuickDrawDataset files: 100%|██████████| 4/4 [00:07<00:00,  1.98s/it]
Loading QuickDrawDataset: 4it [00:07,  1.90s/it]
Tokenizing dataset: 100%|██████████| 4/4 [01:16<00:00, 19.15s/it]


No checkpoint found, starting fresh training.


Initial Eval: 100%|██████████| 361/361 [00:00<00:00, 490.80it/s]


In [2]:
trainer.train_mixed(training_config["num_epochs"])

Epoch 1/15 [train]: 100%|██████████| 3067/3067 [06:34<00:00,  7.77it/s]
Epoch 1/15 [val]: 100%|██████████| 361/361 [00:13<00:00, 25.95it/s]


Epoch 1 | Train Loss: 1.1123 | Val Loss: 0.9958


Epoch 2/15 [train]: 100%|██████████| 3067/3067 [06:37<00:00,  7.72it/s]
Epoch 2/15 [val]: 100%|██████████| 361/361 [00:13<00:00, 26.06it/s]


Epoch 2 | Train Loss: 0.9841 | Val Loss: 0.9584


Epoch 3/15 [train]: 100%|██████████| 3067/3067 [06:16<00:00,  8.15it/s]
Epoch 3/15 [val]: 100%|██████████| 361/361 [00:12<00:00, 27.95it/s]


Epoch 3 | Train Loss: 0.9577 | Val Loss: 0.9396


Epoch 4/15 [train]: 100%|██████████| 3067/3067 [06:09<00:00,  8.30it/s]
Epoch 4/15 [val]: 100%|██████████| 361/361 [00:12<00:00, 27.95it/s]


Epoch 4 | Train Loss: 0.9426 | Val Loss: 0.9278


Epoch 5/15 [train]: 100%|██████████| 3067/3067 [06:09<00:00,  8.30it/s]
Epoch 5/15 [val]: 100%|██████████| 361/361 [00:12<00:00, 27.95it/s]


Epoch 5 | Train Loss: 0.9325 | Val Loss: 0.9211


Epoch 6/15 [train]: 100%|██████████| 3067/3067 [06:09<00:00,  8.30it/s]
Epoch 6/15 [val]: 100%|██████████| 361/361 [00:12<00:00, 27.96it/s]


Epoch 6 | Train Loss: 0.9250 | Val Loss: 0.9173


Epoch 7/15 [train]: 100%|██████████| 3067/3067 [06:09<00:00,  8.30it/s]
Epoch 7/15 [val]: 100%|██████████| 361/361 [00:12<00:00, 27.96it/s]


Epoch 7 | Train Loss: 0.9196 | Val Loss: 0.9114


Epoch 8/15 [train]: 100%|██████████| 3067/3067 [06:09<00:00,  8.30it/s]
Epoch 8/15 [val]: 100%|██████████| 361/361 [00:12<00:00, 27.95it/s]


Epoch 8 | Train Loss: 0.9145 | Val Loss: 0.9081


Epoch 9/15 [train]: 100%|██████████| 3067/3067 [06:09<00:00,  8.30it/s]
Epoch 9/15 [val]: 100%|██████████| 361/361 [00:12<00:00, 27.96it/s]


Epoch 9 | Train Loss: 0.9105 | Val Loss: 0.9045


Epoch 10/15 [train]: 100%|██████████| 3067/3067 [06:15<00:00,  8.16it/s]
Epoch 10/15 [val]: 100%|██████████| 361/361 [00:13<00:00, 26.18it/s]


Epoch 10 | Train Loss: 0.9071 | Val Loss: 0.9037


Epoch 11/15 [train]: 100%|██████████| 3067/3067 [06:32<00:00,  7.82it/s]
Epoch 11/15 [val]: 100%|██████████| 361/361 [00:13<00:00, 26.34it/s]


Epoch 11 | Train Loss: 0.9038 | Val Loss: 0.9000


Epoch 12/15 [train]: 100%|██████████| 3067/3067 [06:24<00:00,  7.98it/s]
Epoch 12/15 [val]: 100%|██████████| 361/361 [00:13<00:00, 26.07it/s]


Epoch 12 | Train Loss: 0.9011 | Val Loss: 0.8993


Epoch 13/15 [train]: 100%|██████████| 3067/3067 [06:33<00:00,  7.80it/s]
Epoch 13/15 [val]: 100%|██████████| 361/361 [00:13<00:00, 26.05it/s]


Epoch 13 | Train Loss: 0.8984 | Val Loss: 0.8963


Epoch 14/15 [train]: 100%|██████████| 3067/3067 [06:31<00:00,  7.83it/s]
Epoch 14/15 [val]: 100%|██████████| 361/361 [00:13<00:00, 26.08it/s]


Epoch 14 | Train Loss: 0.8961 | Val Loss: 0.8950


Epoch 15/15 [train]: 100%|██████████| 3067/3067 [06:31<00:00,  7.82it/s]
Epoch 15/15 [val]: 100%|██████████| 361/361 [00:13<00:00, 26.34it/s]


Epoch 15 | Train Loss: 0.8938 | Val Loss: 0.8954


In [3]:
generations_inline = ""
generations = []

for j, label_name in enumerate(label_names):
    for i in range(5):
        generated = sample(
            model=trainer.model,
            start_tokens=[trainer.tokenizer.vocab["START"]],
            temperature=0.8,
            top_k=20,
            top_p=0.7,
            greedy=False,
            eos_id=trainer.tokenizer.vocab["END"],
            class_label=j,
        )
        decoded_sketch = tokenizer.decode(generated, stroke_width=0.3)
        decoded_sketch = stroke_to_bezier_single(decoded_sketch)
        decoded_sketch = clean_svg(decoded_sketch)

        generations_inline += f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Generated {label_name}</b><br>{decoded_sketch}</div>'
        generations.append((generated, decoded_sketch))


from IPython.display import HTML, display

display(HTML(generations_inline))