# Sketch feature prediction

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from dataset import QuickDrawDataset
from utils import DeltaPenPositionTokenizer
from prepare_data import stroke_to_rdp
from tqdm import tqdm
import pickle
import pandas as pd

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

seed = 42
torch.manual_seed(seed)
if device == "cuda":
    torch.cuda.manual_seed_all(seed)

Using device: cuda


In [3]:
labels = ["cat"]
training_data = QuickDrawDataset(labels=labels, download=True)
tokenizer = DeltaPenPositionTokenizer(bins=64)


class SketchDataset(Dataset):
    def __init__(
        self,
        svg_list,
        tokenizer,
        max_len=200,
        cache_file="sketch_new.pkl",
        csv_path="dataset_splitter/quickdrawdataset_marked.csv",
    ):
        self.data = []
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.pad_id = tokenizer.vocab["PAD"]

        # Try to load from cache
        try:
            with open(cache_file, "rb") as f:
                self.data = pickle.load(f)
            print(f"Loaded tokenized data from {cache_file}")
        except FileNotFoundError:
            for svg in tqdm(svg_list, desc="Tokenizing SVGs"):
                svg = stroke_to_rdp(svg, epsilon=2.0)  # tuning
                tokens = tokenizer.encode(svg)
                tokens = tokens[:max_len]
                tokens = tokens + [self.pad_id] * (max_len - len(tokens))
                self.data.append(tokens)

            with open(cache_file, "wb") as f:
                pickle.dump(self.data, f)
            print(f"Saved tokenized data to {cache_file}")

        # load labels
        df = pd.read_csv(csv_path)
        self.recognizable = df["recognizable"].astype(int).tolist()
        self.feature_complete = df["feature_complete"].astype(int).tolist()
        self.data_cp = self.data[:len(self.recognizable)]

    def __getitem__(self, idx):
        seq = torch.tensor(self.data_cp[idx], dtype=torch.long)
        labels = torch.tensor(
            [self.recognizable[idx], self.feature_complete[idx]], dtype=torch.float
        )
        return seq, labels

    def __len__(self):
        return len(self.data_cp)

dataset = SketchDataset(training_data, tokenizer)

Downloading QuickDraw files: 100%|██████████| 1/1 [00:00<00:00, 5592.41it/s]
Loading QuickDraw files: 100%|██████████| 1/1 [00:03<00:00,  3.23s/it]
Tokenizing SVGs: 100%|██████████| 103031/103031 [02:04<00:00, 827.17it/s]


Saved tokenized data to sketch_new.pkl


In [4]:
class SketchClassifier(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(d_model, 2)  # 2 outputs: recognizable & complete
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x: [B, T]
        x = self.embedding(x)  # [B, T, D]
        x = x.transpose(0, 1)  # Transformer expects [T, B, D]
        x = self.transformer(x)  # [T, B, D]
        x = x.transpose(0, 1)  # [B, T, D]
        x = x.mean(dim=1)      # simple average pooling
        x = self.fc(x)
        return self.sigmoid(x)


dataloader = DataLoader(dataset, batch_size=128, shuffle=True, pin_memory=True)
model = SketchClassifier(vocab_size=len(tokenizer.vocab)).to(device)

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(60):
    model.train()
    total_loss = 0.0
    for seqs, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
        seqs, labels = seqs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(seqs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} loss: {total_loss/len(dataloader):.4f}")

Epoch 1: 100%|██████████| 8/8 [00:00<00:00,  8.57it/s]


Epoch 1 loss: 0.6199


Epoch 2: 100%|██████████| 8/8 [00:00<00:00, 11.62it/s]


Epoch 2 loss: 0.5980


Epoch 3: 100%|██████████| 8/8 [00:00<00:00, 11.56it/s]


Epoch 3 loss: 0.5965


Epoch 4: 100%|██████████| 8/8 [00:00<00:00, 11.59it/s]


Epoch 4 loss: 0.5873


Epoch 5: 100%|██████████| 8/8 [00:00<00:00, 11.58it/s]


Epoch 5 loss: 0.5939


Epoch 6: 100%|██████████| 8/8 [00:00<00:00, 11.58it/s]


Epoch 6 loss: 0.5839


Epoch 7: 100%|██████████| 8/8 [00:00<00:00, 11.57it/s]


Epoch 7 loss: 0.5929


Epoch 8: 100%|██████████| 8/8 [00:00<00:00, 11.57it/s]


Epoch 8 loss: 0.5908


Epoch 9: 100%|██████████| 8/8 [00:00<00:00, 11.54it/s]


Epoch 9 loss: 0.5863


Epoch 10: 100%|██████████| 8/8 [00:00<00:00, 11.57it/s]


Epoch 10 loss: 0.5790


Epoch 11: 100%|██████████| 8/8 [00:00<00:00, 11.58it/s]


Epoch 11 loss: 0.5840


Epoch 12: 100%|██████████| 8/8 [00:00<00:00, 11.59it/s]


Epoch 12 loss: 0.5949


Epoch 13: 100%|██████████| 8/8 [00:00<00:00, 11.57it/s]


Epoch 13 loss: 0.5757


Epoch 14: 100%|██████████| 8/8 [00:00<00:00, 11.57it/s]


Epoch 14 loss: 0.5827


Epoch 15: 100%|██████████| 8/8 [00:00<00:00, 11.56it/s]


Epoch 15 loss: 0.5798


Epoch 16: 100%|██████████| 8/8 [00:00<00:00, 11.56it/s]


Epoch 16 loss: 0.5997


Epoch 17: 100%|██████████| 8/8 [00:00<00:00, 11.57it/s]


Epoch 17 loss: 0.5864


Epoch 18: 100%|██████████| 8/8 [00:00<00:00, 11.55it/s]


Epoch 18 loss: 0.5875


Epoch 19: 100%|██████████| 8/8 [00:00<00:00, 11.58it/s]


Epoch 19 loss: 0.5756


Epoch 20: 100%|██████████| 8/8 [00:00<00:00, 11.58it/s]


Epoch 20 loss: 0.5927


Epoch 21: 100%|██████████| 8/8 [00:00<00:00, 11.53it/s]


Epoch 21 loss: 0.5845


Epoch 22: 100%|██████████| 8/8 [00:00<00:00, 11.52it/s]


Epoch 22 loss: 0.5844


Epoch 23: 100%|██████████| 8/8 [00:00<00:00, 11.52it/s]


Epoch 23 loss: 0.5863


Epoch 24: 100%|██████████| 8/8 [00:00<00:00, 11.53it/s]


Epoch 24 loss: 0.6010


Epoch 25: 100%|██████████| 8/8 [00:00<00:00, 11.53it/s]


Epoch 25 loss: 0.5855


Epoch 26: 100%|██████████| 8/8 [00:00<00:00, 11.52it/s]


Epoch 26 loss: 0.5799


Epoch 27: 100%|██████████| 8/8 [00:00<00:00, 11.51it/s]


Epoch 27 loss: 0.5714


Epoch 28: 100%|██████████| 8/8 [00:00<00:00, 11.52it/s]


Epoch 28 loss: 0.5871


Epoch 29: 100%|██████████| 8/8 [00:00<00:00, 11.55it/s]


Epoch 29 loss: 0.5865


Epoch 30: 100%|██████████| 8/8 [00:00<00:00, 11.54it/s]


Epoch 30 loss: 0.5903


Epoch 31: 100%|██████████| 8/8 [00:00<00:00, 11.53it/s]


Epoch 31 loss: 0.5746


Epoch 32: 100%|██████████| 8/8 [00:00<00:00, 11.53it/s]


Epoch 32 loss: 0.5826


Epoch 33: 100%|██████████| 8/8 [00:00<00:00, 11.52it/s]


Epoch 33 loss: 0.5803


Epoch 34: 100%|██████████| 8/8 [00:00<00:00, 11.53it/s]


Epoch 34 loss: 0.5703


Epoch 35: 100%|██████████| 8/8 [00:00<00:00, 11.54it/s]


Epoch 35 loss: 0.5756


Epoch 36: 100%|██████████| 8/8 [00:00<00:00, 11.50it/s]


Epoch 36 loss: 0.5533


Epoch 37: 100%|██████████| 8/8 [00:00<00:00, 11.50it/s]


Epoch 37 loss: 0.5568


Epoch 38: 100%|██████████| 8/8 [00:00<00:00, 11.49it/s]


Epoch 38 loss: 0.5575


Epoch 39: 100%|██████████| 8/8 [00:00<00:00, 11.52it/s]


Epoch 39 loss: 0.6034


Epoch 40: 100%|██████████| 8/8 [00:00<00:00, 11.53it/s]


Epoch 40 loss: 0.5634


Epoch 41: 100%|██████████| 8/8 [00:00<00:00, 11.52it/s]


Epoch 41 loss: 0.5581


Epoch 42: 100%|██████████| 8/8 [00:00<00:00, 11.52it/s]


Epoch 42 loss: 0.5693


Epoch 43: 100%|██████████| 8/8 [00:00<00:00, 11.53it/s]


Epoch 43 loss: 0.5511


Epoch 44: 100%|██████████| 8/8 [00:00<00:00, 11.51it/s]


Epoch 44 loss: 0.5554


Epoch 45: 100%|██████████| 8/8 [00:00<00:00, 11.49it/s]


Epoch 45 loss: 0.5172


Epoch 46: 100%|██████████| 8/8 [00:00<00:00, 11.50it/s]


Epoch 46 loss: 0.4783


Epoch 47: 100%|██████████| 8/8 [00:00<00:00, 11.52it/s]


Epoch 47 loss: 0.4806


Epoch 48: 100%|██████████| 8/8 [00:00<00:00, 11.50it/s]


Epoch 48 loss: 0.4344


Epoch 49: 100%|██████████| 8/8 [00:00<00:00, 11.49it/s]


Epoch 49 loss: 0.4240


Epoch 50: 100%|██████████| 8/8 [00:00<00:00, 11.46it/s]


Epoch 50 loss: 0.3379


Epoch 51: 100%|██████████| 8/8 [00:00<00:00, 11.45it/s]


Epoch 51 loss: 0.3484


Epoch 52: 100%|██████████| 8/8 [00:00<00:00, 11.49it/s]


Epoch 52 loss: 0.2967


Epoch 53: 100%|██████████| 8/8 [00:00<00:00, 11.46it/s]


Epoch 53 loss: 0.2862


Epoch 54: 100%|██████████| 8/8 [00:00<00:00, 11.43it/s]


Epoch 54 loss: 0.2393


Epoch 55: 100%|██████████| 8/8 [00:00<00:00, 11.42it/s]


Epoch 55 loss: 0.2209


Epoch 56: 100%|██████████| 8/8 [00:00<00:00, 11.45it/s]


Epoch 56 loss: 0.1894


Epoch 57: 100%|██████████| 8/8 [00:00<00:00, 11.45it/s]


Epoch 57 loss: 0.1577


Epoch 58: 100%|██████████| 8/8 [00:00<00:00, 11.46it/s]


Epoch 58 loss: 0.1578


Epoch 59: 100%|██████████| 8/8 [00:00<00:00, 11.46it/s]


Epoch 59 loss: 0.1545


Epoch 60: 100%|██████████| 8/8 [00:00<00:00, 11.44it/s]

Epoch 60 loss: 0.1076





In [8]:
# torch.save(model, "sketch_transformer_model_cat_v2_deep.pth")
# model = torch.load("sketch_transformer_model_cat_v2_init.pth", map_location=device, weights_only=False)
from IPython.display import HTML, display
svgs_inline = ""

print(len(dataset.recognizable))

model.eval()

with torch.no_grad():
    for svg_tokens in tqdm(dataset.data[len(dataset.recognizable): len(dataset.recognizable)+15], desc="Evaluating on unlabeled data"):

        decoded_svg = tokenizer.decode(svg_tokens)
        x = torch.tensor(svg_tokens, dtype=torch.long).unsqueeze(0).to(device)
        output = model(x)
        output = output.squeeze(0).cpu().numpy()
        svgs_inline += f'<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>recognizable prob: {float(output[0])}, complete prob: {float(output[1])}</b><br>{decoded_svg}</div>'
        
display(HTML(svgs_inline))


908


Evaluating on unlabeled data: 100%|██████████| 15/15 [00:00<00:00, 101.21it/s]
