In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from pfc_data_generate import generate_data
import os
import time
import matplotlib.pyplot as plt
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.init()
torch.cuda.empty_cache() 

In [None]:
# Helper functions for printing and saving
def print_epoch_summary(epoch, epochs, data_loss, physics_loss, time_used):
    """Prints formatted training progress summary for each epoch.
    
    Args:
        epoch: Current epoch index (0-based)
        epochs: Total number of epochs
        data_loss: Supervised loss from ground truth data
        physics_loss: Physics-informed regularization loss
        time_used: Time taken for the epoch in seconds
    """
    print(f"Epoch [{epoch+1}/{epochs}] | "
          f"Data Loss: {data_loss:.4e} | "
          f"Physics Loss: {physics_loss:.4e} | "
          f"Time: {time_used:.1f}s")

def save_best_model(model, loss, save_dir="fno_model"):
    """Saves model weights when validation loss improves.
    
    Args:
        model: Model instance to save
        loss: Validation loss value for comparison
        save_dir: Directory to save checkpoints
    
    Creates directory if needed and preserves best performing weights.
    """
    os.makedirs(save_dir, exist_ok=True)
    path = os.path.join(save_dir, "best_model.pth")
    torch.save(model.state_dict(), path)
    print(f"Saved best model with loss {loss:.4e} to {path}")

# Data preparation pipeline
def prepare_data(gamma, grid_dim, L, dt, T, Nskip):
    """Generates and processes training data for PDE system.
    
    Args:
        gamma: System parameter (e.g., viscosity)
        grid_dim: Number of spatial grid points
        L: Domain size [L × L]
        dt: Time step size
        T: Total simulation time
        Nskip: Number of steps between saved snapshots
    
    Returns:
        Tuple of (input_sequence, target_sequence, k2, dealias_mask, dt)
        - input/target: Consecutive time steps [N_samples, 1, grid_dim, grid_dim]
        - k2: Precomputed Fourier space Laplacian
        - dealias: Spectral dealiasing filter
    """
    # Generate raw simulation data using external solver
    data, k2, dealias, t = generate_data(gamma, grid_dim, L, dt, T, Nskip)
    
    # Convert to PyTorch tensors and add channel dimension
    data = torch.tensor(data).float().unsqueeze(1).to(device)  # Shape: [N, 1, H, W]
    
    # Prepare physical parameters for GPU computation
    k2 = torch.tensor(k2).float().to(device)       # Fourier-space Laplacian
    dealias = torch.tensor(dealias).float().to(device)  # Dealiasing mask
    
    # Create input-target pairs: predict next time step
    return data[:-1], data[1:], k2, dealias, dt


In [None]:
class SpectralConv2d(nn.Module):
    """Spectral convolution layer using Fourier transforms.
    
    Operates in frequency domain to capture global spatial relationships efficiently.
    Uses truncated Fourier series to maintain parameter efficiency.
    """
    def __init__(self, in_channels, out_channels, modes1, modes2):
        super().__init__()
        self.modes1 = modes1  # Number of Fourier modes to keep in spatial dim 1
        self.modes2 = modes2  # Number of Fourier modes to keep in spatial dim 2
        
        # Scale factor for weight initialization (maintain variance)
        self.scale = 1 / (in_channels * out_channels)
        
        # Learnable complex-valued weights for frequency domain operations
        self.weights1 = nn.Parameter(
            self.scale * torch.rand(in_channels, out_channels, 
                                  modes1, modes2, dtype=torch.cfloat))
        self.weights2 = nn.Parameter(
            self.scale * torch.rand(in_channels, out_channels, 
                                  modes1, modes2, dtype=torch.cfloat))

    def forward(self, x):
        # Input shape: [batch, channels, height, width]
        B, _, H, W = x.shape
        
        # Transform to Fourier space (real FFT for real-valued inputs)
        x_ft = torch.fft.rfft2(x)  # Complex tensor shape: [B, C, H, W//2+1]
        
        # Initialize output Fourier tensor
        out_ft = torch.zeros(B, self.weights1.shape[1], H, W//2 + 1, 
                           dtype=torch.cfloat, device=x.device)
        
        # Multiply relevant Fourier modes (lower frequencies)
        out_ft[:, :, :self.modes1, :self.modes2] = torch.einsum(
            "bixy,ioxy->boxy", x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
        
        # Multiply higher frequencies in spatial dimension 1
        out_ft[:, :, -self.modes1:, :self.modes2] = torch.einsum(
            "bixy,ioxy->boxy", x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)

        # Return to spatial domain while preserving original dimensions
        return torch.fft.irfft2(out_ft, s=(H, W))

class FNO2d(nn.Module):
    """Fourier Neural Operator architecture for 2D PDE systems.
    
    Combines spectral convolutions with standard CNNs to model both 
    global and local spatial dependencies efficiently.
    """
    def __init__(self, modes1=24, modes2=24, width=64):
        super().__init__()
        self.modes1 = modes1  # Fourier mode truncation in first spatial dimension
        self.modes2 = modes2  # Fourier mode truncation in second spatial dimension
        self.width = width    # Channel dimension size throughout network
        
        # Initial feature embedding (lifts input to higher dimension)
        self.fc0 = nn.Linear(1, self.width)  # Input: [B, H, W, 1]
        
        # Fourier convolution layers with corresponding 1x1 convolutions
        self.conv0 = SpectralConv2d(width, width, modes1, modes2)
        self.conv1 = SpectralConv2d(width, width, modes1, modes2)
        self.conv2 = SpectralConv2d(width, width, modes1, modes2)
        self.conv3 = SpectralConv2d(width, width, modes1, modes2)
        
        # Pointwise convolutions for residual connections
        self.w0 = nn.Conv2d(width, width, 1)  # 1x1 kernel for channel mixing
        self.w1 = nn.Conv2d(width, width, 1)
        self.w2 = nn.Conv2d(width, width, 1)
        self.w3 = nn.Conv2d(width, width, 1)

        # Final projection layers
        self.fc1 = nn.Linear(width, 256)  # High-dim intermediate projection
        self.fc2 = nn.Linear(256, 1)     # Output projection to physical space

    def forward(self, x):
        # Input shape: [batch, 1, height, width]
        
        # Dimensionality transformations for linear layer processing
        x = x.permute(0, 2, 3, 1)  # [B, H, W, 1]
        x = self.fc0(x)             # Lift to [B, H, W, width]
        x = x.permute(0, 3, 1, 2)  # [B, width, H, W]
        
        # Sequence of Fourier layers with residual connections
        for conv, w in zip([self.conv0, self.conv1, self.conv2, self.conv3],
                          [self.w0, self.w1, self.w2, self.w3]):
            x1 = conv(x)  # Spectral convolution (global processing)
            x2 = w(x)    # Pointwise convolution (local processing)
            x = F.silu(x1 + x2)  # Residual connection with activation
        
        # Final projections to output space
        x = x.permute(0, 2, 3, 1)  # [B, H, W, width]
        x = self.fc1(x)            # [B, H, W, 256]
        x = F.silu(x)              # Non-linear activation
        x = self.fc2(x)            # [B, H, W, 1]
        return x.permute(0, 3, 1, 2)  # Return to [B, 1, H, W] shape

def load_model_weights(model, weight_path):
    """
    Loads model weights from a .pth file and sets the model to eval mode.
    
    Args:
        model (nn.Module): The model to load weights into.
        weight_path (str): Path to the saved weights.
    
    Returns:
        nn.Module: The model with loaded weights.
    """
    state_dict = torch.load(weight_path, map_location=device)
    # Remove 'module.' prefix if present
    state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    model.load_state_dict(state_dict)
    model.eval()
    return model

In [None]:
def Noperator_func(u, k2, dealias):
    """
    Computes the nonlinear term in Fourier space for the modified Kuramoto–Sivashinsky equation.
    """
    return -(k2 * torch.fft.fft2(u**3)) * dealias

def compute_physics_loss(u_prev, u_pred, dt, k2, dealias, lineardenominator_hat):
    """
    Computes loss based on the discretized physics evolution.
    """
    u_hat = (torch.fft.fft2(u_prev) + dt * Noperator_func(u_prev, k2, dealias)) * lineardenominator_hat
    u_hat = torch.fft.ifft2(u_hat).real
    return F.mse_loss(u_hat, u_pred)

def train_fno(
    gamma=0.25, grid_dim=64, L=16*np.pi, dt=0.1, T=1200,
    Nskip=1, epochs=50, batch_size=24, modes=24, width=64
):
    """
    Trains a Fourier Neural Operator on the nonlinear PDE data with optional physics-informed loss.
    """

    # === Step 1: Prepare training data ===
    inputs, targets, k2, dealias, dt = prepare_data(gamma, grid_dim, L, dt, T, Nskip)
    dataset = TensorDataset(inputs, targets)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Compute linear operator used for physics update (in Fourier domain)
    L_operator = -k2 * (k2**2 - 2*k2 + 1 - gamma)
    lineardenominator_hat = 1 / (1 - dt * L_operator)

    # === Step 2: Initialize model and optimizer ===
    model = FNO2d(modes1=modes, modes2=modes, width=width).to(device)
    model.train()

    optimizer = optim.AdamW(model.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=80)
    best_loss = float('inf')

    # === Step 3: Training loop ===
    for epoch in range(epochs):
        epoch_loss = 0.0
        data_loss = 0.0
        physics_loss = 0.0
        start_time = time.time()

        for x, y in loader:
            x, y = x.to(device), y.to(device)

            # --- Forward and compute losses ---
            optimizer.zero_grad()
            pred = model(x)

            # Supervised MSE loss
            loss_data = F.mse_loss(pred, y)

            # Physics-based consistency loss
            loss_physics = compute_physics_loss(
                x.squeeze(1), pred.squeeze(1), dt, k2, dealias, lineardenominator_hat
            )

            # Combine losses (add physics loss with small weight if desired)
            loss = loss_data  # + 0.1 * loss_physics

            loss.backward()
            optimizer.step()

            # Track losses
            epoch_loss += loss.item()
            data_loss += loss_data.item()
            physics_loss += loss_physics.item()

        # Learning rate scheduling based on validation performance
        avg_loss = epoch_loss / len(loader)
        scheduler.step(avg_loss)

        # Save model if it's the best seen so far
        if avg_loss < best_loss:
            best_loss = avg_loss
            save_best_model(model, best_loss)

        # Epoch summary
        print_epoch_summary(
            epoch, epochs,
            data_loss / len(loader),
            physics_loss / len(loader),
            time.time() - start_time
        )

    return model, inputs, targets, loader


In [None]:
trained_model, inputs, targets, loader = train_fno(epochs=10, batch_size=80, modes=16, width=32)

In [None]:
def autoregressive_prediction(model, initial_condition, steps=10):
    """
    Generates predictions for multiple steps ahead by feeding model outputs as new inputs.

    Args:
        model (torch.nn.Module): Trained FNO model.
        initial_condition (ndarray or tensor): Starting input with shape [B, H, W].
        steps (int): Number of autoregressive steps to take.

    Returns:
        Tensor: Stacked predictions over the sequence, shape [steps, B, 1, H, W].
    """
    current_state = torch.tensor(initial_condition).float().unsqueeze(1).to(device)  # [B, 1, H, W]
    predictions = []

    for _ in range(steps):
        with torch.no_grad():
            next_step = model(current_state)
        predictions.append(next_step)
        current_state = next_step  # Use output as input for next step

    return torch.concatenate(predictions, dim=0)

def efficient_nth_step_prediction(model, initial_condition, n_steps):
    """
    Predicts the state after n steps using only the last output (memory-efficient).

    Args:
        model (torch.nn.Module): Trained model.
        initial_condition (ndarray or tensor): Starting condition, shape [B, H, W].
        n_steps (int): Number of steps to evolve the system.

    Returns:
        ndarray: The state after n steps, shape [H, W].
    """
    current_state = torch.tensor(initial_condition).float().unsqueeze(1).to(device)
    
    with torch.no_grad():
        for _ in range(n_steps):
            current_state = model(current_state)

    return current_state.squeeze(1).cpu().numpy()[0]  # Return only the first sample

def compute_mse_per_sample(tensor1, tensor2):
    """
    Computes per-sample mean squared error between two tensors.

    Args:
        tensor1, tensor2 (Tensor): Shape [B, C, H, W].

    Returns:
        Tensor: Shape [B], MSE for each sample in batch.
    """
    squared_diff = (tensor1 - tensor2)**2
    mse_per_sample = squared_diff.mean(dim=(1, 2, 3))  # Mean over C, H, W
    return mse_per_sample

def plot_mse_progression(mse_tensor, dt=None, save_path=None):
    """Visualizes MSE development over prediction steps.
    
    Args:
        mse_tensor: Tensor containing MSE values per step
        dt: Time step size for physical time axis labeling
        save_path: Optional path to save figure
    
    Converts tensor to numpy, creates time axis if dt provided,
    and displays/saves plot with formatting.
    """
    mse_values = mse_tensor.detach().cpu().numpy()
    x_values = dt * np.arange(len(mse_values)) if dt else np.arange(len(mse_values))
    
    plt.figure(figsize=(10, 5))
    plt.plot(x_values, mse_values, color='blue', linewidth=1)
    plt.xlabel('Physical Time' if dt else 'Prediction Step', fontsize=12)
    plt.ylabel('Mean Squared Error', fontsize=12)
    plt.title("Temporal Error Progression", fontsize=14)
    plt.grid(True, linestyle='--', alpha=0.6)
    
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
        print(f"Error plot saved to {save_path}")
    plt.show()

In [None]:
# Load trained model for inference
test_model = FNO2d(modes1=16, modes2=16, width=32).to(device)
test_model = load_model_weights(test_model, "fno_model/best_model.pth")
test_model.eval()

# Set up prediction tensor
initial_data = inputs[0]  # Just to inspect or use as seed
num_frames = len(inputs)
preds = torch.empty_like(inputs)  # [N, 1, H, W]

# Predict frame-by-frame
for i in range(num_frames):
    frame = inputs[i].unsqueeze(0).to(device)  # [1, 1, H, W]
    
    with torch.no_grad():
        output = test_model(frame)
    
    preds[i] = output.squeeze(0).cpu()  # Store prediction
    
    if i % 100 == 0:
        torch.cuda.empty_cache()  # Manage GPU memory for long runs
