# Modular Diffusion Continual Learning Experiment

This notebook demonstrates the complete experiment pipeline using modularized functions.

In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import os, math, time, random
from pathlib import Path
import torch
from torch import optim
import matplotlib.pyplot as plt
import numpy as np

# Import the comprehensive experiment runner
from src.run_experiment import run_full_experiment

In [None]:
# Run the complete experiment for MNIST with normal model size
results_mnist = run_full_experiment(
    dataset='mnist', 
    model_size='normal', 
    n_epochs=200, 
    batch_size=128
)

In [None]:
# Run experiment for MNIST with small model (includes full Fisher analysis)
results_mnist_small = run_full_experiment(
    dataset='mnist', 
    model_size='small', 
    n_epochs=100,  # Fewer epochs for small model
    batch_size=128
)

In [None]:
# Run experiment for CIFAR with normal model size
results_cifar = run_full_experiment(
    dataset='cifar', 
    model_size='normal', 
    n_epochs=200, 
    batch_size=128
)

In [None]:
# Compare results across experiments
print("EXPERIMENT COMPARISON")
print("="*50)
print(f"MNIST Normal - Best FID: {min([v for k, v in results_mnist['results'].items() if k.endswith('_fid') and v is not None]):.3f}")
print(f"MNIST Small - Best FID: {min([v for k, v in results_mnist_small['results'].items() if k.endswith('_fid') and v is not None]):.3f}")
print(f"CIFAR Normal - Best FID: {min([v for k, v in results_cifar['results'].items() if k.endswith('_fid') and v is not None]):.3f}")

print(f"\nOptimal coefficients:")
print(f"MNIST Normal c*: {results_mnist['c_optimal']:.6f}")
print(f"MNIST Small c*: {results_mnist_small['c_optimal']:.6f}")
print(f"CIFAR Normal c*: {results_cifar['c_optimal']:.6f}")

In [None]:
# Plot Fisher error comparisons
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for i, (name, results) in enumerate([('MNIST Normal', results_mnist), 
                                   ('MNIST Small', results_mnist_small), 
                                   ('CIFAR Normal', results_cifar)]):
    error_analysis = results['error_analysis']
    t_levels = error_analysis['t_levels']
    
    axes[i].plot(t_levels, np.array(error_analysis['diag_errors'])*10000, label='Diagonal', marker='o')
    axes[i].plot(t_levels, np.array(error_analysis['rank1_errors'])*10000, label='Rank-1', marker='s')
    axes[i].plot(t_levels, np.array(error_analysis['rank1_optimal_errors'])*10000, label='Optimal Rank-1', marker='^')
    axes[i].set_yscale('log')
    axes[i].set_xlabel('Timestep Level')
    axes[i].set_ylabel('Error Norm (×10⁴)')
    axes[i].set_title(f'{name}\nFisher Approximation Errors')
    axes[i].legend()
    axes[i].grid(True)

plt.tight_layout()
plt.show()

In [None]:
# Individual component testing (if needed for debugging)
from src.experiment_runner import setup_experiment, load_datasets, create_models

# Setup experiment with one function call
device, ROOT = setup_experiment(seed=123)

Device: cpu


In [None]:
# Load all datasets with one function call
datasets = load_datasets(batch_size=128)
cl_mnist_train_loaders, cl_mnist_test_loaders = datasets['cl_mnist_train'], datasets['cl_mnist_test']
cl_cifar_train_loaders, cl_cifar_test_loaders = datasets['cl_cifar_train'], datasets['cl_cifar_test']
print("Available test loader keys:", cl_mnist_test_loaders.keys())

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to archive/data\cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:19<00:00, 8.70MB/s] 



Extracting archive/data\cifar-10-python.tar.gz to archive/data
Files already downloaded and verified
Files already downloaded and verified
Building DataLoaders for each class in train dataset...
Building DataLoaders for each class in train dataset...


100%|██████████| 50000/50000 [00:24<00:00, 2057.28it/s]
100%|██████████| 50000/50000 [00:24<00:00, 2057.28it/s]


Building DataLoaders for each class in MNIST test dataset...


100%|██████████| 10000/10000 [00:05<00:00, 1706.17it/s]



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to archive/data\MNIST\raw\train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to archive/data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:01<00:00, 8.04MB/s]



Extracting archive/data\MNIST\raw\train-images-idx3-ubyte.gz to archive/data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to archive/data\MNIST\raw\train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to archive/data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 3.66MB/s]

Extracting archive/data\MNIST\raw\train-labels-idx1-ubyte.gz to archive/data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to archive/data\MNIST\raw\t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to archive/data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 7.81MB/s]



Extracting archive/data\MNIST\raw\t10k-images-idx3-ubyte.gz to archive/data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to archive/data\MNIST\raw\t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to archive/data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<?, ?B/s]



Extracting archive/data\MNIST\raw\t10k-labels-idx1-ubyte.gz to archive/data\MNIST\raw

Building DataLoaders for each class in train dataset...


100%|██████████| 60000/60000 [00:31<00:00, 1880.56it/s]
100%|██████████| 60000/60000 [00:31<00:00, 1880.56it/s]


Building DataLoaders for each class in MNIST test dataset...


100%|██████████| 10000/10000 [00:04<00:00, 2057.76it/s]
100%|██████████| 10000/10000 [00:04<00:00, 2057.76it/s]


In [None]:
# Create models and optimizers with one function call
mnist_model, cifar_model, mnist_opt, cifar_opt = create_models(device, model_size='normal')
print(f"MNIST model parameters: {sum(p.numel() for p in mnist_model.parameters() if p.requires_grad)}")
print(f"CIFAR model parameters: {sum(p.numel() for p in cifar_model.parameters() if p.requires_grad)}")