## Conv Output Learning: State → Derived Quantities

### Architecture Overview
This notebook trains a **CNN-based model** to map **state(t) → derived(t)**, where:
- **State**: (density, potential) - shape (2, H, W)
- **Derived**: (gamma_n, gamma_c) - scalar values

### Comparison with FNO Output Learning
- **FNO**: Learns in frequency domain, good for smooth/global patterns
- **CNN**: Learns local spatial patterns, potentially faster and more parameter-efficient

### Combined Inference Pipeline
1. **State Model** (from `fno_test.ipynb`): state(t) → state(t+1) (autoregressive)
2. **Output Model** (this notebook): state(t) → derived(t) (CNN-based direct mapping)

In [None]:
%matplotlib inline
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as anim
from IPython.display import HTML, display
from neuralop.models import FNO
from neuralop.training import AdamW
from neuralop.utils import count_model_params
from torch.utils.data import DataLoader

from utils import (
    load_h5_data, process_state_data, process_derived_data,
    StateDerivedDataset, TrajectoryWithDerivedDataset,
    save_model, load_model,
    train_epoch_direct, validate_direct,
    rollout_combined, compute_mse_over_time, print_summary
)

device = "cuda" if torch.cuda.is_available() else "cpu"
show_animations = True
data_truncation = 0.5

plt.rcParams['animation.embed_limit'] = 500

print(f"\033[1mUsing Device: {device}")
print(f"\033[1mShowing animations: {show_animations}")

In [None]:
# Data paths
data_dir = "/work/10407/anthony50102/frontera/data/hw2d_sim/t600_d256x256_raw/"

train_files = ["hw2d_sim_step0.025_end1_pts512_c11_k015_N3_nu5e-8_20250315142044_11702_0.h5",
               "hw2d_sim_step0.025_end1_pts512_c11_k015_N3_nu5e-8_20250315142045_4677_2.h5"]
test_files = ["hw2d_sim_step0.025_end1_pts512_c11_k015_N3_nu5e-8_20250316215751_19984_3.h5"]

# Model checkpoints
STATE_MODEL_NAME = "state_model"
CNN_OUTPUT_MODEL_NAME = "derived_cnn"

### Data Processing

In [None]:
# Process and save data with derived quantities
import os

processed_files = []

for file in train_files:
    save_name = "output_train_" + "".join(file.split(".")[:-1]) + ".npz"
    
    if os.path.exists(save_name):
        print(f"File already processed: {save_name}")
        processed_files.append(save_name)
        continue
    
    data = load_h5_data(data_dir + file, truncation=data_truncation)
    state = process_state_data(data['density'], data['potential'])
    derived = process_derived_data(data['gamma_n'], data['gamma_c'])
    
    np.savez(save_name, state=state, derived=derived)
    processed_files.append(save_name)
    print(f"Saved: {save_name}")

print(f"\nProcessed {len(processed_files)} files")

### Dataset and Loaders

In [None]:
# Create dataloaders using shared dataset classes
data_path = processed_files[0]

train_dataset = StateDerivedDataset(data_path, mode='train')
val_dataset = StateDerivedDataset(data_path, mode='val')
test_dataset = StateDerivedDataset(data_path, mode='test')

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

# Trajectory loader for rollout training
traj_loader = DataLoader(TrajectoryWithDerivedDataset(".", mode='train'), batch_size=1, shuffle=True)

# Test
for batch in train_loader:
    print(f"State shape: {batch['state'].shape}")    # (batch, 2, H, W)
    print(f"Derived shape: {batch['derived'].shape}")  # (batch, 2)
    break

print(f"\nTrain: {len(train_loader.dataset)} | Val: {len(val_loader.dataset)} | Test: {len(test_loader.dataset)}")

### Define Models

In [None]:
# Pre-trained state model (architecture only - loaded later during evaluation)
print(f"State model will be loaded from: {STATE_MODEL_NAME}_latest.pt")
print("Train state model in fno_test.ipynb first!")

In [None]:
# ============== CNN-based Output Model ==============

class ResidualBlock(nn.Module):
    """Residual block with two conv layers and skip connection."""
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
    
    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return F.relu(x + residual)


class ConvEncoder(nn.Module):
    """CNN encoder: progressively downsample then pool to scalar."""
    def __init__(self, in_channels=2, hidden_channels=64, num_res_blocks=2):
        super().__init__()
        
        self.init_conv = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=7, padding=3),
            nn.BatchNorm2d(hidden_channels),
            nn.ReLU()
        )
        
        # 256 → 128 → 64 → 32 → 16 → 8
        self.down_blocks = nn.ModuleList()
        channels = hidden_channels
        for _ in range(5):
            out_channels = min(channels * 2, 512)
            block = nn.Sequential(
                nn.Conv2d(channels, out_channels, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
                *[ResidualBlock(out_channels) for _ in range(num_res_blocks)]
            )
            self.down_blocks.append(block)
            channels = out_channels
        
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, 128), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(128, 2)
        )
    
    def forward(self, x):
        x = self.init_conv(x)
        for block in self.down_blocks:
            x = block(x)
        x = self.global_pool(x).flatten(1)
        return self.fc(x)


class LightweightConvNet(nn.Module):
    """Lighter CNN: fewer params, faster training."""
    def __init__(self, in_channels=2, base_channels=32):
        super().__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, 4, 2, 1), nn.BatchNorm2d(base_channels), nn.ReLU(),
            nn.Conv2d(base_channels, base_channels*2, 4, 2, 1), nn.BatchNorm2d(base_channels*2), nn.ReLU(),
            nn.Conv2d(base_channels*2, base_channels*4, 4, 2, 1), nn.BatchNorm2d(base_channels*4), nn.ReLU(),
            nn.Conv2d(base_channels*4, base_channels*8, 4, 2, 1), nn.BatchNorm2d(base_channels*8), nn.ReLU(),
            nn.Conv2d(base_channels*8, base_channels*8, 4, 2, 1), nn.BatchNorm2d(base_channels*8), nn.ReLU(),
        )
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.head = nn.Sequential(
            nn.Linear(base_channels*8, 64), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(64, 2)
        )
    
    def forward(self, x):
        return self.head(self.global_pool(self.features(x)).flatten(1))


# Choose model
USE_LIGHTWEIGHT = False

if USE_LIGHTWEIGHT:
    conv_output_model = LightweightConvNet(in_channels=2, base_channels=32).to(device)
    print("Using LightweightConvNet")
else:
    conv_output_model = ConvEncoder(in_channels=2, hidden_channels=64, num_res_blocks=2).to(device)
    print("Using ConvEncoder")

n_params = sum(p.numel() for p in conv_output_model.parameters())
print(f"CNN output model: {n_params:,} parameters")

### Training Loop

In [None]:
import random

def train_with_rollout(derived_model, state_model, traj_loader, optimizer, device, rollout_len, num_epochs):
    """Train derived model on state model rollout predictions."""
    derived_model.train()
    state_model.eval()
    
    for epoch in range(num_epochs):
        total_loss, n_batches = 0, 0
        
        for batch in traj_loader:
            state_traj = batch['state'].to(device)      # (1, T, 2, H, W)
            derived_traj = batch['derived'].to(device)  # (1, T, 2)
            
            _, T, c, h, w = state_traj.shape
            max_start = T - rollout_len - 1
            if max_start < 0:
                continue
            
            start = random.randint(0, max_start)
            optimizer.zero_grad()
            
            state = state_traj[:, start]
            loss = 0
            
            for step in range(rollout_len):
                loss += F.mse_loss(derived_model(state), derived_traj[:, start + step])
                with torch.no_grad():
                    state = state_model(state)
            
            (loss / rollout_len).backward()
            torch.nn.utils.clip_grad_norm_(derived_model.parameters(), 1.0)
            optimizer.step()
            
            total_loss += loss.item() / rollout_len
            n_batches += 1
        
        if (epoch + 1) % 5 == 0:
            print(f"[Rollout={rollout_len}] Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/max(n_batches,1):.6f}")

In [None]:
# ============== Phase 1: Direct Mapping ==============
print("=" * 50)
print("Phase 1: Direct state → derived mapping (CNN)")
print("=" * 50)

optimizer = AdamW(conv_output_model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)

train_losses, val_losses = [], []
best_val_loss = float('inf')
num_epochs = 100

for epoch in range(num_epochs):
    train_loss = train_epoch_direct(conv_output_model, train_loader, optimizer, device)
    val_loss = validate_direct(conv_output_model, val_loader, device)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    scheduler.step(val_loss)
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_model(conv_output_model, CNN_OUTPUT_MODEL_NAME, metadata={'epoch': epoch, 'val_loss': val_loss})
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{num_epochs} | Train: {train_loss:.6f} | Val: {val_loss:.6f}")

# Load best
load_model(conv_output_model, CNN_OUTPUT_MODEL_NAME)
print(f"\nBest validation loss: {best_val_loss:.6f}")

In [None]:
# Plot training curves
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(train_losses, label='Train')
ax.plot(val_losses, label='Val')
ax.set_xlabel('Epoch')
ax.set_ylabel('MSE Loss')
ax.set_title('CNN Output Model: Training Curves')
ax.set_yscale('log')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Phase 2: Rollout training (uncomment after training state model)
# print("\n" + "=" * 50)
# print("Phase 2: Rollout fine-tuning")
# print("=" * 50)

# state_model = FNO(n_modes=(64, 64), in_channels=2, out_channels=2, hidden_channels=512).to(device)
# load_model(state_model, STATE_MODEL_NAME)
# state_model.eval()

# optimizer = AdamW(conv_output_model.parameters(), lr=1e-4, weight_decay=1e-4)
# for rollout_len in [5, 10, 20]:
#     train_with_rollout(conv_output_model, state_model, traj_loader, optimizer, device, rollout_len, num_epochs=15)
# save_model(conv_output_model, CNN_OUTPUT_MODEL_NAME, metadata={'phase': 'rollout'})

### Combined Rollout and Evaluation

In [None]:
# Load state model for rollout evaluation
state_model = FNO(n_modes=(64, 64), in_channels=2, out_channels=2, hidden_channels=512).to(device)
try:
    load_model(state_model, STATE_MODEL_NAME)
    state_model.eval()
except FileNotFoundError:
    print("State model not found - using untrained model for demo")

# Load test data
test_data = np.load(data_path)
test_start = int(len(test_data['state']) * 0.9)
test_state = test_data['state'][test_start:]
test_derived = test_data['derived'][test_start:]

print(f"Test state shape: {test_state.shape}")
print(f"Test derived shape: {test_derived.shape}")

# Run combined rollout
num_steps = min(len(test_state), 200)
state_recon, derived_recon = rollout_combined(
    state_model, conv_output_model, test_state[0], num_steps, device
)
print(f"Rollout complete: {state_recon.shape}, {derived_recon.shape}")

In [None]:
# Evaluation plots
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Gamma_n comparison
axes[0, 0].plot(test_derived[:num_steps, 0], 'b-', label='Ground Truth', alpha=0.8)
axes[0, 0].plot(derived_recon[:, 0], 'r--', label='CNN Predicted', alpha=0.8)
axes[0, 0].set_xlabel('Timestep')
axes[0, 0].set_ylabel('gamma_n')
axes[0, 0].set_title('Gamma_n: Ground Truth vs Predicted')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Gamma_c comparison
axes[0, 1].plot(test_derived[:num_steps, 1], 'b-', label='Ground Truth', alpha=0.8)
axes[0, 1].plot(derived_recon[:, 1], 'r--', label='CNN Predicted', alpha=0.8)
axes[0, 1].set_xlabel('Timestep')
axes[0, 1].set_ylabel('gamma_c')
axes[0, 1].set_title('Gamma_c: Ground Truth vs Predicted')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Derived quantity errors
gamma_n_mse = compute_mse_over_time(test_derived[:num_steps, 0:1], derived_recon[:, 0:1])
gamma_c_mse = compute_mse_over_time(test_derived[:num_steps, 1:2], derived_recon[:, 1:2])
axes[1, 0].plot(gamma_n_mse, label='gamma_n')
axes[1, 0].plot(gamma_c_mse, label='gamma_c')
axes[1, 0].set_xlabel('Timestep')
axes[1, 0].set_ylabel('MSE')
axes[1, 0].set_title('Derived Quantity MSE Over Time')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# State MSE
state_mse = compute_mse_over_time(test_state[:num_steps], state_recon)
axes[1, 1].plot(state_mse)
axes[1, 1].set_xlabel('Timestep')
axes[1, 1].set_ylabel('MSE')
axes[1, 1].set_title('State Prediction MSE')
axes[1, 1].set_yscale('log')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Summary
print("\n" + "=" * 50)
print("CNN Model Summary")
print("=" * 50)
print_summary("State MSE", state_mse)
print_summary("Gamma_n MSE", gamma_n_mse)
print_summary("Gamma_c MSE", gamma_c_mse)

In [None]:
# Animation
if show_animations:
    print("\033[1mState reconstruction animation (CNN)")
    
    step = 5
    gt = test_state[:num_steps:step]
    pred = state_recon[::step]
    
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    vmin, vmax = gt[:, 0].min(), gt[:, 0].max()
    
    img1 = axes[0].imshow(gt[0, 0], vmin=vmin, vmax=vmax, cmap='viridis')
    img2 = axes[1].imshow(pred[0, 0], vmin=vmin, vmax=vmax, cmap='viridis')
    axes[0].set_title("Ground Truth")
    axes[1].set_title("CNN Surrogate")
    plt.colorbar(img1, ax=axes[0], fraction=0.046)
    plt.colorbar(img2, ax=axes[1], fraction=0.046)
    
    def animate(frame):
        img1.set_data(gt[frame, 0])
        img2.set_data(pred[frame, 0])
        return [img1, img2]
    
    animation = anim.FuncAnimation(fig, animate, frames=len(gt), interval=50, blit=True)
    display(HTML(animation.to_jshtml()))
else:
    print("Animations disabled")

### Model Comparison: CNN vs FNO
Load both models and compare their performance on the same test data.

In [None]:
# ============== Model Comparison: CNN vs FNO ==============
# Run after training both notebooks

def compare_models(cnn_model, fno_model, loader, device):
    """Compare CNN and FNO on ground truth states."""
    cnn_model.eval()
    fno_model.eval()
    
    cnn_errors, fno_errors = [], []
    with torch.no_grad():
        for batch in loader:
            state = batch['state'].to(device)
            derived = batch['derived'].to(device)
            
            cnn_errors.append(torch.abs(cnn_model(state) - derived).cpu().numpy())
            fno_errors.append(torch.abs(fno_model(state) - derived).cpu().numpy())
    
    return np.concatenate(cnn_errors), np.concatenate(fno_errors)

# Uncomment after training both models:
# from fno_output_learning import DerivedQuantityModel  # or recreate the FNO model here
# fno_output_model = ...
# load_model(fno_output_model, 'derived_fno')
# 
# cnn_err, fno_err = compare_models(conv_output_model, fno_output_model, test_loader, device)
# print("Model Comparison (on ground truth):")
# print(f"CNN - gamma_n MAE: {cnn_err[:, 0].mean():.6f}, gamma_c: {cnn_err[:, 1].mean():.6f}")
# print(f"FNO - gamma_n MAE: {fno_err[:, 0].mean():.6f}, gamma_c: {fno_err[:, 1].mean():.6f}")

print("Model comparison ready - uncomment after training both notebooks")