# I. Set ups for Data
- Create data loader


In [1]:
import torch
from pathlib import Path

#device agnostic
device = "cuda" if torch.cuda.is_available() else "cpu"

#Hyper Parameters
RAND_SEED = 420
torch.manual_seed(RAND_SEED)

<torch._C.Generator at 0x7f61fc12d990>

In [2]:
print(device)

cuda


## 1. Data Handling Process

### 1.1 Set up for data handling process

In [3]:
#Data paths
data_folder_path = Path("../dataset")
data_path = data_folder_path / "cfd_bench/cylinder"
bc_path = data_path / "bc"
geo_path = data_path / "geo"
prop_path = data_path / "prop"

CHANNELS =2 #(vx,vy)
GRID_SIZE =64

### 1.2 Compute  u and v attributes 

In [4]:
import numpy as np
import os
from tqdm import tqdm

def compute_global_stats(data_dir):
    """
    Computes global mean and std for u and v incrementally 
    to avoid Out-Of-Memory errors.
    """
    print(f"üîç Scanning {data_dir}...")
    
    # 1. Get list of valid case directories
    file_paths = sorted([
        os.path.join(data_dir, f) for f in os.listdir(data_dir) 
        if os.path.isdir(os.path.join(data_dir, f))
    ])
    
    if not file_paths:
        print("‚ùå No data found.")
        return None

    # 2. Initialize accumulators
    # We use float64 to prevent overflow during accumulation
    u_sum = 0.0
    u_sq_sum = 0.0
    v_sum = 0.0
    v_sq_sum = 0.0
    total_elements = 0

    print(f"üöÄ Computing stats over {len(file_paths)} simulation files...")

    # 3. Iterate file by file (Incremental Pass)
    for case_path in tqdm(file_paths):
        u_path = os.path.join(case_path, "u.npy")
        v_path = os.path.join(case_path, "v.npy")
        
        # Handle potential missing files
        if not os.path.exists(u_path) or not os.path.exists(v_path):
            continue
            
        # Load one file into memory
        # We don't strictly need mmap here since we process one by one, 
        # but it helps if individual files are massive.
        u = np.load(u_path).astype(np.float64)
        v = np.load(v_path).astype(np.float64)
        
        # Accumulate stats
        # N = total pixels per file (Time * H * W)
        n_pixels = u.size 
        total_elements += n_pixels
        
        u_sum += np.sum(u)
        u_sq_sum += np.sum(u ** 2)
        
        v_sum += np.sum(v)
        v_sq_sum += np.sum(v ** 2)

    # 4. Final Calculation
    # Mean = Sum / N
    u_mean = u_sum / total_elements
    v_mean = v_sum / total_elements
    
    # Std = sqrt( E[X^2] - (E[X])^2 )
    # Variance = (Sum_Sq / N) - (Mean^2)
    u_var = (u_sq_sum / total_elements) - (u_mean ** 2)
    v_var = (v_sq_sum / total_elements) - (v_mean ** 2)
    
    u_std = np.sqrt(u_var)
    v_std = np.sqrt(v_var)

    print("\n‚úÖ Global Stats Computed:")
    print(f"u_mean: {u_mean:.4f}, u_std: {u_std:.4f}")
    print(f"v_mean: {v_mean:.4f}, v_std: {v_std:.4f}")
    
    return u_mean, u_std, v_mean, v_std

# Usage
# stats = compute_global_stats("path/to/dataset")

In [5]:
u_mean,v_mean,u_std,v_std  = compute_global_stats(prop_path)

üîç Scanning ../dataset/cfd_bench/cylinder/prop...
üöÄ Computing stats over 116 simulation files...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 116/116 [00:06<00:00, 18.81it/s]


‚úÖ Global Stats Computed:
u_mean: 0.9953, u_std: 0.3867
v_mean: 0.9739, v_std: 0.4182





### 1.2 Datasets classes

In [6]:
import numpy as np
from torch.utils.data import Dataset
import torch
import os
import json

# Constants
GRID_SIZE = 64
CHANNELS = 2  # u, v

class CylinderFlow_SlidingWindow(Dataset):
    """
    Sliding Window Dataset for CFDBench Cylinder Flow.
    - Extracts ALL valid overlapping windows from the timeline.
    - Drastically increases dataset size (Data Augmentation by Time-Shifting).
    """
    def __init__(self, file_paths, u_mean, v_mean, u_std, v_std, t_in=10, t_out=5, stride=1):
        """
        stride (int): How many steps to shift the window. 
                      stride=1 means maximum overlap (maximum data).
                      stride=10 means no overlap (like your previous version).
        """
        self.t_in = t_in
        self.t_out = t_out
        self.total_sequence_len = t_in + t_out 
        self.stride = stride
        
        # Normalization Stats
        self.u_mean, self.v_mean = u_mean, v_mean
        self.u_std, self.v_std = u_std, v_std
        
        self.file_paths = file_paths
        if not self.file_paths:
            raise ValueError("No file paths provided.")

        # 1. Inspect first case to calculate dataset size
        first_case = self.file_paths[0]
        self.ext = '.npy' if os.path.exists(os.path.join(first_case, 'u.npy')) else '.np'
        
        # Load shape just once to calculate indices
        sample_u = np.load(os.path.join(first_case, f'u{self.ext}'), mmap_mode='r')
        self.total_timesteps_per_sim = sample_u.shape[0] # e.g., 600
        
        # Calculate how many valid windows fit in one simulation
        # Example: 600 steps, need 15 contiguous. Valid starts are 0 to 585.
        # Formula: (Total - Sequence) // Stride + 1
        self.windows_per_sim = (self.total_timesteps_per_sim - self.total_sequence_len) // self.stride + 1
        
        self.length = len(self.file_paths) * self.windows_per_sim
        
        # 2. Cache Physics Parameters (Optimization)
        # This prevents opening JSON files 50,000 times per epoch
        self.case_params = []
        for path in self.file_paths:
            self.case_params.append(self._load_params(path))
            
        print(f"‚úÖ Sliding Window Dataset Ready:")
        print(f"   - Files: {len(self.file_paths)}")
        print(f"   - Windows per file: {self.windows_per_sim} (Stride={stride})")
        print(f"   - Total Samples: {self.length}")

    def _load_params(self, case_path):
        json_path = os.path.join(case_path, "params.json")
        if not os.path.exists(json_path):
             json_path = os.path.join(case_path, "case.json")
        
        # Defaults
        params = {"viscosity": 0.001, "density": 1.0, "x_max": 0.16, "x_min": -0.06, "y_max": 0.06, "y_min": -0.06}
        if os.path.exists(json_path):
            with open(json_path, "r") as f:
                params.update(json.load(f))
        return params

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # 1. Map global index to (Simulation Index, Window Index)
        sim_idx = idx // self.windows_per_sim
        window_idx = idx % self.windows_per_sim
        
        case_path = self.file_paths[sim_idx]
        params_map = self.case_params[sim_idx] # Use cached params
        
        # 2. Calculate Time Indices
        start_time = window_idx * self.stride
        end_in = start_time + self.t_in
        end_out = end_in + self.t_out

        # 3. Load Data (mmap is fast for slicing)
        u_mmap = np.load(os.path.join(case_path, f'u{self.ext}'), mmap_mode='r')
        v_mmap = np.load(os.path.join(case_path, f'v{self.ext}'), mmap_mode='r')
        
        # 4. Slice
        u_in = u_mmap[start_time : end_in]
        v_in = v_mmap[start_time : end_in]
        u_out = u_mmap[end_in : end_out]
        v_out = v_mmap[end_in : end_out]

        # 5. Normalize (On-the-fly)
        u_in = (u_in - self.u_mean) / self.u_std        
        v_in = (v_in - self.v_mean) / self.v_std        
        u_out = (u_out - self.u_mean) / self.u_std        
        v_out = (v_out - self.u_mean) / self.u_std        

        # 6. Format to Tensor (C, H, W) for FNO
        x_np = np.stack([u_in, v_in], axis=-1)
        y_np = np.stack([u_out, v_out], axis=-1)
        
        # Permute: (Time, H, W, C) -> (H, W, Time*C)
        # Note: Standard FNO usually takes (x, y, c), so we flatten time into channels
        x_tensor = torch.from_numpy(x_np).permute(1, 2, 0, 3).reshape(GRID_SIZE, GRID_SIZE, -1)
        y_tensor = torch.from_numpy(y_np).permute(1, 2, 0, 3).reshape(GRID_SIZE, GRID_SIZE, -1)
        
        return x_tensor.float(), y_tensor.float(), params_map

print("‚úÖ CylinderFlow_SlidingWindow class defined.")

‚úÖ CylinderFlow_SlidingWindow class defined.


# II. Set up for model
- Declare up model class
- Set up loss functions
- Train - test function


## 1.Models set ups

### 1.1 FNO2D layer class

In [7]:
import torch
from torch import nn
import torch.nn.functional as F
class FNO_Layer2D(nn.Module):
  """ A class act as one Fourier Layer as described in the original paper"""
  def __init__(self,in_chanels, out_chanels,mode_x,mode_y):
    super(FNO_Layer2D, self).__init__()
    self.in_chanels = in_chanels
    self.out_chanels = out_chanels
    self.mode_x = mode_x
    self.mode_y = mode_y
    #Scaler to scale the parameters
    self.scalar = (1/(in_chanels * out_chanels))
    self.weight_x = nn.Parameter(
        self.scalar * torch.rand(in_chanels,out_chanels,mode_x, mode_y, dtype = torch.cfloat)
    )
    self.weight_y = nn.Parameter(
        self.scalar * torch.rand(in_chanels,out_chanels,mode_x, mode_y,dtype = torch.cfloat)
    )

  def forward(self, x):
    #Size 
    batch,_,H,W = x.shape
      
    # FFT (2D fourier transform)
    x_ft = torch.fft.rfft2(x)
      
    ## Fourier layer
    out_ft = torch.zeros(
        batch, self.out_chanels,H, W //2 +1,
        dtype = torch.cfloat,
        device = device
    )

    out_ft[:,:,:self.mode_x,:self.mode_y] = torch.einsum(
        "bixy,ioxy->boxy",
        x_ft[:,:,:self.mode_x,:self.mode_y],
        self.weight_x
    )

    out_ft[:,:,-self.mode_x:,:self.mode_y] = torch.einsum(
        "bixy,ioxy->boxy",
        x_ft[:,:,-self.mode_x:, :self.mode_y],
        self.weight_y
    )
    ## Inverse FFT
    x = torch.fft.irfft2(out_ft , s=(H,W))
    return x
print("FNO_Layer2D class defined.")

FNO_Layer2D class defined.


### 1.2 FNO2D class

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FNO2D(nn.Module):
    def __init__(self, width, mode_x, mode_y, in_chanels=1, out_chanels=200):
        """ Default is 1 time step -> 200 timestep rest"""
        super(FNO2D, self).__init__()
        self.in_chanels = in_chanels + 2 # add one more for the grid
        self.out_chanels = out_chanels
        self.width = width
        self.mode_x = mode_x
        self.mode_y = mode_y
        
        # P layer (lift the input channel up):
        self.fc0 = nn.Linear(self.in_chanels, self.width)

        # T block: 4 Fourier layers
        self.conv0 = FNO_Layer2D(width, width, mode_x, mode_y)
        self.w0 = nn.Conv2d(self.width, self.width, 1)

        self.conv1 = FNO_Layer2D(width, width, mode_x, mode_y)
        self.w1 = nn.Conv2d(self.width, self.width, 1)

        self.conv2 = FNO_Layer2D(width, width, mode_x, mode_y)
        self.w2 = nn.Conv2d(self.width, self.width, 1)

        self.conv3 = FNO_Layer2D(width, width, mode_x, mode_y)
        self.w3 = nn.Conv2d(self.width, self.width, 1)

        # Q layer (project down to output_chanels)
        self.fc1 = nn.Linear(width, width * 2)
        self.fc2 = nn.Linear(width * 2, self.out_chanels)

    def forward(self, x, params_map=None):
        # Get the grid information
        # FIX: Removed syntax error (positional arg after keyword arg)
        grid = get_grid(x.shape, params_map, x.device) 

        # concat to the input
        x = torch.cat((x, grid), dim=-1) # (batch,H,W,chanels)

        ## Through P layer:
        x = self.fc0(x)

        x = x.permute(0, 3, 1, 2) # (batch, chanel , H , W)

        ## Through T
        # 1
        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        # 2
        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        # 3
        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        # 4
        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2

        # Switch back
        x = x.permute(0, 2, 3, 1) # (batch,H,W,chanels)

        ## Through Q
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x

def get_grid(shape, params_map, device): 
    batch, H, W = shape[0], shape[1], shape[2]
    
    # --- ROBUST PARAMETER EXTRACTION ---
    # The DataLoader returns Batched Tensors (e.g., size 32).
    # We need scalars for linspace. We assume the domain is constant for the batch.
    if params_map and 'x_min' in params_map:
        # Take the first element and convert to python float
        # This prevents "RuntimeError: The truth value of an array..."
        x_min = params_map['x_min'][0].item() if torch.is_tensor(params_map['x_min']) else params_map['x_min']
        x_max = params_map['x_max'][0].item() if torch.is_tensor(params_map['x_max']) else params_map['x_max']
        y_min = params_map['y_min'][0].item() if torch.is_tensor(params_map['y_min']) else params_map['y_min']
        y_max = params_map['y_max'][0].item() if torch.is_tensor(params_map['y_max']) else params_map['y_max']
    else:
        # Fallback defaults
        x_min, x_max, y_min, y_max = -1.0, 1.0, -1.0, 1.0

    # --- ASPECT RATIO PRESERVATION STRATEGY ---
    # Raw coordinates (e.g., 0.16) are very small for neural networks.
    # It is better to normalize the LONGEST dimension to [0, 1] 
    # and scale the other dimension to preserve aspect ratio.
    
    width_x = x_max - x_min
    height_y = y_max - y_min
    
    # Prevent division by zero
    if width_x == 0: width_x = 1.0
    
    aspect_ratio = height_y / width_x

    # X goes from 0 to 1 (Normalized)
    grid_x = torch.linspace(0, 1, H, device=device, dtype=torch.float)
    # Y goes from 0 to aspect_ratio (e.g., 0.375)
    grid_y = torch.linspace(0, aspect_ratio, W, device=device, dtype=torch.float)

    # Reshape for broadcasting
    grid_x = grid_x.view(1, H, 1, 1).expand(batch, -1, W, -1)
    grid_y = grid_y.view(1, 1, W, 1).expand(batch, H, -1, -1)

    # Concatenate
    grid = torch.cat((grid_x, grid_y), dim=-1)
    return grid 

print("FNO2D class defined.")

FNO2D class defined.


## 2.Loss functions


### 2.2 Physic loss for NavierStoke2D

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class NavierStoke2DLoss(nn.Module):
    def __init__(self, dt=0.001, size_average=True,u_mean = u_mean , v_mean = v_mean , u_std = u_std, v_std =v_std):
        super(NavierStoke2DLoss, self).__init__()
        
        self.dt = dt
        self.u_mean = u_mean
        self.v_mean = v_mean
        self.u_std = u_std
        self.v_std = v_std
        
        self.data_loss_fn = torch.nn.MSELoss(reduction='mean')
        print(f"üìâ Loss Initialized")

    def compute_derivatives(self, u, v):
        """
        Vectorized derivatives for sequence (B, H, W, T)
        Maintains (B, T, H, W) structure to allow broadcasting with (B) parameters.
        """
        # Input shape: (B, H, W, T)
        b, h, w, t = u.shape
        
        # Permute to (B, T, H, W) - DO NOT FLATTEN B and T
        u_perm = u.permute(0, 3, 1, 2)
        v_perm = v.permute(0, 3, 1, 2)

        # Pad spatial dims (H and W). F.pad works on the last dimensions.
        # Padding (left, right, top, bottom)
        u_pad = F.pad(u_perm, (1, 1, 1, 1), mode='replicate')
        v_pad = F.pad(v_perm, (1, 1, 1, 1), mode='replicate')

        # Reshape dx/dy for broadcasting: (B) -> (B, 1, 1, 1)
        dx_view = self.dx.view(b, 1, 1, 1)
        dy_view = self.dy.view(b, 1, 1, 1)

        # Calculate Gradients. Result shape: (B, T, H, W)
        du_dx = (u_pad[..., 1:-1, 2:] - u_pad[..., 1:-1, :-2]) / (2 * dx_view)
        du_dy = (u_pad[..., 2:, 1:-1] - u_pad[..., :-2, 1:-1]) / (2 * dy_view)
        dv_dx = (v_pad[..., 1:-1, 2:] - v_pad[..., 1:-1, :-2]) / (2 * dx_view)
        dv_dy = (v_pad[..., 2:, 1:-1] - v_pad[..., :-2, 1:-1]) / (2 * dy_view)

        # Reshape back to (B, H, W, T)
        return (
            du_dx.permute(0, 2, 3, 1),
            dv_dx.permute(0, 2, 3, 1),
            du_dy.permute(0, 2, 3, 1),
            dv_dy.permute(0, 2, 3, 1)
        )

    def laplacian(self, w):
        # Input w shape: (B, H, W, T)
        b, h, w_dim, t = w.shape
        
        # Permute to (B, T, H, W)
        w_perm = w.permute(0, 3, 1, 2)
        w_pad = F.pad(w_perm, (1, 1, 1, 1), mode='replicate')
        
        # Reshape dx/dy for broadcasting: (B, 1, 1, 1)
        dx_view = self.dx.view(b, 1, 1, 1)
        dy_view = self.dy.view(b, 1, 1, 1)

        d2x = (w_pad[..., 1:-1, 2:] - 2 * w_pad[..., 1:-1, 1:-1] + w_pad[..., 1:-1, :-2]) / (dx_view ** 2)
        d2y = (w_pad[..., 2:, 1:-1] - 2 * w_pad[..., 1:-1, 1:-1] + w_pad[..., :-2, 1:-1]) / (dy_view ** 2)
        
        return (d2x + d2y).permute(0, 2, 3, 1)

    def cal_physics_loss(self, u, v):
        """
        u, v shape: (B, H, W, T_total)
        """
        # 1. Get Viscosity (Batch, 1, 1, 1) for broadcasting against (B, H, W, T)
        # Use a local variable to avoid modifying self.nu in place repeatedly
        nu_view = self.nu.view(-1, 1, 1, 1)

        # 2. Compute Derivatives
        du_dx, dv_dx, du_dy, dv_dy = self.compute_derivatives(u, v)

        # --- Continuity Loss ---
        loss_con = torch.mean((du_dx + dv_dy) ** 2)

        # --- Vorticity Loss ---
        omega = dv_dx - du_dy
        
        # Time Derivative (Central Difference)
        omega_next = omega[..., 2:]
        omega_prev = omega[..., :-2]
        dw_dt = (omega_next - omega_prev) / (2 * self.dt)
        
        # Align Spatial Terms (Slice to center to match time slicing)
        u_center = u[..., 1:-1]
        v_center = v[..., 1:-1]
        omega_center = omega[..., 1:-1]
        
        # Advection terms (u * grad(w))
        dw_dx, _, dw_dy, _ = self.compute_derivatives(omega_center, torch.zeros_like(omega_center))
        advection = u_center * dw_dx + v_center * dw_dy
        
        # Diffusion term
        diffusion = nu_view * self.laplacian(omega_center)
        
        # Residual
        residual = dw_dt + advection - diffusion
        loss_vort = torch.mean(residual ** 2)

        return loss_vort, loss_con

    def forward(self, x_in, y_pred, y_true, params_map):
        """
        x_in: (Batch, H, W, 20 ) -> Contains History
        y_pred: (Batch, H, W, T_out*2) -> Contains Predictions
        """
        # 1. Setup Grid (Physical Units)
        # Ensure these are on the correct device
        x_max = params_map.get("x_max", 1.0).to(y_pred.device)
        x_min = params_map.get("x_min", 0.0).to(y_pred.device)
        y_max = params_map.get("y_max", 1.0).to(y_pred.device)
        y_min = params_map.get("y_min", 0.0).to(y_pred.device)
        
        x_range = x_max - x_min
        y_range = y_max - y_min
        
        self.nu = (params_map.get("viscosity", 0.001)/ params_map.get("density")).to(y_pred.device) # dataset give dynamic vis coe so we make it kinetic nu
        self.dx = x_range / 64.0
        self.dy = y_range / 64.0

        # 2. Data Loss
        data_loss = self.data_loss_fn(y_pred, y_true)

        # 3. Stitch History + Prediction for Physics
        x_last = x_in[...,-2:]
        vel_traj = torch.cat((x_last, y_pred), dim=-1) 
        
        # Unpack Prediction (Interleaved u, v)
        b, h, w, c = vel_traj.shape
        # Ensure steps are integer division
        t_steps = c // 2
        vel_traj = vel_traj.view(b, h, w, t_steps, 2)
        u_full = vel_traj[..., 0]
        v_full = vel_traj[..., 1]

        #De-normalized
        u_og = u_full * self.u_std + self.u_mean
        v_og = v_full * self.v_std + self.v_mean
       
        # 4. Calculate Physics
        loss_vort, loss_con = self.cal_physics_loss(u_og, v_og)

        return data_loss, loss_vort, loss_con

### 2.3 Data Loss function for autoregressive

In [10]:
class DataLossOnly(nn.Module):
    def __init__(self,size_average = True):
        super(DataLossOnly,self).__init__()
        self.size_average = size_average 
        self.data_loss_fn = torch.nn.MSELoss(reduction='mean')
        
    def forward(self,y_pred,y_target): 
        data_loss = self.data_loss_fn(y_pred,y_target) 
        return data_loss
    

## 3 Training & Evaluating process

### 3.1 Training and testing loops + printing time method

In [11]:
def training_loop(model, data_loader, phys_weight, loss_fn, optimizer, device, noise_level, rollout_steps, sampling_prob):
    model.train()
    total_data_loss = 0.0
    total_phys_loss = 0.0
    total_samples = 0
    
    # Weights for the two physics components
    W_VORT = 1.0
    W_CONT = 1.0
    CHANNELS = 2

    for x, y, params_map in data_loader:
        x, y = x.to(device), y.to(device)
        
        # Noise injection for stability
        if noise_level > 0.0:
            x_noisy = x + torch.randn_like(x) * x.std() * noise_level
            x_current = x_noisy
        else:
            x_current = x
            
        predictions = [] # List to store steps
        
        # --- ROLLOUT LOOP ---
        for step in range(rollout_steps):
            
            # 1. Forward Pass
            y_pred = model(x_current) # (B, H, W, 2)
            predictions.append(y_pred) # Store it
            
            # 2. Prepare Input for NEXT step
            start_c = step * CHANNELS
            end_c = (step + 1) * CHANNELS
            y_target_step = y[..., start_c:end_c]
            
            # 3. Update Window (Autoregressive Logic)
            if torch.rand(1) < sampling_prob:
                next_in = y_target_step # Teacher Forcing
            else:
                next_in = y_pred        # Autoregressive
                
            # Slide window: Drop first 2 channels, add new 2
            x_current = torch.cat([
                x_current[..., CHANNELS:], 
                next_in], dim=-1)

        # --- END OF LOOP ---

        # 4. Prepare Tensors for Loss
        # Stack predictions along channel dim: (B, H, W, Rollout*2)
        y_pred_full = torch.cat(predictions, dim=-1)
        
        # Slice ground truth to match rollout length
        target_channels = rollout_steps * CHANNELS
        y_target_full = y[..., :target_channels]

        # 5. Calculate Loss ONCE on the full sequence
        # Note: loss_fn expects (x_in, y_pred, y_target, params)
        data_loss, vort_loss, con_loss = loss_fn(x, y_pred_full, y_target_full, params_map)
        
        # 6. Accumulate
        phys_term = phys_weight * (W_VORT * vort_loss + W_CONT * con_loss)
        loss_accumulated = data_loss + phys_term
        
        total_data_loss += data_loss.item()
        total_phys_loss += phys_term.item() 
        total_samples += 1

        # 7. Backward Pass
        optimizer.zero_grad()
        loss_accumulated.backward()
        
        # Clip gradients (Essential for AR training)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
    

    return total_data_loss / total_samples, total_phys_loss / total_samples

In [12]:
def testing_loop(model, data_loader, loss_fn, rollout_steps,device):
    model.eval()
    val_data_loss = 0
    CHANNELS = 2

    with torch.inference_mode():
        for x_batch, y_batch, params_map in data_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            predictions =[]
            x_current = x_batch
            
            for step in range(rollout_steps): 
                y_pred = model(x_current) # Output: (B, H, W, C)
                predictions.append(y_pred)
                x_current =torch.cat([x_current[...,CHANNELS:], y_pred],dim =-1)


            # --- ROBUST LOSS CALCULATION ---
            # DataLossOnly returns 1 value. Physics loss returns 3.
            # We try to unpack; if it fails, we assume it's a single value.
            y_pred_tensor = torch.cat(predictions, dim=-1) 
            loss_output = loss_fn(y_pred_tensor, y_batch[...,:CHANNELS*rollout_steps]) 
            
            if isinstance(loss_output, tuple):
                data_loss = loss_output[0] # Extract data loss from (data, vort, con)
            else:
                data_loss = loss_output    # It's just the scalar data loss

            val_data_loss += data_loss.item()

    return val_data_loss / len(data_loader)

In [13]:
def printing_time(start, end, model_name):
    print(f"It takes {(end-start):.2f} s to train {model_name}")

### 3.2 Traing & validating for epochs



In [14]:
import torch
from tqdm.auto import tqdm
from timeit import default_timer as timer

def train_and_test(lr_scheduler, model, phys_loss, data_loss, optimizer, phys_weight, train_loader,
                   dev_loader, epochs, device, noise_level, rollout_steps, sampling_prob):
    model.to(device)
    
    # --- CONFIGURATION ---
    WARMUP_EPOCHS = 5  
    RAMP_EPOCHS = 10   # Epochs 5-14: Linearly increase weight
    # ---------------------

    results = {"train_data_loss": [], "train_phys_loss": [], "val_data_loss": []}
    start_time = timer()

    for epoch in tqdm(range(epochs)):
        
        # --- 1. DYNAMIC WEIGHT SCHEDULE ---
        if epoch < WARMUP_EPOCHS:
            curr_phys_weight = 0.0
            phase = "Data Only"
        elif epoch < (WARMUP_EPOCHS + RAMP_EPOCHS):
            # Linear ramp: 0.0 -> target_phys_weight
            progress = (epoch - WARMUP_EPOCHS) / RAMP_EPOCHS
            curr_phys_weight = phys_weight * progress
            phase = f"Ramping ({progress:.0%})"
        else:
            curr_phys_weight = phys_weight
            phase = "Full Phys"
            
        # --- 2. TEACHER FORCING DECAY ---
        # Slowly reduce reliance on ground truth history
        if epoch > 0 and epoch % 8 == 0:
            sampling_prob = max(0.0, sampling_prob - 0.1)

        # --- 3. TRAINING ---
        train_data_loss, train_phys_loss = training_loop(
            model, train_loader, curr_phys_weight, phys_loss, optimizer, 
            device, noise_level, rollout_steps, sampling_prob
        )

        # --- 4. VALIDATION ---
        val_data_loss = testing_loop(model, dev_loader, data_loss, rollout_steps, device)
        
        # Update Scheduler based on Validation Data Loss
        lr_scheduler.step(val_data_loss)

        # --- 5. LOGGING ---
        # We log the RAW physics loss to monitor the billions, 
        # but the training uses the 'curr_phys_weight' to keep gradients safe.
        if epoch % 1 == 0:
            print(f"[{epoch}/{epochs}] {phase} | Train Data Loss: {train_data_loss:.4e} | Train Phys Loss: {train_phys_loss:.2e} | Val: {val_data_loss:.4e}")

        results["train_data_loss"].append(train_data_loss)
        results["train_phys_loss"].append(train_phys_loss)
        results["val_data_loss"].append(val_data_loss)

    end_time = timer()
    printing_time(start_time, end_time, type(model).__name__)
    return results

### 3.3 Evaluating model function

# III. Models 

## 1.DATA

### 1.1 T10 data

In [15]:
import random
from torch.utils.data import DataLoader


T_IN =10
T_OUT =5
# 1. Get all available simulation paths
all_files = sorted([
    os.path.join(prop_path, f) for f in os.listdir(prop_path) 
    if os.path.isdir(os.path.join(prop_path, f))
])

# 2. Shuffle and Split Files (80% Train, 10% Val, 10% Test)
# Set seed for reproducibility!
random.seed(42) 
random.shuffle(all_files)

n_total = len(all_files)
n_train = int(0.8 * n_total)
n_val = int(0.1 * n_total)

train_files = all_files[:n_train]
val_files   = all_files[n_train:n_train+n_val]
test_files  = all_files[n_train+n_val:]

print(f"üìÇ Split: {len(train_files)} Train, {len(val_files)} Val, {len(test_files)} Test files.")

# 3. Instantiate Separate Datasets
train_ds = CylinderFlow_SlidingWindow(
    file_paths=train_files,
    t_in=T_IN, 
    t_out=T_OUT, 
    stride = 6,
    u_mean = u_mean,
    v_mean = v_mean,
    u_std = u_std,
    v_std = v_std
)

val_ds = CylinderFlow_SlidingWindow(
    file_paths=val_files,
    t_in=T_IN, 
    t_out=T_OUT, 
    u_mean = u_mean,
    v_mean = v_mean,
    u_std = u_std,
    v_std = v_std,
    stride = T_IN + T_OUT,
)


# 4. Create DataLoaders
train_loader_t10 = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=0)
val_loader_t10   = DataLoader(val_ds,   batch_size=32, shuffle=False, num_workers=0)

üìÇ Split: 92 Train, 11 Val, 13 Test files.
‚úÖ Sliding Window Dataset Ready:
   - Files: 92
   - Windows per file: 165 (Stride=6)
   - Total Samples: 15180
‚úÖ Sliding Window Dataset Ready:
   - Files: 11
   - Windows per file: 66 (Stride=15)
   - Total Samples: 726


## 4. Model 5
**(Plautaue at .. epoch after .. hour)

** Model 5 (Autoregressive, 3-in) Test Loss: 

- use data_t10 
- FNO2D
- NavierStoke2D loss 

### 4.1 Loss function and optimizer

In [16]:
## HYPER Parameters from the paper
WIDTH =64
T_IN =10
T_OUT =1
MODE_X = 12
MODE_Y = 12

#Define model
model5 = FNO2D(in_chanels=T_IN*CHANNELS, 
                 out_chanels=T_OUT*CHANNELS, 
                 mode_x=MODE_X, 
                 mode_y=MODE_Y, 
                 width=WIDTH).to(device)

In [17]:
dt =0.001 #(Please modify if doesnt match)
phys_loss_fn5 = NavierStoke2DLoss(dt = dt, size_average =True)
data_loss_fn5 = DataLossOnly(size_average = True)


üìâ Loss Initialized


In [18]:
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau

#Optimizer (Adam)
LR =0.001
optimizer5 = torch.optim.Adam(model5.parameters(),lr = LR)

# We'll use 'ReduceLROnPlateau'.
# This scheduler will automatically "reduce" the learning rate
# when it detects that our validation loss has "plateaued" (stopped improving).
scheduler5 = ReduceLROnPlateau(
    optimizer5,
    mode='min',      # It will look at the validation loss, where 'min' is better
    factor=0.1,      # Reduce LR by a factor of 10 (e.g., 0.001 -> 0.0001)
    patience=5,      # Wait 5 epochs for improvement before reducing
)

### 4.2 Training and testing model 5

In [19]:
# --- HYPERPARAMETERS ---
# ROLLOUT_STEPS must match the number of output steps provided by your DataLoader's y_batch.

ROLLOUT_STEPS = 5  
SAMPLING_PROB = 0.6
EPOCHS =32
NOISE_LEVEL =0.005
PHYS_WEIGHT =1e-13

result5 = train_and_test(scheduler5, model = model5,
              phys_loss = phys_loss_fn5, data_loss = data_loss_fn5, optimizer = optimizer5,phys_weight = PHYS_WEIGHT, train_loader = train_loader_t10,
              dev_loader = val_loader_t10, epochs = EPOCHS,device = device, noise_level = NOISE_LEVEL, rollout_steps =ROLLOUT_STEPS , sampling_prob =SAMPLING_PROB)

  0%|          | 0/32 [00:00<?, ?it/s]

[0/32] Data Only | Train Data Loss: 3.3339e-03 | Train Phys Loss: 0.00e+00 | Val: 1.0168e-04
[1/32] Data Only | Train Data Loss: 5.9073e-05 | Train Phys Loss: 0.00e+00 | Val: 5.0410e-05
[2/32] Data Only | Train Data Loss: 3.8161e-05 | Train Phys Loss: 0.00e+00 | Val: 5.2447e-05
[3/32] Data Only | Train Data Loss: 3.2852e-05 | Train Phys Loss: 0.00e+00 | Val: 5.1533e-05
[4/32] Data Only | Train Data Loss: 3.5041e-05 | Train Phys Loss: 0.00e+00 | Val: 3.1953e-05
[5/32] Ramping (0%) | Train Data Loss: 3.4710e-05 | Train Phys Loss: 0.00e+00 | Val: 2.0823e-05
[6/32] Ramping (10%) | Train Data Loss: 2.4238e-05 | Train Phys Loss: 6.76e-07 | Val: 1.6381e-05
[7/32] Ramping (20%) | Train Data Loss: 3.5515e-05 | Train Phys Loss: 1.35e-06 | Val: 1.6305e-05
[8/32] Ramping (30%) | Train Data Loss: 2.2332e-05 | Train Phys Loss: 2.00e-06 | Val: 2.1266e-05
[9/32] Ramping (40%) | Train Data Loss: 2.6217e-05 | Train Phys Loss: 2.64e-06 | Val: 1.1851e-05
[10/32] Ramping (50%) | Train Data Loss: 1.7581e-05

In [20]:
print(f"Result 5: {result5}")

Result 5: {'train_data_loss': [0.003333898678821769, 5.907324953065989e-05, 3.816136999950303e-05, 3.285162012999583e-05, 3.504103225982123e-05, 3.470976001010154e-05, 2.423757678211826e-05, 3.551462441273064e-05, 2.233151404461272e-05, 2.621670441409438e-05, 1.7580840156274223e-05, 2.594936950883589e-05, 2.3111144668961827e-05, 1.4531499756757464e-05, 1.584483480703631e-05, 1.9711058236058115e-05, 1.5268003582222216e-05, 1.654756656199951e-05, 1.374202524091647e-05, 1.4289032811862661e-05, 1.5120776086309649e-05, 1.2490942741638527e-05, 1.1820120312940593e-05, 1.2136309433117276e-05, 1.2468175494799844e-05, 1.1817789914187558e-05, 1.1145375039727936e-05, 1.0975835014170643e-05, 1.0360224565183192e-05, 3.242282291466836e-06, 3.1099977023155878e-06, 3.0773104088822088e-06], 'train_phys_loss': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.757761424450391e-07, 1.347019388143125e-06, 2.0013928062882194e-06, 2.6439735391159934e-06, 3.2764086307093895e-06, 3.8976526809806645e-06, 4.509176935445644e-06, 5

## 6. Saving and comparing models

### 6.1 Saving model

In [None]:
import torch
import os

print("Saving initial untrained models...")

# 1. Define a path to save your models in your Drive
MODELS_PATH = Path("models")
os.makedirs(MODELS_PATH, exist_ok=True)

# 2. Define the save paths for each model
MODEL5_SAVE_PATH = MODELS_PATH / "model5_initial_state.pth"


# 3. Save the models' state dictionaries
# .state_dict() saves only the learnable parameters (weights and biases)
try:
    torch.save(model5.state_dict(), MODEL5_SAVE_PATH)

    print(f"Successfully saved model4 to: {MODEL5_SAVE_PATH}")

except Exception as e:
    print(f"Error saving models: {e}")

# To load them back later (for inference or to resume training), you would:
# 1. Re-create the model instance: model0 = FNO1D(WIDTH, K_MAX, x_grid_tensor).to(device)
# 2. Load the weights: model0.load_state_dict(torch.load(MODEL0_SAVE_PATH))

### 6.2 Compare models

In [21]:
def evaluate_autoregressive_rollout(model, file_paths, loss_fn, t_in, device, u_mean, u_std, v_mean, v_std):
    """
    Args:
        u_mean, u_std, etc.: Floats. MUST MATCH TRAINING STATS.
    """
    model.eval()
    total_rollout_loss = 0.0
    valid_samples = 0
    
    print(f"Starting Rollout Evaluation on {len(file_paths)} simulation files...")

    with torch.no_grad():
        for case_path in tqdm(file_paths, desc="Evaluating Rollout"):
            
            # 1. Load Raw Data
            try:
                u = np.load(os.path.join(case_path, 'u.npy')) 
                v = np.load(os.path.join(case_path, 'v.npy')) 
            except FileNotFoundError:
                continue

            # 2. Normalize using Training Stats
            u_norm = (u - u_mean) / u_std
            v_norm = (v - v_mean) / v_std

            # Stack
            vel = np.stack([u_norm, v_norm], axis=-1)
            velocity_traj = torch.from_numpy(vel).float().to(device)

            # Check length
            total_time = velocity_traj.shape[0]
            if total_time <= t_in: continue

            # Prepare Input (First t_in steps)
            initial_input = velocity_traj[:t_in] 
            current_window = initial_input.unsqueeze(0)
            # Flatten batch & time for FNO input: (1, 64, 64, 20)
            current_window = current_window.permute(0, 2, 3, 1,4).reshape(1, 64, 64, -1)

            # Prepare Ground Truth (From t_in to End)
            y_true_full = velocity_traj[t_in:] 
            
            predictions = []
            n_steps_to_predict = y_true_full.shape[0]
            
            # Autoregressive Loop
            stable = True
            for i in range(n_steps_to_predict):
                y_next_pred = model(current_window) # Output: (1, 10, 64, 64, 2)
                
                if torch.any(torch.isinf(y_next_pred)) or torch.any(torch.isnan(y_next_pred)):
                    print(f"!!! Rollout unstable at step {i} !!!")
                    stable = False
                    break
                
                predictions.append(y_next_pred)
                
                # Update Window: Shift left, append new prediction
                # current_window: (1, 64, 64, 20)
                # We want to drop the first 2 channels (oldest t) and append y_next_pred (newest t)
                current_window = torch.cat((current_window[..., 2:], y_next_pred), dim=-1)

            if not stable: continue

            # Calculate Loss
            # Stack predictions: (1, Time, 64, 64, 2)
            y_pred_tensor = torch.stack(predictions, dim=1) 
            
            # Reshape for MSE Loss (Batch * Time * Pixels * Channels)
            loss = loss_fn(y_pred_tensor, y_true_full.unsqueeze(0))
            total_rollout_loss += loss.item()
            valid_samples += 1

    if valid_samples == 0: return float('inf')
    return total_rollout_loss / valid_samples

In [23]:
# (Run this in a new cell after training model2)

print("Starting test...")

# Use test_loader_t1 (from cell 17) for a fair test
# Use loss_fn0 (from cell 55) for a fair data-only loss

rollout_loss_m5 = evaluate_autoregressive_rollout(
    model=model5, 
    file_paths=test_files, 
    loss_fn=data_loss_fn5,           
    device=device,
    t_in=10,# <-- This correctly matches your model2
    u_mean = u_mean,
    v_mean = v_mean,
    u_std = u_std,
    v_std = v_std
)

print("\n--- Test Results ---")
print(f"Model 5 (Autoregressive, 3-in) Test Loss: {rollout_loss_m5:.4f}")

Starting test...
Starting Rollout Evaluation on 13 simulation files...


Evaluating Rollout:   0%|          | 0/13 [00:00<?, ?it/s]


--- Test Results ---
Model 5 (Autoregressive, 3-in) Test Loss: 1.6344


In [None]:
import matplotlib.pyplot as pltW
import numpy as np
from results import result1, result0

# --- IMPORTANT ---
# You must copy the dictionary outputs from your notebook cells
# (cell 50 for model0 and cell 67 for model1) and paste them here.
# I have used your completed output from cell 50 as an example.
# You will need to replace results_1 with your full 30-epoch output.

# --- Create Epoch Array ---
# Assumes both models were trained for the same number of epochs
result0 = {'train_data_loss': [0.4110739827156067, 0.23925882577896118, 0.20458190143108368, 0.19567212462425232, 0.17208527028560638, 0.16243018209934235, 0.16161102056503296, 0.1497381031513214, 0.1424730122089386, 0.13952957093715668, 0.13167569041252136, 0.1314629167318344, 0.13347485661506653, 0.1252565085887909, 0.12282412499189377, 0.12299676984548569, 0.11490682512521744, 0.11774254590272903, 0.10972394794225693, 0.10895457118749619, 0.11255241930484772, 0.10872527211904526, 0.10500839352607727, 0.10531488806009293, 0.10197656601667404, 0.10225849598646164, 0.10453744977712631, 0.10029015690088272, 0.09715524315834045, 0.10300806164741516, 0.09627469629049301, 0.09561239928007126, 0.09352393448352814, 0.09420380741357803, 0.09108910709619522, 0.08958891034126282, 0.08952014893293381, 0.0945991724729538, 0.08834867179393768, 0.08973327279090881, 0.08752944320440292, 0.08816472440958023, 0.08264376223087311, 0.08628880232572556, 0.08571340143680573, 0.08524730801582336, 0.08302949368953705, 0.08613189309835434, 0.08262985199689865, 0.0860193744301796, 0.08372187614440918, 0.08175082504749298, 0.0784018486738205, 0.07959341257810593, 0.0778963565826416, 0.0784391239285469, 0.07724633812904358, 0.07945655286312103, 0.07577656954526901, 0.08160007745027542, 0.08002155274152756, 0.07753711193799973, 0.07859975844621658, 0.0764986202120781, 0.07406435161828995, 0.0723976269364357, 0.0739876925945282, 0.07461126893758774, 0.07415719330310822, 0.07250193506479263, 0.07158296555280685, 0.0725533589720726, 0.07088251411914825, 0.0713513195514679, 0.07630784064531326, 0.07213081419467926, 0.06057237461209297, 0.05909671634435654, 0.05885971337556839, 0.058606330305337906, 0.05847145989537239, 0.058534953743219376, 0.05826769769191742, 0.05820904299616814, 0.05816391482949257, 0.05802098661661148, 0.058009110391139984, 0.057899314910173416, 0.0577441044151783, 0.05807175859808922, 0.057901881635189056, 0.05766398832201958, 0.05750136449933052, 0.05765574425458908, 0.05767909809947014, 0.05718502774834633, 0.0571594312787056, 0.057440076023340225, 0.057252444326877594, 0.05690488591790199, 0.05690895393490791, 0.05691118910908699, 0.056899115443229675, 0.05672742426395416, 0.05663994699716568, 0.05657578259706497, 0.05658656731247902, 0.05656237527728081, 0.05650690570473671, 0.05641086399555206, 0.056626658886671066, 0.05617628991603851, 0.05610703304409981, 0.05624675750732422, 0.0564168281853199, 0.05596563592553139, 0.05595419928431511, 0.055922769010066986, 0.05574888736009598, 0.05567246302962303, 0.05585700646042824, 0.055710289627313614, 0.05597971752285957, 0.05580592527985573, 0.05584094300866127, 0.05574943125247955, 0.05546464025974274, 0.055211909115314484, 0.05526223033666611, 0.05522475764155388, 0.05564416944980621, 0.05541204661130905, 0.05530177429318428, 0.05499405786395073, 0.05497822165489197, 0.055076297372579575, 0.05496148020029068, 0.0547984354197979, 0.05514082685112953, 0.055032216012477875, 0.05497255548834801, 0.05481015145778656, 0.05463377758860588, 0.054746970534324646, 0.05454118177294731, 0.05435951054096222, 0.05476135388016701, 0.054547540843486786, 0.05460606887936592, 0.05433722585439682], 'train_phys_loss': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'val_data_loss': [0.2572142621354451, 0.21145666733620658, 0.19115776675088064, 0.1802690809681302, 0.17118003964424133, 0.15824853404173775, 0.17612133447139983, 0.14190959067098677, 0.13687605919345977, 0.1335370416442553, 0.13127026198402283, 0.1297357457261237, 0.12929621764591762, 0.12160444259643555, 0.15860963624621194, 0.12193896668770957, 0.13977357768823231, 0.11167937008634446, 0.11051414717757513, 0.11135667432395238, 0.11710662323804129, 0.10518188218748759, 0.10574372337451057, 0.11249114868659822, 0.10047837618797545, 0.10142433477772607, 0.09888798826270634, 0.09998914327413316, 0.10522396519543632, 0.10073239829332109, 0.09524016141418427, 0.09673324570296303, 0.0917700073785252, 0.09177173449406548, 0.09537116840245231, 0.0898829839295811, 0.09122747289282936, 0.09619156460440348, 0.09295563683623359, 0.09442581981420517, 0.0865695761546256, 0.08646621212126716, 0.08684119321997204, 0.089255215156646, 0.0872151656519799, 0.08562316390730086, 0.0906154742789647, 0.08481654382887341, 0.08532685396217164, 0.09498826461651969, 0.08751835993358068, 0.08428332919166201, 0.10068383375330577, 0.08187451710303624, 0.08577221821224879, 0.07869435054442239, 0.08390380430316167, 0.08048265462829954, 0.07998732867695037, 0.07940924309548877, 0.08420824247693258, 0.07788392381062584, 0.08739961293481645, 0.07972035036673622, 0.07721601213727679, 0.07708905519001068, 0.08067990259991752, 0.07722300066361351, 0.09537790822131294, 0.07313832954045325, 0.07719190785336116, 0.0766766376438595, 0.07478425805530851, 0.0733860337308475, 0.07394332036612526, 0.07324050420096942, 0.06501247840268272, 0.06459305080629531, 0.06465581023976916, 0.06449638975281564, 0.06492461772665145, 0.06398267458592143, 0.0644972435538731, 0.06389726597874884, 0.06378511274381289, 0.06382066389871022, 0.06389406625004042, 0.06369595667199483, 0.06350627009357725, 0.0635057757534678, 0.06327962041610763, 0.06329736746256313, 0.06380144559911319, 0.06318371796182223, 0.06303894472500635, 0.06305838967599565, 0.06306172854134015, 0.06316362587468964, 0.06280702772358107, 0.06297555837839369, 0.06427314351238901, 0.06272708405814474, 0.06253687392861124, 0.06268517643449799, 0.062455795646186855, 0.06254122944341765, 0.06238306195489944, 0.06241820912275996, 0.062523636493891, 0.06233104856477843, 0.0623762799752137, 0.06204700215704857, 0.06217261888678112, 0.06274812715867209, 0.0618926150694726, 0.06204395309563667, 0.06187254799500344, 0.061938941537860844, 0.06266835385135242, 0.06309285830883753, 0.06154843940148278, 0.06175081975876339, 0.06152673846199399, 0.06216449757653569, 0.06183923437954888, 0.06193630456451386, 0.06164784404256987, 0.06338022170322281, 0.06135745565333064, 0.06773435566869992, 0.06207913467808375, 0.0613168551926575, 0.061097964880958436, 0.06094158074212453, 0.06099840265417856, 0.06099581180347337, 0.06111503788639629, 0.06091841511310093, 0.06254472116392756, 0.06069481703970167, 0.060798205198749664, 0.0615965946917496, 0.060730867383499, 0.06066012004065135, 0.06074823173029082, 0.0605203460842844, 0.060685241033160496, 0.06406244617842492, 0.06041179910775215, 0.06047245848273474]}

epochs = range(len(result0['train_data_loss']))


# --- Create Plots ---
plt.figure(figsize=(18, 5))

# Plot 1: Training Data Loss Comparison
plt.subplot(1, 3, 1)
plt.plot(epochs, result0['train_data_loss'], label='Model 0 (Data Only)', color='blue')
plt.plot(epochs, result1['train_data_loss'], label='Model 1 (PINO)', color='red')
plt.title('Training Data Loss')
plt.xlabel('Epochs')
plt.ylabel('Relative L2 Loss (L_data)')
plt.legend()
plt.grid(True, linestyle=':')
plt.yscale('log') # Use log scale if losses are very different

# Plot 2: Validation Data Loss Comparison
plt.subplot(1, 3, 2)
plt.plot(epochs, result0['val_data_loss'], label='Model 0 (Data Only)', color='blue')
plt.plot(epochs, result1['val_data_loss'], label='Model 1 (PINO)', color='red')
plt.title('Validation Data Loss')
plt.xlabel('Epochs')
plt.ylabel('Relative L2 Loss (L_data)')
plt.legend()
plt.grid(True, linestyle=':')
plt.yscale('log')

# Plot 3: Model 1 (PINO) Loss Breakdown
plt.subplot(1, 3, 3)
plt.plot(epochs, result1['train_data_loss'], label='Data Loss', color='red')
plt.plot(epochs, result1['train_phys_loss'], label='Physics Loss', color='green')
plt.title('Model 1 (PINO) Loss Components')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, linestyle=':')
plt.yscale('log') # Log scale is essential here since physics loss is large



plt.tight_layout()
plt.show()