# Test Custom Loss Function for GINO Model

This notebook tests a custom loss function for the GINO multi-column model that can apply different loss weights or functions to individual target variables (e.g., mass_concentration vs head).

## 1. Import Required Libraries

In [2]:
import os
import sys
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

from src.data.transform import Normalize
from src.data.patch_dataset_multi_col import GWPatchDatasetMultiCol
from src.data.batch_sampler import PatchBatchSampler
from src.models.neuralop.gino import GINO
from torch.utils.data import DataLoader

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
PyTorch version: 2.5.1+cu118
CUDA available: True


## 2. Configuration and Paths

In [22]:
# Paths
base_data_dir = '/srv/scratch/z5370003/projects/data/groundwater/FEFLOW/coastal/variable_density'
raw_data_dir = os.path.join(base_data_dir, 'all')
patch_data_dir = os.path.join(base_data_dir, 'filter_patch')

# Model checkpoint path (update this to your trained model)
model_path = '/srv/scratch/z5370003/projects/results/04_groundwater/variable_density/GINO/multi_col/mass_conc_head/exp_lr4.8e4_exp_bs512/gino_multi_20251010_092951/checkpoints/checkpoint_epoch_0009.pth'

# Model parameters
target_cols = ['mass_concentration', 'head']
target_col_indices = [0, 1]  # Indices in the observation columns
input_window_size = 5
output_window_size = 5
batch_size = 4  # Small batch for testing

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cuda


## 3. Load Data Transforms and Dataset

In [23]:
# Calculate transforms
df = pd.read_csv(os.path.join(raw_data_dir, '0000.csv'))

# Coordinate transform
coord_mean = df[['X', 'Y', 'Z']].mean().values
coord_std = df[['X', 'Y', 'Z']].std().values
coord_transform = Normalize(mean=coord_mean, std=coord_std)

# Observation transform
target_obs_cols = ['mass_concentration', 'head', 'pressure']
obs_mean = df[target_obs_cols].mean().values
obs_std = df[target_obs_cols].std().values
obs_transform = Normalize(mean=obs_mean, std=obs_std)

print(f"Coordinate mean: {coord_mean}")
print(f"Coordinate std: {coord_std}")
print(f"Observation mean: {obs_mean}")
print(f"Observation std: {obs_std}")

del df

Coordinate mean: [ 3.57225665e+05  6.45774324e+06 -9.27782248e+00]
Coordinate std: [569.1699999  566.35797379  15.26565618]
Observation mean: [1.77942252e+04 3.95881156e-01 9.48469883e+01]
Observation std: [1.55859465e+04 2.13080032e-01 1.51226320e+02]


In [24]:
# Create train dataset
train_ds = GWPatchDatasetMultiCol(
    data_path=patch_data_dir,
    dataset='train',
    coord_transform=coord_transform,
    obs_transform=obs_transform,
    input_window_size=input_window_size,
    output_window_size=output_window_size,
    target_col_indices=target_col_indices,
)

print(f"Train dataset length: {len(train_ds)}")
print(f"Target columns: {target_cols}")
print(f"Target column indices: {target_col_indices}")

Train dataset length: 13180
Target columns: ['mass_concentration', 'head']
Target column indices: [0, 1]


## 4. Create DataLoader with Collate Function

In [25]:
# Model configuration
coord_dim = 3
latent_query_dims = (32, 32, 24)
n_target_cols = len(target_cols)

def make_collate_fn():
    """Create collate function for batch processing."""
    def collate_fn(batch_samples):
        # Get shared point cloud from first sample
        core_coords = batch_samples[0]['core_coords']
        ghost_coords = batch_samples[0]['ghost_coords']
        
        # Combine core and ghost points
        point_coords = torch.concat([core_coords, ghost_coords], dim=0).float()
        
        # Create latent queries grid over the per-batch bounding box
        coords_min = torch.min(point_coords, dim=0).values
        coords_max = torch.max(point_coords, dim=0).values
        latent_query_arr = [
            torch.linspace(coords_min[i], coords_max[i], latent_query_dims[i], device=device)
            for i in range(coord_dim)
        ]
        latent_queries = torch.stack(torch.meshgrid(*latent_query_arr, indexing='ij'), dim=-1)
        
        # Build batched sequences
        x_list, y_list = [], []
        for sample in batch_samples:
            sample_input = torch.concat([sample['core_in'], sample['ghost_in']], dim=0).float().unsqueeze(0)
            sample_output = torch.concat([sample['core_out'], sample['ghost_out']], dim=0).float().unsqueeze(0)
            x_list.append(sample_input)
            y_list.append(sample_output)
        
        x = torch.cat(x_list, dim=0)  # [B, N_points, input_window_size * n_target_cols]
        y = torch.cat(y_list, dim=0)  # [B, N_points, output_window_size * n_target_cols]
        
        return {
            'point_coords': point_coords,
            'latent_queries': latent_queries,
            'x': x,
            'y': y,
            'core_len': len(core_coords),
            'patch_id': sample['patch_id']
        }
    return collate_fn

# Create sampler and dataloader
sampler = PatchBatchSampler(
    train_ds,
    batch_size=batch_size,
    shuffle_within_batches=True,
    shuffle_patches=True,
    seed=42
)

collate_fn = make_collate_fn()
train_loader = DataLoader(train_ds, batch_sampler=sampler, collate_fn=collate_fn)

print(f"Number of batches: {len(train_loader)}")

Building patch groups (one-time operation)...
Building patch_ids cache...
Cached 13180 patch_ids
Found 20 patches with 13180 total samples
Patch sizes: min=659, max=659, avg=659.0
Pre-built 3300 batches
Number of batches: 3300


## 5. Load GINO Model

In [26]:
# Load checkpoint
print(f"Loading checkpoint from: {model_path}")
checkpoint = torch.load(model_path, map_location=device)

# Extract saved configuration
saved_args = checkpoint['args']
print("\nSaved model configuration:")
print(f"- FNO modes: {saved_args.fno_n_modes}")
print(f"- FNO layers: {saved_args.fno_n_layers}")
print(f"- Hidden channels: {saved_args.fno_hidden_channels}")
print(f"- GNO radius: {saved_args.gno_radius}")
print(f"- Latent dims: {saved_args.latent_query_dims}")
print(f"- Target columns: {saved_args.target_cols}")
print(f"- Number of target columns: {saved_args.n_target_cols}")

# Initialize model with saved configuration
model = GINO(
    # Input GNO configuration
    in_gno_coord_dim=saved_args.coord_dim,
    in_gno_radius=saved_args.gno_radius,
    in_gno_out_channels=saved_args.in_gno_out_channels,
    in_gno_channel_mlp_layers=saved_args.in_gno_channel_mlp_layers,
    
    # FNO configuration
    fno_n_layers=saved_args.fno_n_layers,
    fno_n_modes=saved_args.fno_n_modes,
    fno_hidden_channels=saved_args.fno_hidden_channels,
    lifting_channels=saved_args.lifting_channels,
    
    # Output GNO configuration
    out_gno_coord_dim=saved_args.coord_dim,
    out_gno_radius=saved_args.gno_radius,
    out_gno_channel_mlp_layers=saved_args.out_gno_channel_mlp_layers,
    projection_channel_ratio=saved_args.projection_channel_ratio,
    out_channels=saved_args.out_channels,
).to(device)

# Load model weights
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"\nModel loaded successfully!")
print(f"Epoch: {checkpoint['epoch'] + 1}")
print(f"Train loss: {checkpoint['train_losses'][-1]:.6f}")
print(f"Val loss: {checkpoint['val_losses'][-1]:.6f}")

Loading checkpoint from: /srv/scratch/z5370003/projects/results/04_groundwater/variable_density/GINO/multi_col/mass_conc_head/exp_lr4.8e4_exp_bs512/gino_multi_20251010_092951/checkpoints/checkpoint_epoch_0009.pth


  checkpoint = torch.load(model_path, map_location=device)



Saved model configuration:
- FNO modes: (12, 12, 8)
- FNO layers: 4
- Hidden channels: 64
- GNO radius: 0.15
- Latent dims: (32, 32, 24)
- Target columns: ['mass_concentration', 'head']
- Number of target columns: 2

Model loaded successfully!
Epoch: 10
Train loss: 0.330761
Val loss: 0.339951


## 6. Test Forward Pass

In [27]:
# Get a single batch
batch = next(iter(train_loader))

# Move to device
point_coords = batch['point_coords'].to(device).float()
latent_queries = batch['latent_queries'].to(device).float()
x = batch['x'].to(device).float()
y = batch['y'].to(device).float()
core_len = batch['core_len']

print(f"Batch shapes:")
print(f"  point_coords: {point_coords.shape}")
print(f"  latent_queries: {latent_queries.shape}")
print(f"  x (input): {x.shape}")
print(f"  y (target): {y.shape}")
print(f"  core_len: {core_len}")

# Forward pass
with torch.no_grad():
    outputs = model(
        input_geom=point_coords,
        latent_queries=latent_queries,
        x=x,
        output_queries=point_coords,
    )

print(f"\nOutput shape: {outputs.shape}")
print(f"Expected shape: [batch_size={x.shape[0]}, n_points={point_coords.shape[0]}, output_window_size * n_target_cols={output_window_size * n_target_cols}]")

Batch shapes:
  point_coords: torch.Size([512, 3])
  latent_queries: torch.Size([32, 32, 24, 3])
  x (input): torch.Size([4, 512, 10])
  y (target): torch.Size([4, 512, 10])
  core_len: 405

Output shape: torch.Size([4, 512, 10])
Expected shape: [batch_size=4, n_points=512, output_window_size * n_target_cols=10]


## 7. Define Custom Loss Function

This custom loss function allows:
- Different loss weights for different target variables
- Different loss functions per variable (e.g., MSE for mass_concentration, MAE for head)
- Reshaping predictions from concatenated format to [batch, points, timesteps, n_variables]

In [None]:
from src.models.neuralop.losses import LpLoss

def variance_aware_multicol_loss(
    y_pred,
    y_true,
    output_window_size,
    target_cols,
    lambda_conc_focus=0.5,
    alpha=0.3,
    beta=2.0,
):
    """
    y_pred, y_true: [B, N_points, T_out * C]
    output_window_size: T_out
    target_cols: list like ['mass_concentration', 'head']
    lambda_conc_focus: how much extra weight to put on variance-aware conc loss
    alpha: base weight for low-variance nodes (0<alpha<1)
    beta: exponent controlling how sharply we emphasise high-variance nodes
    """

    B, N, TC = y_pred.shape
    C = TC // output_window_size
    assert TC == output_window_size * C

    # reshape to [B, N, T_out, C]
    y_pred = y_pred.view(B, N, output_window_size, C)
    y_true = y_true.view(B, N, output_window_size, C)

    # Loss function
    global_loss_fn = LpLoss(d=2, p=2, reduce_dims=[0, 1], reductions='mean')
    local_loss_fn = LpLoss(d=1, p=2, reductions='mean')

    # ----- 1) global MSE over all variables -----
    global_loss = global_loss_fn(y_pred, y_true)

    # ----- 2) variance-aware term for concentration -----
    conc_idx = target_cols.index('mass_concentration')  # assumes name present

    conc_pred = y_pred[..., conc_idx]   # [B, N, T]
    conc_true = y_true[..., conc_idx]   # [B, N, T]

    # node-wise temporal variance (on *normalized* targets)
    with torch.no_grad():
        # var over time dimension
        var_t = conc_true.var(dim=[0, 2], unbiased=False)  # [N]
        
        # normalise variance within the batch
        # (avoid division by tiny mean)
        var_norm = var_t / (var_t.mean() + 1e-6)

        # map to weights in [alpha, ~1] with emphasis on high variance
        #   w = alpha + (1-alpha) * var_norm^beta, then renormalise mean to 1
        weights = alpha + (1.0 - alpha) * (var_norm ** beta)
        weights = weights / (weights.mean() + 1e-6)   # keep gradients stable
        

    conc_l2 = local_loss_fn(conc_pred, conc_true)        # [N]
    conc_var_loss = (weights * conc_l2).mean()

    # ----- 3) combine -----
    loss = (1.0 - lambda_conc_focus) * global_loss + lambda_conc_focus * conc_var_loss

    return loss, global_loss, conc_var_loss

In [101]:
# Get a single batch
batch = next(iter(train_loader))

# Move to device
point_coords = batch['point_coords'].to(device).float()
latent_queries = batch['latent_queries'].to(device).float()
x = batch['x'].to(device).float()
y = batch['y'].to(device).float()
core_len = batch['core_len']

print(f"Batch shapes:")
print(f"  point_coords: {point_coords.shape}")
print(f"  latent_queries: {latent_queries.shape}")
print(f"  x (input): {x.shape}")
print(f"  y (target): {y.shape}")
print(f"  core_len: {core_len}")

# Forward pass
with torch.no_grad():
    outputs = model(
        input_geom=point_coords,
        latent_queries=latent_queries,
        x=x,
        output_queries=point_coords,
    )

    loss = variance_aware_multicol_loss(
        y_pred=outputs,
        y_true=y,
        output_window_size=output_window_size,
        target_cols=target_cols,
    )

loss

Batch shapes:
  point_coords: torch.Size([512, 3])
  latent_queries: torch.Size([32, 32, 24, 3])
  x (input): torch.Size([4, 512, 10])
  y (target): torch.Size([4, 512, 10])
  core_len: 405
tensor(512., device='cuda:0')


(tensor(0.8502, device='cuda:0'),
 {'global_mse': tensor(0.3516, device='cuda:0'),
  'conc_var_loss': tensor(1.3488, device='cuda:0')})