In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from dataset import QuickDrawDataset
from utils import svg_stroke5
from tqdm import tqdm

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
if device == "cuda":
    torch.cuda.manual_seed_all(seed)
    
def normalize_strokes(batch, scale=64.0):
    batch = batch.clone().float()
    # Normalize delta x, delta y
    batch[..., 0:2] = batch[..., 0:2] / scale
    # Leave pen_state as is (0, 1)
    return batch

def denormalize_strokes(batch, scale=64.0):
    batch = batch.clone()
    batch[..., 0:2] = batch[..., 0:2] * scale
    return batch



Using device: cuda


In [3]:
labels = ["cat"]

def to_tensor_svg_stroke5(svg_content):
    return svg_stroke5(svg_content, bins=64).to(dtype=torch.int8)

training_data = QuickDrawDataset(
    labels=labels,
    base_transform=to_tensor_svg_stroke5,
    cache_file="data/quickdraw_stoke5_cache.pt",
)

Loading QuickDraw files: 100%|██████████| 1/1 [00:03<00:00,  3.18s/it]


In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

# -------------------------
# Hyperparameters
# -------------------------
def default_hparams():
    return {
        "decay_rate": 0.9999,
        "grad_clip": 1.0,
        # KL annealing
        "kl_weight": 0.5,
        "kl_weight_start": 0.01,
        "kl_decay_rate": 0.99995,
        "kl_tolerance": 0.2,
        # architecture
        "z_size": 128,
        "enc_rnn_size": 256,
        "dec_rnn_size": 512,
        "num_mixture": 20,
        # data / training
        "max_seq_len": 200,
        "batch_size": 100,
        "learning_rate": 1e-3,
        "epochs": 25,
    }

# -------------------------
# Model
# -------------------------
class SketchRNNVAE(nn.Module):
    def __init__(self, hps):
        super().__init__()
        self.hps = hps
        in_dim = 5  # input vector size (dx, dy, p1, p2, p3)
        self.enc_rnn_size = hps["enc_rnn_size"]
        self.dec_rnn_size = hps["dec_rnn_size"]
        self.z_size = hps["z_size"]
        self.M = hps["num_mixture"]

        # Encoder: single-layer bidirectional LSTM (batch_first)
        self.encoder = nn.LSTM(input_size=in_dim, hidden_size=self.enc_rnn_size,
                               num_layers=1, batch_first=True, bidirectional=True)

        # Mu and logvar
        self.fc_mu = nn.Linear(self.enc_rnn_size * 2, self.z_size)
        self.fc_logvar = nn.Linear(self.enc_rnn_size * 2, self.z_size)

        # Decoder initial state from z -> produce 2 * dec_rnn_size (h0 and c0 concatenated)
        self.z_to_init = nn.Linear(self.z_size, 2 * self.dec_rnn_size)
        # tanh activation (mirrors Keras initial_state with tanh)
        self.tanh = nn.Tanh()

        # Decoder LSTM: we will feed (decoder_input concat tile(z))
        self.decoder = nn.LSTM(input_size=in_dim + self.z_size, hidden_size=self.dec_rnn_size,
                               num_layers=1, batch_first=True)

        # Output dense: pen logits (3) + M * 6 (pi, mu1, mu2, sigma1, sigma2, rho)
        self.n_out = 3 + self.M * 6
        self.output_layer = nn.Linear(self.dec_rnn_size, self.n_out)

    def encode(self, x, lengths=None):
        # x: (B, T, 5)
        out, (h_n, c_n) = self.encoder(x)  # out (B, T, 2*enc_rnn_size)
        # h_n shape: (num_layers * num_directions, B, hidden)
        if h_n.size(0) == 2:
            # single layer, bidirectional -> two direction hidden states
            h_forward = h_n[0]   # (B, hidden)
            h_backward = h_n[1]  # (B, hidden)
            h = torch.cat([h_forward, h_backward], dim=1)  # (B, 2*enc_rnn_size)
        else:
            # fallback: use last output timestep
            h = out[:, -1, :]

        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, decoder_inputs, z, initial_states=None):
        # decoder_inputs: (B, T_dec, 5)
        # z: (B, z_size)
        # initial_states: tuple (h0, c0) each (1, B, dec_rnn_size) optional
        B, T, _ = decoder_inputs.shape
        # tile z to (B, T, z_size)
        z_tile = z.unsqueeze(1).expand(-1, T, -1)
        dec_in = torch.cat([decoder_inputs, z_tile], dim=2)  # (B, T, 5+z_size)
        if initial_states is None:
            # compute from z
            init = self.tanh(self.z_to_init(z))  # (B, 2*dec_rnn_size)
            h0 = init[:, :self.dec_rnn_size].unsqueeze(0).contiguous()  # (1, B, dec_rnn_size)
            c0 = init[:, self.dec_rnn_size:].unsqueeze(0).contiguous()
            initial_states = (h0, c0)
        out, (h_n, c_n) = self.decoder(dec_in, initial_states)
        y = self.output_layer(out)  # (B, T, n_out)
        return y, (h_n, c_n)

    def forward(self, encoder_inputs, decoder_inputs):
        # encoder_inputs: (B, T_enc, 5) for encoder
        # decoder_inputs: (B, T_dec, 5) for teacher forcing
        mu, logvar = self.encode(encoder_inputs)
        z = self.reparameterize(mu, logvar)
        y, _ = self.decode(decoder_inputs, z)
        return y, mu, logvar

    # helper to extract MDN parameters from network output
    def get_mixture_coeffs(self, out):
        # out: (B, T, n_out)
        B, T, _ = out.shape
        pen_logits = out[:, :, :3]  # (B,T,3)

        md_params = out[:, :, 3:]  # (B, T, M*6)
        md_params = md_params.view(B, T, self.M, 6)  # (B,T,M,6)

        # convention used here: md_params[..., 0] = pi logits
        z_pi_logits = md_params[..., 0]   # (B,T,M)
        z_mu1 = md_params[..., 1]
        z_mu2 = md_params[..., 2]
        z_logsigma1 = md_params[..., 3]
        z_logsigma2 = md_params[..., 4]
        z_rho = md_params[..., 5]

        # apply transforms
        z_pi = F.softmax(z_pi_logits, dim=-1)        # (B,T,M)
        z_sigma1 = torch.exp(z_logsigma1)           # positive
        z_sigma2 = torch.exp(z_logsigma2)           # positive
        z_rho = torch.tanh(z_rho)                   # in (-1,1)
        z_pen = F.softmax(pen_logits, dim=-1)       # (B,T,3)

        return {
            "pi": z_pi,
            "mu1": z_mu1,
            "mu2": z_mu2,
            "sigma1": z_sigma1,
            "sigma2": z_sigma2,
            "rho": z_rho,
            "pen": z_pen,
            "pen_logits": pen_logits,
            "pi_logits": z_pi_logits
        }

    # evaluate 2D normal pdf for all mixtures (not used directly in loss below)
    @staticmethod
    def bivariate_normal_pdf(x1, x2, mu1, mu2, s1, s2, rho, eps=1e-6):
        x1_exp = x1.unsqueeze(-1)
        x2_exp = x2.unsqueeze(-1)
        norm1 = x1_exp - mu1
        norm2 = x2_exp - mu2
        s1s2 = s1 * s2 + eps
        z = (norm1 / (s1 + eps)) ** 2 + (norm2 / (s2 + eps)) ** 2 - 2 * (rho * norm1 * norm2) / (s1s2)
        neg_rho = 1 - rho ** 2
        exp_term = torch.exp(-z / (2 * (neg_rho + eps)))
        denom = 2 * math.pi * s1s2 * torch.sqrt(neg_rho + eps)
        pdf = exp_term / (denom + eps)
        return pdf

# -------------------------
# Loss helpers
# -------------------------
def kl_loss_from_mu_logvar(mu, logvar, kl_tolerance):
    # mu, logvar: (B, z_size)
    kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
    kl = torch.sum(kl, dim=1)  # sum over latent dims
    kl = torch.mean(kl)        # average over batch
    return torch.clamp(kl, min=kl_tolerance)

def mdn_pen_loss(y_true, y_pred, model):
    B, T, _ = y_true.shape
    comps = model.get_mixture_coeffs(y_pred)
    pi_logits = comps["pi_logits"]
    mu1, mu2 = comps["mu1"], comps["mu2"]
    s1, s2 = comps["sigma1"], comps["sigma2"]
    rho = comps["rho"]
    pen_logits = comps["pen_logits"]

    x1 = y_true[:, :, 0]
    x2 = y_true[:, :, 1]
    pen_true = y_true[:, :, 2:5]

    # Expand targets to match mixtures
    x1e = x1.unsqueeze(-1)
    x2e = x2.unsqueeze(-1)

    # Equation (24–25) from paper
    norm1 = (x1e - mu1) / (s1 + 1e-6)
    norm2 = (x2e - mu2) / (s2 + 1e-6)
    z = norm1**2 + norm2**2 - 2 * rho * norm1 * norm2
    neg_rho = 1 - rho**2 + 1e-6

    log_component = -torch.log(2 * math.pi * s1 * s2 * torch.sqrt(neg_rho) + 1e-8) - z / (2 * neg_rho)

    # Mix with pi
    log_pi = F.log_softmax(pi_logits, dim=-1)
    log_mix = torch.logsumexp(log_pi + log_component, dim=-1)

    gmm_loss = -log_mix

    # Mask beyond end-of-sequence
    fs = 1.0 - pen_true[:, :, 2]
    gmm_loss = gmm_loss * fs

    # Pen loss
    pen_target = torch.argmax(pen_true, dim=-1)
    pen_loss_all = F.cross_entropy(
        pen_logits.reshape(-1, 3),
        pen_target.view(-1),
        reduction="none"
    ).view(B, T)
    pen_loss = pen_loss_all * fs

    total = gmm_loss + pen_loss
    denom = torch.sum(fs) + 1e-8
    recon_loss = torch.sum(total) / denom
    return recon_loss, gmm_loss.mean().item(), pen_loss.mean().item()

# -------------------------
# Data normalization helper (was missing)
# -------------------------
def normalize_strokes(batch):
    """
    Normalize dx,dy across the batch by their standard deviation.
    Expects batch shape (B, T, 5) where [:,:,0:2] = dx,dy and [:,:,2:5] is one-hot pen.
    """
    coords = batch[:, :, :2]
    # compute std across batch & time for dx/dy
    std = coords.reshape(-1, 2).std(dim=0)
    # handle zero std
    std_safe = std.clone()
    std_safe[std_safe == 0] = 1.0
    coords = coords / std_safe.view(1, 1, 2)
    batch = batch.clone()
    batch[:, :, :2] = coords
    # Ensure pen states are still floats (no change). If they are not one-hot, we leave as-is.
    return batch

# -------------------------
# Example training loop
# -------------------------
hps = default_hparams()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SketchRNNVAE(hps).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=hps["learning_rate"])

# DataLoader: handle datasets that yield single-tensor items or tuples
dataloader = DataLoader(training_data, batch_size=hps["batch_size"], shuffle=True, pin_memory=True, drop_last=True)

# KL weight setup
kl_weight = hps["kl_weight_start"]
kl_decay_rate = hps["kl_decay_rate"]

max_seq_length = hps["max_seq_len"]

for epoch in range(hps["epochs"]):
    model.train()
    total_loss, total_recon, total_kl = 0.0, 0.0, 0.0
    n_batches = 0

    for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}", leave=False):
        batch = batch.to(device).float()  # (B, T, 5)

        # normalize strokes (dx, dy only)
        batch = normalize_strokes(batch)

        # shift inputs for teacher forcing:
        # encoder gets full sequence length T
        # decoder receives first T-1 tokens (teacher forcing), target is next-step for those T-1 outputs
        decoder_input = batch[:, :max_seq_length - 1, :]   # (B, T-1, 5) teacher forcing input
        target_output = batch[:, 1:max_seq_length, :]      # (B, T-1, 5) next-step targets

        # forward
        y_pred, mu, logvar = model(batch, decoder_input)

        # losses
        recon_loss, gmm_mean, pen_mean = mdn_pen_loss(target_output, y_pred, model)
        kl = kl_loss_from_mu_logvar(mu, logvar, hps["kl_tolerance"])

        # anneal KL weight (approach hps["kl_weight"] gradually)
        kl_weight = min(hps["kl_weight"], kl_weight / hps["kl_decay_rate"])
        loss = recon_loss + kl_weight * kl

        # backward
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), hps["grad_clip"])
        optimizer.step()

        total_loss += loss.item()
        total_recon += recon_loss.item()
        total_kl += kl.item()
        n_batches += 1

    print(
        f"Epoch {epoch+1}/{hps['epochs']} | "
        f"Loss: {total_loss/n_batches:.6f} | "
        f"Recon: {total_recon/n_batches:.6f} | "
        f"KL: {total_kl/n_batches:.6f} | "
        f"KL_weight: {kl_weight:.6f}"
    )


                                                            

Epoch 1/25 | Loss: 2.468064 | Recon: 2.463628 | KL: 0.434701 | KL_weight: 0.010529


                                                            

Epoch 2/25 | Loss: 2.023833 | Recon: 2.019531 | KL: 0.398206 | KL_weight: 0.011085


                                                            

Epoch 3/25 | Loss: 1.885995 | Recon: 1.881513 | KL: 0.394125 | KL_weight: 0.011671


                                                            

Epoch 4/25 | Loss: 1.833480 | Recon: 1.829131 | KL: 0.363343 | KL_weight: 0.012288


                                                            

Epoch 5/25 | Loss: 1.805812 | Recon: 1.801808 | KL: 0.317750 | KL_weight: 0.012937


                                                            

Epoch 6/25 | Loss: 1.781940 | Recon: 1.778208 | KL: 0.281305 | KL_weight: 0.013621


                                                            

Epoch 7/25 | Loss: 1.767214 | Recon: 1.763705 | KL: 0.251155 | KL_weight: 0.014341


                                                            

Epoch 8/25 | Loss: 1.759960 | Recon: 1.756479 | KL: 0.236593 | KL_weight: 0.015098


                                                            

Epoch 9/25 | Loss: 1.732675 | Recon: 1.729110 | KL: 0.230039 | KL_weight: 0.015896


                                                             

Epoch 10/25 | Loss: 1.650333 | Recon: 1.646708 | KL: 0.222254 | KL_weight: 0.016737


                                                             

Epoch 11/25 | Loss: 1.503337 | Recon: 1.499573 | KL: 0.219133 | KL_weight: 0.017621


                                                             

Epoch 12/25 | Loss: 1.445793 | Recon: 1.441858 | KL: 0.217661 | KL_weight: 0.018552


                                                             

Epoch 13/25 | Loss: 1.375055 | Recon: 1.370991 | KL: 0.213426 | KL_weight: 0.019533


                                                             

Epoch 14/25 | Loss: 1.328887 | Recon: 1.324621 | KL: 0.212857 | KL_weight: 0.020565


                                                             

Epoch 15/25 | Loss: 1.253135 | Recon: 1.248722 | KL: 0.209120 | KL_weight: 0.021652


                                                             

Epoch 16/25 | Loss: 1.073836 | Recon: 1.069205 | KL: 0.208423 | KL_weight: 0.022796


                                                             

Epoch 17/25 | Loss: 0.924833 | Recon: 0.920010 | KL: 0.206174 | KL_weight: 0.024001


                                                             

Epoch 18/25 | Loss: 0.840641 | Recon: 0.835584 | KL: 0.205339 | KL_weight: 0.025270


                                                             

Epoch 19/25 | Loss: 0.750787 | Recon: 0.745358 | KL: 0.209277 | KL_weight: 0.026605


                                                             

Epoch 20/25 | Loss: 0.757285 | Recon: 0.751748 | KL: 0.202793 | KL_weight: 0.028011


                                                             

Epoch 21/25 | Loss: 0.676159 | Recon: 0.670356 | KL: 0.201879 | KL_weight: 0.029492


Epoch 22:  74%|███████▍  | 765/1030 [00:33<00:11, 22.81it/s]