<a href="https://colab.research.google.com/github/Tanyabharara/draw_and_learn_sketchrnn/blob/main/Untitled5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# =====================================================
# 1️⃣ Imports
# =====================================================
import os
import urllib.request
import numpy as np
from tqdm import tqdm
from torch import nn, optim
import torch
from torch.utils.data import Dataset, DataLoader, random_split

# =====================================================
# 2️⃣ Configurations
# =====================================================
DATA_DIR = "quickdraw_npy"
SAMPLES_PER_CLASS = 2000
IMG_SIZE = 28
BATCH_SIZE = 64
EPOCHS = 50
PATIENCE = 7
TARGET_LOSS = 0.01

# ✅ Categories (51)
categories = [
    "airplane","ant","apple","backpack","banana","bed","bicycle","bird","book","bread",
    "bus","cake","camera","car","cat","chair","clock","computer","cup","dog","door",
    "duck","envelope","eye","face","fish","flower","fork","frog","guitar","hat","horse",
    "house","lollipop","pencil","pig","pizza","rabbit","shoe","snake","spider","spoon",
    "star","sun","train","tree","truck","umbrella","alarm_clock","birthday_cake","butterfly"
]
print("📦 Total categories:", len(categories))

os.makedirs(DATA_DIR, exist_ok=True)

# =====================================================
# 3️⃣ Download .npy Files with Error Handling
# =====================================================
def download_drawings():
    base_url = "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap"
    failed = []
    for cat in tqdm(categories, desc="⏬ Downloading categories"):
        file_path = os.path.join(DATA_DIR, f"{cat}.npy")
        if not os.path.exists(file_path):
            url = f"{base_url}/{cat.replace('_', '%20')}.npy"  # Handle underscores
            try:
                urllib.request.urlretrieve(url, file_path)
            except Exception as e:
                print(f"❌ Failed to download {cat}: {e}")
                failed.append(cat)
    print("✅ Download complete!")
    if failed:
        print(f"⚠️ Failed downloads: {failed}")

download_drawings()

# =====================================================
# 4️⃣ Dataset Class
# =====================================================
class SketchRNNDataset(Dataset):
    def __init__(self, root_dir, categories, samples_per_class):
        self.data, self.labels = [], []
        for label, cat in enumerate(categories):
            file_path = os.path.join(root_dir, f"{cat}.npy")
            if not os.path.exists(file_path):
                continue  # Skip missing files
            drawings = np.load(file_path)
            drawings = drawings.reshape(-1, IMG_SIZE, IMG_SIZE)
            if len(drawings) < samples_per_class:
                continue
            idx = np.random.choice(len(drawings), samples_per_class, replace=False)
            self.data.append(drawings[idx])
            self.labels.extend([label] * samples_per_class)
        self.data = np.vstack(self.data).astype(np.float32) / 255.0

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx]).unsqueeze(0)  # (1, 28, 28)
        y = self.labels[idx]
        return x, y

# =====================================================
# 5️⃣ Prepare Dataset
# =====================================================
dataset = SketchRNNDataset(DATA_DIR, categories, SAMPLES_PER_CLASS)
train_size = int(0.8 * len(dataset))
val_size   = int(0.1 * len(dataset))
test_size  = len(dataset) - train_size - val_size

train_data, val_data, test_data = random_split(dataset, [train_size, val_size, test_size])
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_data, batch_size=BATCH_SIZE)
test_loader  = DataLoader(test_data, batch_size=BATCH_SIZE)

print(f"✅ Dataset ready! Train: {len(train_data)} | Val: {len(val_data)} | Test: {len(test_data)}")

# =====================================================
# 6️⃣ Sketch-RNN Autoencoder
# =====================================================
class SketchRNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),  # -> 14x14
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # -> 7x7
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 256),
            nn.ReLU()
        )
        self.latent = nn.Linear(256, 128)
        self.decoder_fc = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 64 * 7 * 7),
            nn.ReLU()
        )
        self.decoder_conv = nn.Sequential(
            nn.Unflatten(1, (64, 7, 7)),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),  # -> 14x14
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),  # -> 28x28
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        z = self.latent(x)
        x = self.decoder_fc(z)
        x = self.decoder_conv(x)
        return x

# =====================================================
# 7️⃣ Train Setup
# =====================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SketchRNN().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# =====================================================
# 8️⃣ Training Loop with Early Stopping
# =====================================================
best_val_loss = float('inf')
patience_counter = 0

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    for x, _ in train_loader:
        x = x.to(device)
        optimizer.zero_grad()
        recon = model(x)
        loss = criterion(recon, x)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x.size(0)
    avg_train_loss = train_loss / len(train_loader.dataset)

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for x, _ in val_loader:
            x = x.to(device)
            recon = model(x)
            loss = criterion(recon, x)
            val_loss += loss.item() * x.size(0)
    avg_val_loss = val_loss / len(val_loader.dataset)

    print(f"📊 Epoch {epoch+1}/{EPOCHS} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), "best_sketch_rnn.pth")
        print("💾 Saved new best model.")
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print("⏹️ Early stopping triggered.")
            break

    if avg_val_loss <= TARGET_LOSS:
        print(f"🏁 Target loss {TARGET_LOSS} reached.")
        break

print("✅ Training complete.")


📦 Total categories: 51


⏬ Downloading categories:  47%|████▋     | 24/51 [02:45<03:03,  6.81s/it]

In [None]:
# =====================================================
# 6️⃣ SketchRNN Classifier (instead of Autoencoder)
# =====================================================
class SketchRNNClassifier(nn.Module):
    def __init__(self, num_classes=len(categories)):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),  # -> 14x14
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # -> 7x7
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 256),
            nn.ReLU()
        )
        self.fc = nn.Linear(256, num_classes)  # classification head

    def forward(self, x):
        x = self.encoder(x)
        x = self.fc(x)
        return x


In [None]:
model = SketchRNNClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)


In [None]:
for epoch in range(EPOCHS):
    # ---- Train ----
    model.train()
    train_loss, correct, total = 0, 0, 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x.size(0)

        _, predicted = outputs.max(1)
        correct += (predicted == y).sum().item()
        total += y.size(0)

    avg_train_loss = train_loss / len(train_loader.dataset)
    train_acc = correct / total

    # ---- Validation ----
    model.eval()
    val_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            loss = criterion(outputs, y)
            val_loss += loss.item() * x.size(0)

            _, predicted = outputs.max(1)
            correct += (predicted == y).sum().item()
            total += y.size(0)

    avg_val_loss = val_loss / len(val_loader.dataset)
    val_acc = correct / total

    print(f"📊 Epoch {epoch+1}/{EPOCHS} | "
          f"Train Loss: {avg_train_loss:.4f}, Acc: {train_acc:.4f} | "
          f"Val Loss: {avg_val_loss:.4f}, Acc: {val_acc:.4f}")


In [None]:
# =====================================================
# 9️⃣ Evaluate on Test Set
# =====================================================
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        outputs = model(x)
        _, predicted = outputs.max(1)
        correct += (predicted == y).sum().item()
        total += y.size(0)

test_acc = correct / total
print(f"✅ Test Accuracy: {test_acc:.4f}")
