# I. Set ups for Data
- Create data loader


In [1]:
import torch
from pathlib import Path

#device agnostic
device = "cuda" if torch.cuda.is_available() else "cpu"

#Hyper Parameters
RAND_SEED = 420
torch.manual_seed(RAND_SEED)

<torch._C.Generator at 0x7f4f1ebd5990>

In [2]:
print(device)

cuda


## 1. Data Handling Process

### 1.1 Set up for data handling process

In [3]:
import h5py
#Data paths
data_folder_path = Path("dataset")
data_path = data_folder_path / "cfd_bench/cylinder_flow"
bc_path = data_path / "bc"
geo_path = data_path / "geo"
prop_path = data_path / "prop"

CHANNELS =2 #(vx,vy)
GRID_SIZE =64

### 1.2 Datasets classes

In [5]:
import numpy as np
from torch.utils.data import Dataset
import torch
import os
import json
import glob

# Constants
GRID_SIZE = 64
CHANNELS = 2  # u, v

class CylinderFlow_Autoregressive_Sparse(Dataset):
    """
    Phase 2 Dataset for CFDBench Cylinder Flow.
    - Loads u.npy and v.npy
    - Input: Velocity History ONLY (Batch, 64, 64, 20)
    - Output: Future Velocity (Batch, 64, 64, 10)
    - Returns: x, y, params_map (Dictionary with Physics Metadata)
    """
    def __init__(self, data_dir, t_in=10, t_out=5, windows_per_sim=20):
        self.t_in = t_in
        self.t_out = t_out
        self.total_steps = t_in + t_out 
        self.data_dir = data_dir 
        self.windows_per_sim = windows_per_sim
        
        # 1. Find valid case folders
        self.file_paths = sorted([
            os.path.join(data_dir, f) for f in os.listdir(data_dir) 
            if os.path.isdir(os.path.join(data_dir, f))
        ])
        
        if not self.file_paths:
            raise ValueError(f"No case folders found in {data_dir}")

        # 2. Inspect first case for dimensions
        first_case = self.file_paths[0]
        self.ext = '.npy' if os.path.exists(os.path.join(first_case, 'u.npy')) else '.np'
        
        # Load shape
        sample_u = np.load(os.path.join(first_case, f'u{self.ext}'), mmap_mode='r')
        
        self.num_timesteps = sample_u.shape[0]
        self.max_start_time = self.num_timesteps - self.total_steps
        
        self.sims_in_dataset = len(self.file_paths)
        self.length = self.sims_in_dataset * self.windows_per_sim
        
        print(f"âœ… Dataset Ready: {self.length} samples. (Physics params separated)")

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # 1. Identify Simulation
        sim_idx = idx // self.windows_per_sim
        case_path = self.file_paths[sim_idx]
        json_path = os.path.join(case_path, "params.json")
        
        # 2. Random Start Time
        start_time_index = np.random.randint(0, self.max_start_time + 1)
        
        # 3. Indices
        end_in_index = start_time_index + self.t_in
        end_out_index = end_in_index + self.t_out

        # 4. Load Velocity Data
        u_mmap = np.load(os.path.join(case_path, f'u{self.ext}'), mmap_mode='r')
        v_mmap = np.load(os.path.join(case_path, f'v{self.ext}'), mmap_mode='r')
        
        # Slice
        u_in = u_mmap[start_time_index : end_in_index]
        v_in = v_mmap[start_time_index : end_in_index]
        u_out = u_mmap[end_in_index : end_out_index]
        v_out = v_mmap[end_in_index : end_out_index]
        
        # Stack
        x_np = np.stack([u_in, v_in], axis=-1).copy()
        y_np = np.stack([u_out, v_out], axis=-1).copy()
        
        # 5. Prepare Tensors
        x_tensor = torch.from_numpy(x_np).permute(1, 2, 0, 3).reshape(GRID_SIZE, GRID_SIZE, -1)
        y_tensor = torch.from_numpy(y_np).permute(1, 2, 0, 3).reshape(GRID_SIZE, GRID_SIZE, -1)
        
        # 6. Retrieve Physics Parameters
        # Fallback for filename
        if not os.path.exists(json_path):
             json_path = os.path.join(case_path, "case.json")

        # Default Map
        params_map = {"viscosity": 0.001, "density": 1.0, "x_max": 0.16, "x_min": -0.06, "y_max": 0.06, "y_min": -0.06}

        if os.path.exists(json_path):
            with open(json_path, "r") as f:
                data = json.load(f)
                # Update defaults with real values
                params_map.update(data)

        return x_tensor.float(), y_tensor.float(), params_map

print("âœ… CylinderFlow_Autoregressive_Sparse (Physics-Informed) defined.")

BurgerDataset_Autoregressive_Sparse (FASTER dataset) defined.


# II. Set up for model
- Declare up model class
- Set up loss functions
- Train - test function


## 1.Models set ups

### 1.1 FNO2D layer class

In [6]:
import torch
from torch import nn
import torch.nn.functional as F
class FNO_Layer2D(nn.Module):
  """ A class act as one Fourier Layer as described in the original paper"""
  def __init__(self,in_chanels, out_chanels,mode_x,mode_y):
    super(FNO_Layer2D, self).__init__()
    self.in_chanels = in_chanels
    self.out_chanels = out_chanels
    self.mode_x = mode_x
    self.mode_y = mode_y
    #Scaler to scale the parameters
    self.scalar = (1/(in_chanels * out_chanels))
    self.weight_x = nn.Parameter(
        self.scalar * torch.rand(in_chanels,out_chanels,mode_x, mode_y, dtype = torch.cfloat)
    )
    self.weight_y = nn.Parameter(
        self.scalar * torch.rand(in_chanels,out_chanels,mode_x, mode_y,dtype = torch.cfloat)
    )

  def forward(self, x):
    #Size 
    batch,_,H,W = x.shape
      
    # FFT (2D fourier transform)
    x_ft = torch.fft.rfft2(x)
      
    ## Fourier layer
    out_ft = torch.zeros(
        batch, self.out_chanels,H, W //2 +1,
        dtype = torch.cfloat,
        device = device
    )

    out_ft[:,:,:self.mode_x,:self.mode_y] = torch.einsum(
        "bixy,ioxy->boxy",
        x_ft[:,:,:self.mode_x,:self.mode_y],
        self.weight_x
    )

    out_ft[:,:,-self.mode_x:,:self.mode_y] = torch.einsum(
        "bixy,ioxy->boxy",
        x_ft[:,:,-self.mode_x:, :self.mode_y],
        self.weight_y
    )
    ## Inverse FFT
    x = torch.fft.irfft2(out_ft , s=(H,W))
    return x
print("FNO_Layer2D class defined.")

FNO_Layer2D class defined.


### 1.2 FNO2D class

In [7]:
class FNO2D(nn.Module):
  def __init__(self ,width, mode_x, mode_y,in_chanels =1, out_chanels =200) :
    """ Default is 1 time step -> 200 timestep rest"""
    super(FNO2D,self).__init__()
    self.in_chanels = in_chanels + 2 # add one more for the grid
    self.out_chanels = out_chanels
    self.width = width
    self.mode_x = mode_x
    self.mode_y = mode_y
      
    # P layer ( lift the input chanel up):
    self.fc0 = nn.Linear(self.in_chanels, self.width)

    # T block: 4 Fourier layers
    self.conv0 = FNO_Layer2D(width,width,mode_x,mode_y)
    self.w0 = nn.Conv2d(self.width,self.width,1)

    self.conv1 = FNO_Layer2D(width,width,mode_x,mode_y)
    self.w1 = nn.Conv2d(self.width,self.width,1)

    self.conv2 = FNO_Layer2D(width,width,mode_x,mode_y)
    self.w2 = nn.Conv2d(self.width,self.width,1)

    self.conv3 = FNO_Layer2D(width,width,mode_x,mode_y)
    self.w3 = nn.Conv2d(self.width,self.width,1)

    # Q layer (project down to output_chanels)
    self.fc1 = nn.Linear(width,width*2)
    self.fc2 = nn.Linear(width*2, self.out_chanels)

  def forward(self,x):
    #Get the grid information
    grid = get_grid(shape = x.shape)

    # concat to the input
    x = torch.cat((x,grid),dim=-1) # (batch,H,W,chanels)

    ##Through P layer:
    x = self.fc0(x)

    x = x.permute(0,3,1,2) #(batch, chanel , H , W)

    ##Through T
    #1
    x1 = self.conv0(x)
    x2 = self.w0(x)
    x = x1+x2
    x = F.gelu(x)

    #2
    x1 = self.conv1(x)
    x2 = self.w1(x)
    x = x1+x2
    x = F.gelu(x)

    #3
    x1 = self.conv2(x)
    x2 = self.w2(x)
    x = x1+x2
    x = F.gelu(x)

    #4
    x1 = self.conv3(x)
    x2 = self.w3(x)
    x = x1+x2

    #Switch back
    x = x.permute(0, 2,3, 1) # (batch,H,W,chanels)

    ## Through Q
    x = self.fc1(x)
    x = F.gelu(x)
    x = self.fc2(x)
    return x

def get_grid(shape): 
    batch, H,W = shape[0], shape[1],shape[2]

    grid_x = torch.tensor(np.linspace(-1,1,H), dtype= torch.float)
    grid_y = torch.tensor(np.linspace(-1,1,W), dtype = torch.float)
    
    grid_x = grid_x.reshape(1, H, 1, 1).repeat([batch, 1, W, 1])
    grid_y = grid_y.reshape(1, 1, W, 1).repeat([batch, H, 1, 1])
    grid = torch.cat((grid_x, grid_y), dim =-1).to(device) 
    return grid 

print("FNO2D class defined.")

FNO2D class defined.


## 2.Loss functions


### 2.1 LPLoss

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LpLoss(object):
    def __init__(self, d=2, p=2, size_average=True, reduction=True):
        super(LpLoss, self).__init__()
        self.d = d
        self.p = p
        self.size_average = size_average
        self.reduction = reduction

    def __call__(self, x, y):
        num_examples = x.size()[0]
        diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1)
        y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1)
        if self.reduction:
            if self.size_average:
                return torch.mean(diff_norms/y_norms)
            else:
                return torch.sum(diff_norms/y_norms)
        return diff_norms/y_norms

### 2.2 Physic loss for NavierStoke2D

In [8]:
class NavierStoke2DLoss(nn.Module):
    def __init__(self, dt=0.001, size_average=True):
        super(NavierStoke2DLoss, self).__init__()
        
        self.dt = dt
        self.data_loss_fn = LpLoss(size_average=size_average)
        print(f"ðŸ“‰ Loss Initialized")

    def compute_derivatives(self, u, v):
        """Vectorized derivatives for sequence (B, H, W, T)"""
        # Reshape for 2D padding: (Batch*Time, 1, H, W)
        b, h, w, t = u.shape
        u_reshaped = u.permute(0, 3, 1, 2).reshape(b * t, 1, h, w)
        v_reshaped = v.permute(0, 3, 1, 2).reshape(b * t, 1, h, w)

        # Pad & Calculate Gradients
        u_pad = F.pad(u_reshaped, (1, 1, 1, 1), mode='replicate')
        v_pad = F.pad(v_reshaped, (1, 1, 1, 1), mode='replicate')

        du_dx = (u_pad[..., 1:-1, 2:] - u_pad[..., 1:-1, :-2]) / (2 * self.dx)
        du_dy = (u_pad[..., 2:, 1:-1] - u_pad[..., :-2, 1:-1]) / (2 * self.dy)
        dv_dx = (v_pad[..., 1:-1, 2:] - v_pad[..., 1:-1, :-2]) / (2 * self.dx)
        dv_dy = (v_pad[..., 2:, 1:-1] - v_pad[..., :-2, 1:-1]) / (2 * self.dy)

        # Reshape back to (B, H, W, T)
        return (
            du_dx.view(b, t, h, w).permute(0, 2, 3, 1),
            dv_dx.view(b, t, h, w).permute(0, 2, 3, 1),
            du_dy.view(b, t, h, w).permute(0, 2, 3, 1),
            dv_dy.view(b, t, h, w).permute(0, 2, 3, 1)
        )

    def laplacian(self, w):
        b, h, w_dim, t = w.shape
        w_reshaped = w.permute(0, 3, 1, 2).reshape(b * t, 1, h, w_dim)
        w_pad = F.pad(w_reshaped, (1, 1, 1, 1), mode='replicate')
        d2x = (w_pad[..., 1:-1, 2:] - 2 * w_pad[..., 1:-1, 1:-1] + w_pad[..., 1:-1, :-2]) / (self.dx ** 2)
        d2y = (w_pad[..., 2:, 1:-1] - 2 * w_pad[..., 1:-1, 1:-1] + w_pad[..., :-2, 1:-1]) / (self.dy ** 2)
        return (d2x + d2y).view(b, t, h, w_dim).permute(0, 2, 3, 1)

    def cal_physics_loss(self, u, v, phys_params):
        """
        u, v shape: (B, H, W, T_total)
        """
        # 1. Get Viscosity (Batch, 1, 1, 1)
        self.nu = self.nu.view(-1, 1, 1, 1)

        # 2. Compute Derivatives
        du_dx, dv_dx, du_dy, dv_dy = self.compute_derivatives(u, v)

        # --- Continuity Loss ---
        loss_con = torch.mean((du_dx + dv_dy) ** 2)

        # --- Vorticity Loss ---
        omega = dv_dx - du_dy
        
        # Time Derivative (Central Difference)
        # Uses t-1 and t+1 to calculate slope at t. 
        # We lose first and last frame of the sequence.
        omega_next = omega[..., 2:]
        omega_prev = omega[..., :-2]
        dw_dt = (omega_next - omega_prev) / (2 * self.dt)
        
        # Align Spatial Terms (Slice to center)
        u_center = u[..., 1:-1]
        v_center = v[..., 1:-1]
        omega_center = omega[..., 1:-1]
        
        # Advection terms (u * grad(w))
        dw_dx, _, dw_dy, _ = self.compute_derivatives(omega_center, torch.zeros_like(omega_center))
        advection = u_center * dw_dx + v_center * dw_dy
        
        # Diffusion term
        diffusion = self.nu * self.laplacian(omega_center)
        
        # Residual
        residual = dw_dt + advection - diffusion
        loss_vort = torch.mean(residual ** 2)

        return loss_vort, loss_con

    def forward(self, x_in, y_pred, y_true, phys_params):
        """
        x_in: (Batch, H, W, 20 ) -> Contains History
        y_pred: (Batch, H, W, T_out*2) -> Contains Predictions
        """
        # 1. Setup Grid (Physical Units)
        x_range = params_map.get("x_max", 1.0) - params_map.get("x_min", 0.0)
        y_range = params_map.get("y_max", 1.0) - params_map.get("y_min", 0.0)
        self.nu = params_map.get("viscosity", 0.001)
        
        self.dx = x_range / 64.0
        self.dy = y_range / 64.0

        # 2. Data Loss
        data_loss = self.data_loss_fn(y_pred, y_true)

        # 3. Stitch History + Prediction for Physics
        
        # Assuming 2 physics params at the end:
        vel_traj = torch.cat((x_in,y_pred), dim=-1) #(B,H,W,(T+10)*2)
        
        # Unpack Prediction (Interleaved u, v)
        b, h, w, c = vel_traj.shape
        t_steps = c // 2
        vel_traj = vel_traj.view(b, h, w, t_steps, 2)
        u_full = vel_traj[..., 0]
        v_full = vel_traj[..., 1]
        

        # 4. Calculate Physics
        loss_vort, loss_con = self.cal_physics_loss(u_full, v_full, phys_params)

        return data_loss, loss_vort, loss_con

### 2.3 Data Loss function for autoregressive

In [10]:
class DataLossOnly(nn.Module):
    def __init__(self,size_average = True):
        super(DataLossOnly,self).__init__()
        self.size_average = size_average 
        self.data_loss_fn = LpLoss(size_average)
    def forward(self,x_current,y_pred,y_target): 
        data_loss = self.data_loss_fn(y_pred,y_target) 

        return data_loss, torch.tensor(0.0).to(device)
    

## 3 Training & Evaluating process

### 3.1 Training and testing loops + printing time method

In [11]:
def training_loop(model, data_loader, phys_weight, loss_fn, optimizer, device, noise_level, rollout_steps, sampling_prob):
    model.train()
    total_data_loss = 0.0
    total_phys_loss = 0.0
    total_samples = 0
    
    # Weights for the two physics components
    W_VORT = 1.0
    W_CONT = 1.0
    CHANNELS = 2

    for batch, (x, y, phys_params) in enumerate(data_loader):
        x, y, phys_params = x.to(device), y.to(device), phys_params.to(device)
        
        if noise_level > 0.0:
            x += torch.randn_like(x) * x.std() * noise_level
            
        x_current = x
        loss_accumulated = 0.0
        
        # --- ROLLOUT LOOP ---
        for step in range(rollout_steps):
            
            # 1. Forward Pass
            y_pred = model(x_current) # (B, H, W, 2)
            
            # 2. Prepare Target
            start_c = step * CHANNELS
            end_c = (step + 1) * CHANNELS
            y_target = y[..., start_c:end_c]
            
            
            # 4. Calculate Loss
            data_l, vort_l, con_l = loss_fn(x_current,y_pred, y_target, phys_params)
            
            # Accumulate
            step_loss = data_l + phys_weight * (W_VORT * vort_l + W_CONT * con_l)
            loss_accumulated += step_loss
            
            # 5. Update Window (Autoregressive)
            if torch.rand(1) < sampling_prob:
                next_in = y_target # Teacher Forcing
            else:
                next_in = y_pred   # Autoregressive
                
            # Slide window: Drop first 2 
            x_current = torch.cat([
                x_current[..., CHANNELS:], 
                next_in], dim=-1)

            #Loss for report
            total_data_loss += data_l.item() # Log last step for simplicity
            total_phys_loss += (vort_l.item() + con_l.item())
            total_samples += 1

        # --- Backward Pass ---
        # We backpropagate the SUM of losses from all steps
        optimizer.zero_grad()
        loss_accumulated.backward()
        
        # Clip gradients to prevent explosion during physics calculation
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()

    return total_data_loss / total_samples, total_phys_loss / total_samples

In [12]:
#Testing loop
def testing_loop(model:torch.nn.Module,
                 data_loader: torch.utils.data.DataLoader,
                 loss_fn:torch.nn.Module,
                 device):

  model.eval()
  val_data_loss = 0
  with torch.inference_mode(): # Disable gradient calculation
    for batch,(x_batch, y_batch,phys_params) in enumerate(data_loader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)

        y_pred = model(x_batch)

        # We only care about the data loss for validation
        data_loss, _ ,_= loss_fn(x_batch, y_pred, y_batch[:,:,:,0:2],phys_params) #First step

        #Accumulating
        val_data_loss += data_loss.item()


  #Scaling
  val_data_loss /= len(data_loader)

  return val_data_loss


In [13]:
def printing_time(start, end, model_name):
    print(f"It takes {(end-start):.2f} s to train {model_name}")

### 3.2 Traing & validating for epochs



In [14]:
import torch
from tqdm.auto import tqdm
from timeit import default_timer as timer

def train_and_test(lr_scheduler, model: torch.nn.Module,
                   loss_fn:torch.nn.Module, optimizer: torch.optim.Optimizer,phys_weight :float, train_loader,
                   dev_loader, epochs = 100, device = device,noise_level =0.0, rollout_steps =5, sampling_prob =0.9):
  #Setups
  model.to(device)

  results = {
      "train_data_loss":[],
      "train_phys_loss":[],
      "val_data_loss":[],
  }

  #Start timing
  start_time = timer()

  #Start
  for epoch in tqdm(range(epochs)):
    if epoch %8 == 0:
      print("--------------------------")
      print(f"Epoch: {epoch}\n--------------")
        
      sampling_prob -=0.1

    #Training
    train_data_loss, train_phys_loss = training_loop(model,train_loader, phys_weight, loss_fn, optimizer,device,noise_level,rollout_steps,sampling_prob)

    #Testing
    val_data_loss = testing_loop(model,dev_loader,loss_fn,device)

    #Decay lr based on val_loss
    lr_scheduler.step(val_data_loss)

    #Printing
    if epoch % 8 ==0:
      print(f"Train data Loss:{train_data_loss:.3f} || Train physic Loss:{train_phys_loss:.3f} || Test Loss:{val_data_loss:.3f}")

    #Restore values
    results["train_data_loss"].append(train_data_loss.item() if isinstance(train_data_loss, torch.Tensor) else train_data_loss)
    results["train_phys_loss"].append(train_phys_loss.item() if isinstance(train_phys_loss, torch.Tensor) else train_phys_loss)
    results["val_data_loss"].append(val_data_loss.item() if isinstance(val_data_loss, torch.Tensor) else val_data_loss)

  #Printing time
  print("--------------------------")
  end_time = timer()
  printing_time(start_time,end_time,type(model).__name__)

  return results


### 3.3 Evaluating model function

# III. Models 1D

## 1.DATA

### 1.1 T10 data

In [17]:
import numpy as np 
from numpy import random

#Attributes of the data set
T_in = 10
T_out = 5
BATCH_SIZE = 32
WINDOW_PER_SIM_TRAIN = 20 # 20 random windows per sim
WINDOW_PER_SIM_DEV = 50  


total_size = len(full_dataset)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size

# Set seed for reproducibility so test set stays same
generator = torch.Generator().manual_seed(42)

train_ds, val_ds, test_ds = random_split(
    full_dataset, 
    [train_size, val_size, test_size],
    generator=generator
)

print(f"âœ… Dataset Split:")
print(f"   Train: {len(train_ds)} samples")
print(f"   Val:   {len(val_ds)} samples")
print(f"   Test:  {len(test_ds)} samples")

# 3. Create DataLoaders
# Use num_workers=0 for safety, or 2-4 if on Linux/Mac
train_loader_t10 = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader_t10 = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader_t10 = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print("ðŸŽ‰ DataLoaders are ready!")


Total training samples (windows): 20000
Total dev samples (windows): 5000


## 4. Model 4
**(Plautaue at .. epoch after .. hour)

** Model 4 (Autoregressive, 3-in) Test Loss: 

- use data_t10 
- FNO2D
- DatalossOnly 

### 4.1 Loss function and optimizer

In [26]:
## HYPER Parameters from the paper
WIDTH =64
T_IN =10
T_OUT =1
MODE_X = 12
MODE_Y = 12

#Define model
model5 = FNO2D(in_chanels=T_IN*CHANNELS, 
                 out_chanels=T_OUT*CHANNELS, 
                 mode_x=MODE_X, 
                 mode_y=MODE_Y, 
                 width=WIDTH).to(device)

In [27]:
dt =0.001 #(Please modify if doesnt match)
loss_fn5 = NavierStoke2DLoss(dt = dt, size_average =True)

In [28]:
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau

#Optimizer (Adam)
LR =0.001
optimizer5 = torch.optim.Adam(model4.parameters(),lr = LR)

# We'll use 'ReduceLROnPlateau'.
# This scheduler will automatically "reduce" the learning rate
# when it detects that our validation loss has "plateaued" (stopped improving).
scheduler5 = ReduceLROnPlateau(
    optimizer4,
    mode='min',      # It will look at the validation loss, where 'min' is better
    factor=0.1,      # Reduce LR by a factor of 10 (e.g., 0.001 -> 0.0001)
    patience=5,      # Wait 5 epochs for improvement before reducing
)

### 4.2 Training and testing model 5

In [29]:
# --- HYPERPARAMETERS ---
# ROLLOUT_STEPS must match the number of output steps provided by your DataLoader's y_batch.

ROLLOUT_STEPS = 5  
SAMPLING_PROB = 0.9 
EPOCHS =16
NOISE_LEVEL =0.05
PHYS_WEIGHT =0.001

result5 = train_and_test(scheduler5, model = model5,
              loss_fn = loss_fn5, optimizer = optimizer5,phys_weight = PHYS_WEIGHT, train_loader = train_loader_t10,
              dev_loader = dev_loader_t10, epochs = EPOCHS,device = device, noise_level = NOISE_LEVEL, rollout_steps =ROLLOUT_STEPS , sampling_prob =SAMPLING_PROB)

  0%|          | 0/16 [00:00<?, ?it/s]

--------------------------
Epoch: 0
--------------
Train data Loss:0.056 || Train physic Loss:0.000 || Test Loss:0.028
--------------------------
Epoch: 8
--------------
Train data Loss:0.014 || Train physic Loss:0.000 || Test Loss:0.009
--------------------------
It takes 3783.50 s to train FNO2D


In [22]:
    print(f"Result 5: {result5}")

Result 4: {'train_data_loss': [0.05087211075901985, 0.019582114146053792, 0.01718435420244932, 0.01622707962244749, 0.013885308791100979, 0.012675398913770914, 0.014305818380713462, 0.01390168257534504, 0.013852717347741126, 0.010688792301565409, 0.003958933025971055, 0.003676559438854456, 0.0035291603688895703, 0.0034527564647048713, 0.0033684210056811573, 0.003355002436041832], 'train_phys_loss': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'val_data_loss': [0.014770651550572009, 0.011961442513308327, 0.01349966850274118, 0.008752051269291503, 0.02810782082616144, 0.013955968230440739, 0.02735861696919818, 0.013814288179633344, 0.01048888521755387, 0.009351865329133098, 0.003875018720319317, 0.0037171958306519565, 0.0036459622057866616, 0.0034536947466575416, 0.0033496209067310307, 0.003377579756164152]}


## 6. Saving and comparing models

### 6.1 Saving model

In [None]:
import torch
import os

print("Saving initial untrained models...")

# 1. Define a path to save your models in your Drive
MODELS_PATH = Path("models")
os.makedirs(MODELS_PATH, exist_ok=True)

# 2. Define the save paths for each model
MODEL4_SAVE_PATH = MODELS_PATH / "model4_initial_state.pth"


# 3. Save the models' state dictionaries
# .state_dict() saves only the learnable parameters (weights and biases)
try:
    torch.save(model4.state_dict(), MODEL4_SAVE_PATH)

    print(f"Successfully saved model4 to: {MODEL4_SAVE_PATH}")

except Exception as e:
    print(f"Error saving models: {e}")

# To load them back later (for inference or to resume training), you would:
# 1. Re-create the model instance: model0 = FNO1D(WIDTH, K_MAX, x_grid_tensor).to(device)
# 2. Load the weights: model0.load_state_dict(torch.load(MODEL0_SAVE_PATH))

### 6.2 Compare models

In [30]:
import torch
import torch.nn.functional as F
import numpy as np

# This must match your model and dataset
GRID_SIZE = 64
CHANNELS = 2 # vx, vy
T_IN = 10    # 10 steps in

def evaluate_autoregressive_rollout(model, dataset, loss_fn,t_in, device):
    """
    Performs a full autoregressive rollout on the 2D grid dataset
    by loading each sample file.
    
    Args:
        model: Your trained 2D FNO model
        dataset: The CylinderFlowGridDataset object (e.g., dev_dataset)
        loss_fn: An instance of your 2D LpLoss (NOT DataLossOnly)
        device: Your "cuda" or "cpu" device
    """
    model.eval()
    total_rollout_loss = 0.0
    
    with torch.no_grad():
        # Loop through each sample file in the dataset
        for i in tqdm(range(len(dataset)), desc="Evaluating Rollout"):
            
            # 1. Load the full 600-step trajectory for one sample
            data = torch.load(dataset.file_paths[i])
            # Shape: (600, 64, 64, 2)
            velocity_traj = data['velocity'].to(device).to(torch.float32)

            # --- 2. Create the first input window ---
            # Get t=0...9
            x_window_t = velocity_traj[:t_in] # (10, 64, 64, 2)
            
            # Reshape to (1, 64, 64, 20) for the model
            current_window = x_window_t.permute(1, 2, 0, 3) 
            current_window = current_window.reshape(GRID_SIZE, GRID_SIZE, t_in * CHANNELS)
            current_window = current_window.unsqueeze(0) # Add batch dim

            # --- 3. Create the ground truth for comparison ---
            # Get t=10...599
            y_true_all_steps = velocity_traj[t_in:] # (590, 64, 64, 2)
            
            # Reshape to (1, 64, 64, 1180)
            y_true_rollout = y_true_all_steps.permute(1, 2, 0, 3)
            y_true_rollout = y_true_rollout.reshape(GRID_SIZE, GRID_SIZE, -1)
            y_true_rollout = y_true_rollout.unsqueeze(0)

            
            predictions = [] # List to store all 590 predictions
            n_steps = y_true_all_steps.shape[0] # 590 steps
            
            # --- 4. Start the rollout loop ---
            for _ in range(n_steps):
                # model input: (1, 64, 64, 20)
                # model output: (1, 64, 64, 2)
                y_pred_step = model(current_window)
                
                if torch.any(torch.isinf(y_pred_step)) or torch.any(torch.isnan(y_pred_step)):
                    print(f"!!! Rollout became unstable at step {i} !!!")
                    break 
                
                predictions.append(y_pred_step)
                
                # --- 5. Feed back in ---
                # Drop oldest 2 channels, add new 2 channels
                current_window = torch.cat(
                    (current_window[:, :, :, CHANNELS:], y_pred_step), 
                    dim=3 # Concatenate on the channel dimension
                )
            
            # --- 6. Stack all predictions ---
            # y_pred_full_rollout shape: (1, 64, 64, 1180)
            if len(predictions) != n_steps: # Handle unstable rollout
                print("Rollout failed, skipping this sample.")
                continue
                
            y_pred_full_rollout = torch.cat(predictions, dim=3)
            
            # --- 7. Calculate final loss ---
            # Use the simple LpLoss, not the DataLossOnly wrapper
            rollout_loss = loss_fn(x_window_t,y_pred_full_rollout, y_true_rollout)
            total_rollout_loss += rollout_loss[0]

    # Return the average loss over all test samples
    return total_rollout_loss / len(dataset.file_paths)

print("evaluate_autoregressive_rollout (2D version) function defined.")

evaluate_autoregressive_rollout (2D version) function defined.


In [31]:
# (Run this in a new cell after training model2)

print("Starting FAIR comparison rollout for Model 2 and 3...")

# Use test_loader_t1 (from cell 17) for a fair test
# Use loss_fn0 (from cell 55) for a fair data-only loss

rollout_loss_m4 = evaluate_autoregressive_rollout(
    model=model4, 
    dataset=test_data_t10, 
    loss_fn=loss_fn4,           
    device=device,
    t_in=10# <-- This correctly matches your model2
)

print("\n--- Fair Comparison Test Results ---")
print(f"Model 0 (Direct, Data-Only) Test Loss:     0.0617") # From cell 105
print(f"Model 1 (Direct, PINO, w=0.1) Test Loss:   0.063") # From cell 111
print(f"Model 4 (Autoregressive, 3-in) Test Loss: {rollout_loss_m4:.4f}")

Starting FAIR comparison rollout for Model 2 and 3...


Evaluating Rollout:   0%|          | 0/100 [00:00<?, ?it/s]


--- Fair Comparison Test Results ---
Model 0 (Direct, Data-Only) Test Loss:     0.0617
Model 1 (Direct, PINO, w=0.1) Test Loss:   0.063
Model 4 (Autoregressive, 3-in) Test Loss: 0.3364


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from results import result1, result0

# --- IMPORTANT ---
# You must copy the dictionary outputs from your notebook cells
# (cell 50 for model0 and cell 67 for model1) and paste them here.
# I have used your completed output from cell 50 as an example.
# You will need to replace results_1 with your full 30-epoch output.

# --- Create Epoch Array ---
# Assumes both models were trained for the same number of epochs
result0 = {'train_data_loss': [0.4110739827156067, 0.23925882577896118, 0.20458190143108368, 0.19567212462425232, 0.17208527028560638, 0.16243018209934235, 0.16161102056503296, 0.1497381031513214, 0.1424730122089386, 0.13952957093715668, 0.13167569041252136, 0.1314629167318344, 0.13347485661506653, 0.1252565085887909, 0.12282412499189377, 0.12299676984548569, 0.11490682512521744, 0.11774254590272903, 0.10972394794225693, 0.10895457118749619, 0.11255241930484772, 0.10872527211904526, 0.10500839352607727, 0.10531488806009293, 0.10197656601667404, 0.10225849598646164, 0.10453744977712631, 0.10029015690088272, 0.09715524315834045, 0.10300806164741516, 0.09627469629049301, 0.09561239928007126, 0.09352393448352814, 0.09420380741357803, 0.09108910709619522, 0.08958891034126282, 0.08952014893293381, 0.0945991724729538, 0.08834867179393768, 0.08973327279090881, 0.08752944320440292, 0.08816472440958023, 0.08264376223087311, 0.08628880232572556, 0.08571340143680573, 0.08524730801582336, 0.08302949368953705, 0.08613189309835434, 0.08262985199689865, 0.0860193744301796, 0.08372187614440918, 0.08175082504749298, 0.0784018486738205, 0.07959341257810593, 0.0778963565826416, 0.0784391239285469, 0.07724633812904358, 0.07945655286312103, 0.07577656954526901, 0.08160007745027542, 0.08002155274152756, 0.07753711193799973, 0.07859975844621658, 0.0764986202120781, 0.07406435161828995, 0.0723976269364357, 0.0739876925945282, 0.07461126893758774, 0.07415719330310822, 0.07250193506479263, 0.07158296555280685, 0.0725533589720726, 0.07088251411914825, 0.0713513195514679, 0.07630784064531326, 0.07213081419467926, 0.06057237461209297, 0.05909671634435654, 0.05885971337556839, 0.058606330305337906, 0.05847145989537239, 0.058534953743219376, 0.05826769769191742, 0.05820904299616814, 0.05816391482949257, 0.05802098661661148, 0.058009110391139984, 0.057899314910173416, 0.0577441044151783, 0.05807175859808922, 0.057901881635189056, 0.05766398832201958, 0.05750136449933052, 0.05765574425458908, 0.05767909809947014, 0.05718502774834633, 0.0571594312787056, 0.057440076023340225, 0.057252444326877594, 0.05690488591790199, 0.05690895393490791, 0.05691118910908699, 0.056899115443229675, 0.05672742426395416, 0.05663994699716568, 0.05657578259706497, 0.05658656731247902, 0.05656237527728081, 0.05650690570473671, 0.05641086399555206, 0.056626658886671066, 0.05617628991603851, 0.05610703304409981, 0.05624675750732422, 0.0564168281853199, 0.05596563592553139, 0.05595419928431511, 0.055922769010066986, 0.05574888736009598, 0.05567246302962303, 0.05585700646042824, 0.055710289627313614, 0.05597971752285957, 0.05580592527985573, 0.05584094300866127, 0.05574943125247955, 0.05546464025974274, 0.055211909115314484, 0.05526223033666611, 0.05522475764155388, 0.05564416944980621, 0.05541204661130905, 0.05530177429318428, 0.05499405786395073, 0.05497822165489197, 0.055076297372579575, 0.05496148020029068, 0.0547984354197979, 0.05514082685112953, 0.055032216012477875, 0.05497255548834801, 0.05481015145778656, 0.05463377758860588, 0.054746970534324646, 0.05454118177294731, 0.05435951054096222, 0.05476135388016701, 0.054547540843486786, 0.05460606887936592, 0.05433722585439682], 'train_phys_loss': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'val_data_loss': [0.2572142621354451, 0.21145666733620658, 0.19115776675088064, 0.1802690809681302, 0.17118003964424133, 0.15824853404173775, 0.17612133447139983, 0.14190959067098677, 0.13687605919345977, 0.1335370416442553, 0.13127026198402283, 0.1297357457261237, 0.12929621764591762, 0.12160444259643555, 0.15860963624621194, 0.12193896668770957, 0.13977357768823231, 0.11167937008634446, 0.11051414717757513, 0.11135667432395238, 0.11710662323804129, 0.10518188218748759, 0.10574372337451057, 0.11249114868659822, 0.10047837618797545, 0.10142433477772607, 0.09888798826270634, 0.09998914327413316, 0.10522396519543632, 0.10073239829332109, 0.09524016141418427, 0.09673324570296303, 0.0917700073785252, 0.09177173449406548, 0.09537116840245231, 0.0898829839295811, 0.09122747289282936, 0.09619156460440348, 0.09295563683623359, 0.09442581981420517, 0.0865695761546256, 0.08646621212126716, 0.08684119321997204, 0.089255215156646, 0.0872151656519799, 0.08562316390730086, 0.0906154742789647, 0.08481654382887341, 0.08532685396217164, 0.09498826461651969, 0.08751835993358068, 0.08428332919166201, 0.10068383375330577, 0.08187451710303624, 0.08577221821224879, 0.07869435054442239, 0.08390380430316167, 0.08048265462829954, 0.07998732867695037, 0.07940924309548877, 0.08420824247693258, 0.07788392381062584, 0.08739961293481645, 0.07972035036673622, 0.07721601213727679, 0.07708905519001068, 0.08067990259991752, 0.07722300066361351, 0.09537790822131294, 0.07313832954045325, 0.07719190785336116, 0.0766766376438595, 0.07478425805530851, 0.0733860337308475, 0.07394332036612526, 0.07324050420096942, 0.06501247840268272, 0.06459305080629531, 0.06465581023976916, 0.06449638975281564, 0.06492461772665145, 0.06398267458592143, 0.0644972435538731, 0.06389726597874884, 0.06378511274381289, 0.06382066389871022, 0.06389406625004042, 0.06369595667199483, 0.06350627009357725, 0.0635057757534678, 0.06327962041610763, 0.06329736746256313, 0.06380144559911319, 0.06318371796182223, 0.06303894472500635, 0.06305838967599565, 0.06306172854134015, 0.06316362587468964, 0.06280702772358107, 0.06297555837839369, 0.06427314351238901, 0.06272708405814474, 0.06253687392861124, 0.06268517643449799, 0.062455795646186855, 0.06254122944341765, 0.06238306195489944, 0.06241820912275996, 0.062523636493891, 0.06233104856477843, 0.0623762799752137, 0.06204700215704857, 0.06217261888678112, 0.06274812715867209, 0.0618926150694726, 0.06204395309563667, 0.06187254799500344, 0.061938941537860844, 0.06266835385135242, 0.06309285830883753, 0.06154843940148278, 0.06175081975876339, 0.06152673846199399, 0.06216449757653569, 0.06183923437954888, 0.06193630456451386, 0.06164784404256987, 0.06338022170322281, 0.06135745565333064, 0.06773435566869992, 0.06207913467808375, 0.0613168551926575, 0.061097964880958436, 0.06094158074212453, 0.06099840265417856, 0.06099581180347337, 0.06111503788639629, 0.06091841511310093, 0.06254472116392756, 0.06069481703970167, 0.060798205198749664, 0.0615965946917496, 0.060730867383499, 0.06066012004065135, 0.06074823173029082, 0.0605203460842844, 0.060685241033160496, 0.06406244617842492, 0.06041179910775215, 0.06047245848273474]}

epochs = range(len(result0['train_data_loss']))


# --- Create Plots ---
plt.figure(figsize=(18, 5))

# Plot 1: Training Data Loss Comparison
plt.subplot(1, 3, 1)
plt.plot(epochs, result0['train_data_loss'], label='Model 0 (Data Only)', color='blue')
plt.plot(epochs, result1['train_data_loss'], label='Model 1 (PINO)', color='red')
plt.title('Training Data Loss')
plt.xlabel('Epochs')
plt.ylabel('Relative L2 Loss (L_data)')
plt.legend()
plt.grid(True, linestyle=':')
plt.yscale('log') # Use log scale if losses are very different

# Plot 2: Validation Data Loss Comparison
plt.subplot(1, 3, 2)
plt.plot(epochs, result0['val_data_loss'], label='Model 0 (Data Only)', color='blue')
plt.plot(epochs, result1['val_data_loss'], label='Model 1 (PINO)', color='red')
plt.title('Validation Data Loss')
plt.xlabel('Epochs')
plt.ylabel('Relative L2 Loss (L_data)')
plt.legend()
plt.grid(True, linestyle=':')
plt.yscale('log')

# Plot 3: Model 1 (PINO) Loss Breakdown
plt.subplot(1, 3, 3)
plt.plot(epochs, result1['train_data_loss'], label='Data Loss', color='red')
plt.plot(epochs, result1['train_phys_loss'], label='Physics Loss', color='green')
plt.title('Model 1 (PINO) Loss Components')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, linestyle=':')
plt.yscale('log') # Log scale is essential here since physics loss is large



plt.tight_layout()
plt.show()