# Initial experiments: Denoising

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from dataset import QuickDrawDataset
from utils import AbsolutePenPositionTokenizer
from tqdm import tqdm
import pickle

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

seed = 42
torch.manual_seed(seed)
if device == "cuda":
    torch.cuda.manual_seed_all(seed)

Using device: cuda


In [9]:
labels = ["cat"]

training_data = QuickDrawDataset(
    labels=labels,
)

tokenizer = AbsolutePenPositionTokenizer(bins=64, additional_tokens=["MASK"])


class SketchDataset(Dataset):
    def __init__(
        self,
        svg_list,
        tokenizer,
        max_len=200,
        cache_file="sketch_tokenized_dataset.pkl",
    ):
        self.data = []
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.pad_id = tokenizer.vocab["PAD"]

        # Try to load from cache
        try:
            with open(cache_file, "rb") as f:
                self.data = pickle.load(f)
            print(f"Loaded tokenized data from {cache_file}")
        except FileNotFoundError:
            for svg in tqdm(svg_list, desc="Tokenizing SVGs"):
                tokens = tokenizer.encode(svg)
                # Truncate + pad
                tokens = tokens[:max_len]
                tokens = tokens + [self.pad_id] * (max_len - len(tokens))
                self.data.append(tokens)

            with open(cache_file, "wb") as f:
                pickle.dump(self.data, f)
            print(f"Saved tokenized data to {cache_file}")

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx])

    def __len__(self):
        return len(self.data)


dataset = SketchDataset(training_data, tokenizer, max_len=200)

Loading QuickDraw files: 100%|██████████| 1/1 [00:01<00:00,  1.84s/it]


Loaded tokenized data from sketch_tokenized_dataset.pkl


In [10]:
def corrupt_input(input_ids, tokenizer, mask_prob=0.2, dropout_prob=0.05):
    mask_token = tokenizer.vocab.get("MASK", None)
    pad_token = tokenizer.vocab["PAD"]
    assert mask_token is not None, "Tokenizer must have a [MASK] token"

    batch_size, seq_len = input_ids.shape
    device = input_ids.device

    rand = torch.rand(batch_size, seq_len, device=device)
    mask = (rand < mask_prob) & (input_ids != pad_token)
    dropout = (rand < dropout_prob) & (input_ids != pad_token)

    # Create masked input
    corrupted = input_ids.clone()
    corrupted[mask] = mask_token
    # Remove some tokens (simulate missing strokes or words)
    corrupted[dropout] = pad_token
    return corrupted


class SketchAutoencoder(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=6, max_len=200):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_len = max_len

        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=4 * d_model,
            activation="gelu",
            batch_first=True,  # much easier shape handling
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.norm = nn.LayerNorm(d_model)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        seq_len = x.size(1)
        pos = torch.arange(0, seq_len, device=x.device).unsqueeze(0)
        x = self.embed(x) + self.pos_embed(pos)
        h = self.encoder(x)
        h = self.norm(h)
        logits = self.fc_out(h)
        return logits

In [11]:
def train_denoising_autoencoder(
    model, dataloader, tokenizer, epochs=10, lr=1e-4, device="cuda"
):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    pad_token_id = tokenizer.vocab["PAD"]
    criterion = nn.CrossEntropyLoss(ignore_index=pad_token_id)
    vocab_size = len(tokenizer.vocab)

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for clean_ids in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            clean_ids = clean_ids.to(device)
            noisy_ids = corrupt_input(clean_ids, tokenizer)

            logits = model(noisy_ids)
            loss = criterion(logits.reshape(-1, vocab_size), clean_ids.reshape(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1} | DAE Loss: {total_loss / len(dataloader):.4f}")


dataloader = DataLoader(dataset, batch_size=64, shuffle=True, pin_memory=True)
model = SketchAutoencoder(
    vocab_size=len(tokenizer.vocab), d_model=256, nhead=8, num_layers=6
)

train_denoising_autoencoder(
    model, dataloader, tokenizer, epochs=20, lr=1e-3, device=device
)

Epoch 1/20: 100%|██████████| 1610/1610 [03:21<00:00,  7.98it/s]


Epoch 1 | DAE Loss: 1.5643


Epoch 2/20: 100%|██████████| 1610/1610 [03:24<00:00,  7.87it/s]


Epoch 2 | DAE Loss: 0.9862


Epoch 3/20: 100%|██████████| 1610/1610 [03:26<00:00,  7.78it/s]


Epoch 3 | DAE Loss: 0.9066


Epoch 4/20: 100%|██████████| 1610/1610 [03:28<00:00,  7.74it/s]


Epoch 4 | DAE Loss: 0.8709


Epoch 5/20: 100%|██████████| 1610/1610 [03:28<00:00,  7.73it/s]


Epoch 5 | DAE Loss: 0.8496


Epoch 6/20: 100%|██████████| 1610/1610 [03:27<00:00,  7.74it/s]


Epoch 6 | DAE Loss: 0.8335


Epoch 7/20:   3%|▎         | 45/1610 [00:05<03:25,  7.63it/s]


KeyboardInterrupt: 

In [None]:
from IPython.display import HTML, display

model.eval()
device = torch.device(device)
model = model.to(device)

svg_inline = ""
clean_seq = dataset[0].unsqueeze(0).to(device)  # (1, seq_len)

for _ in range(5):
    noisy_seq = corrupt_input(clean_seq, tokenizer, mask_prob=0.15, dropout_prob=0.05)

    with torch.no_grad():
        logits = model(noisy_seq)
        preds = torch.argmax(logits, dim=-1)  # (1, seq_len)

    denoised_seq = preds.squeeze(0).cpu().tolist()
    denoised_svg = tokenizer.decode(denoised_seq)
    svg_inline += f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>Generated</b><br>{denoised_svg}</div>'

display(HTML(svg_inline))

In [None]:
# TODO create a model with latent space and a model to sample from it