# Initial experiments Part 2

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from dataset import QuickDrawDataset
from utils import AbsolutePenPositionTokenizer
from tqdm import tqdm
import pickle

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

seed = 42
torch.manual_seed(seed)
if device == "cuda":
    torch.cuda.manual_seed_all(seed)

In [None]:
labels = ["cat"]

training_data = QuickDrawDataset(
    labels=labels,
)

tokenizer = AbsolutePenPositionTokenizer(bins=64)

class SketchDataset(Dataset):
    def __init__(self, svg_list, tokenizer, max_len=200, cache_file="sketch_tokenized_dataset.pkl"):
        self.data = []
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.pad_id = tokenizer.vocab["PAD"]
        
        # Try to load from cache
        try:
            with open(cache_file, "rb") as f:
                self.data = pickle.load(f)
            print(f"Loaded tokenized data from {cache_file}")
        except FileNotFoundError:
            for svg in tqdm(svg_list, desc="Tokenizing SVGs"):
                tokens = tokenizer.encode(svg)
                # Truncate + Pad
                tokens = tokens[:max_len]
                tokens = tokens + [self.pad_id] * (max_len - len(tokens))
                self.data.append(tokens)
                
            # Save to cache
            with open(cache_file, "wb") as f:
                pickle.dump(self.data, f)
            print(f"Saved tokenized data to {cache_file}")

    def __getitem__(self, idx):
        seq = self.data[idx]
        input_ids = torch.tensor(seq[:-1])
        target_ids = torch.tensor(seq[1:])
        return input_ids, target_ids
    
    def __len__(self):
        return len(self.data)
    
dataset = SketchDataset(training_data, tokenizer, max_len=200)

In [None]:
def generate_square_subsequent_mask(sz: int):
    """Causal mask to stop attention to future positions"""
    return torch.triu(torch.ones(sz, sz), diagonal=1).bool()

import math

class SinusoidalPositionalEncoding(nn.Module):
    """Fixed sinusoidal positional encoding (Attention Is All You Need)"""
    def __init__(self, d_model, max_len):
        super().__init__()
        self.d_model = d_model
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        x: (batch, seq_len, d_model)
        returns: (batch, seq_len, d_model)
        """
        seq_len = x.size(1)
        # slice along sequence dimension, not embedding dimension
        return x + self.pe[:, :seq_len, :]

class SketchTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=6, max_len=200):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_len = max_len

        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = SinusoidalPositionalEncoding(d_model, max_len=max_len)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=4*d_model, activation='gelu', batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        self.norm = nn.LayerNorm(d_model)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        """
        x: (batch, seq_len) input tokens
        Returns: (batch, seq_len, vocab_size) logits
        """
        batch_size, seq_len = x.shape
        mask = generate_square_subsequent_mask(seq_len).to(x.device)
        x = self.embed(x) * math.sqrt(self.d_model)
        x = self.pos_encoder(x)
        x = self.transformer(x, mask=mask)
        x = self.norm(x)
        logits = self.fc_out(x)
        
        return logits
    
def train_model(model, dataloader, vocab_size, epochs=10, lr=1e-4, device="cuda"):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # ignore pad token

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for input_ids, target_ids in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            input_ids, target_ids = input_ids.to(device), target_ids.to(device)

            # Transformer expects shape (seq_len, batch, d_model)
            logits = model(input_ids)  # (seq_len, batch, vocab_size)
            loss = criterion(logits.view(-1, vocab_size), target_ids.view(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1} Loss: {total_loss/len(dataloader):.4f}")


dataloader = DataLoader(dataset, batch_size=256, shuffle=True, pin_memory=True)
model = SketchTransformer(vocab_size=len(tokenizer.vocab), d_model=256, nhead=8, num_layers=6)

train_model(model, dataloader, vocab_size=len(tokenizer.vocab), epochs=40, lr=1e-3, device=device)

Epoch 1/40: 100%|██████████| 403/403 [01:32<00:00,  4.35it/s]


Epoch 1 Loss: 2.3922


Epoch 2/40: 100%|██████████| 403/403 [01:32<00:00,  4.33it/s]


Epoch 2 Loss: 1.7782


Epoch 3/40:   4%|▍         | 17/403 [00:04<01:34,  4.07it/s]


KeyboardInterrupt: 

In [None]:
# save the model
# torch.save(model, "sketch_transformer_cat_decoder_checkpoint.pth")

In [73]:
# model = torch.load("sketch_transformer_cat_decoder_checkpoint.pth", map_location=device, weights_only=False)

def sample_sequence(model, start_token, max_len=200, temperature=1.0, greedy=False, eos_id=None, device="cuda"):
    model.eval()
    
    tokens = [start_token]
    tokens_tensor = torch.tensor([tokens], device=device)  # (1, 1)

    for _ in range(max_len - 1):
        with torch.no_grad():
            logits = model(tokens_tensor)  # (batch=1, seq_len, vocab_size)
            next_logits = logits[:, -1, :] / temperature  # take last step
            probs = F.softmax(next_logits, dim=-1)

            if greedy:
                next_token = torch.argmax(probs, dim=-1).item()
            else:
                next_token = torch.multinomial(probs, num_samples=1).item()

        tokens.append(next_token)

        # stop if EOS reached
        if eos_id is not None and next_token == eos_id:
            break

        next_token_tensor = torch.tensor([[next_token]], device=device)
        tokens_tensor = torch.cat([tokens_tensor, next_token_tensor], dim=1)

    return tokens

# TODO Top-k filtering, Top-p filtering (nucleus)

start_token = tokenizer.vocab["START"]
eos_token = tokenizer.vocab.get("END", None)

generated = sample_sequence(
    model, 
    start_token, 
    max_len=200, 
    temperature=0.5, 
    greedy=False, 
    eos_id=eos_token, 
    device=device
)

print("Generated token sequence:", generated)
decoded_sketch = tokenizer.decode(generated, stroke_width=0.4)
print("Decoded sketch:", decoded_sketch)

from IPython.display import HTML, display

display(HTML(f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Generated</b><br>{decoded_sketch}</div>'))

Generated token sequence: [4098, 4096, 650, 583, 515, 704, 1027, 1416, 1737, 2443, 2506, 3008, 3073, 3344, 4096, 207, 19, 25, 226, 361, 555, 812, 1647, 2094, 2472, 2725, 2976, 2902, 4096, 1630, 1628, 4096, 1438, 1440, 4096, 1625, 4096, 1881, 4096, 1760, 1698, 4096, 1890, 4096, 1887, 4096, 2142, 2460, 4096, 1699, 4096, 1443, 4096, 1574, 4096, 1449, 4096, 1573, 4096, 1447, 4096, 2907, 4096, 2532, 4096, 2090, 4096, 2087, 4096, 2469, 4096, 1960, 4096, 1513, 4099]
Decoded sketch: <svg viewBox="0 0 64 64"><g stroke-width="0.4">
<path d="M 10 10 L 9 7 L 8 3 L 11 0 L 16 3 L 22 8 L 27 9 L 38 11 L 39 10 L 47 0 L 48 1 L 52 16" stroke="black" fill="none"/>
<path d="M 3 15 L 0 19 L 0 25 L 3 34 L 5 41 L 8 43 L 12 44 L 25 47 L 32 46 L 38 40 L 42 37 L 46 32 L 45 22" stroke="black" fill="none"/>
<path d="M 25 30 L 25 28" stroke="black" fill="none"/>
<path d="M 22 30 L 22 32" stroke="black" fill="none"/>
<path d="M 25 25" stroke="black" fill="none"/>
<path d="M 29 25" stroke="black" fill="none"/>
<path 

In [None]:
def top_p_filtering(logits, p=0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cumulative_probs > p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = False

    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    logits[:, indices_to_remove] = -float("Inf")
    return logits

def top_k_filtering(logits, k):
    if k <= 0:
        return logits
    top_k = min(k, logits.size(-1))
    values, _ = torch.topk(logits, top_k)
    min_values = values[:, -1].unsqueeze(-1)
    logits[logits < min_values] = -float("Inf")
    return logits

def sample_sequence_feat(
    model, start_token, max_len=200, temperature=1.0,
    top_k=60, top_p=0.9,
    greedy=False, eos_id=None, device="cuda"
):
    model.eval()
    tokens = [start_token]
    tokens_tensor = torch.tensor([tokens], device=device)

    for _ in range(max_len - 1):
        with torch.no_grad():
            logits = model(tokens_tensor)
            next_logits = logits[:, -1, :] / temperature

            # top-k / top-p filtering
            next_logits = top_k_filtering(next_logits, top_k)
            next_logits = top_p_filtering(next_logits, top_p)

            probs = F.softmax(next_logits, dim=-1)

            if greedy:
                next_token = torch.argmax(probs, dim=-1).item()
            else:
                next_token = torch.multinomial(probs, num_samples=1).item()

        tokens.append(next_token)

        if eos_id is not None and next_token == eos_id:
            break

        next_token_tensor = torch.tensor([[next_token]], device=device)
        tokens_tensor = torch.cat([tokens_tensor, next_token_tensor], dim=1)

    return tokens


start_token = tokenizer.vocab["START"]
eos_token = tokenizer.vocab.get("END", None)

generated = sample_sequence_feat(
    model, 
    start_token, 
    max_len=200, 
    temperature=1.0, 
    greedy=False, 
    eos_id=eos_token, 
    device=device
)

print("Generated token sequence:", generated)
decoded_sketch = tokenizer.decode(generated, stroke_width=0.4)
print("Decoded sketch:", decoded_sketch)

from IPython.display import HTML, display

display(HTML(f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Generated</b><br>{decoded_sketch}</div>'))