<h1><center></center></h1>
<div style="display: flex; justify-content: center; margin: 0 auto;" align="center">
  <img src="https://myth-ai.com/wp-content/uploads/2023/05/646f153be1e56.png" href="https://myth-ai.com/" width="100px" align="center">
  <h1>Technical Assignment</h1>
</div>

<div align="center">
  <h2>
  Sketch Generation via Diffusion Models using Sequential Strokes
  </h2>
</div>


<div align="center">
  <img src="https://github.com/googlecreativelab/quickdraw-dataset/blob/master/preview.jpg?raw=true">
  <figcaption>
    Collection of 50 million drawings across 345 categories, contributed by players of the game Quick, Draw!. Drawings were captured as timestamped vectors.
    <i>Source: <a href="https://quickdraw.withgoogle.com/data/">Quick, Draw! Dataset</a>.</i>
  </figcaption>
</div>

---

## Objective

In this project, you are expected to implement a **conditional generative diffusion model** that learns to generate hand-drawn sketches in a **stroke-by-stroke** sequential manner. Rather than generating the entire sketch at once, your model should mimic the **sequential nature of human drawing**, producing strokes one after another in a realistic and interpretable way.

You will use the [Quick, Draw!](https://quickdraw.withgoogle.com/data/) dataset released by Google, which provides timestamped vector representations of user-drawn sketches across 345 object categories.

---

## Brief Explanation

You will design and train a **separate conditional diffusion model** for each of the following three categories:

- `cat`
- `bus`
- `rabbit`

Each model must learn to generate sketches from that category using **sequential stroke data**. That means you will build **three separate models** in total—one per category.

Your implementation must be documented in a reproducible Jupyter notebook, including training steps, visualizations, and both qualitative and quantitative evaluations.

- Include comprehensive documentation of your approach and design decisions.
- Provide clear training procedures, model architecture explanations, and inference code.
- Ensure full reproducibility (running all cells should yield consistent results with fixed random seeds).

---

## Data Specification

The Quick, Draw! dataset includes over 50 million sketches in vector format, with each sketch consisting of multiple strokes, where each stroke is a sequence of coordinates (`x`, `y`) along with timing information.

You can download the raw `.ndjson` files from the this [section](#cell-id1). The following commands will download the required categories (`cat`, `bus`, `rabbit`) into the ./data directory.

**⚠️ Note:** If you're not using Google Colab or Kaggle, make sure you have `gsutil` installed. You can install it via pip:

```bash
pip install gsutil
```

**⚠️ Important:** The dataset files are in [NDJSON](https://github.com/ndjson/ndjson-spec) format. Make sure to install the ndjson Python module before attempting to parse the files.

```bash
pip install ndjson
```

### Train/Test Subsets for Target Categories

After downloading the dataset in the `./data` directory, extract the provided `subset.zip` file. This archive includes the predefined train/test splits for each of the three categories.

```
subset/
├── cat/
│  └── indices.json
├── bus/
│  └── indices.json
└── rabbit/
│  └── indices.json
```

Each `indices` file contains a JSON structure with two keys:

- `"train"`: list of indices for training
- `"test"`: list of indices for testing

**⚠️ Important:** Strictly adhere to these predefined splits for consistent evaluation.


---


## Evaluation

You must evaluate your model both **qualitatively** and **quantitatively**.

### Quantitative Evaluation

Use the following metrics to compare the real test set sketches with those generated by your model:

- **FID (Fréchet Inception Distance)**
- **KID (Kernel Inception Distance)**

These metrics should be computed **separately for each category** using the sketches indexed under the `"test"` key in each category’s `indices.json` file.

> **Final submission must include three FID and three KID scores—one pair per category.**

### Qualitative Evaluation

Provide visual demonstrations including:

- Sample generated sketches for each category.
- Your submission must include three animated GIFs (one per category) showing the stroke-by-stroke generation process, similar to `example.gif` file in the link.
- Comparison between real and generated sketches.


---


## Deliverables

Your submission should include the following:

- A well-structured **Jupyter Notebook** that:
  - Explains your approach and design decisions
  - Implements the conditional diffusion model
  - Includes training procedure and inference pipeline code
  - Presents both qualitative and quantitative results
  - Visual examples of generated sketches for each of the 3 categories
  - Animated GIFs demonstrating progressive sketch generation (similar to the provided example.gif)
  - Clearly computed FID/KID scores for each category
- Model performance analysis across categories
- Comparison of generated vs. real sketch characteristics
- Discussion of limitations and potential improvements


> 🔒 All visualizations must be based on sketches generated by your own model. Using samples from external sources will be considered **plagiarism** and will result in disqualification.

> 🔁 The notebook must be **fully reproducible**: running all cells from top to bottom should produce the same results (assuming fixed random seed).

---

## Acknowledgements

- [The Quick, Draw! Dataset](https://github.com/googlecreativelab/quickdraw-dataset)
- [Quick, Draw! Kaggle Competition](https://www.kaggle.com/c/quickdraw-doodle-recognition/overview)
- [Diffusion Models Overview (Lil’Log)](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/)
- [Ha, D., & Eck, D. (2017). A neural representation of sketch drawings. arXiv preprint arXiv:1704.03477.](https://arxiv.org/pdf/1704.03477)
- Special thanks to M. Sung, KAIST

# Download the Quick, Draw! Dataset

<a name="cell-id1"></a>

In [None]:
# If you're not using Colab or Kaggle, uncomment the following line:
!pip install gsutil

In [None]:
%pip install ndjson

In [None]:
%mkdir data
!gsutil -m cp 'gs://quickdraw_dataset/full/simplified/cat.ndjson' ./data
!gsutil -m cp 'gs://quickdraw_dataset/full/simplified/bus.ndjson' ./data
!gsutil -m cp 'gs://quickdraw_dataset/full/simplified/rabbit.ndjson' ./data

# Solution

- Briefly explain why you chose the method you did.
- Discuss the drawbacks and advantages of your chosen method.
- Evaluate and discuss the results for each metric.

# Section 1: Data Loading and Preprocessing

## Imports and Setup

In [None]:
import json
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

torch.manual_seed(42)
np.random.seed(42)

In [None]:
# Load raw sketches
with open("./data/cat.ndjson", 'r') as f:
    cat_sketches = [json.loads(line) for line in f]

# Load train/test indices
with open("./subset/cat/indices.json", 'r') as f:
    cat_indices = json.load(f)

cat_train = [cat_sketches[i] for i in cat_indices["train"]]
cat_test = [cat_sketches[i] for i in cat_indices["test"]]

len(cat_train), len(cat_test)

In [None]:
# Take first training sketch
sketch = cat_train[0]
drawing = sketch["drawing"]

# Collect all points for normalization
all_points = []
for stroke in drawing:
    points = list(zip(stroke[0], stroke[1]))
    all_points.extend(points)

# Normalize to [-1, 1]
all_x, all_y = zip(*all_points)
min_x, max_x = min(all_x), max(all_x)
min_y, max_y = min(all_y), max(all_y)
width = max_x - min_x if max_x > min_x else 1
height = max_y - min_y if max_y > min_y else 1

# Build sequence: [command, x, y]
# 0: draw, 1: pen_up, 2: pen_down, 3: end
sequence = []

for stroke in drawing:
    sequence.append([2, 0, 0])  # pen_down

    for x, y in zip(stroke[0], stroke[1]):
        norm_x = 2 * (x - min_x) / width - 1
        norm_y = 2 * (y - min_y) / height - 1
        sequence.append([0, norm_x, norm_y])  # draw

    sequence.append([1, 0, 0])  # pen_up

sequence.append([3, 0, 0])  # end
sequence = torch.tensor(sequence, dtype=torch.float32)
sequence.shape

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))

current_stroke = []
pen_down = False

for command, x, y in sequence:
    if command == 2:  # pen_down
        pen_down = True
        current_stroke = []
    elif command == 1:  # pen_up
        if pen_down and len(current_stroke) > 1:
            xs, ys = zip(*current_stroke)
            ax.plot(xs, ys, 'b-', linewidth=2)
        pen_down = False
    elif command == 0 and pen_down:  # draw
        current_stroke.append([x, y])
    elif command == 3:  # end
        break

ax.set_xlim(-1.2, 1.2)
ax.set_ylim(-1.2, 1.2)
ax.set_aspect('equal')
ax.invert_yaxis()
ax.set_title("Preprocessed Cat Sketch")
ax.grid(True, alpha=0.3)
plt.show()

In [None]:
cat_sequences = []
max_length = 200

for sketch in cat_train:
    drawing = sketch["drawing"]

    # Collect and normalize points
    all_points = []
    for stroke in drawing:
        all_points.extend(zip(stroke[0], stroke[1]))

    if not all_points:
        continue

    all_x, all_y = zip(*all_points)
    min_x, max_x = min(all_x), max(all_x)
    min_y, max_y = min(all_y), max(all_y)
    width = max_x - min_x if max_x > min_x else 1
    height = max_y - min_y if max_y > min_y else 1

    # Build sequence
    sequence = []
    for stroke in drawing:
        sequence.append([2, 0, 0])
        for x, y in zip(stroke[0], stroke[1]):
            norm_x = 2 * (x - min_x) / width - 1
            norm_y = 2 * (y - min_y) / height - 1
            sequence.append([0, norm_x, norm_y])
        sequence.append([1, 0, 0])
    sequence.append([3, 0, 0])

    # Pad or truncate
    sequence = torch.tensor(sequence, dtype=torch.float32)
    if len(sequence) > max_length:
        sequence = sequence[:max_length]
    else:
        padding = torch.tensor([[3, 0, 0]] * (max_length - len(sequence)))
        sequence = torch.cat([sequence, padding])

    cat_sequences.append(sequence)

len(cat_sequences)

In [None]:
class SimpleDataset(Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx]

cat_dataset = SimpleDataset(cat_sequences)
cat_loader = DataLoader(cat_dataset, batch_size=32, shuffle=True)

# Test batch
batch = next(iter(cat_loader))
batch.shape

# Section 2: Data Analysis and Visualization

In [None]:
# Check sequence lengths distribution
seq_lengths = [len(seq[seq[:, 0] != 3]) for seq in cat_sequences]  # Exclude padding
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.hist(seq_lengths, bins=50, alpha=0.7)
plt.xlabel('Sequence Length')
plt.ylabel('Count')
plt.title('Distribution of Sequence Lengths')

plt.subplot(1, 2, 2)
plt.boxplot(seq_lengths)
plt.ylabel('Sequence Length')
plt.title('Sequence Length Statistics')

plt.tight_layout()
plt.show()

f"Min: {min(seq_lengths)}, Max: {max(seq_lengths)}, Mean: {np.mean(seq_lengths):.1f}"

In [None]:
# Count strokes per sketch (pen_down commands)
stroke_counts = []
for seq in cat_sequences:
    pen_downs = (seq[:, 0] == 2).sum().item()
    stroke_counts.append(pen_downs)

plt.figure(figsize=(8, 5))
plt.hist(stroke_counts, bins=range(1, max(stroke_counts)+2), alpha=0.7)
plt.xlabel('Number of Strokes')
plt.ylabel('Count')
plt.title('Distribution of Stroke Counts in Cat Sketches')
plt.grid(True, alpha=0.3)
plt.show()

f"Strokes per sketch - Min: {min(stroke_counts)}, Max: {max(stroke_counts)}, Mean: {np.mean(stroke_counts):.1f}"

In [None]:
# Show 6 random cat sketches
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()

for i, ax in enumerate(axes):
    seq = cat_sequences[i]

    current_stroke = []
    pen_down = False

    for command, x, y in seq:
        if command == 2:  # pen_down
            pen_down = True
            current_stroke = []
        elif command == 1:  # pen_up
            if pen_down and len(current_stroke) > 1:
                xs, ys = zip(*current_stroke)
                ax.plot(xs, ys, 'b-', linewidth=2)
            pen_down = False
        elif command == 0 and pen_down:  # draw
            current_stroke.append([x, y])
        elif command == 3:  # end
            break

    ax.set_xlim(-1.2, 1.2)
    ax.set_ylim(-1.2, 1.2)
    ax.set_aspect('equal')
    ax.invert_yaxis()
    ax.set_title(f'Cat {i+1}')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Extract all drawing coordinates (command=0)
all_coords = []
for seq in cat_sequences:
    drawing_points = seq[seq[:, 0] == 0]  # Only drawing commands
    if len(drawing_points) > 0:
        all_coords.append(drawing_points[:, 1:])  # x, y coordinates

all_coords = torch.cat(all_coords)
x_coords = all_coords[:, 0].numpy()
y_coords = all_coords[:, 1].numpy()

plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.hist(x_coords, bins=50, alpha=0.7, color='red')
plt.xlabel('X Coordinate')
plt.ylabel('Count')
plt.title('X Coordinate Distribution')

plt.subplot(1, 3, 2)
plt.hist(y_coords, bins=50, alpha=0.7, color='green')
plt.xlabel('Y Coordinate')
plt.ylabel('Count')
plt.title('Y Coordinate Distribution')

plt.subplot(1, 3, 3)
plt.scatter(x_coords[::100], y_coords[::100], alpha=0.1, s=1)
plt.xlabel('X Coordinate')
plt.ylabel('Y Coordinate')
plt.title('Coordinate Scatter Plot')
plt.axis('equal')

plt.tight_layout()
plt.show()

In [None]:
# Load bus and rabbit data
categories = ['bus', 'rabbit']
all_data = {'cat': cat_sequences}

for category in categories:
    # Load sketches
    with open(f"./data/{category}.ndjson", 'r') as f:
        sketches = [json.loads(line) for line in f]

    with open(f"./subset/{category}/indices.json", 'r') as f:
        indices = json.load(f)

    train_sketches = [sketches[i] for i in indices["train"]]

    # Process sequences
    sequences = []
    for sketch in train_sketches:
        drawing = sketch["drawing"]

        all_points = []
        for stroke in drawing:
            all_points.extend(zip(stroke[0], stroke[1]))

        if not all_points:
            continue

        all_x, all_y = zip(*all_points)
        min_x, max_x = min(all_x), max(all_x)
        min_y, max_y = min(all_y), max(all_y)
        width = max_x - min_x if max_x > min_x else 1
        height = max_y - min_y if max_y > min_y else 1

        sequence = []
        for stroke in drawing:
            sequence.append([2, 0, 0])
            for x, y in zip(stroke[0], stroke[1]):
                norm_x = 2 * (x - min_x) / width - 1
                norm_y = 2 * (y - min_y) / height - 1
                sequence.append([0, norm_x, norm_y])
            sequence.append([1, 0, 0])
        sequence.append([3, 0, 0])

        sequence = torch.tensor(sequence, dtype=torch.float32)
        if len(sequence) > 200:
            sequence = sequence[:200]
        else:
            padding = torch.tensor([[3, 0, 0]] * (200 - len(sequence)))
            sequence = torch.cat([sequence, padding])

        sequences.append(sequence)

    all_data[category] = sequences

{k: len(v) for k, v in all_data.items()}

In [None]:
# Compare stroke counts across categories
plt.figure(figsize=(12, 4))

for i, (category, sequences) in enumerate(all_data.items()):
    stroke_counts = []
    for seq in sequences:
        pen_downs = (seq[:, 0] == 2).sum().item()
        stroke_counts.append(pen_downs)

    plt.subplot(1, 3, i+1)
    plt.hist(stroke_counts, bins=range(1, 20), alpha=0.7)
    plt.xlabel('Number of Strokes')
    plt.ylabel('Count')
    plt.title(f'{category.capitalize()} - Stroke Distribution')
    plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for i, (category, sequences) in enumerate(all_data.items()):
    seq = sequences[0]
    ax = axes[i]

    current_stroke = []
    pen_down = False

    for command, x, y in seq:
        if command == 2:  # pen_down
            pen_down = True
            current_stroke = []
        elif command == 1:  # pen_up
            if pen_down and len(current_stroke) > 1:
                xs, ys = zip(*current_stroke)
                ax.plot(xs, ys, 'b-', linewidth=2)
            pen_down = False
        elif command == 0 and pen_down:  # draw
            current_stroke.append([x, y])
        elif command == 3:  # end
            break

    ax.set_xlim(-1.2, 1.2)
    ax.set_ylim(-1.2, 1.2)
    ax.set_aspect('equal')
    ax.invert_yaxis()
    ax.set_title(f'{category.capitalize()} Sample')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Create summary table
summary = {}
for category, sequences in all_data.items():
    seq_lengths = [len(seq[seq[:, 0] != 3]) for seq in sequences]
    stroke_counts = [(seq[:, 0] == 2).sum().item() for seq in sequences]

    summary[category] = {
        'count': len(sequences),
        'avg_length': np.mean(seq_lengths),
        'avg_strokes': np.mean(stroke_counts),
        'max_length': max(seq_lengths),
        'max_strokes': max(stroke_counts)
    }

import pandas as pd
pd.DataFrame(summary).T.round(2)

# Section 3: Model Architecture Design

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=200):
        super().__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [None]:
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = t[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, nhead, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.linear1 = nn.Linear(d_model, d_model * 4)
        self.linear2 = nn.Linear(d_model * 4, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self attention
        x2 = self.norm1(x)
        x2, _ = self.self_attn(x2, x2, x2, attn_mask=mask)
        x = x + self.dropout(x2)

        # Feed forward
        x2 = self.norm2(x)
        x2 = self.linear2(F.relu(self.linear1(x2)))
        x = x + self.dropout(x2)

        return x

In [None]:
class SketchDiffusionModel(nn.Module):
    def __init__(self, d_model=256, nhead=8, num_layers=6, seq_len=200):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len

        # Input projection: [command, x, y] -> d_model
        self.input_proj = nn.Linear(3, d_model)

        # Time embedding
        self.time_embed = TimeEmbedding(d_model)
        self.time_proj = nn.Linear(d_model, d_model)

        # Positional encoding
        self.pos_encoding = PositionalEncoding(d_model, seq_len)

        # Transformer layers
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, nhead) for _ in range(num_layers)
        ])

        # Output projection
        self.output_proj = nn.Linear(d_model, 3)

    def forward(self, x, t):
        # x: [batch, seq_len, 3]
        # t: [batch]

        batch_size, seq_len = x.shape[0], x.shape[1]

        # Input projection
        x = self.input_proj(x)  # [batch, seq_len, d_model]

        # Add time embedding
        t_emb = self.time_embed(t)  # [batch, d_model]
        t_emb = self.time_proj(t_emb)  # [batch, d_model]
        t_emb = t_emb.unsqueeze(1).expand(-1, seq_len, -1)  # [batch, seq_len, d_model]
        x = x + t_emb

        # Add positional encoding
        x = self.pos_encoding(x)

        # Transformer layers
        for layer in self.layers:
            x = layer(x)

        # Output projection
        x = self.output_proj(x)  # [batch, seq_len, 3]

        return x

In [None]:
# Create model instance
model = SketchDiffusionModel(d_model=256, nhead=8, num_layers=6)

# Test forward pass
batch_size = 4
seq_len = 200
x = torch.randn(batch_size, seq_len, 3)
t = torch.randint(0, 1000, (batch_size,))

output = model(x, t)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
class NoiseSchedule:
    def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02):
        self.num_timesteps = num_timesteps

        # Linear schedule
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)

        # Calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)

        # Calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)

noise_schedule = NoiseSchedule()

In [None]:
def forward_diffusion(x0, t, noise_schedule):
    """Add noise to data according to diffusion schedule"""
    device = x0.device

    # Get noise schedule values
    sqrt_alphas_cumprod_t = noise_schedule.sqrt_alphas_cumprod[t].to(device)
    sqrt_one_minus_alphas_cumprod_t = noise_schedule.sqrt_one_minus_alphas_cumprod[t].to(device)

    # Reshape for broadcasting
    sqrt_alphas_cumprod_t = sqrt_alphas_cumprod_t.reshape(-1, 1, 1)
    sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod_t.reshape(-1, 1, 1)

    # Sample noise
    noise = torch.randn_like(x0)

    # Add noise
    x_t = sqrt_alphas_cumprod_t * x0 + sqrt_one_minus_alphas_cumprod_t * noise

    return x_t, noise

In [None]:
# Test forward diffusion with cat data
sample_batch = torch.stack(cat_sequences[:8])  # Get 8 samples
timesteps = torch.randint(0, 1000, (8,))

# Add noise
noisy_batch, noise = forward_diffusion(sample_batch, timesteps, noise_schedule)

print(f"Original batch shape: {sample_batch.shape}")
print(f"Noisy batch shape: {noisy_batch.shape}")
print(f"Noise shape: {noise.shape}")
print(f"Timesteps: {timesteps}")

In [None]:
def diffusion_loss(model, x0, noise_schedule, device):
    """Calculate diffusion loss"""
    batch_size = x0.shape[0]

    # Sample timesteps
    t = torch.randint(0, noise_schedule.num_timesteps, (batch_size,), device=device)

    # Add noise
    x_t, noise = forward_diffusion(x0, t, noise_schedule)

    # Predict noise
    predicted_noise = model(x_t, t)

    # Calculate loss (MSE between actual and predicted noise)
    loss = F.mse_loss(predicted_noise, noise)

    return loss

In [None]:
# Move data to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
sample_batch = sample_batch.to(device)

# Calculate loss
loss = diffusion_loss(model, sample_batch, noise_schedule, device)
print(f"Loss: {loss.item():.4f}")
print(f"Using device: {device}")

# Section 4: Training Pipeline

In [None]:
import torch.optim as optim
from tqdm import tqdm
import os

# Training hyperparameters
learning_rate = 1e-4
num_epochs = 50
batch_size = 16
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {device}")

In [None]:
def train_epoch(model, dataloader, optimizer, noise_schedule, device):
    model.train()
    total_loss = 0

    for batch in tqdm(dataloader, desc="Training"):
        batch = batch.to(device)

        # Calculate loss
        loss = diffusion_loss(model, batch, noise_schedule, device)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

In [None]:
# Create cat dataloader
cat_tensor = torch.stack(cat_sequences)
cat_dataset = SimpleDataset(cat_tensor)
cat_loader = DataLoader(cat_dataset, batch_size=batch_size, shuffle=True)

print(f"Cat dataset size: {len(cat_dataset)}")
print(f"Number of batches: {len(cat_loader)}")

In [None]:
# Create and setup model for cats
cat_model = SketchDiffusionModel(d_model=256, nhead=8, num_layers=6).to(device)
cat_optimizer = optim.Adam(cat_model.parameters(), lr=learning_rate)

# Move noise schedule to device
noise_schedule.betas = noise_schedule.betas.to(device)
noise_schedule.alphas = noise_schedule.alphas.to(device)
noise_schedule.alphas_cumprod = noise_schedule.alphas_cumprod.to(device)
noise_schedule.alphas_cumprod_prev = noise_schedule.alphas_cumprod_prev.to(device)
noise_schedule.sqrt_alphas_cumprod = noise_schedule.sqrt_alphas_cumprod.to(device)
noise_schedule.sqrt_one_minus_alphas_cumprod = noise_schedule.sqrt_one_minus_alphas_cumprod.to(device)
noise_schedule.posterior_variance = noise_schedule.posterior_variance.to(device)

print(f"Cat model parameters: {sum(p.numel() for p in cat_model.parameters()):,}")

In [None]:
# Training loop for cats
cat_losses = []

for epoch in range(num_epochs):
    avg_loss = train_epoch(cat_model, cat_loader, cat_optimizer, noise_schedule, device)
    cat_losses.append(avg_loss)

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

# Plot training curve
plt.figure(figsize=(10, 5))
plt.plot(cat_losses)
plt.title('Cat Model Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()

In [None]:
# Create models directory
os.makedirs('./models', exist_ok=True)

# Save cat model
torch.save({
    'model_state_dict': cat_model.state_dict(),
    'optimizer_state_dict': cat_optimizer.state_dict(),
    'loss_history': cat_losses,
    'epoch': num_epochs
}, './models/cat_model.pth')

print("Cat model saved successfully!")

In [None]:
# Prepare bus data
bus_tensor = torch.stack(all_data['bus'])
bus_dataset = SimpleDataset(bus_tensor)
bus_loader = DataLoader(bus_dataset, batch_size=batch_size, shuffle=True)

# Initialize bus model
bus_model = SketchDiffusionModel(d_model=256, nhead=8, num_layers=6).to(device)
bus_optimizer = optim.Adam(bus_model.parameters(), lr=learning_rate)

print(f"Bus dataset size: {len(bus_dataset)}")

In [None]:
# Training loop for bus
bus_losses = []

for epoch in range(num_epochs):
    avg_loss = train_epoch(bus_model, bus_loader, bus_optimizer, noise_schedule, device)
    bus_losses.append(avg_loss)

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

# Plot training curve
plt.figure(figsize=(10, 5))
plt.plot(bus_losses)
plt.title('Bus Model Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()

In [None]:
# Save bus model
torch.save({
    'model_state_dict': bus_model.state_dict(),
    'optimizer_state_dict': bus_optimizer.state_dict(),
    'loss_history': bus_losses,
    'epoch': num_epochs
}, './models/bus_model.pth')

print("Bus model saved successfully!")

In [None]:
# Prepare rabbit data
rabbit_tensor = torch.stack(all_data['rabbit'])
rabbit_dataset = SimpleDataset(rabbit_tensor)
rabbit_loader = DataLoader(rabbit_dataset, batch_size=batch_size, shuffle=True)

# Initialize rabbit model
rabbit_model = SketchDiffusionModel(d_model=256, nhead=8, num_layers=6).to(device)
rabbit_optimizer = optim.Adam(rabbit_model.parameters(), lr=learning_rate)

print(f"Rabbit dataset size: {len(rabbit_dataset)}")

In [None]:
# Training loop for rabbit
rabbit_losses = []

for epoch in range(num_epochs):
    avg_loss = train_epoch(rabbit_model, rabbit_loader, rabbit_optimizer, noise_schedule, device)
    rabbit_losses.append(avg_loss)

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

# Plot training curve
plt.figure(figsize=(10, 5))
plt.plot(rabbit_losses)
plt.title('Rabbit Model Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()

In [None]:
# Save rabbit model
torch.save({
    'model_state_dict': rabbit_model.state_dict(),
    'optimizer_state_dict': rabbit_optimizer.state_dict(),
    'loss_history': rabbit_losses,
    'epoch': num_epochs
}, './models/rabbit_model.pth')

print("Rabbit model saved successfully!")

In [None]:
# Plot all training curves together
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.plot(cat_losses, label='Cat', color='orange')
plt.plot(bus_losses, label='Bus', color='blue')
plt.plot(rabbit_losses, label='Rabbit', color='green')
plt.title('Training Loss Comparison')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(cat_losses[-20:], label='Cat', color='orange')
plt.plot(bus_losses[-20:], label='Bus', color='blue')
plt.plot(rabbit_losses[-20:], label='Rabbit', color='green')
plt.title('Last 20 Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

In [None]:
# Print final training statistics
final_losses = {
    'cat': cat_losses[-1],
    'bus': bus_losses[-1],
    'rabbit': rabbit_losses[-1]
}

print("Final Training Losses:")
for category, loss in final_losses.items():
    print(f"  {category.capitalize()}: {loss:.4f}")

print(f"\nTotal training time: {num_epochs} epochs per model")
print("All models saved to ./models/ directory")

# References

❗ Do not forget to include the references you used when filling out the notebook.

- []()