# DEIS Sampling Method Testing

This notebook tests the **DEIS (Diffusion Exponential Integrator Sampler)** method using `DEISMultistepScheduler` from the diffusers library.

DEIS is a fast high-order solver for diffusion ODEs that can generate high-quality samples with fewer function evaluations.

## Hyperparameters tested:
- **Number of inference steps** (5, 10, 20, 50, 100)
- **Solver order** (1, 2, 3)
- **Beta schedule** (linear, scaled_linear, squaredcos_cap_v2)
- **Solver type** (logrho, midpoint, heun, bh1, bh2)
- **Lower order final** (True, False)

In [5]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm
import time

from ema_pytorch import EMA
from torchvision.utils import make_grid, save_image
from diffusers import DEISMultistepScheduler

from diffusion.ddpm import Unet, GaussianDiffusion

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cuda


## Load Pretrained Model

In [6]:
# Model configuration (should match training config)
IMAGE_SIZE = 28
CHANNELS = 3
TIMESTEPS = 1000

# Initialize model architecture
model = Unet(
    dim=64,
    dim_mults=(1, 2, 4),
    flash_attn=False,
    channels=CHANNELS
)

# Create diffusion wrapper (needed for loading checkpoint)
diffusion = GaussianDiffusion(
    model,
    image_size=IMAGE_SIZE,
    timesteps=TIMESTEPS,
    sampling_timesteps=250
)

In [7]:
# Load checkpoint
CKPT_PATH = "./ckpts/model-5.pt"  # Update this path to your checkpoint

ckpt = torch.load(CKPT_PATH, map_location=device, weights_only=False)
diffusion.load_state_dict(ckpt["model"])

# Setup EMA model
ema = EMA(diffusion, beta=0.995, update_every=10).to(device)
ema.load_state_dict(ckpt["ema"])

# Get the EMA model for inference
ema_model = ema.ema_model
ema_model.eval()

# Extract the UNet from the diffusion model
unet = ema_model.model
unet.eval()

print(f"Model loaded from: {CKPT_PATH}")

    Found GPU0 NVIDIA GeForce GTX 1080 Ti which is of cuda capability 6.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (7.0) - (12.0)
    
    Please install PyTorch with a following CUDA
    configurations:  12.6 following instructions at
    https://pytorch.org/get-started/locally/
    
    Found GPU1 NVIDIA GeForce GTX 1080 Ti which is of cuda capability 6.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (7.0) - (12.0)
    
    Found GPU2 NVIDIA GeForce GTX 1080 Ti which is of cuda capability 6.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (7.0) - (12.0)
    
NVIDIA GeForce GTX 1080 Ti with CUDA capability sm_61 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_70 sm_75 sm_80 sm_86 sm_90 sm_100 sm_120.
If you want to use the NVIDIA GeForce GTX 1080 Ti GPU with PyTorch, please check the instructions at 

AcceleratorError: CUDA error: out of memory
Search for `cudaErrorMemoryAllocation' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


## DEIS Sampling Implementation

We create a sampling function that uses the `DEISMultistepScheduler` with our trained UNet.

In [None]:
def create_deis_scheduler(
    num_train_timesteps: int = 1000,
    beta_schedule: str = "scaled_linear",
    beta_start: float = 0.0001,
    beta_end: float = 0.02,
    solver_order: int = 2,
    algorithm_type: str = "deis",
    solver_type: str = "logrho",
    lower_order_final: bool = True,
):
    """
    Create a DEIS scheduler with specified parameters.
    
    Args:
        num_train_timesteps: Number of training timesteps
        beta_schedule: Type of beta schedule ('linear', 'scaled_linear', 'squaredcos_cap_v2')
        beta_start: Starting beta value
        beta_end: Ending beta value
        solver_order: Order of the DEIS solver (1, 2, or 3)
        algorithm_type: Algorithm type ('deis' or 'dpmsolver++')
        solver_type: Solver type ('logrho', 'midpoint', 'heun', 'bh1', 'bh2')
        lower_order_final: Use lower order solver at final steps
    """
    scheduler = DEISMultistepScheduler(
        num_train_timesteps=num_train_timesteps,
        beta_schedule=beta_schedule,
        beta_start=beta_start,
        beta_end=beta_end,
        solver_order=solver_order,
        algorithm_type=algorithm_type,
        solver_type=solver_type,
        lower_order_final=lower_order_final,
        prediction_type="epsilon",  # Our model predicts noise
    )
    return scheduler

In [None]:
@torch.inference_mode()
def sample_with_deis(
    unet,
    scheduler,
    num_inference_steps: int = 20,
    batch_size: int = 16,
    image_size: int = 28,
    channels: int = 3,
    device: str = 'cuda',
    return_intermediates: bool = False,
):
    """
    Sample images using DEIS scheduler.
    
    Args:
        unet: The denoising UNet model
        scheduler: DEIS scheduler instance
        num_inference_steps: Number of denoising steps
        batch_size: Number of images to generate
        image_size: Size of generated images
        channels: Number of image channels
        device: Device to use
        return_intermediates: Whether to return intermediate samples
    
    Returns:
        samples: Generated images (batch_size, channels, image_size, image_size)
        intermediates: List of intermediate samples (if return_intermediates=True)
        elapsed_time: Time taken for sampling
    """
    # Set number of inference steps
    scheduler.set_timesteps(num_inference_steps, device=device)
    timesteps = scheduler.timesteps
    
    # Initialize with random noise
    latents = torch.randn(
        (batch_size, channels, image_size, image_size),
        device=device,
        dtype=torch.float32
    )
    
    # Scale initial noise by the scheduler's init_noise_sigma
    latents = latents * scheduler.init_noise_sigma
    
    intermediates = [latents.clone()] if return_intermediates else None
    
    start_time = time.time()
    
    # Denoising loop
    for t in tqdm(timesteps, desc="DEIS Sampling", leave=False):
        # Create batched timesteps
        t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
        
        # Predict noise
        noise_pred = unet(latents, t_batch)
        
        # DEIS step
        latents = scheduler.step(noise_pred, t, latents).prev_sample
        
        if return_intermediates:
            intermediates.append(latents.clone())
    
    elapsed_time = time.time() - start_time
    
    # Unnormalize from [-1, 1] to [0, 1]
    samples = (latents + 1) * 0.5
    samples = samples.clamp(0, 1)
    
    if return_intermediates:
        intermediates = [(x + 1) * 0.5 for x in intermediates]
        return samples, intermediates, elapsed_time
    
    return samples, elapsed_time

## Helper Functions for Visualization

In [None]:
def show_samples(samples, title="Generated Samples", nrow=4, figsize=(10, 10)):
    """Display a grid of generated samples."""
    grid = make_grid(samples, nrow=nrow, padding=2, normalize=False)
    grid = grid.cpu().permute(1, 2, 0).numpy()
    
    plt.figure(figsize=figsize)
    plt.imshow(grid)
    plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    plt.show()


def show_denoising_process(intermediates, num_steps_to_show=8, nrow=8, figsize=(16, 4)):
    """Visualize the denoising process."""
    n_intermediates = len(intermediates)
    indices = np.linspace(0, n_intermediates - 1, num_steps_to_show, dtype=int)
    
    selected = [intermediates[i][0:1] for i in indices]  # Take first sample
    selected = torch.cat(selected, dim=0)
    
    grid = make_grid(selected, nrow=nrow, padding=2, normalize=False)
    grid = grid.cpu().permute(1, 2, 0).numpy().clip(0, 1)
    
    plt.figure(figsize=figsize)
    plt.imshow(grid)
    plt.title("Denoising Process (Noise -> Image)")
    plt.axis('off')
    plt.tight_layout()
    plt.show()


def compare_results(results_dict, figsize=(16, 12)):
    """Compare samples from different configurations."""
    n_configs = len(results_dict)
    fig, axes = plt.subplots(n_configs, 1, figsize=figsize)
    
    if n_configs == 1:
        axes = [axes]
    
    for ax, (name, data) in zip(axes, results_dict.items()):
        samples = data['samples'][:8]  # Show first 8 samples
        grid = make_grid(samples, nrow=8, padding=2, normalize=False)
        grid = grid.cpu().permute(1, 2, 0).numpy()
        
        ax.imshow(grid)
        ax.set_title(f"{name} | Time: {data['time']:.2f}s")
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

---
## Experiment 1: Varying Number of Inference Steps

Test how sample quality changes with different numbers of denoising steps.

In [None]:
# Test different numbers of inference steps
step_counts = [5, 10, 20, 50, 100]

results_steps = {}

for num_steps in step_counts:
    print(f"\nSampling with {num_steps} steps...")
    
    scheduler = create_deis_scheduler(
        num_train_timesteps=TIMESTEPS,
        solver_order=2,
        beta_schedule="scaled_linear",
    )
    
    samples, elapsed = sample_with_deis(
        unet=unet,
        scheduler=scheduler,
        num_inference_steps=num_steps,
        batch_size=16,
        image_size=IMAGE_SIZE,
        channels=CHANNELS,
        device=device
    )
    
    results_steps[f"{num_steps} steps"] = {
        'samples': samples,
        'time': elapsed
    }
    print(f"  Time: {elapsed:.2f}s")

In [None]:
# Visualize results
compare_results(results_steps)

---
## Experiment 2: Varying Solver Order

DEIS supports solver orders 1, 2, and 3. Higher orders can be more accurate but may be less stable.

In [None]:
# Test different solver orders
solver_orders = [1, 2, 3]
NUM_STEPS = 20

results_orders = {}

for order in solver_orders:
    print(f"\nSampling with solver order {order}...")
    
    scheduler = create_deis_scheduler(
        num_train_timesteps=TIMESTEPS,
        solver_order=order,
        beta_schedule="scaled_linear",
    )
    
    samples, elapsed = sample_with_deis(
        unet=unet,
        scheduler=scheduler,
        num_inference_steps=NUM_STEPS,
        batch_size=16,
        image_size=IMAGE_SIZE,
        channels=CHANNELS,
        device=device
    )
    
    results_orders[f"Order {order}"] = {
        'samples': samples,
        'time': elapsed
    }
    print(f"  Time: {elapsed:.2f}s")

In [None]:
# Visualize results
compare_results(results_orders)

---
## Experiment 3: Varying Beta Schedule

Test different noise schedules: linear, scaled_linear, and squaredcos_cap_v2.

In [None]:
# Test different beta schedules
beta_schedules = ["linear", "scaled_linear", "squaredcos_cap_v2"]
NUM_STEPS = 20

results_beta = {}

for beta_sched in beta_schedules:
    print(f"\nSampling with beta_schedule='{beta_sched}'...")
    
    scheduler = create_deis_scheduler(
        num_train_timesteps=TIMESTEPS,
        solver_order=2,
        beta_schedule=beta_sched,
    )
    
    samples, elapsed = sample_with_deis(
        unet=unet,
        scheduler=scheduler,
        num_inference_steps=NUM_STEPS,
        batch_size=16,
        image_size=IMAGE_SIZE,
        channels=CHANNELS,
        device=device
    )
    
    results_beta[beta_sched] = {
        'samples': samples,
        'time': elapsed
    }
    print(f"  Time: {elapsed:.2f}s")

In [None]:
# Visualize results
compare_results(results_beta)

---
## Experiment 4: Varying Solver Type

DEIS supports different solver types for polynomial interpolation.

In [None]:
# Test different solver types
solver_types = ["logrho", "midpoint", "heun", "bh1", "bh2"]
NUM_STEPS = 20

results_solver_type = {}

for solver_type in solver_types:
    print(f"\nSampling with solver_type='{solver_type}'...")
    
    try:
        scheduler = create_deis_scheduler(
            num_train_timesteps=TIMESTEPS,
            solver_order=2,
            beta_schedule="scaled_linear",
            solver_type=solver_type,
        )
        
        samples, elapsed = sample_with_deis(
            unet=unet,
            scheduler=scheduler,
            num_inference_steps=NUM_STEPS,
            batch_size=16,
            image_size=IMAGE_SIZE,
            channels=CHANNELS,
            device=device
        )
        
        results_solver_type[solver_type] = {
            'samples': samples,
            'time': elapsed
        }
        print(f"  Time: {elapsed:.2f}s")
    except Exception as e:
        print(f"  Error: {e}")

In [None]:
# Visualize results
if results_solver_type:
    compare_results(results_solver_type)

---
## Experiment 5: Lower Order Final Steps

Test the effect of using lower order solver at final denoising steps.

In [None]:
# Test lower_order_final setting
lower_order_options = [True, False]
NUM_STEPS = 20

results_lower_order = {}

for lower_order in lower_order_options:
    print(f"\nSampling with lower_order_final={lower_order}...")
    
    scheduler = create_deis_scheduler(
        num_train_timesteps=TIMESTEPS,
        solver_order=3,
        beta_schedule="scaled_linear",
        lower_order_final=lower_order,
    )
    
    samples, elapsed = sample_with_deis(
        unet=unet,
        scheduler=scheduler,
        num_inference_steps=NUM_STEPS,
        batch_size=16,
        image_size=IMAGE_SIZE,
        channels=CHANNELS,
        device=device
    )
    
    results_lower_order[f"lower_order_final={lower_order}"] = {
        'samples': samples,
        'time': elapsed
    }
    print(f"  Time: {elapsed:.2f}s")

In [None]:
# Visualize results
compare_results(results_lower_order)

---
## Experiment 6: Visualize Denoising Process

In [None]:
# Generate samples with intermediate steps
scheduler = create_deis_scheduler(
    num_train_timesteps=TIMESTEPS,
    solver_order=2,
    beta_schedule="scaled_linear",
)

samples, intermediates, elapsed = sample_with_deis(
    unet=unet,
    scheduler=scheduler,
    num_inference_steps=20,
    batch_size=4,
    image_size=IMAGE_SIZE,
    channels=CHANNELS,
    device=device,
    return_intermediates=True
)

print(f"Sampling time: {elapsed:.2f}s")

In [None]:
# Visualize denoising process
show_denoising_process(intermediates, num_steps_to_show=10)

---
## Experiment 7: Comparison with Baseline DDIM

In [None]:
# Compare DEIS with baseline DDIM sampling from the original model
NUM_STEPS = 20

# DEIS sampling
print("Sampling with DEIS...")
scheduler_deis = create_deis_scheduler(
    num_train_timesteps=TIMESTEPS,
    solver_order=2,
    beta_schedule="scaled_linear",
)

deis_samples, deis_time = sample_with_deis(
    unet=unet,
    scheduler=scheduler_deis,
    num_inference_steps=NUM_STEPS,
    batch_size=16,
    image_size=IMAGE_SIZE,
    channels=CHANNELS,
    device=device
)
print(f"  DEIS time: {deis_time:.2f}s")

# Original DDIM sampling (from diffusion model)
print("Sampling with original DDIM...")
ema_model.sampling_timesteps = NUM_STEPS
start_time = time.time()
with torch.no_grad():
    ddim_samples = ema_model.sample(batch_size=16)
ddim_time = time.time() - start_time
print(f"  DDIM time: {ddim_time:.2f}s")

In [None]:
# Visualize comparison
comparison_results = {
    f"DEIS ({NUM_STEPS} steps)": {'samples': deis_samples, 'time': deis_time},
    f"DDIM ({NUM_STEPS} steps)": {'samples': ddim_samples, 'time': ddim_time},
}

compare_results(comparison_results)

---
## Experiment 8: Speed vs Quality Analysis

In [None]:
# Comprehensive benchmark
step_counts_benchmark = [5, 10, 15, 20, 30, 50]
orders_benchmark = [1, 2, 3]

benchmark_results = []

for order in orders_benchmark:
    for num_steps in step_counts_benchmark:
        scheduler = create_deis_scheduler(
            num_train_timesteps=TIMESTEPS,
            solver_order=order,
            beta_schedule="scaled_linear",
        )
        
        # Run multiple times for accurate timing
        times = []
        for _ in range(3):
            _, elapsed = sample_with_deis(
                unet=unet,
                scheduler=scheduler,
                num_inference_steps=num_steps,
                batch_size=16,
                image_size=IMAGE_SIZE,
                channels=CHANNELS,
                device=device
            )
            times.append(elapsed)
        
        avg_time = np.mean(times)
        benchmark_results.append({
            'order': order,
            'steps': num_steps,
            'time': avg_time
        })
        print(f"Order {order}, {num_steps} steps: {avg_time:.3f}s")

In [None]:
# Plot benchmark results
import pandas as pd

df = pd.DataFrame(benchmark_results)

fig, ax = plt.subplots(figsize=(10, 6))

for order in orders_benchmark:
    data = df[df['order'] == order]
    ax.plot(data['steps'], data['time'], marker='o', label=f'Order {order}')

ax.set_xlabel('Number of Steps')
ax.set_ylabel('Time (seconds)')
ax.set_title('DEIS Sampling Time vs Number of Steps')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

---
## Save Best Configuration Samples

In [None]:
# Generate high-quality samples with optimal configuration
OUTPUT_DIR = Path("./results/deis_samples")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Best configuration (adjust based on experiments)
scheduler = create_deis_scheduler(
    num_train_timesteps=TIMESTEPS,
    solver_order=2,
    beta_schedule="scaled_linear",
    solver_type="logrho",
    lower_order_final=True,
)

samples, elapsed = sample_with_deis(
    unet=unet,
    scheduler=scheduler,
    num_inference_steps=50,
    batch_size=64,
    image_size=IMAGE_SIZE,
    channels=CHANNELS,
    device=device
)

print(f"Generated 64 samples in {elapsed:.2f}s")

# Save grid
save_image(samples, OUTPUT_DIR / "deis_samples_grid.png", nrow=8)
print(f"Saved samples to {OUTPUT_DIR}")

# Show samples
show_samples(samples, title="DEIS Samples (Order 2, 50 steps)", nrow=8)

---
## Summary

Key findings from DEIS experiments:

1. **Number of Steps**: More steps generally produce better quality but take longer. DEIS can achieve good quality with 20-50 steps.

2. **Solver Order**: Higher order (2 or 3) usually gives better results for a fixed number of steps. Order 2 is a good balance.

3. **Beta Schedule**: The schedule should ideally match what was used during training. `scaled_linear` is a common choice.

4. **Solver Type**: `logrho` is the default and works well for most cases.

5. **Lower Order Final**: Using lower order at final steps can improve stability for high-order solvers.