# Initial experiments Part 2

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from dataset import QuickDrawDataset
from utils import DeltaPenPositionTokenizer
from prepare_data import stroke_to_rdp, stroke_to_bezier_single, clean_svg
from tqdm import tqdm
import pickle

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

seed = 42
torch.manual_seed(seed)
if device == "cuda":
    torch.cuda.manual_seed_all(seed)

Using device: cuda


In [16]:
labels = ["bulldozer"]
training_data = QuickDrawDataset(labels=labels, download=True)
tokenizer = DeltaPenPositionTokenizer(bins=32)


class SketchDataset(Dataset):
    def __init__(
        self,
        svg_list,
        tokenizer,
        max_len=200,
        cache_file="sketch_bulldozer_dataset.pkl",
    ):
        self.data = []
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.pad_id = tokenizer.vocab["PAD"]

        # Try to load from cache
        try:
            with open(cache_file, "rb") as f:
                self.data = pickle.load(f)
            print(f"Loaded tokenized data from {cache_file}")
        except FileNotFoundError:
            for svg in tqdm(svg_list, desc="Tokenizing SVGs"):
                # use RDP to reduce number of points in SVG
                # svg = stroke_to_rdp(svg, epsilon=1.0)  # tuning
                tokens = tokenizer.encode(svg)
                tokens = tokens[:max_len]
                tokens = tokens + [self.pad_id] * (max_len - len(tokens))
                self.data.append(tokens)

            with open(cache_file, "wb") as f:
                pickle.dump(self.data, f)
            print(f"Saved tokenized data to {cache_file}")

    def __getitem__(self, idx):
        seq = self.data[idx]
        input_ids = torch.tensor(seq[:-1])
        target_ids = torch.tensor(seq[1:])
        return input_ids, target_ids

    def __len__(self):
        return len(self.data)


dataset = SketchDataset(training_data, tokenizer)

Downloading QuickDraw files: 100%|██████████| 1/1 [00:00<00:00, 4476.31it/s]
Loading QuickDraw files: 100%|██████████| 1/1 [00:04<00:00,  4.86s/it]


Loaded tokenized data from sketch_bulldozer_dataset.pkl


In [24]:
def generate_square_subsequent_mask(sz: int):
    """Causal mask to stop attention to future positions"""
    return torch.triu(torch.ones(sz, sz), diagonal=1).bool()


class SketchTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=384, nhead=8, num_layers=6, max_len=200):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_len = max_len

        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=4 * d_model
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        """
        x: (batch, seq_len) input tokens
        Returns: (batch, seq_len, vocab_size) logits
        """
        batch_size, seq_len = x.shape
        positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0)
        x = self.embed(x) + self.pos_embed(positions)  # (batch, seq_len, d_model)
        x = x.transpose(0, 1)  # -> (seq_len, batch, d_model)
        mask = generate_square_subsequent_mask(seq_len).to(x.device)  # causal mask (seq_len, seq_len)
        x = self.transformer(x, mask=mask)  # (seq_len, batch, d_model)
        x = x.transpose(0, 1)  # back to (batch, seq_len, d_model)
        logits = self.fc_out(x)  # (batch, seq_len, vocab_size)
        return logits


def train_model(model, dataloader, vocab_size, epochs=10, lr=1e-4, device="cuda"):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for input_ids, target_ids in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            input_ids, target_ids = input_ids.to(device), target_ids.to(device)

            # Transformer expects shape (seq_len, batch, d_model)
            logits = model(input_ids)  # (seq_len, batch, vocab_size)
            loss = criterion(logits.view(-1, vocab_size), target_ids.view(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1} Loss: {total_loss/len(dataloader):.4f}")


dataloader = DataLoader(dataset, batch_size=128, shuffle=True, pin_memory=True)
model = SketchTransformer(
    vocab_size=len(tokenizer.vocab), d_model=512, nhead=8, num_layers=6
)

# d_model => model capacity (types of drawing features it can learn)
# nhead => model can attend to more positions in parallel
# num layers => model learns more hierarchical abstractions (patterns, shapes , layouts)
model = torch.load("sketch_transformer_model_bulldozer_v2_deep.pth", map_location=device, weights_only=False)

train_model(
    model,
    dataloader,
    vocab_size=len(tokenizer.vocab),
    epochs=15,
    lr=1e-4,
    device=device,
)

Epoch 1/15:  44%|████▍     | 575/1293 [02:31<03:08,  3.80it/s]


KeyboardInterrupt: 

In [18]:
torch.save(model, "sketch_transformer_model_bulldozer_v2_deep.pth")
# model = torch.load("sketch_transformer_model_bulldozer_v2_deep.pth", map_location=device, weights_only=False)

In [39]:
def top_p_filtering(logits, p=0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    sorted_indices_to_remove = cumulative_probs > p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = False
    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    logits[:, indices_to_remove] = -float("Inf")
    return logits

def top_k_filtering(logits, k):
    if k <= 0:
        return logits
    top_k = min(k, logits.size(-1))
    values, _ = torch.topk(logits, top_k)
    min_values = values[:, -1].unsqueeze(-1)
    logits[logits < min_values] = -float("Inf")
    return logits


def sample_sequence_feat(
    model,
    start_tokens,
    max_len=200,
    temperature=0.6,
    top_k=20,
    top_p=0.9,
    greedy=False,
    eos_id=None,
    device="cuda",
):
    model.eval()

    tokens = list(start_tokens)
    tokens_tensor = torch.tensor([tokens], device=device, dtype=torch.long)

    for _ in range(max_len - len(tokens)):
        with torch.no_grad():
            logits = model(tokens_tensor)
            next_logits = logits[:, -1, :] / temperature
            
            # top-k / top-p filtering
            next_logits = top_k_filtering(next_logits, top_k)
            next_logits = top_p_filtering(next_logits, top_p)
            probs = F.softmax(next_logits, dim=-1)
            if greedy:
                next_token = torch.argmax(probs, dim=-1).item()
            else:
                next_token = torch.multinomial(probs, num_samples=1).item()

        tokens.append(next_token)
        if eos_id is not None and next_token == eos_id:
            break

        next_token_tensor = torch.tensor([[next_token]], device=device)
        tokens_tensor = torch.cat([tokens_tensor, next_token_tensor], dim=1)

    return tokens


start_token = tokenizer.vocab["START"]
eos_token = tokenizer.vocab.get("END", None)
generations_inline = ""
generations = []

for i in range(5):
    generated = sample_sequence_feat(
        model,
        start_tokens=[start_token],
        max_len=200,
        greedy=False,
        eos_id=eos_token,
        device=device,
    )
    decoded_sketch = tokenizer.decode(generated, stroke_width=0.3)

    decoded_sketch = stroke_to_bezier_single(decoded_sketch)
    decoded_sketch = clean_svg(decoded_sketch)

    # print("Generated token sequence:", generated)
    # print("Decoded sketch:", decoded_sketch)
    generations_inline += f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Generated</b><br>{decoded_sketch}</div>'
    generations.append((generated, decoded_sketch))


from IPython.display import HTML, display
display(HTML(generations_inline))

# temp<=0.5 fairly deterministic

# temp=0.8, top_k=20, top_p=0.9 more variety but still coherent

# *note important features are usually preseved, but sketches are disorganized (number of curves hueristic does not work well)*
# temp=1.0, top_k=20, top_p=0.75 more variety, some incoherent sequences

#  *note that lower temp means less variety, notice that sequences begin to repete themselves more often*
# temp=0.55, top_k=20, top_p=0.9 good balance
# temp=0.6, top_k=30, top_p=0.9  good balance 

In [32]:


_, decoded_sketch = generations[0]

print(decoded_sketch)

s=  f'<div style="display:inline-block; width: 250px; background-color: white; margin-right:10px;"><b>Generated</b><br><svg viewBox="0 0 32 32" xmlns="http://www.w3.org/2000/svg"><g stroke-width="0.5"><path d="M1,1C0.7408154984157409,3.591845015842592 -0.6792945534187917,8.342651239837494 0.017986217977667418,10.820137820223326C0.5591579055484708,12.74295807949461 12.073717497558519,12.845372513738523 13.44046763114847,11.55953236885153C14.199687379871902,10.84525752979562 14,7.609286420644203 14,6.856740080841096C14,6.752727727568782 14.01850823414632,2.0619400543864974 13.944974291716022,1.9449742917160215C12.53040137323382,-0.30509682911510727 4.278372225697419,1 2,1M3,15C1.4396482348577961,15 -1.9999479232739819,16.90419215754308 1.0934655498836956,18C3.9309313879904595,19.005141195989907 6.966702610306262,15 3,15M14,14C12.621294308542211,14 9.911625114048926,16.697230957687715 11.805658383430288,17.805658383430288C15.453665238459607,19.9405474124557 17.74133919794931,14.87066959897466 14,13M15,8C17,9.999999999999998 18.999999999999996,11.999999999999996 21,14C22.331124476802106,10.00662656959369 24.647858361377082,7 29,7C29.906971910535763,8.81394382107152 30.061039917684287,15.60844758515047 26.305246312594896,15C23.265947749119555,14.507626332556725 24,8.985826688339161 24,7C23.66666666666666,7.333333333333327 23.33333333333334,7.666666666666672 23,8C23,8.333333333333332 23,8.666666666666668 23,9M22,10C22,9.666666666666666 22,9.333333333333334 22,9" fill="none" stroke="#000000" /></g></svg></div>'

display(HTML(s))

<svg viewBox="0 0 32 32" xmlns="http://www.w3.org/2000/svg"><g stroke-width="0.3"><path d="M1,1C0.7408154984157409,3.591845015842592 -0.6792945534187917,8.342651239837494 0.017986217977667418,10.820137820223326C0.5591579055484708,12.74295807949461 12.073717497558519,12.845372513738523 13.44046763114847,11.55953236885153C14.199687379871902,10.84525752979562 14,7.609286420644203 14,6.856740080841096C14,6.752727727568782 14.01850823414632,2.0619400543864974 13.944974291716022,1.9449742917160215C12.53040137323382,-0.30509682911510727 4.278372225697419,1 2,1M3,15C1.4396482348577961,15 -1.9999479232739819,16.90419215754308 1.0934655498836956,18C3.9309313879904595,19.005141195989907 6.966702610306262,15 3,15M14,14C12.621294308542211,14 9.911625114048926,16.697230957687715 11.805658383430288,17.805658383430288C15.453665238459607,19.9405474124557 17.74133919794931,14.87066959897466 14,13M15,8C17,9.999999999999998 18.999999999999996,11.999999999999996 21,14C22.331124476802106,10.00662656959369 2

In [None]:
# note many sketches have missing parts or incomplete shapes (step 1: get a base sketch) : check the number of paths

# psuedo hueuristic: count number of curves in SVG

from prepare_data import count_curves

# sort generations by number of curves
generations_inline = ""

generations_sorted = sorted(generations, key=lambda x: count_curves(x[1]), reverse=True)
for sketch in generations_sorted:
    generations_inline += f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Generated</b><br>{sketch[1]}</div>'

display(HTML(generations_inline))

In [None]:
# Select a sketch from the dataset, remove tokens and let the model complete it
selected_sketch = training_data[38946]

# tokenize and remove some tokens from the end
selected_tokens = tokenizer.encode(selected_sketch)

selected_tokens_partial = selected_tokens[: int(len(selected_tokens) * 0.5)]  # remove 50%
destroyed_sketch = tokenizer.decode(selected_tokens_partial)

comparison_inline = f'''<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Original</b><br>{selected_sketch}</div>
<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Partial</b><br>{destroyed_sketch}</div>'''

for i in range(5):
    generated_completion = sample_sequence_feat(
        model,
        start_tokens=selected_tokens_partial,
        max_len=200,
        temperature=0.6,
        top_k=20,
        top_p=0.9,
        greedy=False,
        eos_id=eos_token,
        device=device,
    )

    generated_sketch = tokenizer.decode(generated_completion, stroke_width=0.3)
    generated_sketch = stroke_to_bezier_single(generated_sketch)
    generated_sketch = clean_svg(generated_sketch)
    comparison_inline += f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Completed {i}</b><br>{generated_sketch}</div>'

display(HTML(comparison_inline))

In [None]:
def compute_perplexity(model, tokens):
    model.eval()
    with torch.no_grad():
        logits = model(tokens[:, :-1])
        target = tokens[:, 1:]
        loss = F.cross_entropy(
            logits.transpose(1, 2),  # (batch, vocab, seq_len)
            target,
            reduction='none',
        )

        loss = loss.mean(dim=1)
        perplexity = torch.exp(loss)
        return perplexity

from IPython.display import HTML, display


sketch_perplexities = []

for i in tqdm(range(20), desc="Computing perplexities"):
    sketch = training_data[i]
    tokens = tokenizer.encode(sketch)
    perplexity = compute_perplexity(model, torch.tensor([tokens], device=device))
    decoded_sketch = tokenizer.decode(tokens, stroke_width=0.3)
    sketch_perplexities.append((perplexity.item(), decoded_sketch))
    
# sort by perplexity
sketch_perplexities.sort(key=lambda x: x[0], reverse=True)

# sort normalized by length
# sketch_perplexities.sort(key=lambda x: x[0] / len(x[1]), reverse=True)

sketches_inline = ""
for perp, sketch in sketch_perplexities:
    sketches_inline += f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Perplexity: {perp:.2f}</b><br>{sketch}</div>'
display(HTML(sketches_inline))

# Sorting by perplexity does seem to highlight some of the worse sketches
