## A FNO network to learn the output
### Ideas on how to do this
- Naive guess: Freeze FNO use simple conv net to learn output
- More apples to apples: Freeze FNO use another FNO for output?
- Maybe more advanced training strategies? Freeze one, freeze the other over and over then learn both?

### This notebooks implementation
- Process data to be 

In [39]:
%matplotlib inline
import torch
import h5py
import sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as anim
from IPython.display import HTML, display
from neuralop.models import FNO
from neuralop import Trainer
from neuralop.training import AdamW
from neuralop.utils import count_model_params
from neuralop import LpLoss, H1Loss
from neuralop.data.datasets import load_darcy_flow_small
import torch.nn as nn
import random
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
show_animations = True
data_truncation = .5

plt.rcParams['animation.embed_limit'] = 500

print(f"\033[1mUsing Device: {device}")
print(f"\033[1mShowing animations: {show_animations}")

[1mUsing Device: cuda
[1mShowing animations: True


In [40]:
# Load the data and save as .np

data_dir = "/work/10407/anthony50102/frontera/data/hw2d_sim/t600_d256x256_raw/"

train_files = ["hw2d_sim_step0.025_end1_pts512_c11_k015_N3_nu5e-8_20250315142044_11702_0.h5",
               "hw2d_sim_step0.025_end1_pts512_c11_k015_N3_nu5e-8_20250315142045_4677_2.h5"]
test_files = ["hw2d_sim_step0.025_end1_pts512_c11_k015_N3_nu5e-8_20250316215751_19984_3.h5"]

In [41]:
def process_derived(density, potential, gamma_n, gamma_c):
    derived_data = np.concatenate(
        (np.expand_dims(gamma_n, 1), np.expand_dims(gamma_c, 1)),
        axis=1)
    derived_mask = np.broadcast_to(derived_data[:, :, None, None], (derived_data.shape[0], derived_data.shape[1], 256, 256))
    data = np.concatenate(
        (np.expand_dims(density, 1),
         np.expand_dims(potential, 1),
         derived_mask),
        axis=1)

    return data


processed_train_files = []

for file in train_files:
    save_name = "train_" + "".join(file.split(".")[:-1]) + "_derived.npz"
    if os.path.exists(save_name):
        print(f"File already processed: {save_name}")
        continue
    else:
        with h5py.File(data_dir + file, 'r') as f:
            end_index = int(f['density'].shape[0] * data_truncation)
            density = f['density'][:end_index]
            potential = f['phi'][:end_index]
            gamma_n = f['gamma_n'][:end_index]
            gamma_c = f['gamma_c'][:end_index]
            data = process_derived(density, potential, gamma_n, gamma_c)
    
            processed_train_files.append(save_name)
    
            np.savez(
                save_name,
                data=data,
            )

In [42]:
class TrajDataset(Dataset):
    # TODO: Look into LRU or memory mapping too speed this up
    def __init__(self, data_path,
                 train_split=0.8,
                 val_split=0.5,
                 mode='train'):
        # Get all .npz files with full paths
        self.data_path = data_path
        all_files = [os.path.join(data_path, f) 
                     for f in os.listdir(data_path) 
                     if f.endswith(".npz")]

        self.num_files = len(all_files)

        # Split files
        train_end = int(self.num_files * train_split)
        val_end = train_end + int((self.num_files - train_end) * val_split)

        self.train_files = all_files[:train_end]
        self.val_files = all_files[train_end:val_end]
        self.test_files = all_files[val_end:]

        if mode == 'train':
            self.avail_files = self.train_files
        elif mode == 'val':
            self.avail_files = self.val_files
        elif mode == 'test':
            self.avail_files = self.test_files
        else:
            raise ValueError(f"mode must be 'train', 'val', or 'test', got {mode}")

    def __len__(self):
        return len(self.avail_files)  # Fixed: should be length of available files, not total

    def __getitem__(self, idx):
        data = np.load(self.avail_files[idx])["data"]
        return {'data': data}

def create_traj_loader(data_path, batch_size=1, num_workers=1):
    dataset = TrajDataset(data_path)

    loader = DataLoader(dataset, batch_size=batch_size,
                        shuffle=True, num_workers=num_workers)
    return loader

# Usage example:
data_path = processed_train_files[0]

traj_loader = create_traj_loader(".")

# Test the loader
for traj in traj_loader:
    print(traj['data'].shape)
    break

torch.Size([1, 8000, 4, 256, 256])


In [None]:
output_model = FNO(
    n_modes=(64, 64),
    in_channels=2,
    out_channels=2,
    hidden_channels=512,
    # projection_channel_ratio=2,
)
output_model = output_model.to(device)

# Count and display the number of parameters
n_params = count_model_params(output_model)
print(f"\nOur output model has {n_params} parameters.")

# Load the state prediction model
# state_model = load("state_model.pth")
sys.stdout.flush()

In [None]:
optimizer = AdamW(output_model.parameters(), lr=1e-2, weight_decay=1e-4)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

l2loss = LpLoss(d=2, p=2)  # L2 loss for function values
h1loss = H1Loss(d=2)  # H1 loss includes gradient information

train_loss = h1loss
eval_losses = {"h1": h1loss, "l2": l2loss}

In [None]:
import torch
import torch.nn as nn
import random

# Simple hyperparameters
max_epochs = 1000
max_rollout_length = 90
stabilization_epochs = 10  # Wait this long for stability
stability_window = 10  # Check last N epochs for stability
max_loss_variance = 0.001  # Loss variance must be below this
clip_norm = 1.0
learning_rate = 1e-3

# Training state
current_rollout = 1
epochs_since_increase = 0
recent_losses = []

for epoch in range(max_epochs):
    epoch_loss = 0
    n_batches = 0

    for batch in traj_loader:
        traj = batch['data'].to(device).float()
        batch_size, traj_len, c, h, w = traj.shape

        optimizer.zero_grad()

        # Always do some single-step training
        for _ in range(5):
            idx = random.randint(0, traj_len - 2)
            state_pred = state_model(traj[:, idx, :2])
            derived_pred = output_model(state_pred)  # b, c , h, w
            loss = nn.functional.mse_loss(derived_pred, traj[:, idx + 1, 2:])
            loss.backward()
            epoch_loss += loss.item()

        # Rollout training
        if current_rollout > 1:
            start = random.randint(0, traj_len - current_rollout - 1)
            state = traj[:, start, :2].clone()  # Start with actual state channels

            for step in range(current_rollout):
                state_pred = state_model(state)  # Predict next state
                derived_pred = output_model(state_pred)  # Compute derived quantities

                target = traj[:, start + step + 1, 2:]
                loss = nn.functional.mse_loss(derived_pred, target)
                loss.backward()
                epoch_loss += loss.item()

                # Update state for next iteration (autoregressive)
                state = state_pred.detach()  # ✓ Fixed: use state_pred

        torch.nn.utils.clip_grad_norm_(output_model.parameters(), max_norm=clip_norm)
        optimizer.step()
        n_batches += 1

    avg_loss = epoch_loss / (n_batches * (5 + max(1, current_rollout)))
    epochs_since_increase += 1

    # Track recent losses
    recent_losses.append(avg_loss)
    if len(recent_losses) > stability_window:
        recent_losses.pop(0)

    # Check if loss is stable (low variance)
    is_stable = False
    if len(recent_losses) >= stability_window:
        mean_loss = sum(recent_losses) / len(recent_losses)
        variance = sum((l - mean_loss) ** 2 for l in recent_losses) / len(recent_losses)
        is_stable = variance < max_loss_variance

    # Increase rollout if stable and waited long enough
    should_increase = (
        epochs_since_increase >= stabilization_epochs and
        is_stable and
        current_rollout < max_rollout_length
    )

    if should_increase:
        current_rollout += 1
        epochs_since_increase = 0
        recent_losses = []  # Reset after increase
        print(f"  → Increasing rollout to {current_rollout}")

    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}, Rollout: {current_rollout}, "
          f"Stable: {is_stable}, Epochs: {epochs_since_increase}")

    if current_rollout >= max_rollout_length:
        print(f"Reached max rollout of {max_rollout_length}")
        break

In [None]:
def rollout(state_model, data, device="cuda"):
    """
    Autoregressive rollout:
    - data: numpy array of shape (T, C, X, Y)
    - model: forward prediction operator
    - returns: reconstruction of full trajectory
    """
    model.eval()

    T, C, X, Y = data.shape

    # Storage for reconstruction
    recon = np.zeros_like(data)

    # Initial condition (t=0)
    current = torch.from_numpy(data[0]).float().to(device)
    recon[0] = data[0]

    for t in range(1, T):
        # Model expects batch dimension
        inp = current.unsqueeze(0)
        with torch.no_grad():
            pred = model(inp)  # output shape: (1, C, X, Y)
        pred_np = pred.squeeze(0).cpu().numpy()
        recon[t] = pred_np
        # Feed output back in
        current = pred.squeeze(0)

    return recon