# Experiment 1:

- Our base sketch transformer architecture
- Single class
- DeltaPenPositionTokenizer

__additional__
- Post processing (bezier inference)
- Sketch completion
- Perplexity, curve count hueristic

In [17]:
from experiment_dir import set_cwd_project_root

set_cwd_project_root()

In [None]:
import torch
import torch.nn.functional as F
from dataset import QuickDrawDataset
from sketch_tokenizers import DeltaPenPositionTokenizer
from models import SketchTransformer
from runner import SketchTrainer, sample, device
from prepare_data import stroke_to_bezier_single, clean_svg

dataset = QuickDrawDataset(label_names=["cat"], download=True)
tokenizer = DeltaPenPositionTokenizer(bins=32)

# d_model => model capacity (types of drawing features it can learn)
# nhead => model can attend to more positions in parallel
# num layers => model learns more hierarchical abstractions (patterns, shapes, layouts)

model = SketchTransformer(
    vocab_size=len(tokenizer.vocab), d_model=512, nhead=16, num_layers=6, max_len=200
)

training_config = {
    "batch_size": 128,
    "num_epochs": 15,
    "learning_rate": 1e-4,
    "log_dir": "logs/sketch_transformer_experiment_1_large2",
    "splits": [0.85, 0.1, 0.05],
    "use_padding_mask": True,
}

trainer = SketchTrainer(model, dataset, tokenizer, training_config)

Downloading QuickDrawDataset files: 100%|██████████| 1/1 [00:00<00:00, 6114.15it/s]
Loading QuickDrawDataset: 1it [00:00, 169.06it/s]
Tokenizing dataset: 100%|██████████| 1/1 [00:00<00:00, 1407.96it/s]


No checkpoints found, starting from scratch.


Initial Eval: 100%|██████████| 81/81 [00:00<00:00, 199.93it/s]


In [19]:
trainer.train_mixed(training_config["num_epochs"])

Epoch 1/15 [train]: 100%|██████████| 685/685 [01:39<00:00,  6.86it/s]
Epoch 1/15 [val]: 100%|██████████| 81/81 [00:04<00:00, 19.85it/s]


Epoch 1 | Train Loss: 4.1210 | Val Loss: 3.5910


Epoch 2/15 [train]: 100%|██████████| 685/685 [01:40<00:00,  6.83it/s]
Epoch 2/15 [val]: 100%|██████████| 81/81 [00:04<00:00, 19.77it/s]


Epoch 2 | Train Loss: 3.5290 | Val Loss: 3.4184


Epoch 3/15 [train]: 100%|██████████| 685/685 [01:40<00:00,  6.81it/s]
Epoch 3/15 [val]: 100%|██████████| 81/81 [00:04<00:00, 19.77it/s]


Epoch 3 | Train Loss: 3.4062 | Val Loss: 3.3226


Epoch 4/15 [train]: 100%|██████████| 685/685 [01:40<00:00,  6.82it/s]
Epoch 4/15 [val]: 100%|██████████| 81/81 [00:04<00:00, 19.81it/s]


Epoch 4 | Train Loss: 3.3375 | Val Loss: 3.2684


Epoch 5/15 [train]: 100%|██████████| 685/685 [01:40<00:00,  6.80it/s]
Epoch 5/15 [val]: 100%|██████████| 81/81 [00:04<00:00, 19.65it/s]


Epoch 5 | Train Loss: 3.2894 | Val Loss: 3.2336


Epoch 6/15 [train]: 100%|██████████| 685/685 [01:40<00:00,  6.80it/s]
Epoch 6/15 [val]: 100%|██████████| 81/81 [00:04<00:00, 19.61it/s]


Epoch 6 | Train Loss: 3.2529 | Val Loss: 3.2066


Epoch 7/15 [train]: 100%|██████████| 685/685 [01:40<00:00,  6.81it/s]
Epoch 7/15 [val]: 100%|██████████| 81/81 [00:04<00:00, 19.80it/s]


Epoch 7 | Train Loss: 3.2241 | Val Loss: 3.1860


Epoch 8/15 [train]: 100%|██████████| 685/685 [01:40<00:00,  6.81it/s]
Epoch 8/15 [val]: 100%|██████████| 81/81 [00:04<00:00, 19.59it/s]


Epoch 8 | Train Loss: 3.2001 | Val Loss: 3.1643


Epoch 9/15 [train]: 100%|██████████| 685/685 [01:40<00:00,  6.80it/s]
Epoch 9/15 [val]: 100%|██████████| 81/81 [00:04<00:00, 19.80it/s]


Epoch 9 | Train Loss: 3.1789 | Val Loss: 3.1559


Epoch 10/15 [train]: 100%|██████████| 685/685 [01:40<00:00,  6.80it/s]
Epoch 10/15 [val]: 100%|██████████| 81/81 [00:04<00:00, 19.71it/s]


Epoch 10 | Train Loss: 3.1609 | Val Loss: 3.1394


Epoch 11/15 [train]: 100%|██████████| 685/685 [01:40<00:00,  6.79it/s]
Epoch 11/15 [val]: 100%|██████████| 81/81 [00:04<00:00, 19.61it/s]


Epoch 11 | Train Loss: 3.1451 | Val Loss: 3.1334


Epoch 12/15 [train]: 100%|██████████| 685/685 [01:41<00:00,  6.78it/s]
Epoch 12/15 [val]: 100%|██████████| 81/81 [00:04<00:00, 19.68it/s]


Epoch 12 | Train Loss: 3.1302 | Val Loss: 3.1226


Epoch 13/15 [train]: 100%|██████████| 685/685 [01:40<00:00,  6.79it/s]
Epoch 13/15 [val]: 100%|██████████| 81/81 [00:04<00:00, 19.75it/s]


Epoch 13 | Train Loss: 3.1170 | Val Loss: 3.1117


Epoch 14/15 [train]: 100%|██████████| 685/685 [01:40<00:00,  6.78it/s]
Epoch 14/15 [val]: 100%|██████████| 81/81 [00:04<00:00, 19.74it/s]


Epoch 14 | Train Loss: 3.1051 | Val Loss: 3.1070


Epoch 15/15 [train]: 100%|██████████| 685/685 [01:39<00:00,  6.91it/s]
Epoch 15/15 [val]: 100%|██████████| 81/81 [00:03<00:00, 20.52it/s]


Epoch 15 | Train Loss: 3.0936 | Val Loss: 3.1008


In [21]:
generations_inline = ""
generations = []

for i in range(5):
    generated = sample(
        model=trainer.model,
        start_tokens=[trainer.tokenizer.vocab["START"]],
        temperature=1.0,
        top_k=60,
        top_p=0.9,
        greedy=False,
        eos_id=trainer.tokenizer.vocab["END"],
    )
    decoded_sketch = tokenizer.decode(generated, stroke_width=0.3)
    decoded_sketch = stroke_to_bezier_single(decoded_sketch)
    decoded_sketch = clean_svg(decoded_sketch)

    generations_inline += f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Generated</b><br>{decoded_sketch}</div>'
    generations.append((generated, decoded_sketch))


from IPython.display import HTML, display

display(HTML(generations_inline))

# temp<=0.5 fairly deterministic
# temp=0.8, top_k=20, top_p=0.9 more variety but still coherent

# *note important features are usually preseved, but sketches are disorganized (number of curves hueristic does not work well)*
# temp=1.0, top_k=20, top_p=0.75 more variety, some incoherent sequences

#  *note that lower temp means less variety, notice that sequences begin to repete themselves more often*
# temp=0.55, top_k=20, top_p=0.9 good balance
# temp=0.6, top_k=30, top_p=0.9  good balance

In [22]:
# note many sketches have missing parts or incomplete shapes (step 1: get a base sketch) : check the number of paths
# psuedo hueuristic: count number of curves in SVG

from prepare_data import count_curves

# sort generations by number of curves
generations_inline = ""

generations_sorted = sorted(generations, key=lambda x: count_curves(x[1]), reverse=True)
for sketch in generations_sorted:
    generations_inline += f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Generated</b><br>{sketch[1]}</div>'

display(HTML(generations_inline))

# Sketch completion

Visualization of how the model has learned the underlying representation of probable sketches for a given label.

Sort items from the dataset based on perplexity.

In [23]:
# Select a sketch from the dataset, remove tokens and let the model complete it
selected_sketch = dataset[38946]

# tokenize and remove some tokens from the end
selected_tokens = tokenizer.encode(selected_sketch)

selected_tokens_partial = selected_tokens[: len(selected_tokens) // 2]  # remove 50%
destroyed_sketch = tokenizer.decode(selected_tokens_partial)

comparison_inline = f"""<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Original</b><br>{selected_sketch}</div>
<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Partial</b><br>{destroyed_sketch}</div>"""

for i in range(5):
    generated = sample(
        model=trainer.model,
        start_tokens=selected_tokens_partial,
        temperature=1.0,
        greedy=False,
        top_k=20,
        top_p=0.65,
        eos_id=trainer.tokenizer.vocab["END"],
    )

    generated_sketch = tokenizer.decode(generated, stroke_width=0.3)
    generated_sketch = stroke_to_bezier_single(generated_sketch)
    generated_sketch = clean_svg(generated_sketch)
    comparison_inline += f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Completed {i}</b><br>{generated_sketch}</div>'

display(HTML(comparison_inline))

## Sketch representation

Visualization of how the model has learned the underlying representation of probable sketches for a given label.

Sort items from the dataset based on perplexity.

In [24]:
def compute_perplexity(model, tokens):
    model.eval()
    with torch.no_grad():
        logits = model(tokens[:, :-1])
        target = tokens[:, 1:]
        loss = F.cross_entropy(
            logits.transpose(1, 2),  # (batch, vocab, seq_len)
            target,
            reduction="none",
        )

        loss = loss.mean(dim=1)
        perplexity = torch.exp(loss)
        return perplexity


sketch_perplexities = []

for i in range(20):
    sketch = dataset[i]
    tokens = tokenizer.encode(sketch)
    perplexity = compute_perplexity(model, torch.tensor([tokens], device=device))
    decoded_sketch = tokenizer.decode(tokens, stroke_width=0.3)
    sketch_perplexities.append((perplexity.item(), decoded_sketch))

# sort by perplexity
sketch_perplexities.sort(key=lambda x: x[0], reverse=True)

# sort normalized by length
# sketch_perplexities.sort(key=lambda x: x[0] / len(x[1]), reverse=True)

sketches_inline = ""
for perp, sketch in sketch_perplexities:
    sketches_inline += f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Perplexity: {perp:.2f}</b><br>{sketch}</div>'


from IPython.display import HTML, display

display(HTML(sketches_inline))

# Sorting by perplexity does seem to highlight some of the worse sketches