Skip to content

ayouba83/DataLab

Repository files navigation

Mixture Score-Based Diffusion Models with Multiple Choice Learning

Implementation of a score-based diffusion model for MNIST image generation, extended with Multiple Choice Learning (MCL) that trains $K$ expert score networks. The project addresses the expert collapse problem inherent to hard Winner-Takes-All training by implementing three mitigation strategies (Annealed WTA, Relaxed WTA, Resilient MCL), compares five multi-expert inference routing strategies, and studies the quality–diversity trade-off.

Installation

pip install -r requirements.txt

Requires Python 3.9+ and PyTorch 2.0+. MNIST is downloaded automatically on first run.

Quick Start — Reproduce All Experiments

To run the full pipeline (train → gating → sample → evaluate → analyze) in a single command:

python run_pipeline.py

Or the in-memory variant:

python run_all.py

Both produce checkpoints under checkpoints/, figures and metrics under outputs/ and outputs/analysis/. GPU (CUDA or MPS) is auto-detected with CPU fallback.

Using GPUs on MesoNet (Juliet)

1. Access MesoNet

  • Create a MesoNet account and request access to the project.
  • Add your SSH key to MesoNet.

2. Set up your environment

python -m venv venv
source venv/bin/activate
pip install -r requirements.txt

3. Run training on GPU

nvidia-smi  # verify GPU visibility

# Baseline
python -m src.train --mode baseline --epochs 200 --out_dir checkpoints --device cuda

# MCL with K=4 experts (annealed WTA)
python -m src.train --mode mcl --K 4 --mcl_variant annealed_wta --epochs 200 --device cuda

Project Structure

├── run_pipeline.py    # Subprocess-based full pipeline
├── run_all.py         # Single-process pipeline (same stages, in-memory)
├── notebook.ipynb     # Colab-oriented reproduction notebook
├── report.tex         # LaTeX source for the academic report
├── report.pdf         # Compiled report
├── requirements.txt   # Dependencies
├── src/
│   ├── model.py       # ScoreNet (U-Net), GatingNet, and ScoringHead architectures
│   ├── diffusion.py   # Noise schedule, forward process, Euler & Heun ODE samplers
│   ├── train.py       # Training: baseline and MCL (4 variants: hard/annealed/relaxed/resilient)
│   ├── sample.py      # Sampling with 5 multi-expert strategies
│   ├── gating.py      # Train a learned gating network for expert routing
│   ├── evaluate.py    # FID, Precision, and Recall metrics
│   ├── analyze.py     # Specialization analysis and visualization
│   └── utils.py       # Data loading, image grids, seeding, EMA
├── checkpoints/       # Saved model weights (gitignored)
└── outputs/           # Generated images, metrics, analysis plots (gitignored)
    ├── metrics.json
    ├── training_curves.png
    ├── expert_usage_training.png
    ├── metrics_comparison.png
    └── analysis/
        ├── expert_vs_digit.png
        ├── expert_vs_sigma.png
        ├── multi_expert_grid.png
        ├── trajectory.png
        └── strategy_comparison.png

Step-by-Step Reproduction

Each script auto-detects CUDA/MPS. Use --device cpu to force CPU.

1. Train the baseline diffusion model

python -m src.train --mode baseline --epochs 200 --base_ch 32 --ch_mult 1 2 \
    --time_dim 128 --lr 3e-4 --ema_decay 0.999 --out_dir checkpoints

2. Train MCL with K=4 experts

Choose a variant with --mcl_variant:

# Annealed WTA (recommended — prevents expert collapse)
python -m src.train --mode mcl --K 4 --mcl_variant annealed_wta --epochs 200 \
    --base_ch 32 --ch_mult 1 2 --time_dim 128 --lr 3e-4 --ema_decay 0.999

# Hard WTA (original, prone to expert collapse)
python -m src.train --mode mcl --K 4 --mcl_variant hard_wta --epochs 200 ...

# Relaxed WTA (partial gradient to losers, α=0.1)
python -m src.train --mode mcl --K 4 --mcl_variant relaxed_wta --relaxed_alpha 0.1 --epochs 200 ...

# Resilient MCL (learned scoring heads)
python -m src.train --mode mcl --K 4 --mcl_variant resilient_mcl --epochs 200 ...

3. Train the gating network

python -m src.gating --mcl_ckpt checkpoints/mcl_K4_final.pt \
    --epochs 25 --collect_batches 80

4. Generate samples

# Baseline (Euler and Heun solvers)
python -m src.sample --checkpoint checkpoints/baseline_final.pt \
    --mode baseline --solver euler --num_samples 2048
python -m src.sample --checkpoint checkpoints/baseline_final.pt \
    --mode baseline --solver heun --num_samples 2048

# MCL — single expert
python -m src.sample --checkpoint checkpoints/mcl_K4_final.pt \
    --mode mcl --strategy single_expert --expert_id 0 --num_samples 2048

# MCL — learned gating
python -m src.sample --checkpoint checkpoints/mcl_K4_final.pt \
    --mode mcl --strategy gated --gating_ckpt checkpoints/gating_K4.pt \
    --num_samples 2048

# Other strategies: random_expert, best_expert, mixture_score

5. Evaluate (FID, Precision, Recall)

python -m src.evaluate --samples_pt outputs/mcl_K4_gated_euler_n2048.pt

6. Analyze expert specialization

python -m src.analyze --mcl_ckpt checkpoints/mcl_K4_final.pt \
    --out_dir outputs/analysis

Method

Score-Based Diffusion (Baseline)

A U-Net (ScoreNet) is trained to predict the noise $\epsilon$ added to data via a variance-exploding forward process $x_t = x_0 + \sigma\epsilon$, with $\sigma$ sampled log-uniformly from $[\sigma_{\min}, \sigma_{\max}]$. The loss is denoising score matching (MSE on noise prediction). Sampling solves the probability-flow ODE $dx/d\sigma = \epsilon_\theta(x, \sigma)$ from $\sigma_{\max}$ to 0, discretized with Euler or Heun's method.

Multiple Choice Learning

$K=4$ independent expert networks are trained with a Winner-Takes-All rule. Four training variants are supported:

Variant Flag Description
Hard WTA hard_wta Only the winner gets gradients. Prone to expert collapse.
Annealed WTA annealed_wta Soft-to-hard annealing (τ: 10→0.01). All experts train early, competition sharpens gradually.
Relaxed WTA relaxed_wta Winner gets weight 1, losers get weight α (default 0.1).
Resilient MCL resilient_mcl Learned scoring heads predict expert competence, preventing dead experts.

With hard WTA, only 1 of 4 experts survives (hypothesis collapse). With annealed WTA, 3 of 4 experts learn meaningful score functions.

Inference Routing Strategies

Strategy Description
single_expert One fixed expert for the entire ODE trajectory
random_expert Uniformly random expert at each ODE step
best_expert All $K$ experts evaluated; smallest prediction norm wins
mixture_score Average predictions of all $K$ experts
gated Learned gating network selects expert per step and sample

Results

Quantitative evaluation on 2,048 generated samples (Annealed WTA, Euler solver, 200 ODE steps):

Strategy FID ↓ Precision ↑ Recall ↑
Baseline (Euler) 59.7 0.474 0.869
Baseline (Heun) 58.3 0.472 0.858
Single Expert 239.3 0.687 0.903
Random Expert 2467.6 0.000 0.000
Best Expert 2611.1 0.000 0.000
Mixture Score 2342.6 0.000 0.000
Learned Gating 171.8 0.692 0.896

Key findings:

  • Annealed WTA prevents expert collapse: 3 of 4 experts learn meaningful score functions (usage: 40%/31%/28%/0%), vs. only 1 with hard WTA.
  • Baseline achieves best FID (58.3 Heun), as expected for a single model without routing overhead.
  • Learned gating is the best MCL strategy (FID 171.8), approaching baseline quality while leveraging multi-expert specialization.
  • Single expert achieves highest recall (0.903), showing the dominant expert covers a broad portion of the manifold.
  • Per-step heuristic strategies fail catastrophically — switching experts mid-trajectory without coherence produces noise.

Hyperparameters (as used in run_pipeline.py)

Parameter Value Description
base_ch 32 Base channel count of the U-Net
ch_mult (1, 2) Channel multipliers per resolution level
time_dim 128 Time embedding dimension
dropout 0.05 Dropout rate in residual blocks
sigma_min 0.01 Minimum noise level
sigma_max 80.0 Maximum noise level
lr 3e-4 Learning rate (Adam)
ema_decay 0.999 EMA decay for inference weights
K 4 Number of MCL experts
mcl_variant annealed_wta MCL training variant
anneal_tau_max 10.0 Initial temperature (soft assignment)
anneal_tau_min 0.01 Final temperature (hard WTA)
epochs 200 Training epochs (baseline & MCL)
batch_size 256 Training batch size
num_steps 200 ODE integration steps at inference

Report

The full academic report is available as report.pdf (LaTeX source: report.tex). It covers the mathematical formulation, implementation details, expert collapse analysis with mitigation strategies, all quantitative and qualitative results, and a discussion on inter-class vs. intra-class diversity.

About

Mixture Score-Based Diffusion Models with Multiple Choice Learning (MNIST)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors