## Imports and Setup

In [None]:
import torch
import h5py
import sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as anim
from IPython.display import HTML, display
from neuralop.models import FNO
from neuralop.training import AdamW
from neuralop.utils import count_model_params
import torch.nn as nn
import os
from torch.utils.data import DataLoader

# Local utilities
from utils import (
    StateTrajectoryDataset, 
    train_epoch_single_step, train_epoch_rollout,
    rollout_state, compute_mse_over_time, print_summary,
    save_model, load_model
)

device = "cuda" if torch.cuda.is_available() else "cpu"
show_animations = True
data_truncation = 0.5

plt.rcParams['animation.embed_limit'] = 500

print(f"\033[1mUsing Device: {device}")
print(f"\033[1mShowing animations: {show_animations}")

### Data Processing

In [None]:
# Load the data and save as .np

data_dir = "/work/10407/anthony50102/frontera/data/hw2d_sim/t600_d256x256_raw/"

train_files = ["hw2d_sim_step0.025_end1_pts512_c11_k015_N3_nu5e-8_20250315142044_11702_0.h5",
               "hw2d_sim_step0.025_end1_pts512_c11_k015_N3_nu5e-8_20250315142045_4677_2.h5"]
test_files = ["hw2d_sim_step0.025_end1_pts512_c11_k015_N3_nu5e-8_20250316215751_19984_3.h5"]

In [None]:
def process(density, potential, gamma_n, gamma_c):
    data = np.concatenate(
        (np.expand_dims(density, 1), np.expand_dims(potential, 1)),
        axis=1)
    derived_data = np.concatenate(
        (np.expand_dims(gamma_n, 1), np.expand_dims(gamma_c, 1)),
        axis=1)

    return data, derived_data


processed_train_files = []

for file in train_files:
    with h5py.File(data_dir + file, 'r') as f:
        end_index = int(f['density'].shape[0] * data_truncation)
        density = f['density'][:end_index]
        potential = f['phi'][:end_index]
        gamma_n = f['gamma_n'][:end_index]
        gamma_c = f['gamma_c'][:end_index]
        data, derived_data = process(density, potential, gamma_n, gamma_n)

        save_name = "train_" + "".join(file.split(".")[:-1]) + ".npz"
        processed_train_files.append(save_name)

        np.savez(
            save_name,
            data=data,
            derived_data=derived_data
        )

In [None]:
# Animation of the training data
if show_animations:
    train_data = np.load(processed_train_files[0])["data"][::25]

    fig, ax = plt.subplots()

    vmin = train_data[:, 0, ...].min()
    vmax = train_data[:, 0, ...].max()

    img = plt.imshow(train_data[0, 0, ...], vmin=vmin, vmax=vmax)

    def animate(frame):
        img.set_data(train_data[frame, 0, ...])
        return [img]

    plt.rcParams['animation.embed_limit'] = 500
    animation = anim.FuncAnimation(fig, animate, frames=int(train_data.shape[0]), interval=20, blit=True)

    display(HTML(animation.to_jshtml()))
else:
    print("Animation is turned off")

### Dataset and Loaders

In [None]:
# Create trajectory loader using shared Dataset class
traj_dataset = StateTrajectoryDataset(processed_train_files[0], mode='train')
traj_loader = DataLoader(traj_dataset, batch_size=1, shuffle=True)

# Test
for batch in traj_loader:
    print(f"Trajectory shape: {batch['data'].shape}")  # (1, T, C, H, W)
    break

### Define Model

In [None]:
model = FNO(
    n_modes=(64, 64),
    in_channels=2,
    out_channels=2,
    hidden_channels=512,
    # projection_channel_ratio=2,
)
model = model.to(device)

# Count and display the number of parameters
n_params = count_model_params(model)
print(f"\nOur model has {n_params} parameters.")
sys.stdout.flush()

### Define optim, scheduler, loss funcs, etc

In [None]:
# Optimizer (scheduler defined inline during training)
optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

### Training Loop

In [None]:
# ============== Phase 1: Single-step training ==============
print("=" * 50)
print("Phase 1: Single-step training")
print("=" * 50)

optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

for epoch in range(100):
    loss = train_epoch_single_step(model, traj_loader, optimizer, device)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/100, Loss: {loss:.6f}")

# ============== Phase 2: Curriculum rollout training ==============
print("\n" + "=" * 50)
print("Phase 2: Curriculum rollout training")
print("=" * 50)

rollout_schedule = [
    (5, 0.0, 20),    # (rollout_len, scheduled_sampling_prob, epochs)
    (10, 0.2, 20),
    (20, 0.4, 20),
    (40, 0.6, 20),
    (80, 0.8, 30),
]

optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

for rollout_len, ss_prob, n_epochs in rollout_schedule:
    print(f"\n--- Rollout: {rollout_len}, Scheduled Sampling: {ss_prob} ---")
    for epoch in range(n_epochs):
        loss = train_epoch_rollout(model, traj_loader, optimizer, device, rollout_len, ss_prob)
        if (epoch + 1) % 5 == 0:
            print(f"  Epoch {epoch+1}/{n_epochs}, Loss: {loss:.6f}")

print("\nTraining complete!")

In [None]:
# Save model with metadata
save_model(model, 'state_model', metadata={
    'architecture': 'FNO',
    'n_modes': (64, 64),
    'hidden_channels': 512,
    'final_rollout': 80,
})

In [None]:
# trainer = Trainer(
#     model=model,
#     n_epochs=15,
#     device=device,
#     wandb_log=False,  # Disable Weights & Biases logging for this tutorial
#     eval_interval=5,  # Evaluate every 5 epochs
#     use_distributed=False,  # Single GPU/CPU training
#     verbose=True,  # Print training progress
# )

# train_loader, test_loaders, data_processor = load_darcy_flow_small(
#     n_train=1000,
#     batch_size=64,
#     n_tests=[100, 50],
#     test_resolutions=[16, 32],
#     test_batch_sizes=[32, 32],
# )

# trainer.train(
#     train_loader=train_loader,
#     test_loaders={256:test_loader},
#     optimizer=optimizer,
#     scheduler=scheduler,
#     regularizer=False,
#     training_loss=train_loss,
#     eval_losses=eval_losses,
# )

In [None]:
# Load full dataset and run rollout evaluation
full_data = np.load(processed_train_files[0])["data"]
print(f"Full data shape: {full_data.shape}")

# Split
total_steps = full_data.shape[0]
train_end = int(total_steps * 0.8)
val_end = train_end + int((total_steps - train_end) * 0.5)

train_data = full_data[:train_end]
test_data = full_data[val_end:]

# Run rollouts using shared utility
train_recon = rollout_state(model, train_data[0], len(train_data), device)
test_recon = rollout_state(model, test_data[0], len(test_data), device)

print(f"Train recon shape: {train_recon.shape}")
print(f"Test recon shape: {test_recon.shape}")

In [None]:
if show_animations:
    print("\033[1mShowing animation")
    # Subsample
    train_data_subbed = train_data[::5]
    train_recon_subbed = train_recon[::5]

    # Figure + axes
    fig, ax = plt.subplots(1, 2, figsize=(8, 4))

    t_vmin = train_data_subbed[:, 0].min()
    t_vmax = train_data_subbed[:, 0].max()
    # r_vmin = train_recon_subbed[:, 0].min()
    # r_vmax = train_recon_subbed[:, 0].max()

    # Create the images on the correct axes
    img = ax[0].imshow(train_data_subbed[0, 0], vmin=t_vmin, vmax=t_vmax)
    img2 = ax[1].imshow(train_recon_subbed[0, 0])

    ax[0].set_title("Ground Truth")
    ax[1].set_title("Reconstruction")


    # Animation function
    def animate(frame):
        img.set_data(train_data_subbed[frame, 0])
        img2.set_data(train_recon_subbed[frame, 0])
        return [img, img2]


    animation = anim.FuncAnimation(
        fig,
        animate,
        frames=train_data_subbed.shape[0],
        interval=20,
        blit=True
    )

    display(HTML(animation.to_jshtml()))
else:
    print("Not showing animations")

In [None]:
# Evaluation using shared utilities
train_losses = compute_mse_over_time(train_data, train_recon)
test_losses = compute_mse_over_time(test_data, test_recon)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(train_losses)
axes[0].set_xlabel('Timestep')
axes[0].set_ylabel('MSE Loss')
axes[0].set_title('Training Data: Rollout Error Over Time')
axes[0].set_yscale('log')
axes[0].grid(True, alpha=0.3)

axes[1].plot(test_losses, color='orange')
axes[1].set_xlabel('Timestep')
axes[1].set_ylabel('MSE Loss')
axes[1].set_title('Test Data: Rollout Error Over Time')
axes[1].set_yscale('log')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Summary stats
print("\n" + "=" * 40)
print_summary("Train", train_losses)
print_summary("Test", test_losses)