## Imports and Setup

In [1]:
import torch
import h5py
import sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as anim
from IPython.display import HTML, display
from neuralop.models import FNO
from neuralop import Trainer
from neuralop.training import AdamW
from neuralop.utils import count_model_params
from neuralop import LpLoss, H1Loss
from neuralop.data.datasets import load_darcy_flow_small
import torch.nn as nn
import random
import os
import datetime

device = "cuda" if torch.cuda.is_available() else "cpu"
show_animations = False
data_truncation = .5
# Implement this after trying the conv net and fno output learning
output_learning = False

plt.rcParams['animation.embed_limit'] = 500

print(f"\033[1mUsing Device: {device}")
print(f"\033[1mShowing animations: {show_animations}")

[1mUsing Device: cuda
[1mShowing animations: False


### Data Processing

In [2]:
# Load the data and save as .np

data_dir = "/work/10407/anthony50102/frontera/data/hw2d_sim/t600_d256x256_raw/"

train_files = ["hw2d_sim_step0.025_end1_pts512_c11_k015_N3_nu5e-8_20250315142044_11702_0.h5",
               "hw2d_sim_step0.025_end1_pts512_c11_k015_N3_nu5e-8_20250315142045_4677_2.h5"]
test_files = ["hw2d_sim_step0.025_end1_pts512_c11_k015_N3_nu5e-8_20250316215751_19984_3.h5"]

In [3]:
def process(density, potential, gamma_n, gamma_c):
    data = np.concatenate(
        (np.expand_dims(density, 1), np.expand_dims(potential, 1)),
        axis=1)
    derived_data = np.concatenate(
        (np.expand_dims(gamma_n, 1), np.expand_dims(gamma_c, 1)),
        axis=1)

    return data, derived_data


processed_train_files = []

for file in train_files:
    with h5py.File(data_dir + file, 'r') as f:
        end_index = int(f['density'].shape[0] * data_truncation)
        density = f['density'][:end_index]
        potential = f['phi'][:end_index]
        gamma_n = f['gamma_n'][:end_index]
        gamma_c = f['gamma_c'][:end_index]
        data, derived_data = process(density, potential, gamma_n, gamma_n)

        save_name = "train_" + "".join(file.split(".")[:-1]) + ".npz"
        processed_train_files.append(save_name)

        np.savez(
            save_name,
            data=data,
            derived_data=derived_data
        )

KeyboardInterrupt: 

In [None]:
# Animation of the training data
if show_animations:
    train_data = np.load(processed_train_files[0])["data"][::25]

    fig, ax = plt.subplots()

    vmin = train_data[:, 0, ...].min()
    vmax = train_data[:, 0, ...].max()

    img = plt.imshow(train_data[0, 0, ...], vmin=vmin, vmax=vmax)

    def animate(frame):
        img.set_data(train_data[frame, 0, ...])
        return [img]

    plt.rcParams['animation.embed_limit'] = 500
    animation = anim.FuncAnimation(fig, animate, frames=int(train_data.shape[0]), interval=20, blit=True)

    display(HTML(animation.to_jshtml()))
else:
    print("Animation is turned off")

### Dataset and Loaders

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader


class ForwardPredictionDataset(Dataset):
    """
    Dataset for learning forward prediction operators.

    Args:
        data_path: Path to the numpy array file
        input_steps: Number of past timesteps to use as input
        output_steps: Number of future timesteps to predict
        stride: Stride between consecutive samples (default: 1)
        train_split: Fraction of data to use for training (default: 0.8)
        mode: 'train', 'val', or 'test'
        val_split: Fraction of remaining data for validation (default: 0.5)
    """

    def __init__(self, data_path, input_steps=1, output_steps=1, stride=1,
                 train_split=0.8, val_split=0.5, mode='train'):
        # Load data: (t, channel, x_dim, y_dim)
        self.data = np.load(data_path)["data"]
        self.input_steps = input_steps
        self.output_steps = output_steps
        self.stride = stride

        total_steps = self.data.shape[0]
        sequence_length = input_steps + output_steps

        # Split data temporally
        train_end = int(total_steps * train_split)
        val_end = train_end + int((total_steps - train_end) * val_split)

        if mode == 'train':
            self.data = self.data[:train_end]
        elif mode == 'val':
            self.data = self.data[train_end:val_end]
        elif mode == 'test':
            self.data = self.data[val_end:]
        else:
            raise ValueError(f"mode must be 'train', 'val', or 'test', got {mode}")

        # Calculate number of valid sequences
        self.num_sequences = (len(self.data) - sequence_length) // stride + 1

    def __len__(self):
        return self.num_sequences

    def __getitem__(self, idx):
        # Calculate start index
        start_idx = idx * self.stride

        # Extract input and output sequences
        input_seq = self.data[start_idx:start_idx + self.input_steps]
        output_seq = self.data[start_idx + self.input_steps:
                                start_idx + self.input_steps + self.output_steps]

        # Convert to torch tensors
        # TODO: Squeezing out the time dimension fix this
        input_tensor = torch.from_numpy(input_seq).float().squeeze(0)  
        output_tensor = torch.from_numpy(output_seq).float().squeeze(0)

        return {'x': input_tensor,
                'y': output_tensor,
                't': start_idx,
                'tend': start_idx + self.input_steps + self.output_steps - 1}


class TrajDataset(Dataset):
    # TODO: Look into LRU or memory mapping too speed this up
    def __init__(self, data_path,
                 train_split=0.8,
                 val_split=0.5,
                 mode='train'):
        # Get all .npz files with full paths
        self.data_path = data_path
        all_files = [os.path.join(data_path, f) 
                     for f in os.listdir(data_path) 
                     if f.endswith(".npz")]

        #TODO: Remove this filtering out of derived files
        all_files = [file for file in all_files if 'derived' not in file]

        self.num_files = len(all_files)

        # Split files
        train_end = int(self.num_files * train_split)
        val_end = train_end + int((self.num_files - train_end) * val_split)

        self.train_files = all_files[:train_end]
        self.val_files = all_files[train_end:val_end]
        self.test_files = all_files[val_end:]

        if mode == 'train':
            self.avail_files = self.train_files
        elif mode == 'val':
            self.avail_files = self.val_files
        elif mode == 'test':
            self.avail_files = self.test_files
        else:
            raise ValueError(f"mode must be 'train', 'val', or 'test', got {mode}")

    def __len__(self):
        return len(self.avail_files)  # Fixed: should be length of available files, not total

    def __getitem__(self, idx):
        data = np.load(self.avail_files[idx])["data"]
        return {'data': data}


# Example usage
def create_dataloaders(data_path, input_steps=4, output_steps=1, 
                       batch_size=32, num_workers=4):
    """
    Create train, validation, and test dataloaders.
    """
    train_dataset = ForwardPredictionDataset(
        data_path, input_steps=input_steps, output_steps=output_steps, mode='train'
    )
    val_dataset = ForwardPredictionDataset(
        data_path, input_steps=input_steps, output_steps=output_steps, mode='val'
    )
    test_dataset = ForwardPredictionDataset(
        data_path, input_steps=input_steps, output_steps=output_steps, mode='test'
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                              shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size,
                            shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size,
                             shuffle=False, num_workers=num_workers)

    return train_loader, val_loader, test_loader


def create_traj_loader(data_path, batch_size=1, num_workers=1):
    dataset = TrajDataset(data_path)

    loader = DataLoader(dataset, batch_size=batch_size,
                        shuffle=True, num_workers=num_workers)
    return loader


# Usage example:
data_path = processed_train_files[0]

train_loader, val_loader, test_loader = create_dataloaders(
    data_path,
    input_steps=1,    # Use 4 past timesteps
    output_steps=1,   # Predict 1 future timestep
    batch_size=32
)

traj_loader = create_traj_loader(".")

# Test the loader
for sample in train_loader:
    print(f"Input shape: {sample['x'].shape}")   # (batch, input_steps, channels, x_dim, y_dim)
    print(f"Target shape: {sample['y'].shape}")  # (batch, output_steps, channels, x_dim, y_dim)
    break

for traj in traj_loader:
    print(traj['data'].shape)
    break
print(len(traj_loader))

### Define Model

In [None]:
model = FNO(
    n_modes=(64, 64),
    in_channels=2,
    out_channels=2,
    hidden_channels=512,
    # projection_channel_ratio=2,
)
model = model.to(device)

# Count and display the number of parameters
n_params = count_model_params(model)
print(f"\nOur model has {n_params} parameters.")
sys.stdout.flush()

### Define optim, scheduler, loss funcs, etc

In [None]:
optimizer = AdamW(model.parameters(), lr=1e-2, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

l2loss = LpLoss(d=2, p=2)  # L2 loss for function values
h1loss = H1Loss(d=2)  # H1 loss includes gradient information

train_loss = h1loss
eval_losses = {"h1": h1loss, "l2": l2loss}

### Training Loop

In [None]:
import torch
import torch.nn as nn
import random
from pathlib import Path

# Hyperparameters
max_epochs = 1000
max_rollout = 90
learning_rate = 1e-4  # Start lower based on your initial loss
clip_norm = 10.0  # Increased - be less aggressive

# Checkpoint
checkpoint_dir = Path("/scratch/10407/anthony50102/checkpoints")
checkpoint_dir.mkdir(exist_ok=True)

# State
current_rollout = 1
best_loss = float('inf')

print("Starting training...")

for epoch in range(max_epochs):
    epoch_loss = 0.0
    n_steps = 0
    
    for batch in traj_loader:
        traj = batch['data'].to(device).float()
        batch_size, traj_len, c, h, w = traj.shape
        
        # Sample random starting point
        max_start = traj_len - current_rollout - 1
        if max_start < 0:
            continue
        start = random.randint(0, max_start)
        
        # Rollout with teacher forcing
        state = traj[:, start].clone()
        total_loss = 0.0
        
        for step in range(current_rollout):
            pred = model(state)
            target = traj[:, start + step + 1]
            loss = nn.functional.mse_loss(pred, target)
            total_loss += loss
            
            # Teacher forcing: gradually decay
            teacher_ratio = max(0.3, 1.0 - epoch / 500)
            if random.random() < teacher_ratio:
                state = target.detach()
            else:
                state = pred.detach()
        
        # Backprop
        optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
        optimizer.step()
        
        epoch_loss += total_loss.item()
        n_steps += 1
    
    if n_steps == 0:
        continue
    
    # Average loss
    avg_loss = epoch_loss / n_steps
    
    # Save best checkpoint
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save({
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'rollout': current_rollout,
            'epoch': epoch,
            'loss': best_loss
        }, checkpoint_dir / "best.pt")
    
    # Gradually increase rollout
    if epoch > 0 and epoch % 50 == 0 and current_rollout < max_rollout:
        current_rollout = min(current_rollout + 5, max_rollout)
        print(f"→ Rollout increased to {current_rollout}")
    
    # Status
    if epoch % 10 == 0:
        print(f"Epoch {epoch+1:4d} | Loss: {avg_loss:.2f} | Rollout: {current_rollout:2d} | Best: {best_loss:.2f}")

print(f"\n✅ Done! Best loss: {best_loss:.2f}, Final rollout: {current_rollout}")

In [None]:
# Save model
now = datetime.datetime.now()
formatted_now = now.strftime("%m_%d_%Y-%H:%M")
torch.save(model.state_dict(), f'state_model_{formatted_now}.pth')

In [None]:
# trainer = Trainer(
#     model=model,
#     n_epochs=15,
#     device=device,
#     wandb_log=False,  # Disable Weights & Biases logging for this tutorial
#     eval_interval=5,  # Evaluate every 5 epochs
#     use_distributed=False,  # Single GPU/CPU training
#     verbose=True,  # Print training progress
# )

# train_loader, test_loaders, data_processor = load_darcy_flow_small(
#     n_train=1000,
#     batch_size=64,
#     n_tests=[100, 50],
#     test_resolutions=[16, 32],
#     test_batch_sizes=[32, 32],
# )

# trainer.train(
#     train_loader=train_loader,
#     test_loaders={256:test_loader},
#     optimizer=optimizer,
#     scheduler=scheduler,
#     regularizer=False,
#     training_loss=train_loss,
#     eval_losses=eval_losses,
# )

In [None]:
def rollout(model, data, device="cuda"):
    """
    Autoregressive rollout:
    - data: numpy array of shape (T, C, X, Y)
    - model: forward prediction operator
    - returns: reconstruction of full trajectory
    """
    model.eval()

    T, C, X, Y = data.shape

    # Storage for reconstruction
    recon = np.zeros_like(data)

    # Initial condition (t=0)
    current = torch.from_numpy(data[0]).float().to(device)
    recon[0] = data[0]

    for t in range(1, T):
        # Model expects batch dimension
        inp = current.unsqueeze(0)
        with torch.no_grad():
            pred = model(inp)  # output shape: (1, C, X, Y)
        pred_np = pred.squeeze(0).cpu().numpy()
        recon[t] = pred_np
        # Feed output back in
        current = pred.squeeze(0)

    return recon

In [None]:
# Load full dataset (not split)
full_data = np.load(data_path)["data"]
print(f"Full data shape: {full_data.shape}")

# Determine temporal splits you used
total_steps = full_data.shape[0]
train_end = int(total_steps * 0.8)
val_end = train_end + int((total_steps - train_end) * 0.5)

train_data = full_data[:train_end]
test_data  = full_data[val_end:]

model = model.to(device)

# Run rollouts
train_recon = rollout(model, train_data, device=device)
test_recon  = rollout(model, test_data, device=device)

print("Train recon shape:", train_recon.shape)
print("Test  recon shape:", test_recon.shape)


In [None]:
if show_animations:
    print("\033[1mShowing animation")
    # Subsample
    train_data_subbed = train_data[::5]
    train_recon_subbed = train_recon[::5]

    # Figure + axes
    fig, ax = plt.subplots(1, 2, figsize=(8, 4))

    t_vmin = train_data_subbed[:, 0].min()
    t_vmax = train_data_subbed[:, 0].max()
    # r_vmin = train_recon_subbed[:, 0].min()
    # r_vmax = train_recon_subbed[:, 0].max()

    # Create the images on the correct axes
    img = ax[0].imshow(train_data_subbed[0, 0], vmin=t_vmin, vmax=t_vmax)
    img2 = ax[1].imshow(train_recon_subbed[0, 0])

    ax[0].set_title("Ground Truth")
    ax[1].set_title("Reconstruction")


    # Animation function
    def animate(frame):
        img.set_data(train_data_subbed[frame, 0])
        img2.set_data(train_recon_subbed[frame, 0])
        return [img, img2]


    animation = anim.FuncAnimation(
        fig,
        animate,
        frames=train_data_subbed.shape[0],
        interval=20,
        blit=True
    )

    display(HTML(animation.to_jshtml()))
else:
    print("Not showing animations")