# Energy-Guided Flow Matching Sampling

Compare base flow matching vs energy-guided sampling for Darcy Flow.

**Key difference from FNO:**
- FNO is supervised (has ground truth MSE) → energy regularization is redundant
- Flow matching is generative (no ground truth at inference) → energy guidance helps!

This notebook tests whether energy guidance during sampling improves physical plausibility.

## 1. Setup

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repositories
!git clone https://github.com/MehdiMHeydari/EQM-Training.git
!pip install -q torch torchvision einops omegaconf h5py matplotlib scipy tqdm torchdiffeq

In [None]:
%cd /content/EQM-Training

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torchdiffeq import odeint
from omegaconf import OmegaConf
import h5py
import sys
sys.path.insert(0, '/content/EQM-Training')
from physics_flow_matching.unet.unet_bb import UNetModelWrapper as UNetModel

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 2. Configuration

Update these paths to match your setup:

In [None]:
# UPDATE THESE PATHS
FLOW_MODEL_CHECKPOINT = "/content/drive/MyDrive/EQM_Checkpoints5/checkpoint_90.pth"
EQM_CONFIG = "/content/EQM-Training/configs/darcy_flow_eqm.yaml"
DATA_PATH = "/content/drive/MyDrive/2D_DarcyFlow_beta1.0_Train.hdf5"

# Verify paths exist
import os
print(f"Checkpoint: {'OK' if os.path.exists(FLOW_MODEL_CHECKPOINT) else 'NOT FOUND'}")
print(f"Config: {'OK' if os.path.exists(EQM_CONFIG) else 'NOT FOUND'}")
print(f"Data: {'OK' if os.path.exists(DATA_PATH) else 'NOT FOUND'}")

## 3. Load Model

In [None]:
# Load config
config = OmegaConf.load(EQM_CONFIG)

# Load flow matching model (this is also our energy model)
print("Loading flow matching model...")
model = UNetModel(
    dim=config.unet.dim,
    out_channels=config.unet.out_channels,
    num_channels=config.unet.num_channels,
    num_res_blocks=config.unet.res_blocks,
    channel_mult=config.unet.channel_mult,
    num_head_channels=config.unet.head_chans,
    attention_resolutions=config.unet.attn_res,
    dropout=config.unet.dropout,
    use_new_attention_order=config.unet.new_attn,
    use_scale_shift_norm=config.unet.film,
)

checkpoint = torch.load(FLOW_MODEL_CHECKPOINT, map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
print("Model loaded!")

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")

## 4. Define Energy Functions

In [None]:
def compute_energy(x):
    """E(x) = sum(x * model(x)) - EQM energy function"""
    x_clamped = torch.clamp(x, -1.0, 1.0)
    with torch.no_grad():
        pred = model(x_clamped)
    return torch.sum(x_clamped * pred, dim=(1, 2, 3))

def compute_energy_gradient(x):
    """Compute gradient of energy for guidance"""
    x = x.detach().requires_grad_(True)
    x_clamped = torch.clamp(x, -1.0, 1.0)
    pred = model(x_clamped)
    energy = torch.sum(x_clamped * pred)
    energy.backward()
    return x.grad

print("Energy functions defined!")

## 5. Define Samplers

In [None]:
class BaseFlowMatchingSampler:
    """Standard flow matching sampler using ODE integration"""
    
    def __init__(self, model, device):
        self.model = model
        self.device = device
    
    def velocity_field(self, t, x):
        """v(x, t) from flow matching model"""
        with torch.no_grad():
            v = self.model(x)
        return v
    
    def sample(self, batch_size=4, num_steps=100, shape=(1, 128, 128)):
        """Sample by integrating ODE from t=0 (noise) to t=1 (data)"""
        # Start from noise
        x = torch.randn(batch_size, *shape, device=self.device)
        
        dt = 1.0 / num_steps
        
        for step in range(num_steps):
            t = step / num_steps
            v = self.velocity_field(t, x)
            x = x + v * dt
        
        return x


class EnergyGuidedFlowMatchingSampler:
    """Flow matching sampler with energy guidance"""
    
    def __init__(self, model, device, guidance_scale=0.1):
        self.model = model
        self.device = device
        self.guidance_scale = guidance_scale
    
    def velocity_field(self, t, x):
        """v(x, t) from flow matching model"""
        with torch.no_grad():
            v = self.model(x)
        return v
    
    def energy_gradient(self, x):
        """Compute energy gradient for guidance"""
        x = x.detach().requires_grad_(True)
        x_clamped = torch.clamp(x, -1.0, 1.0)
        pred = self.model(x_clamped)
        energy = torch.sum(x_clamped * pred)
        energy.backward()
        return x.grad
    
    def sample(self, batch_size=4, num_steps=100, shape=(1, 128, 128)):
        """Sample with energy guidance"""
        # Start from noise
        x = torch.randn(batch_size, *shape, device=self.device)
        
        dt = 1.0 / num_steps
        
        for step in range(num_steps):
            t = step / num_steps
            
            # Standard flow matching velocity
            v = self.velocity_field(t, x)
            
            # Energy guidance (stronger near the end when sample is more formed)
            guidance_weight = self.guidance_scale * (t ** 2)
            
            if guidance_weight > 0:
                energy_grad = self.energy_gradient(x)
                # Guide toward lower energy (more negative = better for EQM)
                v = v + guidance_weight * energy_grad
            
            # Euler step
            x = x + v * dt
        
        return x

print("Samplers defined!")

## 6. Load Real Data for Reference

In [None]:
# Load real samples for comparison
num_samples = 8

with h5py.File(DATA_PATH, 'r') as f:
    real_data = np.array(f['tensor'][:num_samples]).astype(np.float32)

# Normalize to [-1, 1]
real_min, real_max = real_data.min(), real_data.max()
real_normalized = 2 * (real_data - real_min) / (real_max - real_min) - 1
real_tensor = torch.from_numpy(real_normalized[:, np.newaxis, :, :]).float().to(device)

# Compute real data energy
energies_real = compute_energy(real_tensor)
print(f"Real data energy: mean={energies_real.mean().item():.1f}, std={energies_real.std().item():.1f}")
print(f"Real data range: [{real_normalized.min():.2f}, {real_normalized.max():.2f}]")

## 7. Generate Samples

In [None]:
num_steps = 100

# Base sampler
print("Generating samples with BASE flow matching...")
base_sampler = BaseFlowMatchingSampler(model, device)
samples_base = base_sampler.sample(batch_size=num_samples, num_steps=num_steps)
energies_base = compute_energy(samples_base)
print(f"Base: energy mean={energies_base.mean().item():.1f}, range=[{samples_base.min():.2f}, {samples_base.max():.2f}]")

# Energy-guided sampler with different scales
print("\nGenerating samples with ENERGY-GUIDED flow matching...")
guidance_results = {}

for gs in [0.01, 0.05, 0.1, 0.2]:
    energy_sampler = EnergyGuidedFlowMatchingSampler(model, device, guidance_scale=gs)
    samples = energy_sampler.sample(batch_size=num_samples, num_steps=num_steps)
    energies = compute_energy(samples)
    guidance_results[gs] = {'samples': samples, 'energies': energies}
    print(f"Guidance {gs}: energy mean={energies.mean().item():.1f}, range=[{samples.min():.2f}, {samples.max():.2f}]")

## 8. Compare Results

In [None]:
print("="*60)
print("ENERGY COMPARISON")
print("="*60)

real_mean = energies_real.mean().item()
print(f"\n{'Source':<25} {'Mean Energy':>15} {'Std':>10} {'vs Real':>12}")
print("-"*65)
print(f"{'Real Data':<25} {real_mean:>15.1f} {energies_real.std().item():>10.1f} {'(reference)':>12}")
print(f"{'Base Flow':<25} {energies_base.mean().item():>15.1f} {energies_base.std().item():>10.1f} {abs(energies_base.mean().item() - real_mean):>12.1f}")

for gs, results in guidance_results.items():
    e_mean = results['energies'].mean().item()
    e_std = results['energies'].std().item()
    diff = abs(e_mean - real_mean)
    print(f"{'Energy-Guided ' + str(gs):<25} {e_mean:>15.1f} {e_std:>10.1f} {diff:>12.1f}")

# Find best
base_diff = abs(energies_base.mean().item() - real_mean)
best_gs = min(guidance_results.keys(), key=lambda gs: abs(guidance_results[gs]['energies'].mean().item() - real_mean))
best_diff = abs(guidance_results[best_gs]['energies'].mean().item() - real_mean)

print(f"\nBest guided scale: {best_gs}")
print(f"Improvement over base: {base_diff:.1f} -> {best_diff:.1f} ({(1 - best_diff/base_diff)*100:.1f}% closer to real)")

## 9. Visualize Samples

In [None]:
# Use best guidance scale for visualization
samples_energy = guidance_results[best_gs]['samples']
energies_energy = guidance_results[best_gs]['energies']

fig, axes = plt.subplots(3, num_samples, figsize=(2.5*num_samples, 7))

for i in range(num_samples):
    # Real
    axes[0, i].imshow(real_normalized[i], cmap='viridis')
    axes[0, i].set_title(f'E={energies_real[i].item():.0f}', fontsize=9)
    axes[0, i].axis('off')
    
    # Base
    axes[1, i].imshow(samples_base[i, 0].cpu().numpy(), cmap='viridis')
    axes[1, i].set_title(f'E={energies_base[i].item():.0f}', fontsize=9)
    axes[1, i].axis('off')
    
    # Energy-guided
    axes[2, i].imshow(samples_energy[i, 0].cpu().numpy(), cmap='viridis')
    axes[2, i].set_title(f'E={energies_energy[i].item():.0f}', fontsize=9)
    axes[2, i].axis('off')

axes[0, 0].set_ylabel('Real Data', fontsize=12)
axes[1, 0].set_ylabel('Base Flow', fontsize=12)
axes[2, 0].set_ylabel(f'Energy-Guided\n(scale={best_gs})', fontsize=12)

plt.suptitle(f'Flow Matching: Base vs Energy-Guided Sampling\nReal Energy: {real_mean:.0f}', 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('/content/drive/MyDrive/flow_matching_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nSaved to /content/drive/MyDrive/flow_matching_comparison.png")

## 10. Guidance Scale Sweep

In [None]:
# More comprehensive sweep
guidance_scales = [0, 0.005, 0.01, 0.02, 0.05, 0.1, 0.15, 0.2]
sweep_results = []

print("Running guidance scale sweep...")
for gs in guidance_scales:
    if gs == 0:
        sampler = BaseFlowMatchingSampler(model, device)
    else:
        sampler = EnergyGuidedFlowMatchingSampler(model, device, guidance_scale=gs)
    
    samples = sampler.sample(batch_size=16, num_steps=100)
    energies = compute_energy(samples)
    
    sweep_results.append({
        'guidance': gs,
        'energy_mean': energies.mean().item(),
        'energy_std': energies.std().item(),
        'range_min': samples.min().item(),
        'range_max': samples.max().item(),
    })
    print(f"  Scale {gs:.3f}: Energy mean={energies.mean().item():.1f}")

print("Done!")

In [None]:
# Plot sweep results
fig, ax = plt.subplots(figsize=(10, 6))

gs_vals = [r['guidance'] for r in sweep_results]
energy_means = [r['energy_mean'] for r in sweep_results]
energy_stds = [r['energy_std'] for r in sweep_results]

ax.errorbar(gs_vals, energy_means, yerr=energy_stds, marker='o', capsize=5, 
            linewidth=2, markersize=8, color='blue', label='Generated')
ax.axhline(real_mean, color='green', linestyle='--', linewidth=2, 
           label=f'Real Data ({real_mean:.0f})')
ax.axhspan(real_mean - energies_real.std().item(), 
           real_mean + energies_real.std().item(), 
           alpha=0.2, color='green', label='Real ± 1σ')

ax.set_xlabel('Guidance Scale', fontsize=12)
ax.set_ylabel('Mean Energy', fontsize=12)
ax.set_title('Effect of Energy Guidance on Generated Sample Energy', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('/content/drive/MyDrive/guidance_scale_sweep.png', dpi=150)
plt.show()

## 11. Conclusions

### Key Findings:

1. **Base flow matching** generates samples but energy may drift from real data distribution

2. **Energy guidance** can steer samples toward more physically plausible regions

3. **Optimal guidance scale** depends on the specific model and data

### Why Energy Guidance Works Here (but not for FNO):

| Aspect | FNO (Supervised) | Flow Matching (Generative) |
|--------|------------------|---------------------------|
| Ground truth at inference | No, but MSE trained well | **No, can drift** |
| Energy guidance useful | No (redundant with MSE) | **Yes (guides generation)** |

### Recommended Use:
- Start with guidance_scale = 0.05
- Increase if samples are unphysical
- Decrease if samples look over-smoothed