# FNO Training with Energy Regularization

Train a Fourier Neural Operator (FNO) on Darcy Flow with energy-based regularization.

**Loss Function**: `0.8 × MSE + 0.2 × Energy`

The energy regularization uses your trained EQM model to keep predictions physically plausible and in-distribution.

## 1. Setup

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repositories
!git clone https://github.com/MehdiMHeydari/EQM-Training.git
!git clone https://github.com/thuml/Neural-Solver-Library.git

# Move Neural-Solver-Library to expected location
!mv Neural-Solver-Library /content/EQM-Training/Neural-Solver-Library-main

In [None]:
# Install dependencies
!pip install -q torch torchvision einops omegaconf h5py matplotlib scipy tqdm timm

In [None]:
# Verify paths
import os

# EQM checkpoint (UPDATE THIS PATH IF NEEDED)
EQM_CHECKPOINT = "/content/drive/MyDrive/EQM_Checkpoints5/checkpoint_90.pth"

# Config and data paths
EQM_CONFIG = "/content/EQM-Training/configs/darcy_flow_eqm.yaml"
DATA_PATH = "/content/EQM-Training/data/2D_DarcyFlow_beta1.0_Train.hdf5"

# Check if files exist
print("Checking paths...")
print(f"EQM Checkpoint: {'✓' if os.path.exists(EQM_CHECKPOINT) else '✗ NOT FOUND'} {EQM_CHECKPOINT}")
print(f"EQM Config: {'✓' if os.path.exists(EQM_CONFIG) else '✗ NOT FOUND'} {EQM_CONFIG}")
print(f"Data: {'✓' if os.path.exists(DATA_PATH) else '✗ NOT FOUND'} {DATA_PATH}")

## 2. Training Configuration

Adjust these parameters as needed:

In [None]:
# Training configuration
CONFIG = {
    # Data
    "train_samples": 800,
    "val_samples": 200,
    "batch_size": 4,
    
    # Training
    "epochs": 100,
    "lr": 1e-3,
    
    # Loss weights
    "mse_weight": 0.8,
    "energy_weight": 0.2,
    "energy_loss_mode": "relative",  # Options: 'relative', 'threshold', 'normalized'
    "energy_temperature": 1.0,
    
    # Saving
    "checkpoint_save_path": "/content/drive/MyDrive/fno_with_energy.pth",
    "output_plot": "/content/drive/MyDrive/fno_training_curves.png",
}

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 3. Run Training

In [None]:
%cd /content/EQM-Training

!python train_fno_with_energy.py \
    --data_path {DATA_PATH} \
    --eqm_checkpoint {EQM_CHECKPOINT} \
    --eqm_config {EQM_CONFIG} \
    --train_samples {CONFIG["train_samples"]} \
    --val_samples {CONFIG["val_samples"]} \
    --batch_size {CONFIG["batch_size"]} \
    --epochs {CONFIG["epochs"]} \
    --lr {CONFIG["lr"]} \
    --mse_weight {CONFIG["mse_weight"]} \
    --energy_weight {CONFIG["energy_weight"]} \
    --energy_loss_mode {CONFIG["energy_loss_mode"]} \
    --energy_temperature {CONFIG["energy_temperature"]} \
    --checkpoint_save_path {CONFIG["checkpoint_save_path"]} \
    --output_plot {CONFIG["output_plot"]} \
    --save_every 25 \
    --device cuda

## 4. View Training Curves

In [None]:
from IPython.display import Image, display

if os.path.exists(CONFIG["output_plot"]):
    display(Image(filename=CONFIG["output_plot"]))
else:
    print("Training curves not found. Training may still be in progress.")

## 5. Test Trained Model

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import h5py
import sys
sys.path.insert(0, '/content/EQM-Training')
sys.path.insert(0, '/content/EQM-Training/Neural-Solver-Library-main')

from models.FNO import Model as FNO
from energy_regularization import EnergyRegularizationLoss

# Load trained FNO model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

checkpoint = torch.load(CONFIG["checkpoint_save_path"].replace('.pth', '_best.pth'), map_location=device)
print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
print(f"Validation loss: {checkpoint['val_loss']:.4f}")

In [None]:
# Load test data and visualize predictions
with h5py.File(DATA_PATH, 'r') as f:
    test_data = np.array(f['tensor'][900:905]).astype(np.float32)  # 5 test samples

# Normalize
norm_stats = checkpoint['normalization_stats']
test_normalized = 2 * (test_data - norm_stats['output_min']) / (norm_stats['output_max'] - norm_stats['output_min']) - 1

print(f"Test samples shape: {test_data.shape}")
print(f"Normalized range: [{test_normalized.min():.2f}, {test_normalized.max():.2f}]")

In [None]:
# Visualize sample predictions
fig, axes = plt.subplots(2, 5, figsize=(20, 8))

for i in range(5):
    # Ground truth
    axes[0, i].imshow(test_data[i], cmap='viridis')
    axes[0, i].set_title(f'Ground Truth {i+1}')
    axes[0, i].axis('off')
    
    # Show normalized version
    axes[1, i].imshow(test_normalized[i], cmap='viridis')
    axes[1, i].set_title(f'Normalized {i+1}')
    axes[1, i].axis('off')

plt.suptitle('Test Samples', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 6. Compare Energy: FNO vs Ground Truth

In [None]:
# Initialize energy model
energy_loss = EnergyRegularizationLoss(
    checkpoint_path=EQM_CHECKPOINT,
    config_path=EQM_CONFIG,
    device=device,
    training_data_path=DATA_PATH,
    num_calibration_samples=100,
    loss_mode='relative'
)

# Compute energy for test samples
test_tensor = torch.from_numpy(test_normalized[:, np.newaxis, :, :]).float().to(device)
energy_stats = energy_loss.get_energy_stats(test_tensor)

print(f"\nTest Sample Energy Statistics:")
print(f"  Mean: {energy_stats['mean']:.2f}")
print(f"  Std:  {energy_stats['std']:.2f}")
print(f"  Min:  {energy_stats['min']:.2f}")
print(f"  Max:  {energy_stats['max']:.2f}")
print(f"\nTraining Energy Statistics (reference):")
print(f"  Mean: {energy_loss.energy_mean:.2f}")
print(f"  Std:  {energy_loss.energy_std:.2f}")

---

## Notes

**Energy Loss Modes:**
- `relative`: Penalizes deviation from training mean (recommended)
- `threshold`: Only penalizes outliers beyond 2σ
- `normalized`: Scales energy to [0, 1] range

**Tuning Tips:**
- If MSE isn't improving: Decrease `energy_weight` to 0.1
- If predictions look unphysical: Increase `energy_weight` to 0.3
- Lower `temperature` = sharper penalty for OOD predictions

## 7. Compare Models: MSE+Energy vs MSE-Only

Run both trainings first, then compare the results:
1. **MSE+Energy Model**: Trained with `energy_weight=0.2`
2. **MSE-Only Model**: Trained with `energy_weight=0.0`

In [None]:
# Configuration for comparison
# UPDATE THESE PATHS to your trained model checkpoints

MSE_ENERGY_CHECKPOINT = "/content/drive/MyDrive/fno_with_energy_best.pth"  # MSE+Energy model
MSE_ONLY_CHECKPOINT = "/content/drive/MyDrive/fno_mse_only_best.pth"       # MSE-only model

# Verify both checkpoints exist
print("Checking model checkpoints...")
print(f"MSE+Energy: {'✓' if os.path.exists(MSE_ENERGY_CHECKPOINT) else '✗ NOT FOUND'}")
print(f"MSE-Only:   {'✓' if os.path.exists(MSE_ONLY_CHECKPOINT) else '✗ NOT FOUND'}")

In [None]:
# Load both models
from models.FNO import Model as FNO

def load_fno_model(checkpoint_path, device):
    """Load a trained FNO model from checkpoint."""
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    
    model = FNO(
        img_size=(128, 128),
        patch_size=1,
        in_channels=1,
        out_channels=1,
        embed_dim=256,
        depth=12,
        modes=32,
        num_blocks=8
    ).to(device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    return model, checkpoint

# Load models
print("Loading MSE+Energy model...")
model_energy, ckpt_energy = load_fno_model(MSE_ENERGY_CHECKPOINT, device)
print(f"  Best epoch: {ckpt_energy['epoch']}, Val loss: {ckpt_energy['val_loss']:.4f}")

print("\nLoading MSE-only model...")
model_mse, ckpt_mse = load_fno_model(MSE_ONLY_CHECKPOINT, device)
print(f"  Best epoch: {ckpt_mse['epoch']}, Val loss: {ckpt_mse['val_loss']:.4f}")

In [None]:
# Load test data (samples 1000+ are held out)
NUM_TEST_SAMPLES = 100

with h5py.File(DATA_PATH, 'r') as f:
    total_samples = f['tensor'].shape[0]
    start_idx = min(1000, total_samples - NUM_TEST_SAMPLES)
    
    # Load outputs (u) and inputs (a)
    test_outputs = np.array(f['tensor'][start_idx:start_idx + NUM_TEST_SAMPLES]).astype(np.float32)
    test_inputs = np.array(f['x-coordinate'][start_idx:start_idx + NUM_TEST_SAMPLES]).astype(np.float32)

print(f"Loaded {NUM_TEST_SAMPLES} test samples (indices {start_idx} to {start_idx + NUM_TEST_SAMPLES})")

# Normalize using saved stats
norm_stats = ckpt_energy['normalization_stats']

inputs_norm = 2 * (test_inputs - norm_stats['input_min']) / (norm_stats['input_max'] - norm_stats['input_min']) - 1
outputs_norm = 2 * (test_outputs - norm_stats['output_min']) / (norm_stats['output_max'] - norm_stats['output_min']) - 1

# Convert to tensors
inputs_tensor = torch.from_numpy(inputs_norm[:, np.newaxis, :, :]).float().to(device)
targets_tensor = torch.from_numpy(outputs_norm[:, np.newaxis, :, :]).float().to(device)

print(f"Input shape: {inputs_tensor.shape}")
print(f"Target shape: {targets_tensor.shape}")

In [None]:
# Initialize energy computation
energy_loss = EnergyRegularizationLoss(
    checkpoint_path=EQM_CHECKPOINT,
    config_path=EQM_CONFIG,
    device=device,
    training_data_path=DATA_PATH,
    num_calibration_samples=100,
    loss_mode='relative'
)

training_energy_mean = energy_loss.energy_mean
training_energy_std = energy_loss.energy_std

print(f"Training data energy reference:")
print(f"  Mean: {training_energy_mean:.2f}")
print(f"  Std:  {training_energy_std:.2f}")

In [None]:
# Evaluate both models
print("Evaluating models on test set...")

with torch.no_grad():
    # MSE+Energy model predictions
    preds_energy = model_energy(inputs_tensor)
    mse_energy = ((preds_energy - targets_tensor) ** 2).mean(dim=(1, 2, 3))
    energies_energy = energy_loss.compute_energy(preds_energy)
    
    # MSE-only model predictions  
    preds_mse = model_mse(inputs_tensor)
    mse_mse = ((preds_mse - targets_tensor) ** 2).mean(dim=(1, 2, 3))
    energies_mse = energy_loss.compute_energy(preds_mse)
    
    # Ground truth energy
    gt_energies = energy_loss.compute_energy(targets_tensor)

# Convert to numpy
mse_energy_np = mse_energy.cpu().numpy()
mse_mse_np = mse_mse.cpu().numpy()
energies_energy_np = energies_energy.cpu().numpy()
energies_mse_np = energies_mse.cpu().numpy()
gt_energies_np = gt_energies.cpu().numpy()

# Print comparison table
print("\n" + "="*60)
print("COMPARISON RESULTS")
print("="*60)
print(f"\n{'Metric':<25} {'MSE+Energy':>15} {'MSE-Only':>15}")
print("-" * 55)
print(f"{'MSE (mean)':<25} {mse_energy_np.mean():>15.6f} {mse_mse_np.mean():>15.6f}")
print(f"{'MSE (std)':<25} {mse_energy_np.std():>15.6f} {mse_mse_np.std():>15.6f}")
print(f"{'Energy (mean)':<25} {energies_energy_np.mean():>15.2f} {energies_mse_np.mean():>15.2f}")
print(f"{'Energy (std)':<25} {energies_energy_np.std():>15.2f} {energies_mse_np.std():>15.2f}")

# Energy deviation from training mean
dev_energy = abs(energies_energy_np.mean() - training_energy_mean)
dev_mse = abs(energies_mse_np.mean() - training_energy_mean)

print(f"\n{'Energy dev from train μ':<25} {dev_energy:>15.2f} {dev_mse:>15.2f}")
print(f"{'Deviation (in σ units)':<25} {dev_energy/training_energy_std:>15.2f} {dev_mse/training_energy_std:>15.2f}")

In [None]:
# Visualization: Energy Distribution Comparison
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# 1. Energy Distribution
ax1 = axes[0]
bins = np.linspace(
    min(gt_energies_np.min(), energies_energy_np.min(), energies_mse_np.min()),
    max(gt_energies_np.max(), energies_energy_np.max(), energies_mse_np.max()),
    40
)

ax1.hist(gt_energies_np, bins=bins, alpha=0.5, label=f'Ground Truth (μ={gt_energies_np.mean():.0f})', 
         color='green', density=True)
ax1.hist(energies_energy_np, bins=bins, alpha=0.5, 
         label=f'MSE+Energy (μ={energies_energy_np.mean():.0f})', color='blue', density=True)
ax1.hist(energies_mse_np, bins=bins, alpha=0.5, 
         label=f'MSE-Only (μ={energies_mse_np.mean():.0f})', color='red', density=True)

ax1.axvline(training_energy_mean, color='black', linestyle='--', linewidth=2, 
            label=f'Training μ={training_energy_mean:.0f}')
ax1.axvspan(training_energy_mean - 2*training_energy_std, 
            training_energy_mean + 2*training_energy_std, 
            alpha=0.1, color='gray', label='Training ±2σ')

ax1.set_xlabel('Energy', fontsize=12)
ax1.set_ylabel('Density', fontsize=12)
ax1.set_title('Energy Distribution', fontsize=14, fontweight='bold')
ax1.legend(fontsize=9)

# 2. MSE Distribution
ax2 = axes[1]
mse_bins = np.linspace(0, max(mse_energy_np.max(), mse_mse_np.max()), 30)

ax2.hist(mse_energy_np, bins=mse_bins, alpha=0.6, 
         label=f'MSE+Energy (μ={mse_energy_np.mean():.4f})', color='blue')
ax2.hist(mse_mse_np, bins=mse_bins, alpha=0.6, 
         label=f'MSE-Only (μ={mse_mse_np.mean():.4f})', color='red')

ax2.set_xlabel('MSE per Sample', fontsize=12)
ax2.set_ylabel('Count', fontsize=12)
ax2.set_title('MSE Distribution', fontsize=14, fontweight='bold')
ax2.legend(fontsize=10)

# 3. MSE vs Energy Scatter
ax3 = axes[2]
ax3.scatter(mse_energy_np, energies_energy_np, alpha=0.6, label='MSE+Energy', color='blue', s=30)
ax3.scatter(mse_mse_np, energies_mse_np, alpha=0.6, label='MSE-Only', color='red', s=30)

ax3.axhline(training_energy_mean, color='black', linestyle='--', alpha=0.7, label='Training Energy Mean')
ax3.axhspan(training_energy_mean - 2*training_energy_std, 
            training_energy_mean + 2*training_energy_std, alpha=0.1, color='gray')

ax3.set_xlabel('MSE', fontsize=12)
ax3.set_ylabel('Energy', fontsize=12)
ax3.set_title('MSE vs Energy Trade-off', fontsize=14, fontweight='bold')
ax3.legend(fontsize=10)

plt.suptitle('Model Comparison: MSE+Energy vs MSE-Only', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.savefig('/content/drive/MyDrive/model_comparison_distributions.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Visualization: Sample Predictions Comparison
preds_energy_np = preds_energy.cpu().numpy()
preds_mse_np = preds_mse.cpu().numpy()

fig, axes = plt.subplots(4, 5, figsize=(20, 16))

sample_indices = [0, 1, 2, 3, 4]  # Show 5 samples

for col, idx in enumerate(sample_indices):
    # Row 0: Input a(x)
    axes[0, col].imshow(inputs_norm[idx], cmap='viridis')
    axes[0, col].set_title(f'Input a(x) #{idx+1}', fontsize=11)
    axes[0, col].axis('off')
    
    # Row 1: Ground Truth u(x)
    axes[1, col].imshow(outputs_norm[idx], cmap='viridis')
    axes[1, col].set_title(f'Ground Truth', fontsize=11)
    axes[1, col].axis('off')
    
    # Row 2: MSE+Energy prediction
    axes[2, col].imshow(preds_energy_np[idx, 0], cmap='viridis')
    mse_e = mse_energy_np[idx]
    energy_e = energies_energy_np[idx]
    axes[2, col].set_title(f'MSE+Energy\nMSE={mse_e:.4f}, E={energy_e:.0f}', fontsize=10)
    axes[2, col].axis('off')
    
    # Row 3: MSE-Only prediction
    axes[3, col].imshow(preds_mse_np[idx, 0], cmap='viridis')
    mse_m = mse_mse_np[idx]
    energy_m = energies_mse_np[idx]
    axes[3, col].set_title(f'MSE-Only\nMSE={mse_m:.4f}, E={energy_m:.0f}', fontsize=10)
    axes[3, col].axis('off')

# Row labels
fig.text(0.02, 0.88, 'Input', fontsize=12, fontweight='bold', rotation=90, va='center')
fig.text(0.02, 0.65, 'Ground\nTruth', fontsize=12, fontweight='bold', rotation=90, va='center')
fig.text(0.02, 0.42, 'MSE+\nEnergy', fontsize=12, fontweight='bold', rotation=90, va='center')
fig.text(0.02, 0.18, 'MSE\nOnly', fontsize=12, fontweight='bold', rotation=90, va='center')

plt.suptitle('Sample Predictions Comparison', fontsize=16, fontweight='bold')
plt.tight_layout(rect=[0.04, 0, 1, 0.96])
plt.savefig('/content/drive/MyDrive/model_comparison_samples.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Print final analysis and interpretation
print("="*70)
print("FINAL ANALYSIS")
print("="*70)

mse_winner = "MSE-Only" if mse_mse_np.mean() < mse_energy_np.mean() else "MSE+Energy"
energy_winner = "MSE+Energy" if dev_energy < dev_mse else "MSE-Only"

print(f"\n✓ MSE Winner: {mse_winner}")
print(f"  MSE+Energy: {mse_energy_np.mean():.6f}")
print(f"  MSE-Only:   {mse_mse_np.mean():.6f}")

print(f"\n✓ Energy (Physical Plausibility) Winner: {energy_winner}")
print(f"  MSE+Energy deviation: {dev_energy:.2f} ({dev_energy/training_energy_std:.2f}σ from training mean)")
print(f"  MSE-Only deviation:   {dev_mse:.2f} ({dev_mse/training_energy_std:.2f}σ from training mean)")

print("\n" + "-"*70)
print("INTERPRETATION")
print("-"*70)

if mse_winner == "MSE-Only" and energy_winner == "MSE+Energy":
    print("""
✓ TRADE-OFF DETECTED (Expected Behavior)

The results show the classic accuracy vs. physical plausibility trade-off:

• MSE-Only achieves LOWER reconstruction error
  → It's better at matching exact pixel values
  
• BUT MSE-Only predictions are OUT-OF-DISTRIBUTION
  → Energy is significantly different from training data
  → Predictions may not respect physical constraints
  
• MSE+Energy has HIGHER reconstruction error
  → But predictions stay IN-DISTRIBUTION
  → Energy close to training mean = physically plausible
  
CONCLUSION: Energy regularization successfully constrains the FNO
to produce physically meaningful predictions, at the cost of some
reconstruction accuracy. This is the intended behavior!
""")
elif mse_winner == "MSE+Energy" and energy_winner == "MSE+Energy":
    print("""
✓ MSE+ENERGY WINS BOTH!

Surprising result: Energy regularization helped the model learn
better representations that also improve reconstruction accuracy.

This could mean:
• The energy function captures useful physical structure
• Regularization prevented overfitting
• The physics-informed constraint guided optimization to better minima
""")
else:
    print("""
Results show MSE-Only may be preferable for this dataset.
Consider reducing energy_weight or investigating the energy function calibration.
""")