In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

# Import from refactored modules
from sample import MNISTSampler, IsotropicGaussian
from scheduler import LinearAlpha, LinearBeta
from path import GaussianConditionalProbabilityPath, CFGVectorFieldODE
from sde import EulerSimulator
from Unet import MNISTUNet
from trainer import CFGTrainer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Initialize probability path
path = GaussianConditionalProbabilityPath(
    p_data = MNISTSampler(),
    p_simple_shape = [1, 32, 32],
    alpha = LinearAlpha(),
    beta = LinearBeta()
).to(device)

In [None]:
# Initialize model
unet = MNISTUNet(
    channels = [32, 64, 128],
    num_residual_layers = 2,
    t_embed_dim = 40,
    y_embed_dim = 40,
)

# Initialize trainer
trainer = CFGTrainer(path=path, model=unet, eta=0.1)

In [None]:
# Train!
trainer.train(num_epochs=500, device=device, lr=1e-3, batch_size=50)

In [None]:
# Play with these!
samples_per_class = 10
num_timesteps = 100
guidance_scales = [1.0, 3.0, 5.0]

# Graph
fig, axes = plt.subplots(1, len(guidance_scales), figsize=(10 * len(guidance_scales), 10))

for idx, w in enumerate(guidance_scales):
    # Setup ode and simulator
    ode = CFGVectorFieldODE(unet, guidance_scale=w)
    simulator = EulerSimulator(ode)

    # Sample initial conditions
    y = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=torch.int64).repeat_interleave(samples_per_class).to(device)
    num_samples = y.shape[0]
    x0, _ = path.p_simple.sample(num_samples)  # (num_samples, 1, 32, 32)

    # Simulate
    ts = torch.linspace(0, 1, num_timesteps).view(1, -1, 1, 1, 1).expand(num_samples, -1, 1, 1, 1).to(device)
    x1 = simulator.simulate(x0, ts, y=y)

    # Plot
    grid = make_grid(x1, nrow=samples_per_class, normalize=True, value_range=(-1, 1))
    axes[idx].imshow(grid.permute(1, 2, 0).cpu(), cmap="gray")
    axes[idx].axis("off")
    axes[idx].set_title(f"Guidance: $w={w:.1f}$", fontsize=25)
plt.show()

In [None]:
# Visualize samples from conditional probability path
num_rows = 3
num_cols = 3
num_timesteps_vis = 5

# Sample 
num_samples = num_rows * num_cols
z, _ = path.p_data.sample(num_samples)
z = z.view(-1, 1, 32, 32)

# Setup plot
fig, axes = plt.subplots(1, num_timesteps_vis, figsize=(6 * num_cols * num_timesteps_vis, 6 * num_rows))

# Sample from conditional probability paths and graph
ts = torch.linspace(0, 1, num_timesteps_vis).to(device)
for tidx, t in enumerate(ts):
    tt = t.view(1,1,1,1).expand(num_samples, 1, 1, 1)  # (num_samples, 1, 1, 1)
    xt = path.sample_conditional_path(z, tt)  # (num_samples, 1, 32, 32)
    grid = make_grid(xt, nrow=num_cols, normalize=True, value_range=(-1,1))
    axes[tidx].imshow(grid.permute(1, 2, 0).cpu(), cmap="gray")
    axes[tidx].axis("off")
plt.show()

# Using Diffusers UNet2DModel

Below we demonstrate using the `diffusers` library's UNet2DModel wrapped for our framework.

In [None]:
# Import the diffusers wrapper
from Unet import DiffusersUNet2DWrapperLite, DiffusersUNet2DWrapper

# Check if diffusers is available
try:
    from diffusers import UNet2DModel
    print("✓ diffusers library is available")
except ImportError:
    print("✗ diffusers not installed. Install with: pip install diffusers")

## Option 1: Lightweight Diffusers UNet (Pre-configured for MNIST 28x28)

This uses a lightweight configuration optimized for MNIST.

In [None]:
# Create a sampler for 28x28 images (diffusers UNet expects 28x28, not 32x32)
from torchvision import datasets, transforms

class MNISTSampler28(torch.nn.Module):
    """MNIST sampler for 28x28 images (for diffusers UNet)"""
    def __init__(self):
        super().__init__()
        self.dataset = datasets.MNIST(
            root='./data',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ])
        )
        self.dummy = torch.nn.Buffer(torch.zeros(1))
    
    def sample(self, num_samples: int):
        if num_samples > len(self.dataset):
            raise ValueError(f"num_samples exceeds dataset size")
        indices = torch.randperm(len(self.dataset))[:num_samples]
        samples, labels = zip(*[self.dataset[i] for i in indices])
        samples = torch.stack(samples).to(self.dummy.device)
        labels = torch.tensor(labels, dtype=torch.int64).to(self.dummy.device)
        return samples, labels

# Initialize probability path for 28x28 images
path_diffusers = GaussianConditionalProbabilityPath(
    p_data = MNISTSampler28(),
    p_simple_shape = [1, 28, 28],
    alpha = LinearAlpha(),
    beta = LinearBeta()
).to(device)

print("✓ Created path for 28x28 images")

In [None]:
# Initialize the lightweight diffusers UNet
unet_diffusers_lite = DiffusersUNet2DWrapperLite(num_class_embeds=11)

# Initialize trainer
trainer_diffusers_lite = CFGTrainer(path=path_diffusers, model=unet_diffusers_lite, eta=0.1)

print("✓ Created lightweight diffusers UNet model")

In [None]:
# Train the diffusers UNet!
trainer_diffusers_lite.train(num_epochs=500, device=device, lr=1e-3, batch_size=50)

In [None]:
# Sample and visualize results from diffusers UNet
samples_per_class = 10
num_timesteps = 100
guidance_scales = [1.0, 3.0, 5.0]

# Graph
fig, axes = plt.subplots(1, len(guidance_scales), figsize=(10 * len(guidance_scales), 10))

for idx, w in enumerate(guidance_scales):
    # Setup ode and simulator
    ode = CFGVectorFieldODE(unet_diffusers_lite, guidance_scale=w)
    simulator = EulerSimulator(ode)

    # Sample initial conditions
    y = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64).repeat_interleave(samples_per_class).to(device)
    num_samples = y.shape[0]
    x0, _ = path_diffusers.p_simple.sample(num_samples)  # (num_samples, 1, 28, 28)

    # Simulate
    ts = torch.linspace(0, 1, num_timesteps).view(1, -1, 1, 1, 1).expand(num_samples, -1, 1, 1, 1).to(device)
    x1 = simulator.simulate(x0, ts, y=y)

    # Plot
    grid = make_grid(x1, nrow=samples_per_class, normalize=True, value_range=(-1, 1))
    axes[idx].imshow(grid.permute(1, 2, 0).cpu(), cmap="gray")
    axes[idx].axis("off")
    axes[idx].set_title(f"Diffusers UNet - Guidance: $w={w:.1f}$", fontsize=25)
plt.show()

In [None]:
# Compare model sizes
from trainer import model_size_b

MiB = 1024 ** 2

print("Model Size Comparison:")
print("-" * 50)
print(f"Original MNISTUNet:           {model_size_b(unet) / MiB:.2f} MiB")
print(f"Diffusers UNet (Lite):        {model_size_b(unet_diffusers_lite) / MiB:.2f} MiB")
print("-" * 50)

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("\nParameter Count:")
print("-" * 50)
print(f"Original MNISTUNet:           {count_parameters(unet):,}")
print(f"Diffusers UNet (Lite):        {count_parameters(unet_diffusers_lite):,}")
print("-" * 50)