# Sampling from trained models

In [None]:
from experiment_dir import set_cwd_project_root

set_cwd_project_root()

In [None]:
import torch
from tokenizers import DeltaPenPositionTokenizer, AbsolutePenPositionTokenizer
from models import SketchTransformer, SketchTransformerConditional
from runner import sample
from prepare_data import stroke_to_bezier_single, clean_svg

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

seed = 42
torch.manual_seed(seed)
if device == "cuda":
    torch.cuda.manual_seed_all(seed)

Using device: cuda


In [None]:
model = torch.load(
    "logs/sketch_transformer_example2/SketchTransformer_DeltaPenPositionTokenizer-q32_checkpoint14.pth",
    map_location=device,
    weights_only=False,
)

tokenizer = DeltaPenPositionTokenizer(bins=32)

generated = sample(
    model=model,
    start_tokens=[tokenizer.vocab["START"]],
    temperature=1.0,
    greedy=False,
    eos_id=tokenizer.vocab["END"],
)

In [None]:
generations_inline = ""
generations = []

for i in range(5):
    generated = sample(
        model,
        start_tokens=[tokenizer.vocab["START"]],
        max_len=200,
        greedy=False,
        eos_id=tokenizer.vocab["END"],
        device=device,
    )
    decoded_sketch = tokenizer.decode(generated, stroke_width=0.3)

    decoded_sketch = stroke_to_bezier_single(decoded_sketch)
    decoded_sketch = clean_svg(decoded_sketch)

    generations_inline += f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Generated {i}</b><br>{decoded_sketch}</div>'
    generations.append((generated, decoded_sketch))


from IPython.display import HTML, display

display(HTML(generations_inline))