In [None]:
!pip install svgwrite

Collecting svgwrite
  Downloading svgwrite-1.4.3-py3-none-any.whl.metadata (8.8 kB)
Downloading svgwrite-1.4.3-py3-none-any.whl (67 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/67.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.1/67.1 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: svgwrite
Successfully installed svgwrite-1.4.3


In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.distributions import MultivariateNormal, OneHotCategorical, Categorical
import requests
import io
import svgwrite
from IPython.display import Image, SVG, display, HTML
from google.colab.output import eval_js

In [None]:
torch.manual_seed(1337)

<torch._C.Generator at 0x79863861f130>

In [None]:
# Hyper-parameters
data_classes = ["circle", "square", "triangle", "cat", "star", "cloud", "sun", "tree", "face", "apple"]

batch_size = 128
training_iters = 40000
eval_interval = training_iters // 20
learning_rate = 5e-5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 20  # Increased to evaluate across all classes properly

embd = 384
embd_ffn = 4 * embd  # Transformer FFN dimension
num_heads = 6  # Number of attention heads
n_layers = 6  # Number of Transformer layers
dropout = 0.2  # Same dropout rate

n_components = 20  # Number of Gaussians in the MDN output layer (still optimal)


In [None]:
import svgwrite
import numpy as np

# Function to get bounding box of a drawing
def get_bounds(data, factor):
    min_x = min_y = 0
    max_x = max_y = 0
    abs_x = abs_y = 0
    for i in range(len(data)):
        x, y = float(data[i, 0]) / factor, float(data[i, 1]) / factor
        abs_x += x
        abs_y += y
        min_x, min_y = min(min_x, abs_x), min(min_y, abs_y)
        max_x, max_y = max(max_x, abs_x), max(max_y, abs_y)
    return (min_x, max_x, min_y, max_y)

# Function to create SVG path from strokes
def create_path(data, factor, abs_x, abs_y, lift_pen=1):
    command = "m"
    path_str = f"M{abs_x},{abs_y} "
    for i in range(len(data)):
        command = "m" if lift_pen else ("l" if command != "l" else "")
        x, y = float(data[i, 0]) / factor, float(data[i, 1]) / factor
        abs_x += x
        abs_y += y
        lift_pen = data[i, 2]
        path_str += f"{command}{x},{y} "
    return path_str, abs_x, abs_y

# Function to draw single stroke drawing
def draw_strokes(data, factor=0.2, svg_filename='sample.svg', color="black", stroke_width=0.5):
    min_x, max_x, min_y, max_y = get_bounds(data, factor)
    dims = ((max_x - min_x) + 20, (max_y - min_y) + 20)  # Dynamic padding
    dwg = svgwrite.Drawing(svg_filename, size=dims)
    dwg.add(dwg.rect(insert=(0, 0), size=dims, fill='white'))
    abs_x, abs_y = -min_x + 10, -min_y + 10
    path_str, _, _ = create_path(data, factor, abs_x, abs_y)
    dwg.add(dwg.path(path_str).stroke(color, stroke_width).fill("none"))
    dwg.save()
    return dwg.tostring()

# Function to draw two strokes (e.g., comparison)
def draw_two_strokes(data1, data2, color1="black", color2="brown", factor=0.2, svg_filename="sample.svg", stroke_width=0.5):
    min_x, max_x, min_y, max_y = get_bounds(np.concatenate([data1, data2], axis=0), factor)
    dims = ((max_x - min_x) + 20, (max_y - min_y) + 20)  # Dynamic padding
    dwg = svgwrite.Drawing(svg_filename, size=dims)
    dwg.add(dwg.rect(insert=(0, 0), size=dims, fill='white'))
    abs_x, abs_y = -min_x + 10, -min_y + 10
    p1, abs_x, abs_y = create_path(data1, factor, abs_x, abs_y)
    dwg.add(dwg.path(p1).stroke(color1, stroke_width).fill("none"))
    p2, _, _ = create_path(data2, factor, abs_x, abs_y, lift_pen=0)
    dwg.add(dwg.path(p2).stroke(color2, stroke_width).fill("none"))
    dwg.save()
    return dwg.tostring()


In [None]:
import requests
import io
import numpy as np

train_set, valid_set, test_set = [], [], []
for data_class in data_classes:
    data_url = f"https://storage.googleapis.com/quickdraw_dataset/sketchrnn/{data_class}.npz"

    try:
        response = requests.get(data_url)
        response.raise_for_status()  # Raises error for failed request
        load_data = np.load(io.BytesIO(response.content), allow_pickle=True, encoding='latin1')

        # Ensure keys exist before accessing
        if all(key in load_data for key in ['train', 'valid', 'test']):
            train_set.extend(load_data['train'].tolist())
            valid_set.extend(load_data['valid'].tolist())
            test_set.extend(load_data['test'].tolist())
            print(f"✅ Loaded {data_class}: {len(load_data['train'])} train, {len(load_data['valid'])} valid, {len(load_data['test'])} test")
        else:
            print(f"⚠️ Missing keys in dataset: {data_class}")

    except requests.exceptions.RequestException as e:
        print(f"❌ Failed to download {data_class}: {e}")
    except Exception as e:
        print(f"❌ Error processing {data_class}: {e}")

# Final dataset sizes
print(f"\nTotal samples: Train={len(train_set)}, Valid={len(valid_set)}, Test={len(test_set)}")


✅ Loaded circle: 70000 train, 2500 valid, 2500 test
✅ Loaded square: 70000 train, 2500 valid, 2500 test
✅ Loaded triangle: 70000 train, 2500 valid, 2500 test
✅ Loaded cat: 70000 train, 2500 valid, 2500 test
✅ Loaded star: 70000 train, 2500 valid, 2500 test
✅ Loaded cloud: 70000 train, 2500 valid, 2500 test
✅ Loaded sun: 70000 train, 2500 valid, 2500 test
✅ Loaded tree: 70000 train, 2500 valid, 2500 test
✅ Loaded face: 70000 train, 2500 valid, 2500 test
✅ Loaded apple: 70000 train, 2500 valid, 2500 test

Total samples: Train=700000, Valid=25000, Test=25000


In [None]:
max_len = 0
for x in train_set:
    max_len = max(max_len, len(x))  # Update max_len to the longest sequence
print(max_len)  # Print the longest sequence length

block_size = max_len  # Set the block size to max_len
assert block_size <= 250  # Ensure it does not exceed 250


176


In [None]:
block_size = 176  # Set block_size to max_len


In [None]:
import numpy as np

# Initialize variables for maximum width, height, and statistics
max_w, max_h = 0, 0
x_sum, y_sum = 0., 0.
x_sq_sum, y_sq_sum = 0., 0.
N = 0

# Iterate through the dataset
for x in train_set:
    min_x, max_x, min_y, max_y = get_bounds(x, factor=1)  # Get sketch boundaries

    # Update max width and height
    max_w = max(max_w, max_x - min_x)
    max_h = max(max_h, max_y - min_y)

    # Accumulate sums for mean calculation
    x_sum += x[:, 0].sum()
    y_sum += x[:, 1].sum()

    # Accumulate squared sums for variance calculation
    x_sq_sum += (x[:, 0] ** 2).sum()
    y_sq_sum += (x[:, 1] ** 2).sum()

    # Keep track of total number of points
    N += len(x)

# Compute mean values for x and y
x_mean = x_sum / N
y_mean = y_sum / N

# Compute standard deviation (STD)
x_std = np.sqrt(x_sq_sum / N - x_mean ** 2)
y_std = np.sqrt(y_sq_sum / N - y_mean ** 2)

# Print results
print(f"Max Width: {max_w}, Max Height: {max_h}")
print(f"x_mean: {x_mean}, y_mean: {y_mean}")
print(f"x_std: {x_std}, y_std: {y_std}")


Max Width: 2586.0, Max Height: 2092.0
x_mean: 0.4718347548054783, y_mean: 0.3746371362317503
x_std: 34.031243174459675, y_std: 32.316420139180245


In [None]:
# Recompute max_w and max_h to ensure consistency
filtered_train_set = [x for x in train_set if len(x) > 5]  # Ignore very short sequences

max_w, max_h = 0, 0
x_vals, y_vals = [], []
for x in filtered_train_set:
    min_x, max_x, min_y, max_y = get_bounds(x, factor=1)
    max_w = max(max_w, max_x - min_x)
    max_h = max(max_h, max_y - min_y)
    x_vals.extend(x[:, 0])
    y_vals.extend(x[:, 1])

# Compute mean and std
x_mean, y_mean = np.mean(x_vals), np.mean(y_vals)
x_std, y_std = np.std(x_vals), np.std(y_vals)

print(f"✅ Adjusted Max Width: {max_w}, Max Height: {max_h}")
print(f"✅ Adjusted x_mean: {x_mean}, y_mean: {y_mean}")
print(f"✅ Adjusted x_std: {x_std}, y_std: {y_std}")


✅ Adjusted Max Width: 2586.0, Max Height: 2092.0
✅ Adjusted x_mean: 0.4718347548054783, y_mean: 0.3746371362317503
✅ Adjusted x_std: 47.24529126961086, y_std: 41.88252713624297


In [None]:
import torch

# Data loading function
def get_batch(split):
    """
    Generates a batch of data (inputs and targets) from the dataset.

    Args:
    - split (str): One of ["train", "valid", "test"]

    Returns:
    - xs: Input sequences (batch_size, block_size, 3)
    - ys: Target sequences (batch_size, block_size, 3)
    - mask: Mask for valid data points (batch_size, block_size)
    """

    # Select appropriate dataset
    if split == "train":
        data = train_set
    elif split == "valid":
        data = valid_set
    else:
        data = test_set

    # Randomly select batch_size samples
    ix = torch.randint(len(data), (batch_size,))
    xs, ys, lengths = [], [], []

    for i in ix:
        # Extract strokes (x, y, pen state)
        x, y, p = torch.tensor(data[i]).T

        # Rescale to [-1, 1] instead of clipping to [-1000, 1000]
        x = (x - x_mean) / (max_w / 2)
        y = (y - y_mean) / (max_h / 2)

        # Ensure last point in sequence has `pen=2` (end of stroke)
        p[-1] = 2.0

        # Stack normalized strokes
        d = torch.stack([x, y, p], -1)

        # Pad with (0, 0, 2) until block_size
        padding = torch.tensor([0., 0., 2.]).repeat((block_size - len(d), 1))
        xs.append(torch.cat([d, padding]))  # Pad input sequences
        ys.append(torch.cat([d[1:], padding]))  # Shift targets by one step
        lengths.append(len(data[i]))

    # Convert lists to tensors
    xs, ys, lengths = torch.stack(xs), torch.stack(ys), torch.tensor(lengths)

    # Create a mask for valid sequence points
    mask = torch.arange(block_size).expand(batch_size, block_size) < lengths.unsqueeze(1)

    # Move to GPU if available
    xs, ys, mask = xs.to(device), ys.to(device), mask.to(device)

    return xs, ys, mask


In [None]:
import svgwrite
from IPython.display import HTML, display

# Draw a few random examples from the train set
xs, ys, mask = get_batch("train")
print("Batch shape:", xs.shape)

n_samples = 3  # Only draw 3 samples
# Use a smaller factor to enlarge the sketches (smaller factor -> larger drawn output)
svg_samples = [draw_strokes(xs[i].cpu().numpy(), factor=0.002, stroke_width=0.5) for i in range(n_samples)]

# Build an HTML string to display the SVGs side by side
no_wrap_div = '<div style="white-space: nowrap; font-size:0;">' + ''.join(['{}' for _ in range(n_samples)]) + '</div>'
display(HTML(no_wrap_div.format(*svg_samples)))


Batch shape: torch.Size([128, 176, 3])


Model: Autoregressive Transformer (Decoder) + MDN

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import OneHotCategorical, Categorical, MultivariateNormal

# --- MDN and Supporting Networks ---

class MDN(nn.Module):
    """
    Conditional Mixture Density Network (MDN) with full covariance.
    This MDN now supports conditioning on a class index via a learnable embedding.

    Parameters
    ----------
    dim_in : int
        Dimensionality of the input features.
    dim_out : int
        Dimensionality of the response (e.g. stroke: dx, dy).
    n_components : int
        Number of Gaussian components.
    full_cov : bool
        Whether to use full or diagonal covariance.
    num_classes : int, optional
        Number of classes for conditioning. If provided, the model becomes conditional.
    class_emb_dim : int, optional
        Dimension of the class embedding (default: 16).
    """
    def __init__(self, dim_in, dim_out, n_components, full_cov=True, num_classes=None, class_emb_dim=16):
        super().__init__()
        if num_classes is not None:
            self.class_emb = nn.Embedding(num_classes, class_emb_dim)
            new_dim_in = dim_in + class_emb_dim
        else:
            self.class_emb = None
            new_dim_in = dim_in

        self.pi_net = OneHotCategoricalNetwork(new_dim_in, n_components)
        self.normal_net = NormalNetwork(new_dim_in, dim_out, n_components, full_cov)
        self.num_classes = num_classes

    def forward(self, x, tau=1.0, class_idx=None):
        # x shape: (B, T, dim_in)
        if self.class_emb is not None:
            if class_idx is None:
                raise ValueError("Class index must be provided for conditional MDN")
            emb = self.class_emb(class_idx)  # (B, class_emb_dim)
            # Expand along time dimension: (B, T, class_emb_dim)
            emb = emb.unsqueeze(1).expand(-1, x.shape[1], -1)
            x = torch.cat([x, emb], dim=-1)
        return self.pi_net(x, tau), self.normal_net(x, tau)


class NormalNetwork(nn.Module):
    def __init__(self, in_dim, out_dim, n_components, full_cov=True):
        super().__init__()
        self.n_components = n_components
        self.out_dim = out_dim
        self.full_cov = full_cov
        self.tril_indices = torch.tril_indices(row=out_dim, col=out_dim, offset=0)
        self.mean_net = nn.Linear(in_dim, out_dim * n_components)
        if full_cov:
            self.tril_net = nn.Linear(in_dim, int(out_dim * (out_dim + 1) / 2 * n_components))
        else:
            self.tril_net = nn.Linear(in_dim, out_dim * n_components)

    def forward(self, x, tau=1.0):
        # x shape: (B, T, in_dim)
        B, T, _ = x.shape
        mean = self.mean_net(x).reshape(B, T, self.n_components, self.out_dim)  # (B, T, M, d)
        if self.full_cov:
            tril_values = self.tril_net(x).reshape(B, T, self.n_components, -1)  # (B, T, M, (d*(d+1))/2)
            tril = torch.zeros(B, T, self.n_components, self.out_dim, self.out_dim, device=x.device)
            tril[:, :, :, self.tril_indices[0], self.tril_indices[1]] = tril_values
            # Ensure positivity on the diagonal
            tril.diagonal(dim1=-2, dim2=-1)[:] = tril.diagonal(dim1=-2, dim2=-1).exp()
        else:
            tril = self.tril_net(x).reshape(B, T, self.n_components, self.out_dim)
            tril = torch.diag_embed(tril.exp())
        tril *= tau
        return MultivariateNormal(mean, scale_tril=tril)


class OneHotCategoricalNetwork(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.network = nn.Linear(in_dim, out_dim)

    def forward(self, x, tau=1.0):
        logits = self.network(x) / tau
        return OneHotCategorical(logits=logits)


class CategoricalNetwork(nn.Module):
    """
    Conditional version of Categorical network for the pen prediction head.
    """
    def __init__(self, in_dim, out_dim, num_classes=None, class_emb_dim=16):
        super().__init__()
        if num_classes is not None:
            self.class_emb = nn.Embedding(num_classes, class_emb_dim)
            new_in_dim = in_dim + class_emb_dim
        else:
            self.class_emb = None
            new_in_dim = in_dim
        self.network = nn.Linear(new_in_dim, out_dim)

    def forward(self, x, tau=1.0, class_idx=None):
        if self.class_emb is not None:
            if class_idx is None:
                raise ValueError("class_idx must be provided for conditional CategoricalNetwork")
            emb = self.class_emb(class_idx)  # (B, class_emb_dim)
            emb = emb.unsqueeze(1).expand(-1, x.shape[1], -1)
            x = torch.cat([x, emb], dim=-1)
        logits = self.network(x) / tau
        return Categorical(logits=logits)


# --- Transformer Model with Conditional MDN and Pen Head ---

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.query = nn.Linear(embd, head_size, bias=False)
        self.key = nn.Linear(embd, head_size, bias=False)
        self.value = nn.Linear(embd, head_size, bias=False)
        # self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        q = self.query(x)  # (B, T, head_size)
        k = self.key(x)    # (B, T, head_size)
        wei = q @ k.transpose(-2, -1) * (C ** -0.5)
        tril=torch.tril(torch.ones(T,T,device=x.device))
        wei = wei.masked_fill(tril== 0, float("-inf"))
        wei = F.softmax(wei, dim=-1)
        v = self.value(x)  # (B, T, head_size)
        out = wei @ v      # (B, T, head_size)
        return out

class MultiHead(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(embd, embd)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return out

class FeedForward(nn.Module):
    def __init__(self, embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embd, embd_ffn),
            nn.ReLU(),
            nn.Linear(embd_ffn, embd),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, embd, num_heads):
        super().__init__()
        self.sa_heads = MultiHead(num_heads, embd // num_heads)
        self.ffwd = FeedForward(embd)
        self.ln1 = nn.LayerNorm(embd)
        self.ln2 = nn.LayerNorm(embd)

    def forward(self, x):
        x = x + self.sa_heads(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class TransformerModel(nn.Module):
    """
    Transformer-based autoregressive model for stroke generation conditioned on class.

    It includes:
      - Stroke embedding (for dx, dy)
      - Pen embedding (for pen state)
      - Positional embedding
      - Stacked Transformer blocks
      - Layer normalization
      - Conditional MDN head for predicting (dx, dy)
      - Conditional Categorical head for predicting pen state

    Pass the class index (one per sample) via the `class_idx` argument in forward, loss, sample, and generate.
    """
    def __init__(self, num_classes=None,max_length=None):
        super().__init__()
        # Embedding layers
        if max_length is None:
            max_length=block_size+100
        self.stroke_embedding_proj = nn.Linear(2, embd, bias=False)
        self.pen_embedding_table = nn.Embedding(3, embd)
        self.position_embedding_table = nn.Embedding(max_length, embd)
        # Transformer blocks
        self.blocks = nn.Sequential(*[Block(embd, num_heads) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(embd)
        # Conditional MDN head and Pen head
        self.mdn_head = MDN(embd, 2, n_components, full_cov=True, num_classes=num_classes, class_emb_dim=16)
        self.pen_head = CategoricalNetwork(embd, 3, num_classes=num_classes, class_emb_dim=16)
        self.num_classes = num_classes

    def forward(self, x, tau=1.0, class_idx=None):
        # x: (B, T, 3) where last dimension is [dx, dy, pen]
        B, T, C = x.shape
        stroke_emb = self.stroke_embedding_proj(x[:, :, :2])  # (B, T, embd)
        # Clamp pen values to ensure they are within the range [0, 2]
        pen_indices = torch.clamp(x[:, :, 2].long(), min=0, max=2)
        pen_emb = self.pen_embedding_table(pen_indices)   # (B, T, embd)
        pos_emb = self.position_embedding_table(torch.arange(T, device=x.device))  # (T, embd)
        x = stroke_emb + pen_emb + pos_emb  # (B, T, embd)
        x = self.blocks(x)
        x = self.ln_f(x)
        # Conditional MDN and Pen head (pass class_idx)
        pi_net, normal_net = self.mdn_head(x, tau=tau, class_idx=class_idx)
        q_net = self.pen_head(x, tau=tau, class_idx=class_idx)
        return pi_net, normal_net, q_net

    def loss(self, x, targets, mask, tau=1.0, class_idx=None):

        T_target=targets.shape[1]
        # x: input sequences, targets: target sequences, mask: valid positions mask
        pi, normal, q = self.forward(x[:, :T_target, :], tau=tau, class_idx=class_idx)
        ys = targets[:, :T_target, :2]  # (B, T, 2)
        # Compute log probability under the MDN (for each Gaussian)
        loglik = normal.log_prob(ys.unsqueeze(-2).expand_as(normal.loc))
        Ls = -torch.logsumexp(torch.log(pi.probs) + loglik, dim=-1)
        Ls = Ls * mask[:, :T_target]  # Mask the padded positions

        yp = targets[:, :T_target, 2].round().long() # Pen target
        Lp = -q.log_prob(yp)
        return Ls + Lp

    def sample(self, x, tau=1.0, class_idx=None):
        # x: (1, T, 3) initial sequence
        pi, normal, q = self.forward(x, tau=tau, class_idx=class_idx)
        # Sample stroke values using the MDN
        s_samples = torch.sum(pi.sample().unsqueeze(-1) * normal.sample(), dim=-2)
        p_samples = q.sample()
        return torch.cat([s_samples, p_samples.unsqueeze(-1)], dim=-1)

    @torch.no_grad()
    def generate(self, x, max_new_tokens, tau=1.0, class_idx=None, break_eos=True):
        # x: (1, T, 3)
        for _ in range(max_new_tokens):
            # Generate one new token
            samples_next = self.sample(x, tau=tau, class_idx=class_idx)[:, -1:, :]
            x = torch.cat([x, samples_next], dim=1)
            if break_eos and samples_next[0, 0, 2] == 2:
                return x
        return x


In [None]:
# Hyperparameters (ensure embd, block_size, n_layers, num_heads, dropout, n_components are defined)
# For example:
# embd = 384
# block_size = 176
# n_layers = 6
# num_heads = 6
# dropout = 0.2
# n_components = 20

device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_classes = 10  # For 10 classes

# Instantiate the model with conditioning
max_length=176+100
model = TransformerModel(num_classes=num_classes,max_length=max_length).to(device)

# Create dummy input data
B, T = 64, block_size
# Generate dx and dy as random floats
dx_dy = torch.randn(B, T, 2)
# Generate pen values as integers in [0, 2] (do not cast to float here)
pen = torch.randint(0, 3, (B, T, 1))
# Combine to form x: (B, T, 3); cast stroke parts to float and pen will remain numeric
x = torch.cat([dx_dy, pen.float()], dim=-1).to(device)

# Create dummy targets similarly
dx_dy_t = torch.randn(B, T, 2)
pen_t = torch.randint(0, 3, (B, T, 1))
targets = torch.cat([dx_dy_t, pen_t.float()], dim=-1).to(device)

# Dummy mask: all ones (assuming full-length sequences)
mask = torch.ones(B, T).to(device)

# Dummy class indices: one per sample, in range [0, num_classes-1]
class_idx = torch.randint(0, num_classes, (B,)).to(device)

# Forward pass and loss computation
pi, normal, q = model(x, tau=1.0, class_idx=class_idx)
# In the loss, before calling q.log_prob, we ensure the pen targets are integers
def safe_loss(model, x, targets, mask, tau, class_idx):
    pi, normal, q = model(x, tau=tau, class_idx=class_idx)
    ys = targets[:, :, :2]  # Stroke (dx, dy)
    loglik = normal.log_prob(ys.unsqueeze(-2).expand_as(normal.loc))
    Ls = -torch.logsumexp(torch.log(pi.probs) + loglik, dim=-1)
    Ls = Ls * mask
    # Extract pen channel and round if needed, then convert to long
    yp = targets[:, :, 2].round().long()  # Ensure integer values in {0,1,2}
    Lp = -q.log_prob(yp)
    return Ls + Lp

loss = safe_loss(model, x, targets, mask, tau=1.0, class_idx=class_idx)
print("Loss:", loss.mean().item())

# Generating a sequence (starting with an initial seed)
x_seed = x[:1]  # starting with one sample
# Optionally, set break_eos=False to generate full 50 tokens for testing.
generated = model.generate(x_seed, max_new_tokens=50, tau=1.0, class_idx=class_idx[:1], break_eos=False)
print("Generated sequence shape:", generated.shape)


Loss: 4.260533332824707
Generated sequence shape: torch.Size([1, 226, 3])


In [None]:
model = TransformerModel(num_classes=num_classes)
model = model.to(device)
print(f"Model has {sum(p.nelement() for p in model.parameters())} parameters")


Model has 10798187 parameters


In [None]:
X, Y, mask = get_batch("train")
# Generate dummy class indices for each sample in the batch
class_idx = torch.randint(0, num_classes, (X.shape[0],)).to(device)
# Sample predictions conditioned on the class indices
Y_pred = model.sample(X, tau=1.0, class_idx=class_idx)
print(X.shape, Y.shape, Y_pred.shape)


torch.Size([128, 176, 3]) torch.Size([128, 175, 3]) torch.Size([128, 176, 3])


 Training Display (Within Training Loop Evaluation)

In [None]:
@torch.no_grad()
def estimate_loss():
    out = {}
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters, device=device)
        for k in range(eval_iters):
            X, Y, mask = get_batch(split)
            # For conditional training, generate class indices for the batch.
            class_idx = torch.randint(0, num_classes, (X.shape[0],), device=device)
            loss = model.loss(X, Y, mask, tau=1.0, class_idx=class_idx)
            losses[k] = loss.mean()
        out[split] = losses.mean()
    return out

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# Optionally: lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=training_iters)

for iter in range(training_iters):
    if iter % eval_interval == 0:
        model.eval()
        losses = estimate_loss()
        print(f'step {iter}: lr {optimizer.param_groups[0]["lr"]:.6f}, train loss {losses["train"]:.4f}, val loss {losses["val"]:.4f}')

        # Display generated samples:
        n_samples = 9  # 9 samples gives 3 per row
        svg_samples = []
        for _ in range(n_samples):
            # Use a zero seed of shape (1,1,3)
            seed = torch.zeros(1, 1, 3, device=device)
            # Generate a random class index and get its label
            gen_class_idx = torch.randint(0, num_classes, (1,), device=device)
            # Assume you have a list of class names, e.g.:
            class_label = data_classes[gen_class_idx.item()] if 'data_classes' in globals() else f"Class {gen_class_idx.item()}"
            # Generate a sequence (using a larger output, e.g. block_size-1 tokens)
            generated = model.generate(seed, max_new_tokens=block_size - 1, tau=0.4, class_idx=gen_class_idx, break_eos=True)[0]
            # Use a smaller factor (e.g. 0.02) and thicker stroke (e.g. stroke_width=1) to produce a larger image.
            svg = draw_strokes(generated.cpu().numpy(), factor=0.02, stroke_width=1)
            # Wrap each SVG in a container with the class label
            svg_samples.append(f'<div style="display:inline-block; margin:10px; text-align:center;"><div>{class_label}</div>{svg}</div>')
        html = '<div style="display:flex; flex-wrap:wrap;">' + ''.join(svg_samples) + '</div>'
        display(HTML(html))

        model.train()

    xb, yb, mask = get_batch('train')
    class_idx = torch.randint(0, num_classes, (xb.shape[0],), device=device)
    loss = model.loss(xb, yb, mask, tau=1.0, class_idx=class_idx).mean()
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    # Optionally: lr_scheduler.step()

# Save model weights if desired.
from google.colab import files
model_filename = f'model_{"_".join(data_classes)}.pth'
torch.save(model.state_dict(), model_filename)
files.download(model_filename)


step 0: lr 0.000050, train loss 1.5864, val loss 1.5822


step 2000: lr 0.000050, train loss -1.0444, val loss -1.0360


step 4000: lr 0.000050, train loss -1.0541, val loss -1.0722


step 6000: lr 0.000050, train loss -0.9980, val loss -0.9959


step 8000: lr 0.000050, train loss -0.8437, val loss -0.8406


step 10000: lr 0.000050, train loss -1.0575, val loss -1.0568


step 12000: lr 0.000050, train loss -1.0159, val loss -1.0136


step 14000: lr 0.000050, train loss -1.1546, val loss -1.1338


step 16000: lr 0.000050, train loss -0.8279, val loss -0.8398


step 18000: lr 0.000050, train loss -0.9599, val loss -0.9565


step 20000: lr 0.000050, train loss -0.6194, val loss -0.6156


step 22000: lr 0.000050, train loss -1.1955, val loss -1.1787


step 24000: lr 0.000050, train loss 1.8894, val loss 1.8875
