# Initial experiments

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from dataset import QuickDrawDataset
from utils import AbsolutePenPositionTokenizer
from tqdm import tqdm
import pickle

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

seed = 42
torch.manual_seed(seed)
if device == "cuda":
    torch.cuda.manual_seed_all(seed)

Using device: cuda


In [5]:
labels = ["cat"]

training_data = QuickDrawDataset(
    labels=labels,
)

tokenizer = AbsolutePenPositionTokenizer(bins=64)


class SketchDataset(Dataset):
    def __init__(
        self,
        svg_list,
        tokenizer,
        max_len=200,
        cache_file="sketch_tokenized_dataset.pkl",
    ):
        self.data = []
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.pad_id = tokenizer.vocab["PAD"]

        # Try to load from cache
        try:
            with open(cache_file, "rb") as f:
                self.data = pickle.load(f)
            print(f"Loaded tokenized data from {cache_file}")
        except FileNotFoundError:
            for svg in tqdm(svg_list, desc="Tokenizing SVGs"):
                tokens = tokenizer.encode(svg)
                # Truncate + Pad
                tokens = tokens[:max_len]
                tokens = tokens + [self.pad_id] * (max_len - len(tokens))
                self.data.append(tokens)

            # Save to cache
            with open(cache_file, "wb") as f:
                pickle.dump(self.data, f)
            print(f"Saved tokenized data to {cache_file}")

    def __getitem__(self, idx):
        seq = self.data[idx]
        input_ids = torch.tensor(seq[:-1])
        target_ids = torch.tensor(seq[1:])
        return input_ids, target_ids

    def __len__(self):
        return len(self.data)


dataset = SketchDataset(training_data, tokenizer, max_len=200)

Loading QuickDraw files: 100%|██████████| 1/1 [00:03<00:00,  3.15s/it]


Loaded tokenized data from sketch_tokenized_dataset.pkl


In [6]:
def generate_square_subsequent_mask(sz: int):
    """Causal mask to stop attention to future positions"""
    return torch.triu(torch.ones(sz, sz), diagonal=1).bool()


class SketchTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=6, max_len=200):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_len = max_len

        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=4 * d_model,  # , activation='gelu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        """
        x: (batch, seq_len) input tokens
        Returns: (batch, seq_len, vocab_size) logits
        """
        batch_size, seq_len = x.shape
        positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0)

        x = self.embed(x) + self.pos_embed(positions)  # (batch, seq_len, d_model)
        x = x.transpose(0, 1)  # -> (seq_len, batch, d_model)

        # causal mask (seq_len, seq_len)
        mask = generate_square_subsequent_mask(seq_len).to(x.device)

        x = self.transformer(x, mask=mask)  # (seq_len, batch, d_model)
        x = x.transpose(0, 1)  # back to (batch, seq_len, d_model)

        logits = self.fc_out(x)  # (batch, seq_len, vocab_size)
        return logits


def train_model(model, dataloader, vocab_size, epochs=10, lr=1e-4, device="cuda"):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # ignore pad token

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for input_ids, target_ids in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            input_ids, target_ids = input_ids.to(device), target_ids.to(device)

            # Transformer expects shape (seq_len, batch, d_model)
            logits = model(input_ids)  # (seq_len, batch, vocab_size)
            loss = criterion(logits.view(-1, vocab_size), target_ids.view(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1} Loss: {total_loss/len(dataloader):.4f}")


dataloader = DataLoader(dataset, batch_size=256, shuffle=True, pin_memory=True)
model = SketchTransformer(
    vocab_size=len(tokenizer.vocab), d_model=256, nhead=8, num_layers=6
)
# model = torch.load("sketch_transformer_cat_checkpoint0.pth", weights_only=False)

# train_model(
#     model,
#     dataloader,
#     vocab_size=len(tokenizer.vocab),
#     epochs=40,
#     lr=1e-4,
#     device=device,
# )



In [7]:
# save the model
# torch.save(model, "sketch_transformer_cat_checkpoint1.pth")

In [8]:
model = torch.load("sketch_transformer_cat_checkpoint1.pth", map_location=device, weights_only=False)


def sample_sequence(
    model,
    start_token,
    max_len=200,
    temperature=1.0,
    greedy=False,
    eos_id=None,
    device="cuda",
):
    model.eval()

    tokens = [start_token]
    tokens_tensor = torch.tensor([tokens], device=device)  # (1, 1)

    for _ in range(max_len - 1):
        with torch.no_grad():
            logits = model(tokens_tensor)  # (batch=1, seq_len, vocab_size)
            next_logits = logits[:, -1, :] / temperature  # take last step
            probs = F.softmax(next_logits, dim=-1)

            if greedy:
                next_token = torch.argmax(probs, dim=-1).item()
            else:
                next_token = torch.multinomial(probs, num_samples=1).item()

        tokens.append(next_token)

        # stop if EOS reached
        if eos_id is not None and next_token == eos_id:
            break

        next_token_tensor = torch.tensor([[next_token]], device=device)
        tokens_tensor = torch.cat([tokens_tensor, next_token_tensor], dim=1)

    return tokens


# TODO Top-k filtering, Top-p filtering (nucleus)

start_token = tokenizer.vocab["START"]
eos_token = tokenizer.vocab.get("END", None)

generated = sample_sequence(
    model,
    start_token,
    max_len=200,
    temperature=0.5,
    greedy=False,
    eos_id=eos_token,
    device=device,
)

print("Generated token sequence:", generated)
decoded_sketch = tokenizer.decode(generated)
print("Decoded sketch:", decoded_sketch)

from IPython.display import HTML, display

display(
    HTML(
        f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Generated</b><br>{decoded_sketch}</div>'
    )
)

Generated token sequence: [4098, 4096, 345, 400, 326, 452, 517, 711, 1107, 4096, 660, 724, 4096, 2385, 2446, 2816, 3020, 4096, 2512, 2318, 2126, 1934, 1488, 4096, 280, 224, 295, 365, 494, 4096, 3216, 3218, 4096, 2318, 2383, 4096, 2575, 2576, 4096, 2383, 2384, 4096, 1635, 1763, 4096, 1443, 1317, 4096, 1442, 1444, 4096, 1634, 1698, 4096, 2271, 2336, 4096, 2527, 2462, 4096, 2340, 2532, 4096, 1962, 2155, 4096, 873, 550, 4096, 877, 686, 4096, 878, 623, 4096, 943, 626, 4096, 2848, 3040, 4096, 2853, 3047, 4096, 2987, 3052, 4096, 2740, 2996, 4096, 1065, 873, 4096, 879, 687, 4096, 1070, 1071, 4096, 1011, 1077, 4096, 1011, 1079, 4096, 1011, 4096, 1140, 4096, 1271, 4096, 1076, 4096, 1077, 563, 4096, 1464, 1081, 4096, 1781, 4096, 1912, 4096, 2426, 4096, 1909, 4096, 1974, 4096, 2234, 4096, 3648, 4096, 1849, 4096, 2045, 4096, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 4097, 40

In [11]:
def top_p_filtering(logits, p=0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cumulative_probs > p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = False

    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    logits[:, indices_to_remove] = -float("Inf")
    return logits


def top_k_filtering(logits, k):
    if k <= 0:
        return logits
    top_k = min(k, logits.size(-1))
    values, _ = torch.topk(logits, top_k)
    min_values = values[:, -1].unsqueeze(-1)
    logits[logits < min_values] = -float("Inf")
    return logits


def sample_sequence_feat(
    model,
    start_token,
    max_len=200,
    temperature=1.0,
    top_k=60,
    top_p=0.9,
    greedy=False,
    eos_id=None,
    device="cuda",
):
    model.eval()
    tokens = [start_token]
    tokens_tensor = torch.tensor([tokens], device=device)

    for _ in range(max_len - 1):
        with torch.no_grad():
            logits = model(tokens_tensor)
            next_logits = logits[:, -1, :] / temperature

            # top-k / top-p filtering
            next_logits = top_k_filtering(next_logits, top_k)
            next_logits = top_p_filtering(next_logits, top_p)

            probs = F.softmax(next_logits, dim=-1)

            if greedy:
                next_token = torch.argmax(probs, dim=-1).item()
            else:
                next_token = torch.multinomial(probs, num_samples=1).item()

        tokens.append(next_token)

        if eos_id is not None and next_token == eos_id:
            break

        next_token_tensor = torch.tensor([[next_token]], device=device)
        tokens_tensor = torch.cat([tokens_tensor, next_token_tensor], dim=1)

    return tokens


start_token = tokenizer.vocab["START"]
eos_token = tokenizer.vocab.get("END", None)

generated = sample_sequence_feat(
    model,
    start_token,
    max_len=200,
    temperature=1.0,
    greedy=False,
    eos_id=eos_token,
    device=device,
)

print("Generated token sequence:", generated)
decoded_sketch = tokenizer.decode(generated)
print("Decoded sketch:", decoded_sketch)

from IPython.display import HTML, display

display(
    HTML(
        f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Generated</b><br>{decoded_sketch}</div>'
    )
)

Generated token sequence: [4098, 4096, 1101, 1162, 1415, 1605, 1991, 2313, 2628, 2880, 2947, 2954, 3150, 3412, 3545, 3549, 3232, 2916, 2343, 2024, 1641, 1445, 1119, 919, 979, 1037, 971, 1096, 1093, 4096, 1553, 1680, 1744, 1680, 4096, 2576, 2642, 2579, 2706, 2640, 4096, 2069, 2136, 2137, 2070, 4096, 2201, 2270, 2144, 2015, 1884, 4096, 2011, 2083, 2403, 2592, 4096, 2139, 2079, 1951, 1949, 4096, 2203, 2334, 4096, 2649, 3286, 3794, 4096, 2906, 4055, 4096, 2782, 3551, 4096, 1179, 795, 30, 4096, 1309, 865, 423, 4096, 1378, 1061, 938, 4096, 1487, 1552, 1681, 1746, 1617, 4096, 2321, 2256, 2383, 4099]
Decoded sketch: <svg viewBox="0 0 64 64"><g stroke-width="0.8">
<path d="M 17 13 L 18 10 L 22 7 L 25 5 L 31 7 L 36 9 L 41 4 L 45 0 L 46 3 L 46 10 L 49 14 L 53 20 L 55 25 L 55 29 L 50 32 L 45 36 L 36 39 L 31 40 L 25 41 L 22 37 L 17 31 L 14 23 L 15 19 L 16 13 L 15 11 L 17 8 L 17 5" stroke="black" fill="none"/>
<path d="M 24 17 L 26 16 L 27 16 L 26 16" stroke="black" fill="none"/>
<path d="M 40 16 L 

In [10]:
from IPython.display import HTML, display

# Generated token sequence: [4098, 4096, 1420, 1420, 1485, 1807, 1997, 2250, 2565, 2817, 2944, 3012, 3210, 3407, 3540, 3547, 3552, 3173, 2858, 2413, 1710, 1454, 1131, 873, 610, 412, 217, 275, 461, 768, 897, 1228, 1421, 4096, 1816, 1751, 4096, 2264, 2200, 4096, 2142, 1952, 2082, 2081, 2078, 4096, 2016, 2085, 1958, 1702, 1507, 1569, 4096, 2016, 2276, 2406, 2596, 4096, 3033, 4051, 4096, 3100, 3420, 3934, 4096, 3106, 3815, 4096, 857, 24, 4096, 926, 609, 4099]

svg_inline = '''<svg viewBox="0 0 64 64"><g stroke-width="0.8">
<path d="M 22 12 L 22 12 L 23 13 L 28 15 L 31 13 L 35 10 L 40 5 L 44 1 L 46 0 L 47 4 L 50 10 L 53 15 L 55 20 L 55 27 L 55 32 L 49 37 L 44 42 L 37 45 L 26 46 L 22 46 L 17 43 L 13 41 L 9 34 L 6 28 L 3 25 L 4 19 L 7 13 L 12 0 L 14 1 L 19 12 L 22 13" stroke="black" fill="none"/>
<path d="M 28 24 L 27 23" stroke="black" fill="none"/>
<path d="M 35 24 L 34 24" stroke="black" fill="none"/>
<path d="M 33 30 L 30 32 L 32 34 L 32 33 L 32 30" stroke="black" fill="none"/>
<path d="M 31 32 L 32 37 L 30 38 L 26 38 L 23 35 L 24 33" stroke="black" fill="none"/>
<path d="M 31 32 L 35 36 L 37 38 L 40 36" stroke="black" fill="none"/>
<path d="M 47 25 L 63 19" stroke="black" fill="none"/>
<path d="M 48 28 L 53 28 L 61 30" stroke="black" fill="none"/>
<path d="M 48 34 L 59 39" stroke="black" fill="none"/>
<path d="M 13 25 L 0 24" stroke="black" fill="none"/>
<path d="M 14 30 L 9 33" stroke="black" fill="none"/>
</g></svg>'''
display(HTML(f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Input</b><br>{svg_inline}</div>'))

