## FNO Output Learning: State → Derived Quantities

### Architecture Overview
This notebook trains a model to map **state(t) → derived(t)**, where:
- **State**: (density, potential) - shape (2, H, W)
- **Derived**: (gamma_n, gamma_c) - scalar values broadcast to (2, H, W) for compatibility

### Combined Inference Pipeline
1. **State Model** (from `fno_test.ipynb`): state(t) → state(t+1) (autoregressive)
2. **Output Model** (this notebook): state(t) → derived(t) (direct mapping)

Given only an initial condition state(0), we can predict the full trajectory AND derived quantities!

In [None]:
%matplotlib inline
import torch
import h5py
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as anim
from IPython.display import HTML, display
from neuralop.models import FNO
from neuralop.training import AdamW
from neuralop.utils import count_model_params
import torch.nn as nn
import os
from torch.utils.data import DataLoader

# Local utilities
from utils import (
    StateDerivedDataset, TrajectoryWithDerivedDataset,
    train_epoch_direct, validate_direct,
    rollout_combined, compute_mse_over_time, compute_mae_over_time, print_summary,
    save_model, load_model
)

device = "cuda" if torch.cuda.is_available() else "cpu"
show_animations = True
data_truncation = 0.5

plt.rcParams['animation.embed_limit'] = 500

print(f"\033[1mUsing Device: {device}")
print(f"\033[1mShowing animations: {show_animations}")

[1mUsing Device: cuda
[1mShowing animations: True


In [None]:
# Data paths
data_dir = "/work/10407/anthony50102/frontera/data/hw2d_sim/t600_d256x256_raw/"

train_files = ["hw2d_sim_step0.025_end1_pts512_c11_k015_N3_nu5e-8_20250315142044_11702_0.h5",
               "hw2d_sim_step0.025_end1_pts512_c11_k015_N3_nu5e-8_20250315142045_4677_2.h5"]
test_files = ["hw2d_sim_step0.025_end1_pts512_c11_k015_N3_nu5e-8_20250316215751_19984_3.h5"]

# Path to pre-trained state model (update this after training fno_test.ipynb)
STATE_MODEL_PATH = "state_model_12_16_2025-12:00.pth"  # Update with your actual path

In [None]:
def process_with_derived(density, potential, gamma_n, gamma_c):
    """
    Process data to include both state and derived quantities.
    
    Returns:
        state: (T, 2, H, W) - density and potential
        derived: (T, 2) - gamma_n and gamma_c scalars
    """
    state = np.stack([density, potential], axis=1)  # (T, 2, H, W)
    derived = np.stack([gamma_n, gamma_c], axis=1)  # (T, 2) - scalar values
    return state, derived


processed_files = []

for file in train_files:
    save_name = "output_train_" + "".join(file.split(".")[:-1]) + ".npz"
    
    if os.path.exists(save_name):
        print(f"File already processed: {save_name}")
        processed_files.append(save_name)
        continue
    
    with h5py.File(data_dir + file, 'r') as f:
        end_index = int(f['density'].shape[0] * data_truncation)
        density = f['density'][:end_index]
        potential = f['phi'][:end_index]
        gamma_n = f['gamma_n'][:end_index]
        gamma_c = f['gamma_c'][:end_index]
        
        state, derived = process_with_derived(density, potential, gamma_n, gamma_c)
        
        np.savez(save_name, state=state, derived=derived)
        processed_files.append(save_name)
        print(f"Saved: {save_name}")

print(f"\nProcessed {len(processed_files)} files")

In [None]:
# Create dataloaders using shared Dataset classes
data_path = processed_files[0]

train_loader = DataLoader(StateDerivedDataset(data_path, mode='train'), batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(StateDerivedDataset(data_path, mode='val'), batch_size=64, shuffle=False, num_workers=4)
test_loader = DataLoader(StateDerivedDataset(data_path, mode='test'), batch_size=64, shuffle=False, num_workers=4)

# Trajectory loader for rollout training
traj_loader = DataLoader(TrajectoryWithDerivedDataset(".", mode='train'), batch_size=1, shuffle=True)

# Test
for batch in train_loader:
    print(f"State shape: {batch['state'].shape}")
    print(f"Derived shape: {batch['derived'].shape}")
    break

print(f"\nTrain: {len(train_loader.dataset)}, Val: {len(val_loader.dataset)}, Test: {len(test_loader.dataset)}")

torch.Size([1, 8000, 4, 256, 256])


In [None]:
# ============== Load Pre-trained State Model ==============
state_model = FNO(n_modes=(64, 64), in_channels=2, out_channels=2, hidden_channels=512)
state_model = state_model.to(device)

# Uncomment to load trained weights:
# load_model(state_model, 'state_model')
# state_model.eval()
# for p in state_model.parameters():
#     p.requires_grad = False

print(f"State model: {count_model_params(state_model)} parameters")
print("NOTE: Uncomment load_model after training fno_test.ipynb!")

# ============== Output Model (FNO-based) ==============
output_model = FNO(n_modes=(32, 32), in_channels=2, out_channels=2, hidden_channels=128, n_layers=3)
output_model = output_model.to(device)

print(f"Output model: {count_model_params(output_model)} parameters")

In [None]:
# Since derived quantities are scalars, we need a pooling layer to map from spatial output
class OutputHead(nn.Module):
    """Maps FNO output (B, 2, H, W) → scalar predictions (B, 2)"""
    def __init__(self, hidden_dim=64):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)  # Global average pooling
        self.fc = nn.Sequential(
            nn.Linear(2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)
        )
    
    def forward(self, x):
        # x: (B, 2, H, W)
        x = self.pool(x)  # (B, 2, 1, 1)
        x = x.view(x.size(0), -1)  # (B, 2)
        return self.fc(x)  # (B, 2)


# Combined model: FNO features + pooling head
class DerivedQuantityModel(nn.Module):
    def __init__(self, fno_model, output_head):
        super().__init__()
        self.fno = fno_model
        self.head = output_head
    
    def forward(self, x):
        features = self.fno(x)  # (B, 2, H, W)
        return self.head(features)  # (B, 2)


output_head = OutputHead(hidden_dim=64).to(device)
derived_model = DerivedQuantityModel(output_model, output_head).to(device)

print(f"Total derived model parameters: {sum(p.numel() for p in derived_model.parameters())}")

In [None]:
# ============== Training ==============
print("=" * 50)
print("Training: Direct state → derived mapping (FNO)")
print("=" * 50)

optimizer = AdamW(derived_model.parameters(), lr=1e-3, weight_decay=1e-4)
best_val_loss = float('inf')

for epoch in range(100):
    train_loss = train_epoch_direct(derived_model, train_loader, optimizer, device)
    val_loss = validate_direct(derived_model, val_loader, device)
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_model(derived_model, 'derived_fno', metadata={'best_val_loss': best_val_loss})
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/100 | Train: {train_loss:.6f} | Val: {val_loss:.6f}")

# Load best model
load_model(derived_model, 'derived_fno')
print(f"\nBest validation loss: {best_val_loss:.6f}")

In [None]:
# Model already saved during training with best_val_loss
print("Best model saved as: derived_fno_latest.pt")

### Combined Rollout: Full Surrogate Inference
Given only an initial condition, predict both state trajectory AND derived quantities.

In [None]:
# Load test data
test_data = np.load(data_path)
test_state = test_data['state']
test_derived = test_data['derived']

# Use last 10% as test
test_start = int(len(test_state) * 0.9)
test_state = test_state[test_start:]
test_derived = test_derived[test_start:]

print(f"Test state shape: {test_state.shape}")
print(f"Test derived shape: {test_derived.shape}")

# Run combined rollout using shared utility
num_steps = min(len(test_state), 200)
state_recon, derived_recon = rollout_combined(
    state_model, derived_model, test_state[0], num_steps, device
)

print(f"\nState recon: {state_recon.shape}, Derived recon: {derived_recon.shape}")

In [None]:
# Evaluation
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Gamma_n comparison
axes[0, 0].plot(test_derived[:num_steps, 0], 'b-', label='Ground Truth', alpha=0.8)
axes[0, 0].plot(derived_recon[:, 0], 'r--', label='FNO Predicted', alpha=0.8)
axes[0, 0].set_xlabel('Timestep'); axes[0, 0].set_ylabel('gamma_n')
axes[0, 0].set_title('Gamma_n'); axes[0, 0].legend(); axes[0, 0].grid(True, alpha=0.3)

# Gamma_c comparison
axes[0, 1].plot(test_derived[:num_steps, 1], 'b-', label='Ground Truth', alpha=0.8)
axes[0, 1].plot(derived_recon[:, 1], 'r--', label='FNO Predicted', alpha=0.8)
axes[0, 1].set_xlabel('Timestep'); axes[0, 1].set_ylabel('gamma_c')
axes[0, 1].set_title('Gamma_c'); axes[0, 1].legend(); axes[0, 1].grid(True, alpha=0.3)

# Derived error over time
gamma_n_err = np.abs(test_derived[:num_steps, 0] - derived_recon[:, 0])
gamma_c_err = np.abs(test_derived[:num_steps, 1] - derived_recon[:, 1])
axes[1, 0].plot(gamma_n_err, label='gamma_n')
axes[1, 0].plot(gamma_c_err, label='gamma_c')
axes[1, 0].set_xlabel('Timestep'); axes[1, 0].set_ylabel('Abs Error')
axes[1, 0].set_title('Derived Quantity Error'); axes[1, 0].legend(); axes[1, 0].grid(True, alpha=0.3)

# State MSE over time
state_mse = compute_mse_over_time(test_state[:num_steps], state_recon)
axes[1, 1].plot(state_mse)
axes[1, 1].set_xlabel('Timestep'); axes[1, 1].set_ylabel('MSE')
axes[1, 1].set_title('State Prediction Error'); axes[1, 1].set_yscale('log'); axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Summary
print("\n" + "=" * 50)
print_summary("Gamma_n MAE", gamma_n_err)
print_summary("Gamma_c MAE", gamma_c_err)
print_summary("State MSE", state_mse)

In [None]:
# Animation: State reconstruction (if animations enabled)
if show_animations:
    print("\033[1mShowing state reconstruction animation")
    
    # Subsample for animation
    step = 5
    gt_subbed = test_state[:num_steps:step]
    recon_subbed = state_recon[::step]
    
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    
    vmin = gt_subbed[:, 0].min()
    vmax = gt_subbed[:, 0].max()
    
    img1 = axes[0].imshow(gt_subbed[0, 0], vmin=vmin, vmax=vmax, cmap='viridis')
    img2 = axes[1].imshow(recon_subbed[0, 0], vmin=vmin, vmax=vmax, cmap='viridis')
    
    axes[0].set_title("Ground Truth (density)")
    axes[1].set_title("Surrogate Prediction")
    
    plt.colorbar(img1, ax=axes[0], fraction=0.046)
    plt.colorbar(img2, ax=axes[1], fraction=0.046)
    
    def animate(frame):
        img1.set_data(gt_subbed[frame, 0])
        img2.set_data(recon_subbed[frame, 0])
        return [img1, img2]
    
    animation = anim.FuncAnimation(fig, animate, frames=len(gt_subbed), interval=50, blit=True)
    display(HTML(animation.to_jshtml()))
else:
    print("Animations disabled")