# E²-CRF Caching Ablation Study and Results

This notebook performs ablation studies on E²-CRF caching and tests the new method's performance.

## Note: Generating results.yaml

The `results.yaml` file in each `lightning_logs/{model_id}/` folder is generated by running the sampling script:

```bash
python cmd/sample.py model_id=XYZ num_samples=10000 num_diffusion_steps=1000
```

This script:
1. Loads the trained model from checkpoint
2. Generates samples using the diffusion process
3. Computes metrics (Wasserstein distances, etc.) comparing generated samples with training data
4. Saves results to `lightning_logs/{model_id}/results.yaml`
5. Saves samples to `lightning_logs/{model_id}/samples.pt`

The `results.yaml` contains:
- `time_sliced_wasserstein_*`: Sliced Wasserstein distances in time domain
- `freq_sliced_wasserstein_*`: Sliced Wasserstein distances in frequency domain
- `time_marginal_wasserstein_*`: Marginal Wasserstein distances in time domain
- `freq_marginal_wasserstein_*`: Marginal Wasserstein distances in frequency domain
- `spectral_marginal_wasserstein_*`: Spectral density Wasserstein distances
- Baseline metrics (self, dummy) for comparison


In [9]:
import time
from pathlib import Path
from typing import Optional, Dict, List
import warnings

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scienceplots
from hydra import compose, initialize
from hydra.utils import instantiate
from omegaconf import OmegaConf

from fdiff.models.score_models import ScoreModule
from fdiff.sampling.sampler import DiffusionSampler
from fdiff.utils.extraction import get_best_checkpoint

plt.style.use("science")
warnings.filterwarnings("ignore")


In [10]:
# Configuration
runs_dir = Path.cwd() / "../lightning_logs"
save_dir = Path.cwd() / "../outputs"

# Create output directories
(save_dir / "figures").mkdir(parents=True, exist_ok=True)
(save_dir / "tables").mkdir(parents=True, exist_ok=True)

# Model ID - change this to your trained model
model_id = "latest"  # or specify a specific model ID like "03wb0ssr"

# Benchmark parameters
num_samples = 20
num_diffusion_steps = 100


In [11]:
def load_model(model_id: str) -> ScoreModule:
    """Load a trained model from checkpoint."""
    log_dir = Path("../lightning_logs")
    
    if model_id == "latest":
        checkpoints = list(log_dir.glob("*/checkpoints/*.ckpt"))
        if not checkpoints:
            raise ValueError("No checkpoints found")
        checkpoint_path = max(checkpoints, key=lambda p: p.stat().st_mtime)
        model_id = checkpoint_path.parent.parent.name
        print(f"Using latest model: {model_id}")
    else:
        checkpoint_files = list((log_dir / model_id / "checkpoints").glob("*.ckpt"))
        if not checkpoint_files:
            raise ValueError(f"No checkpoint found for model_id: {model_id}")
        checkpoint_path = checkpoint_files[0]
    
    # Load model
    score_model = ScoreModule.load_from_checkpoint(
        str(checkpoint_path),
        weights_only=False
    )
    score_model.eval()
    
    # Move to device
    if torch.cuda.is_available():
        score_model = score_model.cuda()
        device = "cuda"
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        score_model = score_model.to('mps')
        device = "mps"
    else:
        device = "cpu"
    
    print(f"Model loaded on device: {device}")
    return score_model, model_id


In [12]:
def benchmark_sampling(
    score_model: ScoreModule,
    num_samples: int = 10,
    num_diffusion_steps: int = 100,
    use_cache: bool = False,
    cache_kwargs: Optional[dict] = None,
    config_name: str = "baseline",
) -> Dict:
    """Benchmark sampling with or without caching.
    
    Args:
        score_model: The score model to use
        num_samples: Number of samples to generate
        num_diffusion_steps: Number of diffusion steps
        use_cache: Whether to use caching
        cache_kwargs: Optional cache configuration
        config_name: Name of the configuration
        
    Returns:
        Dictionary with timing and statistics
    """
    sampler = DiffusionSampler(
        score_model=score_model,
        sample_batch_size=1,
        use_cache=use_cache,
        cache_kwargs=cache_kwargs or {},
    )
    
    # Reset cache before benchmarking
    if use_cache and score_model.cache is not None:
        score_model.cache.reset()
    
    # Warmup
    _ = sampler.sample(num_samples=1, num_diffusion_steps=10)
    
    # Reset cache again after warmup
    if use_cache and score_model.cache is not None:
        score_model.cache.reset()
    
    # Benchmark
    start_time = time.time()
    samples = sampler.sample(
        num_samples=num_samples,
        num_diffusion_steps=num_diffusion_steps
    )
    end_time = time.time()
    
    elapsed_time = end_time - start_time
    
    # Get cache statistics if available
    cache_stats = {}
    if use_cache and score_model.cache is not None:
        cache_stats = score_model.cache.get_cache_stats()
    
    return {
        "config_name": config_name,
        "elapsed_time": elapsed_time,
        "samples": samples,
        "cache_stats": cache_stats,
        "num_samples": num_samples,
        "num_diffusion_steps": num_diffusion_steps,
    }


## Load Model


In [13]:
score_model, actual_model_id = load_model(model_id)
print(f"Loaded model: {actual_model_id}")


Using latest model: slxribxk
Model loaded on device: mps
Loaded model: slxribxk


## Ablation Study: Different Caching Configurations


In [14]:
results = []

print("=" * 80)
print("E²-CRF Caching Ablation Study")
print("=" * 80)

# 1. Baseline: No caching
print("\n1. Baseline (no caching)...")
result = benchmark_sampling(
    score_model=score_model,
    num_samples=num_samples,
    num_diffusion_steps=num_diffusion_steps,
    use_cache=False,
    config_name="Baseline (No Cache)",
)
results.append(result)
baseline_time = result["elapsed_time"]
print(f"   Time: {baseline_time:.2f}s")

# 2. E2-CRF: Full method (default settings)
print("\n2. E²-CRF (full method, default settings)...")
result = benchmark_sampling(
    score_model=score_model,
    num_samples=num_samples,
    num_diffusion_steps=num_diffusion_steps,
    use_cache=True,
    cache_kwargs={},  # Default: K=5, R=10
    config_name="E²-CRF (Default)",
)
results.append(result)
speedup = baseline_time / result["elapsed_time"]
cache_hit = result["cache_stats"].get("cache_hit_ratio", 0.0)
print(f"   Time: {result['elapsed_time']:.2f}s")
print(f"   Speedup: {speedup:.2f}x")
print(f"   Cache Hit Ratio: {cache_hit:.1%}")


E²-CRF Caching Ablation Study

1. Baseline (no caching)...


                                                            4batch/s]

   Time: 17.70s

2. E²-CRF (full method, default settings)...


                                                            8batch/s]

   Time: 15.78s
   Speedup: 1.12x
   Cache Hit Ratio: 99.9%




In [None]:
# 3. Ablation: Varying K (low-frequency tokens)
print("\n3. Ablation: Varying K (low-frequency tokens)...")
for K in [0, 3, 5, 10, 15]:
    print(f"   K={K}: ", end="", flush=True)
    result = benchmark_sampling(
        score_model=score_model,
        num_samples=num_samples,
        num_diffusion_steps=num_diffusion_steps,
        use_cache=True,
        cache_kwargs={"K": K},
        config_name=f"E²-CRF (K={K})",
    )
    results.append(result)
    speedup = baseline_time / result["elapsed_time"]
    cache_hit = result["cache_stats"].get("cache_hit_ratio", 0.0)
    print(f"Time: {result['elapsed_time']:.2f}s, Speedup: {speedup:.2f}x, Hit: {cache_hit:.1%}")



3. Ablation: Varying K (low-frequency tokens)...
   K=0: 

                                                            8batch/s]

Time: 15.69s, Speedup: 1.13x, Hit: 0.0%
   K=3: 

Sampling:  40%|[34m████      [0m| 8/20 [00:06<00:09,  1.26batch/s]

In [None]:
# 4. Ablation: Varying R (refresh interval)
print("\n4. Ablation: Varying R (refresh interval)...")
for R in [5, 10, 20, 50, 100, 500]:
    print(f"   R={R}: ", end="", flush=True)
    result = benchmark_sampling(
        score_model=score_model,
        num_samples=num_samples,
        num_diffusion_steps=num_diffusion_steps,
        use_cache=True,
        cache_kwargs={"R": R},
        config_name=f"E²-CRF (R={R})",
    )
    results.append(result)
    speedup = baseline_time / result["elapsed_time"]
    cache_hit = result["cache_stats"].get("cache_hit_ratio", 0.0)
    print(f"Time: {result['elapsed_time']:.2f}s, Speedup: {speedup:.2f}x, Hit: {cache_hit:.1%}")



4. Ablation: Varying R (refresh interval)...
   R=5: 

                                                            9batch/s]

Time: 15.60s, Speedup: 1.14x, Hit: 0.0%
   R=10: 

                                                            9batch/s]

Time: 15.62s, Speedup: 1.14x, Hit: 0.0%
   R=20: 

                                                            8batch/s]

Time: 15.63s, Speedup: 1.14x, Hit: 0.0%
   R=50: 

                                                            8batch/s]

Time: 15.79s, Speedup: 1.13x, Hit: 0.0%
   R=100: 

                                                            5batch/s]

Time: 16.05s, Speedup: 1.11x, Hit: 0.0%
   R=500: 

                                                            9batch/s]

Time: 15.50s, Speedup: 1.15x, Hit: 0.0%




In [None]:
# 5. Ablation: Pure cache mode (no recomputation except step 0)
print("\n5. Ablation: Pure cache mode (R=inf)...")
result = benchmark_sampling(
    score_model=score_model,
    num_samples=num_samples,
    num_diffusion_steps=num_diffusion_steps,
    use_cache=True,
    cache_kwargs={"R": 999999},  # Effectively no refresh
    config_name="E²-CRF (Pure Cache)",
)
results.append(result)
speedup = baseline_time / result["elapsed_time"]
cache_hit = result["cache_stats"].get("cache_hit_ratio", 0.0)
print(f"   Time: {result['elapsed_time']:.2f}s")
print(f"   Speedup: {speedup:.2f}x")
print(f"   Cache Hit Ratio: {cache_hit:.1%}")


## Results Summary


In [None]:
# Create summary DataFrame
summary_data = []
for r in results:
    speedup = baseline_time / r["elapsed_time"]
    cache_stats = r.get("cache_stats", {})
    summary_data.append({
        "Configuration": r["config_name"],
        "Time (s)": r["elapsed_time"],
        "Speedup": speedup,
        "Time per Sample (s)": r["elapsed_time"] / r["num_samples"],
        "Time per Step (s)": r["elapsed_time"] / (r["num_samples"] * r["num_diffusion_steps"]),
        "Cache Hit Ratio": cache_stats.get("cache_hit_ratio", 0.0),
        "Cache Ratio": cache_stats.get("cache_ratio", 0.0),
        "Recompute Count": cache_stats.get("recompute_count", 0),
        "Cache Hit Count": cache_stats.get("cache_hit_count", 0),
    })

df_summary = pd.DataFrame(summary_data)
df_summary = df_summary.round(3)
print("\n" + "=" * 80)
print("Ablation Study Results Summary")
print("=" * 80)
print(df_summary.to_string(index=False))


## Visualizations


In [None]:
# Plot 1: Speedup comparison
fig, ax = plt.subplots(figsize=(12, 6))
df_plot = df_summary[df_summary["Configuration"] != "Baseline (No Cache)"]
df_plot = df_plot.sort_values("Speedup", ascending=False)
ax.barh(df_plot["Configuration"], df_plot["Speedup"])
ax.axvline(x=1.0, color='r', linestyle='--', label='Baseline (1.0x)')
ax.set_xlabel("Speedup (×)", fontsize=12)
ax.set_title("E²-CRF Caching Speedup Comparison", fontsize=14)
ax.legend()
ax.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.savefig(save_dir / "figures/ablation_speedup.pdf")
plt.show()


In [None]:
# Plot 2: Cache hit ratio vs Speedup
fig, ax = plt.subplots(figsize=(10, 6))
df_plot = df_summary[df_summary["Cache Hit Ratio"] > 0]
ax.scatter(df_plot["Cache Hit Ratio"], df_plot["Speedup"], s=100, alpha=0.6)
for idx, row in df_plot.iterrows():
    ax.annotate(row["Configuration"], 
                (row["Cache Hit Ratio"], row["Speedup"]),
                fontsize=8, alpha=0.7)
ax.set_xlabel("Cache Hit Ratio", fontsize=12)
ax.set_ylabel("Speedup (×)", fontsize=12)
ax.set_title("Cache Hit Ratio vs Speedup", fontsize=14)
ax.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(save_dir / "figures/ablation_cache_hit_vs_speedup.pdf")
plt.show()


In [None]:
# Plot 3: K parameter sensitivity
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

df_k = df_summary[df_summary["Configuration"].str.contains("K=")]
df_k["K"] = df_k["Configuration"].str.extract(r'K=(\d+)').astype(float)
df_k = df_k.sort_values("K")

ax1.plot(df_k["K"], df_k["Speedup"], marker='o', linewidth=2, markersize=8)
ax1.set_xlabel("K (Low-frequency tokens)", fontsize=12)
ax1.set_ylabel("Speedup (×)", fontsize=12)
ax1.set_title("Speedup vs K Parameter", fontsize=14)
ax1.grid(alpha=0.3)

ax2.plot(df_k["K"], df_k["Cache Hit Ratio"], marker='s', linewidth=2, markersize=8, color='orange')
ax2.set_xlabel("K (Low-frequency tokens)", fontsize=12)
ax2.set_ylabel("Cache Hit Ratio", fontsize=12)
ax2.set_title("Cache Hit Ratio vs K Parameter", fontsize=14)
ax2.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(save_dir / "figures/ablation_k_sensitivity.pdf")
plt.show()


In [None]:
# Plot 4: R parameter sensitivity
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

df_r = df_summary[df_summary["Configuration"].str.contains("R=")]
df_r["R"] = df_r["Configuration"].str.extract(r'R=(\d+)').astype(float)
df_r = df_r.sort_values("R")

ax1.plot(df_r["R"], df_r["Speedup"], marker='o', linewidth=2, markersize=8)
ax1.set_xlabel("R (Refresh interval)", fontsize=12)
ax1.set_ylabel("Speedup (×)", fontsize=12)
ax1.set_title("Speedup vs R Parameter", fontsize=14)
ax1.set_xscale('log')
ax1.grid(alpha=0.3)

ax2.plot(df_r["R"], df_r["Cache Hit Ratio"], marker='s', linewidth=2, markersize=8, color='orange')
ax2.set_xlabel("R (Refresh interval)", fontsize=12)
ax2.set_ylabel("Cache Hit Ratio", fontsize=12)
ax2.set_title("Cache Hit Ratio vs R Parameter", fontsize=14)
ax2.set_xscale('log')
ax2.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(save_dir / "figures/ablation_r_sensitivity.pdf")
plt.show()


## New Method (E²-CRF) vs Baseline Comparison


In [None]:
# Compare E²-CRF (default) with baseline
baseline_result = results[0]  # Baseline
e2crf_result = results[1]  # E²-CRF default

comparison = {
    "Metric": [
        "Total Time (s)",
        "Time per Sample (s)",
        "Time per Step (ms)",
        "Speedup",
        "Cache Hit Ratio",
        "Cache Ratio",
    ],
    "Baseline": [
        baseline_result["elapsed_time"],
        baseline_result["elapsed_time"] / baseline_result["num_samples"],
        baseline_result["elapsed_time"] / (baseline_result["num_samples"] * baseline_result["num_diffusion_steps"]) * 1000,
        1.0,
        0.0,
        0.0,
    ],
    "E²-CRF": [
        e2crf_result["elapsed_time"],
        e2crf_result["elapsed_time"] / e2crf_result["num_samples"],
        e2crf_result["elapsed_time"] / (e2crf_result["num_samples"] * e2crf_result["num_diffusion_steps"]) * 1000,
        baseline_result["elapsed_time"] / e2crf_result["elapsed_time"],
        e2crf_result["cache_stats"].get("cache_hit_ratio", 0.0),
        e2crf_result["cache_stats"].get("cache_ratio", 0.0),
    ],
}

df_comparison = pd.DataFrame(comparison)
df_comparison["Improvement"] = df_comparison.apply(
    lambda row: f"{((row['Baseline'] - row['E²-CRF']) / row['Baseline'] * 100):.1f}%" if row['Metric'] != 'Speedup' else f"{row['E²-CRF']:.2f}x",
    axis=1
)

print("\n" + "=" * 80)
print("E²-CRF vs Baseline Comparison")
print("=" * 80)
print(df_comparison.to_string(index=False))


In [None]:
# Visualize comparison
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Time comparison
ax = axes[0, 0]
metrics_to_plot = ["Total Time (s)", "Time per Sample (s)", "Time per Step (ms)"]
baseline_vals = [df_comparison[df_comparison["Metric"] == m]["Baseline"].values[0] for m in metrics_to_plot]
e2crf_vals = [df_comparison[df_comparison["Metric"] == m]["E²-CRF"].values[0] for m in metrics_to_plot]
x = np.arange(len(metrics_to_plot))
width = 0.35
ax.bar(x - width/2, baseline_vals, width, label='Baseline', alpha=0.8)
ax.bar(x + width/2, e2crf_vals, width, label='E²-CRF', alpha=0.8)
ax.set_ylabel('Time', fontsize=12)
ax.set_title('Time Comparison', fontsize=14)
ax.set_xticks(x)
ax.set_xticklabels(metrics_to_plot, rotation=45, ha='right')
ax.legend()
ax.grid(axis='y', alpha=0.3)

# Speedup
ax = axes[0, 1]
speedup_val = df_comparison[df_comparison["Metric"] == "Speedup"]["E²-CRF"].values[0]
ax.bar(["E²-CRF"], [speedup_val], alpha=0.8, color='green')
ax.axhline(y=1.0, color='r', linestyle='--', label='Baseline (1.0x)')
ax.set_ylabel('Speedup (×)', fontsize=12)
ax.set_title('Speedup Achieved', fontsize=14)
ax.legend()
ax.grid(axis='y', alpha=0.3)

# Cache statistics
ax = axes[1, 0]
cache_metrics = ["Cache Hit Ratio", "Cache Ratio"]
cache_vals = [df_comparison[df_comparison["Metric"] == m]["E²-CRF"].values[0] for m in cache_metrics]
ax.bar(cache_metrics, cache_vals, alpha=0.8, color='orange')
ax.set_ylabel('Ratio', fontsize=12)
ax.set_title('Cache Statistics', fontsize=14)
ax.set_ylim([0, 1])
ax.grid(axis='y', alpha=0.3)

# Time reduction percentage
ax = axes[1, 1]
time_reduction = (baseline_result["elapsed_time"] - e2crf_result["elapsed_time"]) / baseline_result["elapsed_time"] * 100
ax.bar(["Time Reduction"], [time_reduction], alpha=0.8, color='blue')
ax.set_ylabel('Reduction (%)', fontsize=12)
ax.set_title('Time Reduction', fontsize=14)
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig(save_dir / "figures/e2crf_vs_baseline.pdf")
plt.show()


In [None]:
# Save results
df_summary.to_csv(save_dir / "tables/ablation_results.csv", index=False)
print(f"Results saved to {save_dir / 'tables/ablation_results.csv'}")

# Save comparison table
df_comparison.to_csv(save_dir / "tables/e2crf_vs_baseline.csv", index=False)
print(f"Comparison saved to {save_dir / 'tables/e2crf_vs_baseline.csv'}")

# Save LaTeX tables
latex_table = df_summary.to_latex(index=False, float_format="%.3f")
with open(save_dir / "tables/ablation_results.tex", "w") as f:
    f.write(latex_table)
print(f"LaTeX tables saved")


## Generate results.yaml for a Model (Optional)

If you need to generate `results.yaml` for a model that doesn't have it yet, you can run the sampling script. Here's a helper function to do it programmatically:


In [None]:
def generate_results_yaml(
    model_id: str,
    num_samples: int = 10000,
    num_diffusion_steps: int = 1000,
    use_cache: bool = False,
    cache_kwargs: Optional[dict] = None,
    random_seed: int = 42,
) -> dict:
    """Generate results.yaml for a trained model by running sampling and computing metrics.
    
    This replicates what `cmd/sample.py` does, but can be called from notebook.
    
    Args:
        model_id: Model ID (folder name in lightning_logs)
        num_samples: Number of samples to generate
        num_diffusion_steps: Number of diffusion steps
        use_cache: Whether to use caching
        cache_kwargs: Optional cache configuration
        random_seed: Random seed for reproducibility
        
    Returns:
        Dictionary containing the computed metrics
    """
    from hydra.utils import instantiate
    from omegaconf import OmegaConf
    from fdiff.dataloaders.datamodules import Datamodule
    from fdiff.models.score_models import ScoreModule
    from fdiff.sampling.metrics import MetricCollection, SlicedWasserstein, MarginalWasserstein
    from fdiff.sampling.sampler import DiffusionSampler
    from fdiff.utils.extraction import get_best_checkpoint, get_model_type
    from fdiff.utils.fourier import idft
    from functools import partial
    import yaml
    
    log_dir = Path("../lightning_logs")
    save_dir = log_dir / model_id
    
    if not save_dir.exists():
        raise ValueError(f"Model directory not found: {save_dir}")
    
    # Load training config
    train_cfg = OmegaConf.load(save_dir / "train_config.yaml")
    
    # Instantiate datamodule
    datamodule: Datamodule = instantiate(train_cfg.datamodule)
    datamodule.prepare_data()
    datamodule.setup(stage="fit")  # Use stage="fit" for training data
    
    # Load model
    best_checkpoint_path = get_best_checkpoint(save_dir / "checkpoints")
    # Convert to dict for get_model_type (it accepts DictConfig | dict)
    train_cfg_dict = OmegaConf.to_container(train_cfg, resolve=True)
    assert isinstance(train_cfg_dict, dict)
    model_type = get_model_type(train_cfg_dict)
    score_model = model_type.load_from_checkpoint(
        checkpoint_path=best_checkpoint_path,
        weights_only=False
    )
    score_model.eval()
    
    if torch.cuda.is_available():
        score_model = score_model.cuda()
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        score_model = score_model.to('mps')
    
    # Instantiate sampler
    sampler = DiffusionSampler(
        score_model=score_model,
        sample_batch_size=1,
        use_cache=use_cache,
        cache_kwargs=cache_kwargs or {},
    )
    
    # Instantiate metrics (use default metrics if not in config)
    metrics_list = [
        partial(SlicedWasserstein, random_seed=random_seed, num_directions=1000, save_all_distances=True),
        partial(MarginalWasserstein, random_seed=random_seed, save_all_distances=True),
    ]
    
    metrics = MetricCollection(
        metrics=metrics_list,
        original_samples=datamodule.X_train,
        include_baselines=True,
        include_spectral_density=True,
    )
    
    # Generate samples
    print(f"Generating {num_samples} samples with {num_diffusion_steps} steps...")
    X = sampler.sample(
        num_samples=num_samples,
        num_diffusion_steps=num_diffusion_steps
    )
    
    # Map to original scale if standardized
    if datamodule.standardize:
        feature_mean, feature_std = datamodule.feature_mean_and_std
        X = X * feature_std + feature_mean
    
    # Convert to time domain if frequency domain
    if datamodule.fourier_transform:
        X = idft(X)
    
    # Compute metrics
    print("Computing metrics...")
    results = metrics(X)
    
    # Save results
    print(f"Saving results to {save_dir / 'results.yaml'}...")
    with open(save_dir / "results.yaml", "w") as f:
        yaml.dump(data=results, stream=f)
    torch.save(X, save_dir / "samples.pt")
    
    print("Done! results.yaml and samples.pt have been saved.")
    return results

# Example usage (uncomment to run):
# results = generate_results_yaml(
#     model_id="03wb0ssr",
#     num_samples=10000,
#     num_diffusion_steps=1000,
#     use_cache=False,  # Set to True to test with caching
# )


## Training with Diffusion Method Comparison

This section shows how to train a model for 1 epoch and compare different diffusion methods during training.


In [None]:
# To train with diffusion method comparison, run:
# python cmd/train_diffusion_comparison.py trainer.max_epochs=1

# Or use the config file:
# python cmd/train_diffusion_comparison.py

# This will:
# 1. Train for 1 epoch
# 2. Compare different diffusion methods at the end of the epoch
# 3. Save results to lightning_logs/{run_id}/diffusion_comparison_results.csv

print("To train with comparison, run:")
print("  python cmd/train_diffusion_comparison.py")
print("\nOr override epochs:")
print("  python cmd/train_diffusion_comparison.py trainer.max_epochs=1")


## Load and Visualize Training Comparison Results


In [None]:
def load_training_comparison_results(model_id: str) -> pd.DataFrame:
    """Load diffusion comparison results from training.
    
    Args:
        model_id: Model ID (folder name in lightning_logs)
        
    Returns:
        DataFrame with comparison results
    """
    log_dir = Path("../lightning_logs")
    results_file = log_dir / model_id / "diffusion_comparison_results.csv"
    
    if not results_file.exists():
        print(f"Results file not found: {results_file}")
        print("Please run training with comparison first:")
        print("  python cmd/train_diffusion_comparison.py")
        return pd.DataFrame()
    
    df = pd.read_csv(results_file)
    return df

# Example: Load results for a model
# comparison_df = load_training_comparison_results("your_model_id")
# if not comparison_df.empty:
#     print(comparison_df)


In [None]:
def visualize_training_comparison(df: pd.DataFrame, save_dir: Path) -> None:
    """Visualize training comparison results.
    
    Args:
        df: DataFrame with comparison results
        save_dir: Directory to save figures
    """
    if df.empty:
        return
    
    # Plot 1: Time comparison by method
    fig, ax = plt.subplots(figsize=(12, 6))
    methods = df['method'].unique()
    times = [df[df['method'] == m]['time'].mean() for m in methods]
    ax.bar(methods, times, alpha=0.8)
    ax.set_ylabel('Time (seconds)', fontsize=12)
    ax.set_title('Diffusion Method Time Comparison (Training)', fontsize=14)
    ax.set_xticklabels(methods, rotation=45, ha='right')
    ax.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_dir / "figures/training_comparison_time.pdf")
    plt.show()
    
    # Plot 2: Cache hit ratio for cached methods
    cached_methods = df[df['cache_hit_ratio'] > 0]
    if not cached_methods.empty:
        fig, ax = plt.subplots(figsize=(10, 6))
        methods = cached_methods['method'].unique()
        hit_ratios = [cached_methods[cached_methods['method'] == m]['cache_hit_ratio'].mean() for m in methods]
        ax.bar(methods, hit_ratios, alpha=0.8, color='orange')
        ax.set_ylabel('Cache Hit Ratio', fontsize=12)
        ax.set_title('Cache Hit Ratio by Method (Training)', fontsize=14)
        ax.set_xticklabels(methods, rotation=45, ha='right')
        ax.set_ylim([0, 1])
        ax.grid(axis='y', alpha=0.3)
        plt.tight_layout()
        plt.savefig(save_dir / "figures/training_comparison_cache_hit.pdf")
        plt.show()
    
    # Plot 3: Speedup comparison
    baseline_time = df[df['method'] == 'baseline']['time'].mean()
    if baseline_time > 0:
        fig, ax = plt.subplots(figsize=(12, 6))
        methods = df['method'].unique()
        speedups = [baseline_time / df[df['method'] == m]['time'].mean() for m in methods]
        ax.bar(methods, speedups, alpha=0.8, color='green')
        ax.axhline(y=1.0, color='r', linestyle='--', label='Baseline (1.0x)')
        ax.set_ylabel('Speedup (×)', fontsize=12)
        ax.set_title('Speedup Comparison (Training)', fontsize=14)
        ax.set_xticklabels(methods, rotation=45, ha='right')
        ax.legend()
        ax.grid(axis='y', alpha=0.3)
        plt.tight_layout()
        plt.savefig(save_dir / "figures/training_comparison_speedup.pdf")
        plt.show()

# Example usage:
# comparison_df = load_training_comparison_results("your_model_id")
# if not comparison_df.empty:
#     visualize_training_comparison(comparison_df, save_dir)
