# Cifar-10 Experiments

## Model Architecture

The core architecture is a conditioned **U-Net** designed for image generation on CIFAR-10 ($32 \times 32$ resolution). The model employs a symmetric encoder-decoder structure with skip connections, utilizing Residual Blocks and Self-Attention layers to capture both local textures and global context.

### 1. U-Net Architecture & Funnel
The network processes the input through a series of downsampling and upsampling stages. The base channel dimension is set to **128**.

| Stage | Resolution | Channels (In $\to$ Out) | Blocks |
| :--- | :--- | :--- | :--- |
| **Input** | $32 \times 32$ | $3 \to 128$ | Initial Convolution |
| **Down 1** | $32 \times 32 \to 16 \times 16$ | $128 \to 128$ | 2 $\times$ ResBlock + Strided Conv |
| **Down 2** | $16 \times 16 \to 8 \times 8$ | $128 \to 256$ | 2 $\times$ ResBlock + Attention + Strided Conv |
| **Down 3** | $8 \times 8 \to 4 \times 4$ | $256 \to 256$ | 2 $\times$ ResBlock + Attention + Strided Conv |
| **Bottleneck**| $4 \times 4$ | $256$ | ResBlock $\to$ Attention $\to$ ResBlock |
| **Up 1** | $4 \times 4 \to 8 \times 8$ | $512^* \to 256$ | Upsample + Concat $\to$ 2 $\times$ ResBlock + Attn |
| **Up 2** | $8 \times 8 \to 16 \times 16$ | $512^* \to 256$ | Upsample + Concat $\to$ 2 $\times$ ResBlock + Attn |
| **Up 3** | $16 \times 16 \to 32 \times 32$ | $384^* \to 128$ | Upsample + Concat $\to$ 2 $\times$ ResBlock |
| **Output** | $32 \times 32$ | $128 \to 3$ | GroupNorm $\to$ SiLU $\to$ Conv |

*\*Note: Input channels in Up stages are higher due to concatenation with skip connections from the Down path.*

### 2. Global Conditioning
Conditioning is handled via an **additive embedding** approach that combines temporal information and class labels into a unified global context vector. This vector is computed once per forward pass and shared across all network blocks.

* **Time Embedding:** Time steps $t$ are projected using fixed sinusoidal embeddings followed by a Multi-Layer Perceptron (MLP).
* **Class Embedding:** Class labels $c$ (including a null class for Classifier-Free Guidance) are mapped via a learnable lookup table.
* **Combination:** The final conditioning vector $v_{cond}$ is the element-wise sum of the time and class embeddings:
    $$v_{cond} = \text{MLP}(\text{Sinusoidal}(t)) + \text{Embedding}(c)$$

### 3. Temporal & Class Injection
The conditioning vector $v_{cond}$ is injected into the network within every **Residual Block**.

1.  **Projection:** Inside the block, the global $v_{cond}$ vector is passed through a SiLU activation and a Linear layer to match the channel dimensions of the current feature map.
2.  **Broadcasting & Addition:** The projected embedding is spatially broadcast (repeated across height and width) and added to the feature map after the first convolution but before the second convolution.

This mechanism ensures that both the "when" (diffusion timestep) and "what" (class label) signals modulate the feature processing at every resolution level.

---
## Training Details

The model is trained using a unified framework that supports both standard Diffusion (DDPM/DDIM) and Rectified Flow objectives. The training process minimizes the Mean Squared Error (MSE) between the model prediction and the target defined by the physics of the chosen mode.

### 1. Forward Process & Objectives
The forward process adds noise to the clean image $x_0$ to produce a noisy latent $x_t$ at a random timestep $t \in [0, T]$.

* **Diffusion Mode (DDPM)**
    The forward process follows a variance-preserving schedule on a hypersphere:
    $$x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)$$
    * **Target:** The model is trained to predict the noise $\epsilon$.

* **Flow Mode (Rectified Flow)**
    The forward process follows a straight-line interpolation between data $x_0$ and noise $x_1$:
    $$x_t = (1 - t) x_0 + t x_1, \quad x_1 \sim \mathcal{N}(0, I)$$
    * **Target:** The model is trained to predict the **velocity** $v$ (the direction pointing from data to noise):
        $$v = x_1 - x_0$$

### 2. Classifier-Free Guidance (CFG)
CFG is implemented to enable conditional generation without requiring a separate classifier.

* **Training:**
    During training, class labels $c$ are randomly dropped and replaced with a null token ($c_\emptyset$, index 10) with a probability of $p_{drop}$.
    $$c_{train} = \begin{cases} c_\emptyset & \text{if } r < 0.1 \\ c & \text{otherwise} \end{cases}$$
    * **Default Drop Probability:** $0.1$ (10%).

* **Inference:**
    Sampling is performed using a linear combination of conditional and unconditional predictions to amplify the guidance signal.
    $$\hat{prediction} = \text{pred}(x_t, c_\emptyset) + s \cdot (\text{pred}(x_t, c) - \text{pred}(x_t, c_\emptyset))$$
    * **Default Inference Scale ($s$):** $3.0$.

### 3. Noise Sampling Frequency
* **Once per Epoch:** For every image in the dataset, a random timestep $t$ is sampled independently in every epoch.
* This ensures the model sees every image at various noise levels throughout the complete training duration.

### 4. Optimization
* **Optimizer:** `AdamW` is used for weight updates.
* **Learning Rate:** `3e-4` ($0.0003$).
* **Batch Size:** `128`.
* **Betas:** Defaults are used (typically `(0.9, 0.999)`) as no specific overrides are present in the code.

### 5. Training Duration
* **Epochs:** The model is trained for **100** epochs by default.

---
## Sampling & Inference Details

Sampling is performed periodically during training to monitor progress. The model employs a unified sampler that selects the appropriate Ordinary Differential Equation (ODE) solver based on the training mode.

### 1. Diffusion Mode: DDIM Sampler
For the Diffusion model, the **Denoising Diffusion Implicit Models (DDIM)** sampler is used. This allows for deterministic sampling with fewer steps than the training process.

The update step at time $t$ to reach $t-1$ (where $t-1$ represents the next step in the schedule `t_next`) is defined as:

$$x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \underbrace{\left( \frac{x_t - \sqrt{1 - \bar{\alpha}_t} \cdot \hat{\epsilon}_\theta(x_t)}{\sqrt{\bar{\alpha}_t}} \right)}_{\text{predicted } x_0} + \sqrt{1 - \bar{\alpha}_{t-1}} \cdot \hat{\epsilon}_\theta(x_t)$$

* **Prediction:** The model predicts the noise $\hat{\epsilon}_\theta$ (using CFG).
* **Reconstruction:** The "predicted $x_0$" is derived by removing the predicted noise from the current step.
* **Step:** The process points to $x_{t-1}$ deterministically (setting noise variance $\sigma=0$).

### 2. Flow Mode: Euler Solver
For the Rectified Flow model, a first-order **Euler Method** is used to solve the Neural ODE. This treats the generation process as moving along a straight line trajectory from noise to data.

The update step is calculated using the time delta $dt = t_{next} - t_{now}$ (which is negative during sampling):

$$x_{t_{next}} \approx x_{t_{now}} + v_\theta(x_{t_{now}}) \cdot dt$$

* **Prediction:** The model predicts the velocity field $v_\theta$ (using CFG).
* **Step:** The current state is updated by following the velocity vector for the duration of the time step $dt$.


---
## Evaluation Details

Evaluation is performed using the **Frechet Inception Distance (FID)**, which measures the similarity between the distributions of real and generated images in the feature space of a pre-trained InceptionV3 network.

### 1. Metric: Frechet Inception Distance (FID)
The evaluation script uses a PyTorch-native implementation of FID optimized for GPU.

* **Feature Extraction:** Images are resized to $299 \times 299$ and passed through an **InceptionV3** network (pre-trained on ImageNet). Features are extracted from the penultimate pooling layer (2048 dimensions).
* **Formula:** The distance between the Gaussian distributions of the real features $(\mu_r, \Sigma_r)$ and generated features $(\mu_g, \Sigma_g)$ is calculated as:
    $$d^2((\mu_r, \Sigma_r), (\mu_g, \Sigma_g)) = ||\mu_r - \mu_g||^2 + \text{Tr}(\Sigma_r + \Sigma_g - 2(\Sigma_r \Sigma_g)^{1/2})$$

### 2. Evaluation Scenarios
The evaluation script computes FID across five distinct configuration matrices to assess quality, mode collapse, and class separability:

* **Standard Global FID (All vs. All):**
    * Compares the distribution of the **entire real dataset** against the **entire generated dataset**. This is the standard metric for overall image quality and diversity.

* **Class-Specific Diagonal (Real $C_i$ vs. Gen $C_i$):**
    * Compares **Real Class $i$** against **Generated Class $i$**. This measures how well the model generates a specific class (e.g., "Does the generated Dog look like a real Dog?").

* **Cross-Class Matrix (Real $C_i$ vs. Gen $C_j$):**
    * Compares **Real Class $i$** against **Generated Class $j$** (where $i \neq j$). Low values here might indicate class confusion or mode collapse (e.g., generated Cats looking like real Dogs).

* **Class-to-Global (Real $C_i$ vs. Gen All):**
    * Compares **Real Class $i$** against the **entire generated distribution**. This checks if a specific real mode is represented anywhere in the generated output.

* **Internal Separability (Gen $C_i$ vs. Gen $C_j$):**
    * Compares **Generated Class $i$** against **Generated Class $j$**. This acts as a diversity check; high values are desired, indicating that the model's generated classes are distinct from one another.


---
## 5. Results

### Visualize a few generated examples

In [8]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import os

# ==========================================
# 1. Configuration
# ==========================================
# Replace this with your specific file path
FILE_PATH = './generated_data/generated_seed_101007_1000x10_mode_flow_steps_80_cfg_scale_2.4/generated_batch'

# CIFAR-10 Class Names (for labeling the plot)
CLASS_NAMES = [
    "Airplane", "Automobile", "Bird", "Cat", "Deer",
    "Dog", "Frog", "Horse", "Ship", "Truck"
]

# ==========================================
# 2. Loading Logic
# ==========================================
def unpickle(file):
    """Standard CIFAR-10 loading routine."""
    with open(file, 'rb') as fo:
        # encoding='bytes' is required for Python 3
        data_dict = pickle.load(fo, encoding='bytes')
    return data_dict

def load_and_reshape(file_path):
    if not os.path.exists(file_path):
        print(f"Error: File not found at {file_path}")
        return None, None

    print(f"Loading {file_path}...")
    data_dict = unpickle(file_path)
    
    raw_data = data_dict[b'data']   # Shape: (N, 3072)
    labels = data_dict[b'labels']   # List of N integers
    
    # Reshape Logic:
    # 1. Reshape to (Batch, Channels, Height, Width) -> (N, 3, 32, 32)
    # 2. Transpose to (Batch, Height, Width, Channels) -> (N, 32, 32, 3) for Matplotlib
    num_images = raw_data.shape[0]
    images = raw_data.reshape(num_images, 3, 32, 32).transpose(0, 2, 3, 1)
    
    return images, labels

# ==========================================
# 3. Visualization
# ==========================================
def visualize_grid(images, labels, rows=4, cols=8):
    """Plots a grid of images with labels."""
    total_images = rows * cols
    
    # Create a figure
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5))
    fig.suptitle(f"Samples from: {os.path.basename(os.path.dirname(FILE_PATH))}", fontsize=10)
    
    # Flatten axes array for easy iteration
    axes = axes.flatten()
    
    for i in range(total_images):
        if i >= len(images):
            break
            
        ax = axes[i]
        
        # Display image
        ax.imshow(images[i])
        
        # Label
        class_idx = labels[i]
        class_name = CLASS_NAMES[class_idx]
        ax.set_title(f"{class_name}", fontsize=8)
        
        # Remove ticks
        ax.axis('off')
        
    plt.tight_layout()
    plt.show()

# ==========================================
# Main Execution
# ==========================================
if __name__ == "__main__":
    images, labels = load_and_reshape(FILE_PATH)
    
    if images is not None:
        print(f"Loaded {len(images)} images.")
        
        # Visualize a random selection or the first N?
        # Let's visualize a mix of classes by sorting or just taking the first few.
        # Since your generation script generates 1000 class 0, then 1000 class 1...
        # we should sample with a stride to see different classes.
        
        stride = max(1, len(images) // 32) 
        indices = [i * stride for i in range(32)]
        
        # Gather samples
        sample_imgs = images[indices]
        sample_lbls = [labels[i] for i in indices]
        
        visualize_grid(sample_imgs, sample_lbls)

Error: File not found at ./generated_data/generated_seed_101007_1000x10_mode_flow_steps_80_cfg_scale_2.4/generated_batch


## Draw the graphs

In [9]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

def draw_graphs(csv_path, output_dir=None):
    if output_dir is None:
        output_dir = "."
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    # 1. Load Data
    df = pd.read_csv(csv_path)
    
    # Set visual theme
    sns.set_theme(style="whitegrid", context="talk")
    
    # ==========================================
    # Graph 1: Global FID Analysis
    # ==========================================
    # 1a. FID vs CFG Scale (Hue = Steps)
    plt.figure(figsize=(12, 7))
    sns.lineplot(
        data=df, 
        x="cfg_scale", 
        y="global_fid", 
        hue="steps", 
        style="mode",
        palette="viridis",
        markers=True, dashes=False, linewidth=2.5, markersize=9
    )
    plt.title("Global FID vs. CFG Scale")
    plt.ylabel("Global FID (Lower is Better)")
    plt.xlabel("CFG Scale")
    plt.savefig(f"{output_dir}/graph1a_fid_vs_cfg.png", bbox_inches='tight')
    plt.close()

    # 1b. FID vs Sampling Steps (Hue = CFG)
    plt.figure(figsize=(12, 7))
    sns.lineplot(
        data=df, 
        x="steps", 
        y="global_fid", 
        hue="cfg_scale", 
        style="mode",
        palette="rocket_r", # Light = Low CFG, Dark = High CFG
        markers=True, dashes=False, linewidth=2.5, markersize=9
    )
    plt.title("Global FID vs. Sampling Steps")
    plt.ylabel("Global FID (Lower is Better)")
    plt.xlabel("Sampling Steps")
    plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', title="CFG Scale")
    plt.savefig(f"{output_dir}/graph1b_fid_vs_steps.png", bbox_inches='tight')
    plt.close()

    # ==========================================
    # Graph 2: Quality vs. Diversity Trade-off
    # ==========================================
    plt.figure(figsize=(12, 7))
    sns.scatterplot(
        data=df, 
        x="avg_diversity_ratio", 
        y="avg_class_fid", 
        hue="cfg_scale", 
        style="mode",
        size="steps", 
        sizes=(100, 300),
        palette="flare",
        edgecolor="black", alpha=0.8
    )
    # Reference Line for Perfect Diversity
    plt.axvline(1.0, color='green', linestyle='--', linewidth=2, label="Ideal Diversity (1.0)")
    
    plt.title("Quality (FID) vs. Diversity (Trace Ratio)")
    plt.xlabel("Diversity Ratio (Gen Variance / Real Variance)")
    plt.ylabel("Quality: Avg Class-wise FID (Lower is Better)")
    plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left')
    plt.savefig(f"{output_dir}/graph2_quality_diversity.png", bbox_inches='tight')
    plt.close()

    # ==========================================
    # Graph 3: Separability
    # ==========================================
    plt.figure(figsize=(12, 7))
    sns.lineplot(
        data=df, 
        x="cfg_scale", 
        y="avg_separability", 
        hue="steps", 
        style="mode",
        palette="magma",
        markers=True, dashes=False, linewidth=2.5, markersize=9
    )
    
    # Add GT Baseline if available
    if "gt_separability" in df.columns:
        gt_val = df["gt_separability"].iloc[0]
        plt.axhline(gt_val, color='red', linestyle='--', linewidth=2, label=f"GT Baseline ({gt_val:.2f})")
        plt.legend()
    
    plt.title("Class Separability (Cross-Class FID)")
    plt.ylabel("Avg Separability (Higher is Better)")
    plt.xlabel("CFG Scale")
    plt.savefig(f"{output_dir}/graph3_separability.png", bbox_inches='tight')
    plt.close()

    print("Graphs generated successfully.")

if __name__ == "__main__":
    draw_graphs('./eval_results/eval_batch_flow_1000/results.csv', output_dir="./graphs/eval_batch_flow_1000")
    draw_graphs('./eval_results/eval_batch_diffusion_1000/results.csv', output_dir="./graphs/eval_batch_diffusion_1000")

Graphs generated successfully.
Graphs generated successfully.
