# Doodle Recognition with Siamese Network (Image Matching)

This notebook implements a **Siamese Network** approach for doodle recognition, which is fundamentally different from traditional classification.

## Key Differences from Standard Classification:
- **Learns similarity** instead of direct class labels
- **Creates embeddings** where similar doodles are close in vector space
- **Few-shot learning** - can recognize new categories with just a few examples
- **More flexible** - doesn't require retraining for new categories

## How It Works:
1. Two images are passed through the same network (shared weights)
2. The network outputs embeddings for each image
3. A contrastive loss pulls similar images together, pushes different ones apart
4. At inference, we compare the query image to reference images from each category


In [2]:
# Import necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import os
from PIL import Image
import warnings
import random
from collections import defaultdict
import json
warnings.filterwarnings('ignore')

# PyTorch libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Sampler
import torchvision.transforms as transforms
from torchvision import models

from sklearn.model_selection import train_test_split
from sklearn.manifold import TSNE
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from tqdm import tqdm

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    print("MPS (Apple Silicon) available: True")
else:
    print("Using CPU")


ModuleNotFoundError: No module named 'matplotlib'

In [None]:
# Setup device (GPU/MPS/CPU)
def setup_device():
    """Setup device for training (CUDA/MPS/CPU)"""
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"✓ Using CUDA: {torch.cuda.get_device_name(0)}")
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        device = torch.device('mps')
        print("✓ Using MPS (Apple Silicon)")
    else:
        device = torch.device('cpu')
        print("⚠ Using CPU")
    return device

device = setup_device()


## 1. Configuration


In [None]:
# Configuration
IMG_SIZE = 105  # Standard size for Siamese networks (from original paper)
EMBEDDING_DIM = 128  # Dimension of the learned embedding
BATCH_SIZE = 32  # Pairs per batch
LEARNING_RATE = 1e-4
NUM_EPOCHS = 30
MARGIN = 1.0  # Margin for contrastive loss

# Data configuration
MAX_CATEGORIES = 50  # Use subset for faster training (None = all)
IMAGES_PER_CATEGORY = 500  # Images per category
PAIRS_PER_EPOCH = 20000  # Number of pairs to generate per epoch

# Few-shot configuration
N_WAY = 5  # Number of classes in few-shot test
K_SHOT = 5  # Number of support examples per class

print(f"Configuration:")
print(f"  Image size: {IMG_SIZE}x{IMG_SIZE}")
print(f"  Embedding dimension: {EMBEDDING_DIM}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Margin: {MARGIN}")


## 2. Data Loading & Exploration


In [None]:
# Load data paths organized by category
def load_data_by_category(base_dir, max_categories=None, max_per_category=None):
    """
    Load image paths organized by category for Siamese network training.
    
    Returns:
        category_images: dict mapping category name -> list of image paths
        categories: list of category names
    """
    base_path = Path(base_dir)
    all_categories = sorted([d.name for d in base_path.iterdir() if d.is_dir()])
    
    if max_categories:
        categories = all_categories[:max_categories]
    else:
        categories = all_categories
    
    print(f"Loading data from {len(categories)} categories...")
    
    category_images = {}
    total_images = 0
    
    for category in tqdm(categories, desc="Loading categories"):
        category_path = base_path / category
        if not category_path.exists():
            continue
        
        png_files = list(category_path.glob('*.png'))
        if max_per_category:
            png_files = png_files[:max_per_category]
        
        # Verify images are valid
        valid_files = []
        for f in png_files:
            try:
                with Image.open(f) as img:
                    img.verify()
                valid_files.append(str(f))
            except:
                pass
        
        if len(valid_files) >= 2:  # Need at least 2 images per category for pairs
            category_images[category] = valid_files
            total_images += len(valid_files)
    
    print(f"\nLoaded {total_images} images from {len(category_images)} categories")
    return category_images, list(category_images.keys())

# Load data
category_images, categories = load_data_by_category(
    'doodles/doodle',
    max_categories=MAX_CATEGORIES,
    max_per_category=IMAGES_PER_CATEGORY
)

print(f"\nCategories loaded: {categories[:10]}...")
print(f"Images per category: {[len(category_images[c]) for c in categories[:5]]}")


In [None]:
# Visualize sample pairs (similar and dissimilar)
def visualize_sample_pairs(category_images, categories, num_pairs=4):
    """Visualize sample similar and dissimilar pairs"""
    fig, axes = plt.subplots(num_pairs, 4, figsize=(12, num_pairs * 3))
    
    for i in range(num_pairs):
        # Similar pair (same category)
        cat = random.choice(categories)
        img1_path, img2_path = random.sample(category_images[cat], 2)
        img1 = Image.open(img1_path)
        img2 = Image.open(img2_path)
        
        axes[i, 0].imshow(img1, cmap='gray')
        axes[i, 0].set_title(f'Similar: {cat}')
        axes[i, 0].axis('off')
        axes[i, 1].imshow(img2, cmap='gray')
        axes[i, 1].set_title(f'Similar: {cat}')
        axes[i, 1].axis('off')
        
        # Dissimilar pair (different categories)
        cat1, cat2 = random.sample(categories, 2)
        img3_path = random.choice(category_images[cat1])
        img4_path = random.choice(category_images[cat2])
        img3 = Image.open(img3_path)
        img4 = Image.open(img4_path)
        
        axes[i, 2].imshow(img3, cmap='gray')
        axes[i, 2].set_title(f'Dissimilar: {cat1}')
        axes[i, 2].axis('off')
        axes[i, 3].imshow(img4, cmap='gray')
        axes[i, 3].set_title(f'Dissimilar: {cat2}')
        axes[i, 3].axis('off')
    
    plt.suptitle('Sample Pairs: Similar (left) vs Dissimilar (right)', fontsize=14)
    plt.tight_layout()
    plt.show()

visualize_sample_pairs(category_images, categories)


## 3. Siamese Dataset & Data Transforms


In [None]:
# Data transforms
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomRotation(10),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1]
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

print("Transforms defined.")


In [None]:
class SiamesePairDataset(Dataset):
    """
    Dataset that generates pairs of images for Siamese network training.
    
    Returns:
        (img1, img2, label) where label=1 if same class, label=0 if different
    """
    def __init__(self, category_images, categories, transform=None, pairs_per_epoch=10000):
        self.category_images = category_images
        self.categories = categories
        self.transform = transform
        self.pairs_per_epoch = pairs_per_epoch
        
        # Pre-generate pairs for this epoch
        self.pairs = self._generate_pairs()
    
    def _generate_pairs(self):
        """Generate balanced pairs of similar and dissimilar images"""
        pairs = []
        
        # Generate equal number of positive and negative pairs
        n_positive = self.pairs_per_epoch // 2
        n_negative = self.pairs_per_epoch - n_positive
        
        # Positive pairs (same category)
        for _ in range(n_positive):
            cat = random.choice(self.categories)
            if len(self.category_images[cat]) >= 2:
                img1, img2 = random.sample(self.category_images[cat], 2)
                pairs.append((img1, img2, 1))  # label=1 for similar
        
        # Negative pairs (different categories)
        for _ in range(n_negative):
            cat1, cat2 = random.sample(self.categories, 2)
            img1 = random.choice(self.category_images[cat1])
            img2 = random.choice(self.category_images[cat2])
            pairs.append((img1, img2, 0))  # label=0 for dissimilar
        
        random.shuffle(pairs)
        return pairs
    
    def regenerate_pairs(self):
        """Regenerate pairs for a new epoch"""
        self.pairs = self._generate_pairs()
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        img1_path, img2_path, label = self.pairs[idx]
        
        # Load images as grayscale
        img1 = Image.open(img1_path).convert('L')
        img2 = Image.open(img2_path).convert('L')
        
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        
        return img1, img2, torch.tensor(label, dtype=torch.float32)

print("SiamesePairDataset defined.")


In [None]:
# Split data into train and validation sets
train_categories, val_categories = train_test_split(
    categories, test_size=0.2, random_state=SEED
)

print(f"Train categories: {len(train_categories)}")
print(f"Validation categories: {len(val_categories)}")
print(f"\nNote: We split by category to test few-shot generalization!")

# Create category_images dicts for train and val
train_category_images = {k: v for k, v in category_images.items() if k in train_categories}
val_category_images = {k: v for k, v in category_images.items() if k in val_categories}

# Create datasets
train_dataset = SiamesePairDataset(
    train_category_images, train_categories, 
    transform=train_transform, 
    pairs_per_epoch=PAIRS_PER_EPOCH
)

val_dataset = SiamesePairDataset(
    val_category_images, val_categories,
    transform=val_transform,
    pairs_per_epoch=PAIRS_PER_EPOCH // 5
)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"\nTrain pairs per epoch: {len(train_dataset)}")
print(f"Validation pairs: {len(val_dataset)}")


## 4. Siamese Network Architecture


In [None]:
class SiameseEmbeddingNet(nn.Module):
    """
    Embedding network that produces fixed-size embeddings for images.
    Architecture inspired by the original Siamese paper for one-shot learning.
    """
    def __init__(self, embedding_dim=128):
        super(SiameseEmbeddingNet, self).__init__()
        
        # Convolutional layers
        self.conv = nn.Sequential(
            # Conv block 1: 1 -> 64 channels
            nn.Conv2d(1, 64, kernel_size=10, stride=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            # Conv block 2: 64 -> 128 channels
            nn.Conv2d(64, 128, kernel_size=7, stride=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            # Conv block 3: 128 -> 128 channels
            nn.Conv2d(128, 128, kernel_size=4, stride=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            # Conv block 4: 128 -> 256 channels
            nn.Conv2d(128, 256, kernel_size=4, stride=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )
        
        # Calculate the size after conv layers
        self._initialize_fc_size()
        
        # Fully connected layers
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.fc_input_size, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, embedding_dim)
        )
    
    def _initialize_fc_size(self):
        """Calculate the size of features after conv layers"""
        with torch.no_grad():
            dummy = torch.zeros(1, 1, IMG_SIZE, IMG_SIZE)
            out = self.conv(dummy)
            self.fc_input_size = out.view(1, -1).size(1)
    
    def forward(self, x):
        """Extract embedding from input image"""
        x = self.conv(x)
        x = self.fc(x)
        # L2 normalize embeddings
        x = F.normalize(x, p=2, dim=1)
        return x


class SiameseNetwork(nn.Module):
    """
    Full Siamese Network that takes two images and outputs distance/similarity.
    """
    def __init__(self, embedding_dim=128):
        super(SiameseNetwork, self).__init__()
        self.embedding_net = SiameseEmbeddingNet(embedding_dim)
    
    def forward(self, x1, x2):
        """Forward pass for pair of images"""
        embedding1 = self.embedding_net(x1)
        embedding2 = self.embedding_net(x2)
        return embedding1, embedding2
    
    def get_embedding(self, x):
        """Get embedding for a single image"""
        return self.embedding_net(x)

# Initialize model
model = SiameseNetwork(embedding_dim=EMBEDDING_DIM).to(device)

# Print model summary
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")


## 5. Contrastive Loss


In [None]:
class ContrastiveLoss(nn.Module):
    """
    Contrastive Loss function.
    
    For similar pairs (label=1): minimize distance
    For dissimilar pairs (label=0): maximize distance up to margin
    
    Loss = (1-Y) * 0.5 * D^2 + Y * 0.5 * max(0, margin - D)^2
    
    Where:
        Y = 1 for similar pairs, 0 for dissimilar
        D = euclidean distance between embeddings
    """
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
    
    def forward(self, embedding1, embedding2, label):
        # Euclidean distance
        euclidean_distance = F.pairwise_distance(embedding1, embedding2)
        
        # Contrastive loss
        loss_similar = label * torch.pow(euclidean_distance, 2)
        loss_dissimilar = (1 - label) * torch.pow(
            torch.clamp(self.margin - euclidean_distance, min=0.0), 2
        )
        
        loss = torch.mean(0.5 * (loss_similar + loss_dissimilar))
        return loss, euclidean_distance

# Initialize loss and optimizer
criterion = ContrastiveLoss(margin=MARGIN)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

print(f"Loss: Contrastive Loss with margin={MARGIN}")
print(f"Optimizer: Adam with lr={LEARNING_RATE}")


## 6. Training
