# FNO Training with Energy Regularization

Train a Fourier Neural Operator (FNO) on Darcy Flow with energy-based regularization.

**Loss Function**: `0.8 × MSE + 0.2 × Energy`

The energy regularization uses your trained EQM model to keep predictions physically plausible and in-distribution.

## 1. Setup

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repositories
!git clone https://github.com/MehdiMHeydari/EQM-Training.git
!git clone https://github.com/thuml/Neural-Solver-Library.git

# Move Neural-Solver-Library to expected location
!mv Neural-Solver-Library /content/EQM-Training/Neural-Solver-Library-main

In [None]:
# Install dependencies
!pip install -q torch torchvision einops omegaconf h5py matplotlib scipy tqdm timm

In [None]:
# Verify paths
import os

# EQM checkpoint (UPDATE THIS PATH IF NEEDED)
EQM_CHECKPOINT = "/content/drive/MyDrive/EQM_Checkpoints5/checkpoint_90.pth"

# Config and data paths
EQM_CONFIG = "/content/EQM-Training/configs/darcy_flow_eqm.yaml"
DATA_PATH = "/content/EQM-Training/data/2D_DarcyFlow_beta1.0_Train.hdf5"

# Check if files exist
print("Checking paths...")
print(f"EQM Checkpoint: {'✓' if os.path.exists(EQM_CHECKPOINT) else '✗ NOT FOUND'} {EQM_CHECKPOINT}")
print(f"EQM Config: {'✓' if os.path.exists(EQM_CONFIG) else '✗ NOT FOUND'} {EQM_CONFIG}")
print(f"Data: {'✓' if os.path.exists(DATA_PATH) else '✗ NOT FOUND'} {DATA_PATH}")

## 2. Training Configuration

Adjust these parameters as needed:

In [None]:
# Training configuration
CONFIG = {
    # Data
    "train_samples": 800,
    "val_samples": 200,
    "batch_size": 4,
    
    # Training
    "epochs": 100,
    "lr": 1e-3,
    
    # Loss weights
    "mse_weight": 0.8,
    "energy_weight": 0.2,
    "energy_loss_mode": "relative",  # Options: 'relative', 'threshold', 'normalized'
    "energy_temperature": 1.0,
    
    # Saving
    "checkpoint_save_path": "/content/drive/MyDrive/fno_with_energy.pth",
    "output_plot": "/content/drive/MyDrive/fno_training_curves.png",
}

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 3. Run Training

In [None]:
%cd /content/EQM-Training

!python train_fno_with_energy.py \
    --data_path {DATA_PATH} \
    --eqm_checkpoint {EQM_CHECKPOINT} \
    --eqm_config {EQM_CONFIG} \
    --train_samples {CONFIG["train_samples"]} \
    --val_samples {CONFIG["val_samples"]} \
    --batch_size {CONFIG["batch_size"]} \
    --epochs {CONFIG["epochs"]} \
    --lr {CONFIG["lr"]} \
    --mse_weight {CONFIG["mse_weight"]} \
    --energy_weight {CONFIG["energy_weight"]} \
    --energy_loss_mode {CONFIG["energy_loss_mode"]} \
    --energy_temperature {CONFIG["energy_temperature"]} \
    --checkpoint_save_path {CONFIG["checkpoint_save_path"]} \
    --output_plot {CONFIG["output_plot"]} \
    --save_every 25 \
    --device cuda

## 4. View Training Curves

In [None]:
from IPython.display import Image, display

if os.path.exists(CONFIG["output_plot"]):
    display(Image(filename=CONFIG["output_plot"]))
else:
    print("Training curves not found. Training may still be in progress.")

## 5. Test Trained Model

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import h5py
import sys
sys.path.insert(0, '/content/EQM-Training')
sys.path.insert(0, '/content/EQM-Training/Neural-Solver-Library-main')

from models.FNO import Model as FNO
from energy_regularization import EnergyRegularizationLoss

# Load trained FNO model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

checkpoint = torch.load(CONFIG["checkpoint_save_path"].replace('.pth', '_best.pth'), map_location=device)
print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
print(f"Validation loss: {checkpoint['val_loss']:.4f}")

In [None]:
# Load test data and visualize predictions
with h5py.File(DATA_PATH, 'r') as f:
    test_data = np.array(f['tensor'][900:905]).astype(np.float32)  # 5 test samples

# Normalize
norm_stats = checkpoint['normalization_stats']
test_normalized = 2 * (test_data - norm_stats['output_min']) / (norm_stats['output_max'] - norm_stats['output_min']) - 1

print(f"Test samples shape: {test_data.shape}")
print(f"Normalized range: [{test_normalized.min():.2f}, {test_normalized.max():.2f}]")

In [None]:
# Visualize sample predictions
fig, axes = plt.subplots(2, 5, figsize=(20, 8))

for i in range(5):
    # Ground truth
    axes[0, i].imshow(test_data[i], cmap='viridis')
    axes[0, i].set_title(f'Ground Truth {i+1}')
    axes[0, i].axis('off')
    
    # Show normalized version
    axes[1, i].imshow(test_normalized[i], cmap='viridis')
    axes[1, i].set_title(f'Normalized {i+1}')
    axes[1, i].axis('off')

plt.suptitle('Test Samples', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 6. Compare Energy: FNO vs Ground Truth

In [None]:
# Initialize energy model
energy_loss = EnergyRegularizationLoss(
    checkpoint_path=EQM_CHECKPOINT,
    config_path=EQM_CONFIG,
    device=device,
    training_data_path=DATA_PATH,
    num_calibration_samples=100,
    loss_mode='relative'
)

# Compute energy for test samples
test_tensor = torch.from_numpy(test_normalized[:, np.newaxis, :, :]).float().to(device)
energy_stats = energy_loss.get_energy_stats(test_tensor)

print(f"\nTest Sample Energy Statistics:")
print(f"  Mean: {energy_stats['mean']:.2f}")
print(f"  Std:  {energy_stats['std']:.2f}")
print(f"  Min:  {energy_stats['min']:.2f}")
print(f"  Max:  {energy_stats['max']:.2f}")
print(f"\nTraining Energy Statistics (reference):")
print(f"  Mean: {energy_loss.energy_mean:.2f}")
print(f"  Std:  {energy_loss.energy_std:.2f}")

---

## Notes

**Energy Loss Modes:**
- `relative`: Penalizes deviation from training mean (recommended)
- `threshold`: Only penalizes outliers beyond 2σ
- `normalized`: Scales energy to [0, 1] range

**Tuning Tips:**
- If MSE isn't improving: Decrease `energy_weight` to 0.1
- If predictions look unphysical: Increase `energy_weight` to 0.3
- Lower `temperature` = sharper penalty for OOD predictions