Implementation of a score-based diffusion model for MNIST image generation, extended with Multiple Choice Learning (MCL) that trains
pip install -r requirements.txtRequires Python 3.9+ and PyTorch 2.0+. MNIST is downloaded automatically on first run.
To run the full pipeline (train → gating → sample → evaluate → analyze) in a single command:
python run_pipeline.pyOr the in-memory variant:
python run_all.pyBoth produce checkpoints under checkpoints/, figures and metrics under outputs/ and outputs/analysis/. GPU (CUDA or MPS) is auto-detected with CPU fallback.
- Create a MesoNet account and request access to the project.
- Add your SSH key to MesoNet.
python -m venv venv
source venv/bin/activate
pip install -r requirements.txtnvidia-smi # verify GPU visibility
# Baseline
python -m src.train --mode baseline --epochs 200 --out_dir checkpoints --device cuda
# MCL with K=4 experts (annealed WTA)
python -m src.train --mode mcl --K 4 --mcl_variant annealed_wta --epochs 200 --device cuda├── run_pipeline.py # Subprocess-based full pipeline
├── run_all.py # Single-process pipeline (same stages, in-memory)
├── notebook.ipynb # Colab-oriented reproduction notebook
├── report.tex # LaTeX source for the academic report
├── report.pdf # Compiled report
├── requirements.txt # Dependencies
├── src/
│ ├── model.py # ScoreNet (U-Net), GatingNet, and ScoringHead architectures
│ ├── diffusion.py # Noise schedule, forward process, Euler & Heun ODE samplers
│ ├── train.py # Training: baseline and MCL (4 variants: hard/annealed/relaxed/resilient)
│ ├── sample.py # Sampling with 5 multi-expert strategies
│ ├── gating.py # Train a learned gating network for expert routing
│ ├── evaluate.py # FID, Precision, and Recall metrics
│ ├── analyze.py # Specialization analysis and visualization
│ └── utils.py # Data loading, image grids, seeding, EMA
├── checkpoints/ # Saved model weights (gitignored)
└── outputs/ # Generated images, metrics, analysis plots (gitignored)
├── metrics.json
├── training_curves.png
├── expert_usage_training.png
├── metrics_comparison.png
└── analysis/
├── expert_vs_digit.png
├── expert_vs_sigma.png
├── multi_expert_grid.png
├── trajectory.png
└── strategy_comparison.png
Each script auto-detects CUDA/MPS. Use --device cpu to force CPU.
python -m src.train --mode baseline --epochs 200 --base_ch 32 --ch_mult 1 2 \
--time_dim 128 --lr 3e-4 --ema_decay 0.999 --out_dir checkpointsChoose a variant with --mcl_variant:
# Annealed WTA (recommended — prevents expert collapse)
python -m src.train --mode mcl --K 4 --mcl_variant annealed_wta --epochs 200 \
--base_ch 32 --ch_mult 1 2 --time_dim 128 --lr 3e-4 --ema_decay 0.999
# Hard WTA (original, prone to expert collapse)
python -m src.train --mode mcl --K 4 --mcl_variant hard_wta --epochs 200 ...
# Relaxed WTA (partial gradient to losers, α=0.1)
python -m src.train --mode mcl --K 4 --mcl_variant relaxed_wta --relaxed_alpha 0.1 --epochs 200 ...
# Resilient MCL (learned scoring heads)
python -m src.train --mode mcl --K 4 --mcl_variant resilient_mcl --epochs 200 ...python -m src.gating --mcl_ckpt checkpoints/mcl_K4_final.pt \
--epochs 25 --collect_batches 80# Baseline (Euler and Heun solvers)
python -m src.sample --checkpoint checkpoints/baseline_final.pt \
--mode baseline --solver euler --num_samples 2048
python -m src.sample --checkpoint checkpoints/baseline_final.pt \
--mode baseline --solver heun --num_samples 2048
# MCL — single expert
python -m src.sample --checkpoint checkpoints/mcl_K4_final.pt \
--mode mcl --strategy single_expert --expert_id 0 --num_samples 2048
# MCL — learned gating
python -m src.sample --checkpoint checkpoints/mcl_K4_final.pt \
--mode mcl --strategy gated --gating_ckpt checkpoints/gating_K4.pt \
--num_samples 2048
# Other strategies: random_expert, best_expert, mixture_scorepython -m src.evaluate --samples_pt outputs/mcl_K4_gated_euler_n2048.ptpython -m src.analyze --mcl_ckpt checkpoints/mcl_K4_final.pt \
--out_dir outputs/analysisA U-Net (ScoreNet) is trained to predict the noise
| Variant | Flag | Description |
|---|---|---|
| Hard WTA | hard_wta |
Only the winner gets gradients. Prone to expert collapse. |
| Annealed WTA | annealed_wta |
Soft-to-hard annealing (τ: 10→0.01). All experts train early, competition sharpens gradually. |
| Relaxed WTA | relaxed_wta |
Winner gets weight 1, losers get weight α (default 0.1). |
| Resilient MCL | resilient_mcl |
Learned scoring heads predict expert competence, preventing dead experts. |
With hard WTA, only 1 of 4 experts survives (hypothesis collapse). With annealed WTA, 3 of 4 experts learn meaningful score functions.
| Strategy | Description |
|---|---|
single_expert |
One fixed expert for the entire ODE trajectory |
random_expert |
Uniformly random expert at each ODE step |
best_expert |
All |
mixture_score |
Average predictions of all |
gated |
Learned gating network selects expert per step and sample |
Quantitative evaluation on 2,048 generated samples (Annealed WTA, Euler solver, 200 ODE steps):
| Strategy | FID ↓ | Precision ↑ | Recall ↑ |
|---|---|---|---|
| Baseline (Euler) | 59.7 | 0.474 | 0.869 |
| Baseline (Heun) | 58.3 | 0.472 | 0.858 |
| Single Expert | 239.3 | 0.687 | 0.903 |
| Random Expert | 2467.6 | 0.000 | 0.000 |
| Best Expert | 2611.1 | 0.000 | 0.000 |
| Mixture Score | 2342.6 | 0.000 | 0.000 |
| Learned Gating | 171.8 | 0.692 | 0.896 |
Key findings:
- Annealed WTA prevents expert collapse: 3 of 4 experts learn meaningful score functions (usage: 40%/31%/28%/0%), vs. only 1 with hard WTA.
- Baseline achieves best FID (58.3 Heun), as expected for a single model without routing overhead.
- Learned gating is the best MCL strategy (FID 171.8), approaching baseline quality while leveraging multi-expert specialization.
- Single expert achieves highest recall (0.903), showing the dominant expert covers a broad portion of the manifold.
- Per-step heuristic strategies fail catastrophically — switching experts mid-trajectory without coherence produces noise.
| Parameter | Value | Description |
|---|---|---|
base_ch |
32 | Base channel count of the U-Net |
ch_mult |
(1, 2) | Channel multipliers per resolution level |
time_dim |
128 | Time embedding dimension |
dropout |
0.05 | Dropout rate in residual blocks |
sigma_min |
0.01 | Minimum noise level |
sigma_max |
80.0 | Maximum noise level |
lr |
3e-4 | Learning rate (Adam) |
ema_decay |
0.999 | EMA decay for inference weights |
K |
4 | Number of MCL experts |
mcl_variant |
annealed_wta |
MCL training variant |
anneal_tau_max |
10.0 | Initial temperature (soft assignment) |
anneal_tau_min |
0.01 | Final temperature (hard WTA) |
epochs |
200 | Training epochs (baseline & MCL) |
batch_size |
256 | Training batch size |
num_steps |
200 | ODE integration steps at inference |
The full academic report is available as report.pdf (LaTeX source: report.tex). It covers the mathematical formulation, implementation details, expert collapse analysis with mitigation strategies, all quantitative and qualitative results, and a discussion on inter-class vs. intra-class diversity.