# 🎨 Tutorial: Deep Learning for Color Quantization

![Color Quantization](https://upload.wikimedia.org/wikipedia/commons/thumb/6/6a/Dithering_example_dithered_color_palette.png/320px-Dithering_example_dithered_color_palette.png)

## Welcome to the Fascinating World of Color Quantization! 🌈

In this comprehensive tutorial, you'll learn:
- 🎨 What is color quantization and why it matters
- 🧮 The mathematics behind color reduction algorithms
- 🤖 How to use deep learning for intelligent color selection
- 💻 Hands-on implementation with PyTorch
- 🎯 Visualization of quantization effects
- 🧪 Interactive exercises to build your skills

By the end, you'll be ready to implement sophisticated color quantization using neural networks!


## 📚 Table of Contents

1. [🎓 Understanding Color Quantization](#1--understanding-color-quantization)
2. [🧮 Mathematical Foundation](#2--mathematical-foundation)
3. [🔧 Setting Up the Environment](#3--setting-up-the-environment)
4. [🏗️ Classical Approaches](#4--classical-approaches)
5. [🤖 Deep Learning Approach](#5--deep-learning-approach)
6. [🧠 Building the Neural Network](#6--building-the-neural-network)
7. [🎯 Custom Loss Function](#7--custom-loss-function)
8. [💼 Complete Solution](#8--complete-solution)
9. [🎨 Visualizing Results](#9--visualizing-results)
10. [🎮 Interactive Exercises](#10--interactive-exercises)
11. [🚀 Advanced Techniques](#11--advanced-techniques)
12. [📖 Summary and Next Steps](#12--summary-and-next-steps)


## 1. 🎓 Understanding Color Quantization

### What is Color Quantization?

Imagine you have a beautiful photograph with millions of colors 🌈. Color quantization is the process of reducing this rich palette to a much smaller set of colors while preserving the visual quality as much as possible.

Think of it like this:
- **Original image**: 16.7 million possible colors (24-bit RGB)
- **Quantized image**: Only 37 carefully chosen colors
- **Goal**: Make the 37-color version look as close to the original as possible

### Key Concepts:

- **Color Palette**: The limited set of colors we can use (37 in our case)
- **Color Assignment**: Mapping each original pixel to the closest palette color
- **Quality Metrics**: How we measure how "good" our quantization is
- **Color Space**: The mathematical representation of colors (RGB)

### Mathematical Definition:

Given an image $I$ with pixels $p_{i,j} \in \mathbb{R}^3$ and a palette $P = \{c_1, c_2, ..., c_k\}$ where $k=37$:

$$Q(I) = \{q_{i,j} | q_{i,j} = \arg\min_{c \in P} ||p_{i,j} - c||_2\}$$

Where $Q(I)$ is our quantized image and $||\cdot||_2$ is the Euclidean distance.

### Why Should We Care?

🖼️ **Image Compression**: Reduce file sizes while maintaining quality  
🎮 **Game Graphics**: Limited color palettes for retro aesthetics  
🖨️ **Printing**: Optimize colors for specific printing processes  
📱 **Mobile Displays**: Adapt images for devices with limited color ranges  
🎨 **Artistic Effects**: Create stylized, poster-like images


## 2. 🧮 Mathematical Foundation

Before diving into implementation, let's understand the math behind color quantization.

### The Optimization Problem

Color quantization is fundamentally an optimization problem. We want to find the best palette $P$ and assignment function $f$ such that:

$$\min_{P,f} \sum_{i,j} ||p_{i,j} - f(p_{i,j}, P)||_2^2$$

This is the **Mean Squared Error (MSE)** - our primary quality metric.

### Color Cost Function

In our specific problem, we have an additional constraint - the **color cost**. Some colors are "more expensive" than others:

$$\text{Color Cost}(c) = \min_{v \in V} ||c - v||_2$$

Where $V$ are the vertices of the RGB cube: $\{(0,0,0), (0,0,255), (0,255,0), ..., (255,255,255)\}$

### Complete Objective Function

Our final objective combines both image quality and color preference:

$$L = 2 \cdot MSE + 21 \cdot \max(\text{color costs}) + 42 \cdot \text{mean}(\text{color costs})$$

### The Challenge

This is a **non-convex, combinatorial optimization problem**. Traditional approaches:
- K-means clustering 📊
- Median cut algorithm ✂️
- Octree quantization 🌳

**Our Innovation**: Use deep learning to directly optimize this complex objective! 🤖


### 💡 Key Insight

**Why is this problem so hard?**

Traditional clustering algorithms like K-means optimize for minimal within-cluster variance, but our objective function is much more complex! We need to balance:

1. **Image Quality**: How similar the quantized image looks to the original
2. **Color Preferences**: Some colors are "cheaper" than others (RGB cube vertices)
3. **Exact Count**: We must use exactly 37 colors, no more, no less

Think of it like choosing a color palette for a painting - you want colors that represent your image well, but you also prefer certain "primary" colors that are easier to work with!


## 3. 🔧 Setting Up the Environment

Let's start by importing all the necessary libraries and setting up our environment. We'll be working with PyTorch for neural networks and various other libraries for image processing and visualization.


In [None]:
# Essential imports for our color quantization tutorial
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import os
from PIL import Image

# PyTorch for neural networks
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# For image processing and visualization
import cv2
from sklearn.cluster import KMeans
from sklearn.utils import shuffle
import warnings

warnings.filterwarnings("ignore")

# Set up device - GPU greatly speeds up neural network training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🚀 Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("✅ Environment setup complete!")
print(f"📦 PyTorch version: {torch.__version__}")
print(f"🖥️  CUDA available: {torch.cuda.is_available()}")

In [None]:
# Let's create a sample image for our experiments
def create_sample_image(size=64):
    """Create a colorful sample image for experimentation"""
    img = np.zeros((size, size, 3), dtype=np.uint8)

    # Create different colored regions
    # Red region
    img[: size // 2, : size // 2] = [255, 100, 100]

    # Green region
    img[: size // 2, size // 2 :] = [100, 255, 100]

    # Blue region
    img[size // 2 :, : size // 2] = [100, 100, 255]

    # Mixed region with gradient
    for i in range(size // 2, size):
        for j in range(size // 2, size):
            img[i, j] = [
                int(255 * (i - size // 2) / (size // 2)),
                int(255 * (j - size // 2) / (size // 2)),
                int(255 * ((i + j - size) / size)),
            ]

    return img


# Create and visualize our sample image
sample_img = create_sample_image(128)
print(f"📸 Created sample image with shape: {sample_img.shape}")
print(f"🎨 Color range: [{sample_img.min()}, {sample_img.max()}]")
print(
    f"🔢 Unique colors in sample: {len(np.unique(sample_img.reshape(-1, 3), axis=0))}"
)

# Display the image
plt.figure(figsize=(8, 6))
plt.imshow(sample_img)
plt.title("🎨 Sample Image for Color Quantization", fontsize=14, fontweight="bold")
plt.axis("off")
plt.show()

## 4. 🏗️ Classical Approaches

Before we dive into the deep learning solution, let's understand how traditional methods work. This will help us appreciate why neural networks are so powerful for this problem.

### K-Means Clustering: The Classic Approach

K-means is the most common method for color quantization. The idea is simple:
1. Treat each pixel as a point in 3D RGB space
2. Find k=37 cluster centers that minimize within-cluster variance
3. Assign each pixel to its nearest cluster center

Let's implement this and see how it performs!


In [None]:
def mse(img1, img2):
    """Calculate Mean Squared Error between two images"""
    return np.mean((img1.astype(float) - img2.astype(float)) ** 2)


def color_cost(img):
    """
    Calculate color cost according to the problem definition.

    Color cost is the distance to the nearest RGB cube vertex.
    RGB cube vertices are combinations of 0 and 255 for each channel.
    """
    # The 8 vertices of the RGB cube
    vertices = np.array(
        [
            [0, 0, 0],  # black
            [0, 0, 255],  # blue
            [0, 255, 0],  # green
            [0, 255, 255],  # cyan
            [255, 0, 0],  # red
            [255, 0, 255],  # magenta
            [255, 255, 0],  # yellow
            [255, 255, 255],  # white
        ]
    )

    # Flatten image to list of pixels
    pixels = img.reshape(-1, 3)

    # Calculate distance from each pixel to each vertex
    distances = np.sqrt(
        np.sum((pixels[:, None, :] - vertices[None, :, :]) ** 2, axis=2)
    )

    # Cost of each pixel is distance to nearest vertex
    costs = np.min(distances, axis=1)

    return np.mean(costs), np.max(costs)


def quantization_score(original, quantized):
    """
    Calculate the complete quantization score according to the problem.

    Score = 2*MSE + 21*max_color_cost + 42*mean_color_cost
    """
    mse_value = mse(original, quantized)
    mean_cost, max_cost = color_cost(quantized)

    return 2 * mse_value + 21 * max_cost + 42 * mean_cost


print("📊 Scoring functions implemented!")
print("   - mse(): Mean Squared Error")
print("   - color_cost(): Distance to RGB cube vertices")
print("   - quantization_score(): Complete objective function")

In [None]:
def kmeans_quantization(image, n_colors=37):
    """
    Quantize image colors using K-means clustering.

    This is the traditional approach - let's see how it performs!
    """
    # Reshape image to be a list of pixels
    pixels = image.reshape(-1, 3)

    # Apply K-means clustering to find color centers
    kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=10)
    kmeans.fit(pixels)

    # Get the color palette (cluster centers)
    palette = kmeans.cluster_centers_.astype(np.uint8)

    # Assign each pixel to nearest cluster center
    labels = kmeans.labels_
    quantized_pixels = palette[labels]

    # Reshape back to image format
    quantized_image = quantized_pixels.reshape(image.shape)

    return quantized_image, palette


# Test K-means quantization on our sample image
print("🔬 Testing K-means quantization...")
kmeans_result, kmeans_palette = kmeans_quantization(sample_img, n_colors=37)

# Calculate the score
score = quantization_score(sample_img, kmeans_result)
mse_val = mse(sample_img, kmeans_result)
mean_cost, max_cost = color_cost(kmeans_result)

print(f"📊 K-means Results:")
print(f"   MSE: {mse_val:.2f}")
print(f"   Mean color cost: {mean_cost:.2f}")
print(f"   Max color cost: {max_cost:.2f}")
print(f"   🎯 Total score: {score:.2f}")
print(f"   🎨 Colors in palette: {len(kmeans_palette)}")

# Visualize the results
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(sample_img)
axes[0].set_title("🖼️ Original Image", fontweight="bold")
axes[0].axis("off")

axes[1].imshow(kmeans_result)
axes[1].set_title("🎨 K-means Quantized", fontweight="bold")
axes[1].axis("off")

# Show the color palette
palette_display = kmeans_palette.reshape(1, -1, 3)
axes[2].imshow(palette_display, aspect="auto")
axes[2].set_title("🌈 Color Palette (37 colors)", fontweight="bold")
axes[2].set_xticks([])
axes[2].set_yticks([])

plt.tight_layout()
plt.show()

### 🤔 Problems with K-means

While K-means is simple and fast, it has several limitations for our specific problem:

1. **Wrong Objective**: K-means minimizes within-cluster variance, but our objective function is much more complex!
2. **No Color Preferences**: K-means doesn't know that some colors (RGB vertices) are "cheaper"
3. **Global Solution**: K-means finds a global solution, but different images might need different strategies
4. **No Learning**: K-means starts from scratch for each image

**The Solution**: Deep Learning! 🤖

Let's see how we can use neural networks to directly optimize our complex objective function.


## 5. 🤖 Deep Learning Approach

Now comes the exciting part! Instead of using traditional algorithms, we'll train a neural network to generate optimal color palettes.

### The Big Idea 💡

What if we could train a neural network that:
1. **Takes an image as input**
2. **Outputs exactly 37 colors** that form the optimal palette
3. **Is trained to minimize our exact objective function**
4. **Learns to balance image quality vs. color costs**

This is exactly what we'll build!

### Why Neural Networks?

🎯 **Direct Optimization**: We can train on our exact objective function  
🧠 **Adaptive**: Different images get different strategies  
🚀 **Powerful**: CNNs can extract complex visual features  
⚡ **End-to-End**: One network handles the entire pipeline  

### The Architecture Strategy

Our neural network will be a **Convolutional Neural Network (CNN)** with:
- **Input**: RGB image (H × W × 3)
- **Feature Extraction**: Convolutional layers to understand image content
- **Global Understanding**: Pooling to reduce spatial dimensions
- **Output**: Exactly 37 RGB colors (37 × 3 = 111 values)

The key insight: **We train a separate network for each image** (overfitting is good here!).


## 6. 🧠 Building the Neural Network

Let's implement our Convolutional Neural Network! This is the heart of our solution.

### Architecture Design

Our CNN will have the following structure:
1. **Conv Layer 1**: 3 → 32 channels, extract basic features
2. **Conv Layer 2**: 32 → 64 channels, extract more complex features  
3. **Conv Layer 3**: 64 → 128 channels, extract high-level features
4. **Global Pooling**: Reduce spatial dimensions
5. **Fully Connected**: 128×64×64 → 512 → 256 → 111 neurons
6. **Sigmoid Output**: Ensure colors are in [0,1] range


In [None]:
class ColorQuantizationCNN(nn.Module):
    """
    Convolutional Neural Network for generating optimal color palettes.

    This network takes an image and outputs exactly 37 RGB colors that
    form the optimal palette for quantizing that specific image.
    """

    def __init__(self, n_colors=37):
        super(ColorQuantizationCNN, self).__init__()
        self.n_colors = n_colors

        # Convolutional layers for feature extraction
        # Each layer extracts increasingly complex features
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)

        # MaxPooling for dimensionality reduction
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Fully connected layers for color generation
        # Note: Input size depends on image size after pooling
        # For 128x128 input: 128 * (128/8) * (128/8) = 128 * 16 * 16
        self.fc1 = nn.Linear(128 * 16 * 16, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, n_colors * 3)  # 37 colors × 3 channels = 111

        # Dropout for regularization
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        """
        Forward pass through the network.

        Args:
            x: Input image tensor (B, H, W, C) in range [0, 255]

        Returns:
            Tensor of shape (B, n_colors, 3) with colors in range [0, 1]
        """
        # Convert from (B, H, W, C) to (B, C, H, W) - PyTorch format
        x = x.permute(0, 3, 1, 2)

        # Normalize to [0, 1] range
        x = x.float() / 255.0

        # Convolutional feature extraction
        # Each conv-relu-pool reduces spatial size by 2
        x = self.pool(F.relu(self.conv1(x)))  # 128x128 -> 64x64
        x = self.pool(F.relu(self.conv2(x)))  # 64x64 -> 32x32
        x = self.pool(F.relu(self.conv3(x)))  # 32x32 -> 16x16

        # Flatten for fully connected layers
        x = x.view(x.size(0), -1)  # Flatten to (B, 128*16*16)

        # Fully connected layers with dropout
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)

        # Output layer with sigmoid to ensure [0,1] range
        x = torch.sigmoid(self.fc3(x))

        # Reshape to (B, n_colors, 3) for easier handling
        x = x.view(x.size(0), self.n_colors, 3)

        return x


# Create our model
model = ColorQuantizationCNN(n_colors=37).to(device)
print(f"🧠 Neural network created!")
print(f"📊 Model has {sum(p.numel() for p in model.parameters())} parameters")
print(f"🔥 Model is on device: {next(model.parameters()).device}")

# Test the model with a sample input
with torch.no_grad():
    # Create a dummy input (batch_size=1, height=128, width=128, channels=3)
    dummy_input = torch.randint(0, 256, (1, 128, 128, 3)).to(device)
    output = model(dummy_input)
    print(f"✅ Model test successful!")
    print(f"   Input shape: {dummy_input.shape}")
    print(f"   Output shape: {output.shape}")
    print(f"   Output range: [{output.min():.3f}, {output.max():.3f}]")

## 7. 🎯 Custom Loss Function

The magic happens in our loss function! We need to implement the exact same objective that will be used for evaluation.

### Key Components:
1. **Quantization Process**: Convert colors to pixels using nearest neighbor
2. **MSE Calculation**: Compare original vs quantized image
3. **Color Cost**: Distance to RGB cube vertices
4. **Unique Colors**: Ensure exactly 37 different colors


In [None]:
def quantize_with_palette(image, palette):
    """
    Quantize image using the given color palette.

    Args:
        image: Original image tensor (B, H, W, C) in range [0, 1]
        palette: Color palette tensor (B, n_colors, C) in range [0, 1]

    Returns:
        Quantized image tensor with same shape as input
    """
    B, H, W, C = image.shape
    _, n_colors, _ = palette.shape

    # Reshape image to (B, H*W, C) for easier computation
    image_flat = image.view(B, -1, C)

    # Compute distance from each pixel to each palette color
    # Using broadcasting: (B, H*W, 1, C) - (B, 1, n_colors, C)
    distances = torch.norm(image_flat.unsqueeze(2) - palette.unsqueeze(1), dim=3)

    # Find closest palette color for each pixel
    closest_indices = torch.argmin(distances, dim=2)

    # Replace each pixel with its closest palette color
    quantized_flat = palette.gather(1, closest_indices.unsqueeze(-1).expand(-1, -1, C))

    # Reshape back to original image shape
    quantized = quantized_flat.view(B, H, W, C)

    return quantized


def ensure_unique_colors(palette):
    """
    Ensure all colors in palette are unique.

    This is crucial - we need exactly 37 different colors!
    """
    B, n_colors, C = palette.shape

    for b in range(B):
        # Convert to 8-bit integers for exact comparison
        colors_int = (palette[b] * 255).round().int()

        seen_colors = set()
        for i in range(n_colors):
            color_tuple = tuple(colors_int[i].tolist())

            # If we've seen this color before, modify it slightly
            while color_tuple in seen_colors:
                # Modify red channel slightly
                colors_int[i, 0] = (colors_int[i, 0] + 1) % 256
                color_tuple = tuple(colors_int[i].tolist())

            seen_colors.add(color_tuple)

        # Convert back to [0, 1] range
        palette[b] = colors_int.float() / 255.0

    return palette


def pytorch_loss_function(original_image, palette):
    """
    PyTorch version of our loss function for training.

    This implements the exact same objective as the evaluation:
    Loss = 2*MSE + 21*max_color_cost + 42*mean_color_cost
    """
    # Ensure unique colors
    palette = ensure_unique_colors(palette.clone())

    # Quantize the image
    quantized_image = quantize_with_palette(original_image, palette)

    # Convert to [0, 255] range for cost calculation
    original_255 = (original_image * 255).round()
    quantized_255 = (quantized_image * 255).round()

    # Calculate MSE
    mse_loss = torch.mean((quantized_255 - original_255) ** 2)

    # Calculate color costs
    # RGB cube vertices
    vertices = torch.tensor(
        [
            [0, 0, 0],
            [0, 0, 255],
            [0, 255, 0],
            [0, 255, 255],
            [255, 0, 0],
            [255, 0, 255],
            [255, 255, 0],
            [255, 255, 255],
        ],
        dtype=torch.float32,
        device=device,
    )

    # Get unique colors from quantized image
    quantized_flat = quantized_255.view(-1, 3)
    unique_colors = torch.unique(quantized_flat, dim=0)

    # Calculate distance to nearest vertex for each unique color
    distances = torch.norm(unique_colors.unsqueeze(1) - vertices.unsqueeze(0), dim=2)
    min_distances = torch.min(distances, dim=1)[0]

    mean_color_cost = torch.mean(min_distances)
    max_color_cost = torch.max(min_distances)

    # Final loss (same as evaluation function)
    total_loss = 2 * mse_loss + 21 * max_color_cost + 42 * mean_color_cost

    return total_loss, mse_loss, mean_color_cost, max_color_cost


print("🎯 Loss functions implemented!")
print("   - quantize_with_palette(): Convert image using palette")
print("   - ensure_unique_colors(): Guarantee 37 unique colors")
print("   - pytorch_loss_function(): Training objective")

## 8. 💼 Complete Solution

Now let's put it all together! This is the main function that will solve the color quantization problem.

### The Training Strategy

For each image, we:
1. **Create a fresh neural network** (overfitting is good!)
2. **Train for many epochs** to find optimal colors
3. **Track the best result** during training
4. **Return the best quantized image**

This approach works because each image has unique characteristics, and we want a personalized solution for each one.


In [None]:
def train_for_image(image, num_epochs=50, learning_rate=0.001):
    """
    Train a neural network specifically for one image.

    This is the core of our approach - individual training for each image!

    Args:
        image: Input image (numpy array, uint8, HxWx3)
        num_epochs: Number of training epochs
        learning_rate: Learning rate for optimizer

    Returns:
        Best quantized image (numpy array, uint8, HxWx3)
    """
    # Prepare the image
    img_tensor = (
        torch.tensor(image, dtype=torch.float32).unsqueeze(0).to(device) / 255.0
    )

    # Create a fresh model for this image
    model = ColorQuantizationCNN(n_colors=37).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Track the best result
    best_loss = float("inf")
    best_image = None

    print(f"🚀 Training neural network for image...")

    for epoch in range(num_epochs):
        model.train()

        # Forward pass: generate palette
        palette = model(img_tensor)

        # Calculate loss
        total_loss, mse_loss, mean_cost, max_cost = pytorch_loss_function(
            img_tensor, palette
        )

        # Backward pass
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # Convert to numpy for evaluation
        with torch.no_grad():
            palette_clean = ensure_unique_colors(palette.clone())
            quantized = quantize_with_palette(img_tensor, palette_clean)
            quantized_np = (quantized.squeeze().cpu().numpy() * 255).astype(np.uint8)

            # Calculate actual score using our numpy functions
            actual_score = quantization_score(image, quantized_np)

            # Keep track of best result
            if actual_score < best_loss:
                best_loss = actual_score
                best_image = quantized_np.copy()

        # Progress updates
        if (epoch + 1) % 10 == 0:
            print(f"   Epoch {epoch+1}/{num_epochs}: Loss = {actual_score:.2f}")

    print(f"✅ Training complete! Best score: {best_loss:.2f}")
    return best_image


# Test our complete solution on the sample image
print("🔬 Testing complete solution...")
neural_result = train_for_image(sample_img, num_epochs=30, learning_rate=0.001)

# Compare with K-means
neural_score = quantization_score(sample_img, neural_result)
kmeans_score = quantization_score(sample_img, kmeans_result)

print(f"\n📊 Final Comparison:")
print(f"   K-means score: {kmeans_score:.2f}")
print(f"   Neural Net score: {neural_score:.2f}")
print(f"   🎉 Improvement: {kmeans_score - neural_score:.2f} points!")

# Visualize the comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(sample_img)
axes[0].set_title("🖼️ Original", fontweight="bold")
axes[0].axis("off")

axes[1].imshow(kmeans_result)
axes[1].set_title(f"🎨 K-means\\nScore: {kmeans_score:.1f}", fontweight="bold")
axes[1].axis("off")

axes[2].imshow(neural_result)
axes[2].set_title(f"🤖 Neural Net\\nScore: {neural_score:.1f}", fontweight="bold")
axes[2].axis("off")

plt.tight_layout()
plt.show()

## 🎮 Interactive Exercises

Now it's your turn to experiment and learn! Try these challenges to deepen your understanding.

### 🎯 Exercise 1: Experiment with Hyperparameters

Try modifying the training parameters and see how they affect the results:

1. **Learning Rate**: Try values like 0.01, 0.001, 0.0001
2. **Number of Epochs**: Try 10, 50, 100 epochs
3. **Network Architecture**: Add/remove layers or change layer sizes

Use the cell below to experiment!


In [None]:
# 🧪 Experiment Playground - Try different hyperparameters!


# Create a different sample image for experimentation
def create_gradient_image(size=128):
    """Create a smooth gradient image"""
    img = np.zeros((size, size, 3), dtype=np.uint8)
    for i in range(size):
        for j in range(size):
            img[i, j] = [
                int(255 * i / size),  # Red gradient
                int(255 * j / size),  # Green gradient
                int(255 * (i + j) / (2 * size)),  # Blue gradient
            ]
    return img


gradient_img = create_gradient_image(128)

# Try different hyperparameters
hyperparams_to_try = [
    {"lr": 0.01, "epochs": 20, "name": "High LR, Few Epochs"},
    {"lr": 0.001, "epochs": 50, "name": "Medium LR, Medium Epochs"},
    {"lr": 0.0001, "epochs": 30, "name": "Low LR, Few Epochs"},
]

results = []
print("🔬 Testing different hyperparameters...")

for params in hyperparams_to_try:
    print(f"\n📊 Testing: {params['name']}")
    result = train_for_image(
        gradient_img, num_epochs=params["epochs"], learning_rate=params["lr"]
    )
    score = quantization_score(gradient_img, result)
    results.append((result, score, params["name"]))
    print(f"   Final score: {score:.2f}")

# Visualize all results
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
axes = axes.flatten()

# Original image
axes[0].imshow(gradient_img)
axes[0].set_title("🖼️ Original Gradient Image", fontweight="bold")
axes[0].axis("off")

# Results from different hyperparameters
for i, (result, score, name) in enumerate(results):
    axes[i + 1].imshow(result)
    axes[i + 1].set_title(f"🤖 {name}\\nScore: {score:.1f}", fontweight="bold")
    axes[i + 1].axis("off")

plt.tight_layout()
plt.show()

print("\\n🎉 Experiment complete! Which hyperparameters worked best?")

## 12. 📖 Summary and Next Steps

Congratulations! 🎉 You've learned how to use deep learning for color quantization!

### What You've Learned:

1. **🎨 Color Quantization Fundamentals**:
   - Reducing millions of colors to just 37
   - Balancing image quality vs. color preferences
   - Complex objective functions with multiple terms

2. **🤖 Deep Learning Approach**:
   - CNN architecture for color palette generation
   - Individual training per image (overfitting as a feature!)
   - Direct optimization of the evaluation objective

3. **💻 Implementation Skills**:
   - PyTorch CNN design and training
   - Custom loss functions for specialized objectives
   - Color space manipulations and quantization algorithms

4. **🔬 Advanced Techniques**:
   - Ensuring unique color constraints
   - Hyperparameter tuning for optimization
   - Comparing neural vs. traditional methods

### For the Solution Implementation:

You now have all the knowledge to implement the complete solution! The key components are:

```python
def your_quantization_algorithm(img, n_clusters=37, num_epochs=100, learning_rate=0.0001):
    # 1. Create CNN model for this specific image
    model = ColorQuantizationCNN(n_clusters).to(device)
    
    # 2. Train with custom loss function
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # 3. Track best result during training
    best_score = float('inf')
    best_image = None
    
    # 4. Training loop with our exact objective function
    # ... (similar to train_for_image function above)
    
    return best_image
```

### 🚀 Advanced Topics to Explore:

- **Different Network Architectures**: ResNet, DenseNet, Vision Transformers
- **Advanced Optimizers**: AdamW, SGD with momentum, learning rate scheduling
- **Multi-Image Training**: Learning shared features across image types
- **Perceptual Loss Functions**: Using pre-trained networks for better visual quality
- **Real-Time Quantization**: Optimizing for speed vs. quality trade-offs

### 📚 Useful Resources:

- 📖 [Deep Learning for Computer Vision](https://www.deeplearningbook.org/)
- 🛠️ [PyTorch Tutorials](https://pytorch.org/tutorials/)
- 📑 [Article about the median cut method for color quantization](https://gowtham000.hashnode.dev/median-cut-a-popular-colour-quantization-strategy)
- 🎨 [Computer Graphics: Color Theory](https://en.wikipedia.org/wiki/Color_theory)

**Good luck with your implementation!** 🌟

Remember: The key insight is that neural networks can directly optimize complex, non-standard objective functions that traditional algorithms struggle with. This makes them perfect for problems like ours where the evaluation criteria are sophisticated and multi-faceted!
