# Initial experiments Part 3

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from dataset import TUBerlinDataset, SketchyDataset
from utils import AbsoluteBezierPenPositionTokenizer
from prepare_data import stroke_to_bezier, convert_and_quantize_svg
from tqdm import tqdm
import pickle

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

seed = 42
torch.manual_seed(seed)
if device == "cuda":
    torch.cuda.manual_seed_all(seed)
    torch.cuda.empty_cache()

Using device: cuda


In [3]:
labels = ["cat"]
training_data0 = TUBerlinDataset(labels, download=True)
training_data1 = SketchyDataset(labels, download=True)
tokenizer = AbsoluteBezierPenPositionTokenizer(bins=32)

class BezierSketchDataset(Dataset):
    def __init__(
        self,
        svg_lists,
        tokenizer,
        max_len=400,
        cache_file="sketch_bezier_dataset.pkl",
    ):
        self.data = []
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.pad_id = tokenizer.vocab["PAD"]
        max_num_tokens = 0

        # Try to load from cache
        try:
            with open(cache_file, "rb") as f:
                self.data = pickle.load(f)
            print(f"Loaded tokenized data from {cache_file}")
        except FileNotFoundError:
            for svg_list in svg_lists:
                for svg in tqdm(svg_list, desc="Tokenizing SVGs"):
                    svg = stroke_to_bezier(svg, num_samples=100, maxError=15.0)
                    q = convert_and_quantize_svg(svg, bins=256)
                    tokens = tokenizer.encode(q)

                    max_num_tokens = max(max_num_tokens, len(tokens))

                    # Truncate + Pad
                    tokens = tokens[:max_len]
                    tokens = tokens + [self.pad_id] * (max_len - len(tokens))
                    self.data.append(tokens)
                    
            print(f"Max number of tokens in a sequence: {max_num_tokens}")

            with open(cache_file, "wb") as f:
                pickle.dump(self.data, f)
            print(f"Saved tokenized data to {cache_file}")

    def __getitem__(self, idx):
        seq = self.data[idx]
        input_ids = torch.tensor(seq[:-1])
        target_ids = torch.tensor(seq[1:])
        return input_ids, target_ids

    def __len__(self):
        return len(self.data)


dataset = BezierSketchDataset(
    svg_lists=[training_data0, training_data1],
    tokenizer=tokenizer,
)

Downloading TUBerlin files


Loading TUBerlin files: 100%|██████████| 1/1 [00:00<00:00, 90.45it/s]


Downloading Sketchy files


Loading Sketchy files: 100%|██████████| 1/1 [00:03<00:00,  3.06s/it]
Tokenizing SVGs: 100%|██████████| 80/80 [00:09<00:00,  8.55it/s]
Tokenizing SVGs: 100%|██████████| 508/508 [01:22<00:00,  6.15it/s]

Max number of tokens in a sequence: 920
Saved tokenized data to sketch_bezier_dataset.pkl





In [4]:
# from IPython.display import HTML, display

# svg_inline = ""
# for i in range(len(dataset)):
#     input_ids, target_ids, seq = dataset[i]
    
#     if len(seq) > 300:
#         svg = tokenizer.decode(seq)
#         svg_inline += f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Length: {len(seq)}</b><br>{svg}</div>'

# display(HTML(svg_inline))

In [5]:
def generate_square_subsequent_mask(sz: int):
    """Causal mask to stop attention to future positions"""
    return torch.triu(torch.ones(sz, sz), diagonal=1).bool()


class SketchTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=384, nhead=8, num_layers=6, max_len=400):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_len = max_len

        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=4 * d_model
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        """
        x: (batch, seq_len) input tokens
        Returns: (batch, seq_len, vocab_size) logits
        """
        batch_size, seq_len = x.shape
        positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0)
        x = self.embed(x) + self.pos_embed(positions)  # (batch, seq_len, d_model)
        x = x.transpose(0, 1)  # -> (seq_len, batch, d_model)
        mask = generate_square_subsequent_mask(seq_len).to(x.device)  # causal mask (seq_len, seq_len)
        x = self.transformer(x, mask=mask)  # (seq_len, batch, d_model)
        x = x.transpose(0, 1)  # back to (batch, seq_len, d_model)
        logits = self.fc_out(x)  # (batch, seq_len, vocab_size)
        return logits


def train_model(model, dataloader, vocab_size, epochs=10, lr=1e-4, device="cuda"):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for input_ids, target_ids in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            input_ids, target_ids = input_ids.to(device), target_ids.to(device)

            # Transformer expects shape (seq_len, batch, d_model)
            logits = model(input_ids)  # (seq_len, batch, vocab_size)
            loss = criterion(logits.view(-1, vocab_size), target_ids.view(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1} Loss: {total_loss/len(dataloader):.4f}")


dataloader = DataLoader(dataset, batch_size=128, shuffle=True, pin_memory=True)
model = SketchTransformer(
    vocab_size=len(tokenizer.vocab), d_model=384, nhead=8, num_layers=6
)

# d_model => model capacity (types of drawing features it can learn)
# nhead => model can attend to more positions in parallel
# num layers => model learns more hierarchical abstractions (patterns, shapes , layouts)

train_model(
    model,
    dataloader,
    vocab_size=len(tokenizer.vocab),
    epochs=100,
    lr=1e-4,
    device=device,
)

Epoch 1/100: 100%|██████████| 5/5 [00:02<00:00,  2.44it/s]


Epoch 1 Loss: 5.0324


Epoch 2/100: 100%|██████████| 5/5 [00:01<00:00,  2.96it/s]


Epoch 2 Loss: 3.5292


Epoch 3/100: 100%|██████████| 5/5 [00:01<00:00,  3.01it/s]


Epoch 3 Loss: 3.2469


Epoch 4/100: 100%|██████████| 5/5 [00:01<00:00,  2.99it/s]


Epoch 4 Loss: 3.0555


Epoch 5/100: 100%|██████████| 5/5 [00:01<00:00,  3.00it/s]


Epoch 5 Loss: 3.0129


Epoch 6/100: 100%|██████████| 5/5 [00:01<00:00,  2.92it/s]


Epoch 6 Loss: 2.9547


Epoch 7/100: 100%|██████████| 5/5 [00:01<00:00,  2.93it/s]


Epoch 7 Loss: 2.9319


Epoch 8/100: 100%|██████████| 5/5 [00:01<00:00,  2.92it/s]


Epoch 8 Loss: 2.9435


Epoch 9/100: 100%|██████████| 5/5 [00:01<00:00,  2.91it/s]


Epoch 9 Loss: 2.8890


Epoch 10/100: 100%|██████████| 5/5 [00:01<00:00,  2.91it/s]


Epoch 10 Loss: 2.8759


Epoch 11/100: 100%|██████████| 5/5 [00:01<00:00,  2.92it/s]


Epoch 11 Loss: 2.8867


Epoch 12/100: 100%|██████████| 5/5 [00:01<00:00,  2.92it/s]


Epoch 12 Loss: 2.8921


Epoch 13/100: 100%|██████████| 5/5 [00:01<00:00,  2.92it/s]


Epoch 13 Loss: 2.8740


Epoch 14/100: 100%|██████████| 5/5 [00:01<00:00,  2.91it/s]


Epoch 14 Loss: 2.8697


Epoch 15/100: 100%|██████████| 5/5 [00:01<00:00,  2.94it/s]


Epoch 15 Loss: 2.8683


Epoch 16/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 16 Loss: 2.8531


Epoch 17/100: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


Epoch 17 Loss: 2.8615


Epoch 18/100: 100%|██████████| 5/5 [00:01<00:00,  2.92it/s]


Epoch 18 Loss: 2.8283


Epoch 19/100: 100%|██████████| 5/5 [00:01<00:00,  2.91it/s]


Epoch 19 Loss: 2.8443


Epoch 20/100: 100%|██████████| 5/5 [00:01<00:00,  2.91it/s]


Epoch 20 Loss: 2.8330


Epoch 21/100: 100%|██████████| 5/5 [00:01<00:00,  2.91it/s]


Epoch 21 Loss: 2.8156


Epoch 22/100: 100%|██████████| 5/5 [00:01<00:00,  2.94it/s]


Epoch 22 Loss: 2.8193


Epoch 23/100: 100%|██████████| 5/5 [00:01<00:00,  2.93it/s]


Epoch 23 Loss: 2.8133


Epoch 24/100: 100%|██████████| 5/5 [00:01<00:00,  2.93it/s]


Epoch 24 Loss: 2.7968


Epoch 25/100: 100%|██████████| 5/5 [00:01<00:00,  2.93it/s]


Epoch 25 Loss: 2.7873


Epoch 26/100: 100%|██████████| 5/5 [00:01<00:00,  2.93it/s]


Epoch 26 Loss: 2.7884


Epoch 27/100: 100%|██████████| 5/5 [00:01<00:00,  2.93it/s]


Epoch 27 Loss: 2.7860


Epoch 28/100: 100%|██████████| 5/5 [00:01<00:00,  2.93it/s]


Epoch 28 Loss: 2.7727


Epoch 29/100: 100%|██████████| 5/5 [00:01<00:00,  2.92it/s]


Epoch 29 Loss: 2.7741


Epoch 30/100: 100%|██████████| 5/5 [00:01<00:00,  2.92it/s]


Epoch 30 Loss: 2.7463


Epoch 31/100: 100%|██████████| 5/5 [00:01<00:00,  2.93it/s]


Epoch 31 Loss: 2.7375


Epoch 32/100: 100%|██████████| 5/5 [00:01<00:00,  2.92it/s]


Epoch 32 Loss: 2.7410


Epoch 33/100: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


Epoch 33 Loss: 2.7147


Epoch 34/100: 100%|██████████| 5/5 [00:01<00:00,  2.93it/s]


Epoch 34 Loss: 2.6982


Epoch 35/100: 100%|██████████| 5/5 [00:01<00:00,  2.92it/s]


Epoch 35 Loss: 2.6784


Epoch 36/100: 100%|██████████| 5/5 [00:01<00:00,  2.92it/s]


Epoch 36 Loss: 2.6557


Epoch 37/100: 100%|██████████| 5/5 [00:01<00:00,  2.91it/s]


Epoch 37 Loss: 2.6612


Epoch 38/100: 100%|██████████| 5/5 [00:01<00:00,  2.92it/s]


Epoch 38 Loss: 2.6271


Epoch 39/100: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


Epoch 39 Loss: 2.6202


Epoch 40/100: 100%|██████████| 5/5 [00:01<00:00,  2.91it/s]


Epoch 40 Loss: 2.5851


Epoch 41/100: 100%|██████████| 5/5 [00:01<00:00,  2.91it/s]


Epoch 41 Loss: 2.5707


Epoch 42/100: 100%|██████████| 5/5 [00:01<00:00,  2.91it/s]


Epoch 42 Loss: 2.5383


Epoch 43/100: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


Epoch 43 Loss: 2.5242


Epoch 44/100: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


Epoch 44 Loss: 2.5004


Epoch 45/100: 100%|██████████| 5/5 [00:01<00:00,  2.91it/s]


Epoch 45 Loss: 2.4733


Epoch 46/100: 100%|██████████| 5/5 [00:01<00:00,  2.91it/s]


Epoch 46 Loss: 2.4477


Epoch 47/100: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


Epoch 47 Loss: 2.4264


Epoch 48/100: 100%|██████████| 5/5 [00:01<00:00,  2.91it/s]


Epoch 48 Loss: 2.4018


Epoch 49/100: 100%|██████████| 5/5 [00:01<00:00,  2.91it/s]


Epoch 49 Loss: 2.3955


Epoch 50/100: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


Epoch 50 Loss: 2.3576


Epoch 51/100: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


Epoch 51 Loss: 2.3318


Epoch 52/100: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


Epoch 52 Loss: 2.3069


Epoch 53/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 53 Loss: 2.2942


Epoch 54/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 54 Loss: 2.2698


Epoch 55/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 55 Loss: 2.2358


Epoch 56/100: 100%|██████████| 5/5 [00:01<00:00,  2.91it/s]


Epoch 56 Loss: 2.2261


Epoch 57/100: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


Epoch 57 Loss: 2.1959


Epoch 58/100: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


Epoch 58 Loss: 2.1791


Epoch 59/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 59 Loss: 2.1607


Epoch 60/100: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


Epoch 60 Loss: 2.1355


Epoch 61/100: 100%|██████████| 5/5 [00:01<00:00,  2.86it/s]


Epoch 61 Loss: 2.1168


Epoch 62/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 62 Loss: 2.0913


Epoch 63/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 63 Loss: 2.0673


Epoch 64/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 64 Loss: 2.0457


Epoch 65/100: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


Epoch 65 Loss: 2.0293


Epoch 66/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 66 Loss: 2.0139


Epoch 67/100: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


Epoch 67 Loss: 1.9875


Epoch 68/100: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


Epoch 68 Loss: 1.9840


Epoch 69/100: 100%|██████████| 5/5 [00:01<00:00,  2.88it/s]


Epoch 69 Loss: 1.9612


Epoch 70/100: 100%|██████████| 5/5 [00:01<00:00,  2.88it/s]


Epoch 70 Loss: 1.9340


Epoch 71/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 71 Loss: 1.9274


Epoch 72/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 72 Loss: 1.9063


Epoch 73/100: 100%|██████████| 5/5 [00:01<00:00,  2.87it/s]


Epoch 73 Loss: 1.8843


Epoch 74/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 74 Loss: 1.8856


Epoch 75/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 75 Loss: 1.8625


Epoch 76/100: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


Epoch 76 Loss: 1.8483


Epoch 77/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 77 Loss: 1.8373


Epoch 78/100: 100%|██████████| 5/5 [00:01<00:00,  2.88it/s]


Epoch 78 Loss: 1.8315


Epoch 79/100: 100%|██████████| 5/5 [00:01<00:00,  2.88it/s]


Epoch 79 Loss: 1.8119


Epoch 80/100: 100%|██████████| 5/5 [00:01<00:00,  2.88it/s]


Epoch 80 Loss: 1.7943


Epoch 81/100: 100%|██████████| 5/5 [00:01<00:00,  2.88it/s]


Epoch 81 Loss: 1.7798


Epoch 82/100: 100%|██████████| 5/5 [00:01<00:00,  2.88it/s]


Epoch 82 Loss: 1.7769


Epoch 83/100: 100%|██████████| 5/5 [00:01<00:00,  2.88it/s]


Epoch 83 Loss: 1.7685


Epoch 84/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 84 Loss: 1.7494


Epoch 85/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 85 Loss: 1.7495


Epoch 86/100: 100%|██████████| 5/5 [00:01<00:00,  2.88it/s]


Epoch 86 Loss: 1.7304


Epoch 87/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 87 Loss: 1.7274


Epoch 88/100: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


Epoch 88 Loss: 1.7146


Epoch 89/100: 100%|██████████| 5/5 [00:01<00:00,  2.88it/s]


Epoch 89 Loss: 1.6962


Epoch 90/100: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


Epoch 90 Loss: 1.6958


Epoch 91/100: 100%|██████████| 5/5 [00:01<00:00,  2.87it/s]


Epoch 91 Loss: 1.6858


Epoch 92/100: 100%|██████████| 5/5 [00:01<00:00,  2.88it/s]


Epoch 92 Loss: 1.6759


Epoch 93/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 93 Loss: 1.6654


Epoch 94/100: 100%|██████████| 5/5 [00:01<00:00,  2.87it/s]


Epoch 94 Loss: 1.6560


Epoch 95/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 95 Loss: 1.6556


Epoch 96/100: 100%|██████████| 5/5 [00:01<00:00,  2.88it/s]


Epoch 96 Loss: 1.6527


Epoch 97/100: 100%|██████████| 5/5 [00:01<00:00,  2.86it/s]


Epoch 97 Loss: 1.6428


Epoch 98/100: 100%|██████████| 5/5 [00:01<00:00,  2.88it/s]


Epoch 98 Loss: 1.6276


Epoch 99/100: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


Epoch 99 Loss: 1.6174


Epoch 100/100: 100%|██████████| 5/5 [00:01<00:00,  2.89it/s]

Epoch 100 Loss: 1.6209





In [14]:
def top_p_filtering(logits, p=0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    sorted_indices_to_remove = cumulative_probs > p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = False
    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    logits[:, indices_to_remove] = -float("Inf")
    return logits

def top_k_filtering(logits, k):
    if k <= 0:
        return logits
    top_k = min(k, logits.size(-1))
    values, _ = torch.topk(logits, top_k)
    min_values = values[:, -1].unsqueeze(-1)
    logits[logits < min_values] = -float("Inf")
    return logits


def sample_sequence_feat(
    model,
    start_token,
    max_len=200,
    temperature=1.0,
    top_k=60,
    top_p=0.9,
    greedy=False,
    eos_id=None,
    device="cuda",
):
    model.eval()
    tokens = [start_token]
    tokens_tensor = torch.tensor([tokens], device=device)

    for _ in range(max_len - 1):
        with torch.no_grad():
            logits = model(tokens_tensor)
            next_logits = logits[:, -1, :] / temperature

            # top-k / top-p filtering
            next_logits = top_k_filtering(next_logits, top_k)
            next_logits = top_p_filtering(next_logits, top_p)
            probs = F.softmax(next_logits, dim=-1)

            if greedy:
                next_token = torch.argmax(probs, dim=-1).item()
            else:
                next_token = torch.multinomial(probs, num_samples=1).item()

        tokens.append(next_token)

        if eos_id is not None and next_token == eos_id:
            break

        next_token_tensor = torch.tensor([[next_token]], device=device)
        tokens_tensor = torch.cat([tokens_tensor, next_token_tensor], dim=1)

    return tokens


start_token = tokenizer.vocab["START"]
eos_token = tokenizer.vocab.get("END", None)
generations_inline = ""

for i in range(5):
    generated = sample_sequence_feat(
        model,
        start_token,
        max_len=200,
        temperature=1.0,
        greedy=False,
        eos_id=eos_token,
        device=device,
    )
    decoded_sketch = tokenizer.decode(generated, stroke_width=0.3)
    
    # print("Generated token sequence:", generated)
    # print("Decoded sketch:", decoded_sketch)
    generations_inline += f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Generated</b><br>{decoded_sketch}</div>'


from IPython.display import HTML, display
display(HTML(generations_inline))