# Experiment 2:

- Our base sketch transformer architecture + class embedding layer
- Multi-class Conditional Model
- DeltaPenPositionTokenizer

In [26]:
from experiment_dir import set_cwd_project_root

set_cwd_project_root()

In [None]:
from dataset import QuickDrawDataset
from sketch_tokenizers import DeltaPenPositionTokenizer
from models import SketchTransformerConditional
from runner import SketchTrainer, sample
from prepare_data import stroke_to_bezier_single, clean_svg

label_names = ["bird", "crab", "guitar"]
dataset = QuickDrawDataset(label_names=label_names, download=True)
tokenizer = DeltaPenPositionTokenizer(bins=32)

model = SketchTransformerConditional(
    vocab_size=len(tokenizer.vocab),
    d_model=512,
    nhead=8,
    num_layers=8,
    max_len=200,
    num_classes=len(label_names),
)

training_config = {
    "batch_size": 128,
    "num_epochs": 15,
    "learning_rate": 1e-4,
    "log_dir": "logs/sketch_transformer_experiment_2",
    "splits": [0.85, 0.1, 0.05],
    # "use_padding_mask": True,
}

trainer = SketchTrainer(model, dataset, tokenizer, training_config)

Downloading QuickDrawDataset files: 100%|██████████| 3/3 [00:04<00:00,  1.60s/it]
Loading QuickDrawDataset: 3it [00:12,  4.01s/it]
Tokenizing dataset: 100%|██████████| 3/3 [01:44<00:00, 34.94s/it]


No checkpoints found, starting from scratch.


Initial Eval: 100%|██████████| 258/258 [00:01<00:00, 196.73it/s]


In [6]:
trainer.train_mixed(training_config["num_epochs"])

Epoch 1/15 [train]: 100%|██████████| 2191/2191 [05:41<00:00,  6.41it/s]
Epoch 1/15 [val]: 100%|██████████| 258/258 [00:12<00:00, 21.10it/s]


Epoch 1 | Train Loss: 1.2621 | Val Loss: 1.1164


Epoch 2/15 [train]: 100%|██████████| 2191/2191 [05:35<00:00,  6.53it/s]
Epoch 2/15 [val]: 100%|██████████| 258/258 [00:12<00:00, 21.34it/s]


Epoch 2 | Train Loss: 1.1201 | Val Loss: 1.0720


Epoch 3/15 [train]: 100%|██████████| 2191/2191 [05:33<00:00,  6.56it/s]
Epoch 3/15 [val]: 100%|██████████| 258/258 [00:12<00:00, 21.50it/s]


Epoch 3 | Train Loss: 1.0758 | Val Loss: 1.0467


Epoch 4/15 [train]: 100%|██████████| 2191/2191 [05:37<00:00,  6.49it/s]
Epoch 4/15 [val]: 100%|██████████| 258/258 [00:12<00:00, 21.30it/s]


Epoch 4 | Train Loss: 1.0523 | Val Loss: 1.0276


Epoch 5/15 [train]: 100%|██████████| 2191/2191 [05:36<00:00,  6.51it/s]
Epoch 5/15 [val]: 100%|██████████| 258/258 [00:12<00:00, 20.94it/s]


Epoch 5 | Train Loss: 1.0352 | Val Loss: 1.0136


Epoch 6/15 [train]: 100%|██████████| 2191/2191 [05:41<00:00,  6.42it/s]
Epoch 6/15 [val]: 100%|██████████| 258/258 [00:12<00:00, 20.83it/s]


Epoch 6 | Train Loss: 1.0225 | Val Loss: 1.0045


Epoch 7/15 [train]: 100%|██████████| 2191/2191 [05:39<00:00,  6.44it/s]
Epoch 7/15 [val]: 100%|██████████| 258/258 [00:12<00:00, 21.15it/s]


Epoch 7 | Train Loss: 1.0127 | Val Loss: 0.9969


Epoch 8/15 [train]: 100%|██████████| 2191/2191 [05:35<00:00,  6.52it/s]
Epoch 8/15 [val]: 100%|██████████| 258/258 [00:12<00:00, 21.45it/s]


Epoch 8 | Train Loss: 1.0047 | Val Loss: 0.9930


Epoch 9/15 [train]: 100%|██████████| 2191/2191 [05:38<00:00,  6.47it/s]
Epoch 9/15 [val]: 100%|██████████| 258/258 [00:12<00:00, 20.88it/s]


Epoch 9 | Train Loss: 0.9982 | Val Loss: 0.9874


Epoch 10/15 [train]: 100%|██████████| 2191/2191 [05:35<00:00,  6.54it/s]
Epoch 10/15 [val]: 100%|██████████| 258/258 [00:12<00:00, 21.48it/s]


Epoch 10 | Train Loss: 0.9926 | Val Loss: 0.9826


Epoch 11/15 [train]: 100%|██████████| 2191/2191 [05:39<00:00,  6.46it/s]
Epoch 11/15 [val]: 100%|██████████| 258/258 [00:12<00:00, 21.02it/s]


Epoch 11 | Train Loss: 0.9878 | Val Loss: 0.9794


Epoch 12/15 [train]:   7%|▋         | 146/2191 [00:22<05:16,  6.45it/s]


KeyboardInterrupt: 

In [None]:
generations_inline = ""
generations = []

for j, label_name in enumerate(label_names):
    for i in range(5):
        generated = sample(
            model=trainer.model,
            start_tokens=[trainer.tokenizer.vocab["START"]],
            temperature=0.8,
            top_k=20,
            top_p=0.7,
            greedy=False,
            eos_id=trainer.tokenizer.vocab["END"],
            class_label=j,
        )
        decoded_sketch = tokenizer.decode(generated, stroke_width=0.3)
        decoded_sketch = stroke_to_bezier_single(decoded_sketch)
        decoded_sketch = clean_svg(decoded_sketch)

        generations_inline += f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Generated {label_name}</b><br>{decoded_sketch}</div>'
        generations.append((generated, decoded_sketch))
        
        # import time
        
        # with open(f"save_{label_name}_{i}_{int(time.time())}.svg", "w") as f:
        #     f.write(decoded_sketch)


from IPython.display import HTML, display

display(HTML(generations_inline))