In [2]:
import os
import pickle
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader

import sys
sys.path.append('/srv/scratch/z5370003/projects/src/04_groundwater/variable_density')

from src.data.transform import Normalize
from src.data.patch_dataset_multi_col import GWPatchDatasetMultiCol
from src.data.batch_sampler import PatchBatchSampler
from src.models.neuralop.gino import GINO

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


# GINO Model Predictions Analysis

This notebook loads a trained GINO multi-column model and datasets for analysis.

## Configuration

In [3]:
# Paths
base_data_dir = '/srv/scratch/z5370003/projects/data/groundwater/FEFLOW/coastal/variable_density/'
raw_data_dir = os.path.join(base_data_dir, 'all')
patch_data_dir = os.path.join(base_data_dir, 'filter_patch')
model_path = '/srv/scratch/z5370003/projects/results/04_groundwater/variable_density/GINO/multi_col/mass_conc_head/exp_lr4.8e4_exp_bs512/gino_multi_20251010_092951/checkpoints/latest_checkpoint.pth'  # UPDATE THIS

# Model configuration
target_cols = ['mass_concentration', 'head']  # UPDATE THIS
input_window_size = 5
output_window_size = 5
batch_size = 32
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Target column indices
names_to_idx = {
    'mass_concentration': 0,
    'head': 1,
    'pressure': 2
}
target_col_indices = [names_to_idx[col] for col in target_cols]
n_target_cols = len(target_cols)

print(f"Device: {device}")
print(f"Target columns: {target_cols}")
print(f"Target column indices: {target_col_indices}")

Device: cpu
Target columns: ['mass_concentration', 'head']
Target column indices: [0, 1]


## Data Transforms

In [4]:
# Calculate coordinate normalization
df = pd.read_csv(os.path.join(raw_data_dir, '0000.csv'))
coord_mean = df[['X', 'Y', 'Z']].mean().values
coord_std = df[['X', 'Y', 'Z']].std().values
coord_transform = Normalize(mean=coord_mean, std=coord_std)
print(f"Coordinate mean: {coord_mean}")
print(f"Coordinate std: {coord_std}")

# Calculate observation normalization
target_obs_cols = ['mass_concentration', 'head', 'pressure']
obs_mean = df[target_obs_cols].mean().values
obs_std = df[target_obs_cols].std().values
obs_transform = Normalize(mean=obs_mean, std=obs_std)
print(f"Observation mean: {obs_mean}")
print(f"Observation std: {obs_std}")

del df

Coordinate mean: [ 3.57225665e+05  6.45774324e+06 -9.27782248e+00]
Coordinate std: [569.1699999  566.35797379  15.26565618]
Observation mean: [1.77942252e+04 3.95881156e-01 9.48469883e+01]
Observation std: [1.55859465e+04 2.13080032e-01 1.51226320e+02]


## Load Datasets

In [5]:
# Create train dataset
train_ds = GWPatchDatasetMultiCol(
    data_path=patch_data_dir,
    dataset='train',
    coord_transform=coord_transform,
    obs_transform=obs_transform,
    input_window_size=input_window_size,
    output_window_size=output_window_size,
    target_col_indices=target_col_indices,
)

# Create validation dataset
val_ds = GWPatchDatasetMultiCol(
    data_path=patch_data_dir,
    dataset='val',
    coord_transform=coord_transform,
    obs_transform=obs_transform,
    input_window_size=input_window_size,
    output_window_size=output_window_size,
    target_col_indices=target_col_indices,
)

print(f"Train dataset size: {len(train_ds)}")
print(f"Val dataset size: {len(val_ds)}")

Train dataset size: 13180
Val dataset size: 5560


## Load Model

In [6]:
# Load checkpoint
checkpoint = torch.load(model_path, map_location=device)
saved_args = checkpoint['args']

print("Saved model configuration:")
print(f"- FNO modes: {saved_args.fno_n_modes}")
print(f"- FNO layers: {saved_args.fno_n_layers}")
print(f"- Hidden channels: {saved_args.fno_hidden_channels}")
print(f"- GNO radius: {saved_args.gno_radius}")
print(f"- Latent dims: {saved_args.latent_query_dims}")
print(f"- Target columns: {saved_args.target_cols}")
print(f"- Number of target columns: {saved_args.n_target_cols}")

# Initialize model
model = GINO(
    in_gno_coord_dim=saved_args.coord_dim,
    in_gno_radius=saved_args.gno_radius,
    in_gno_out_channels=saved_args.in_gno_out_channels,
    in_gno_channel_mlp_layers=saved_args.in_gno_channel_mlp_layers,
    fno_n_layers=saved_args.fno_n_layers,
    fno_n_modes=saved_args.fno_n_modes,
    fno_hidden_channels=saved_args.fno_hidden_channels,
    lifting_channels=saved_args.lifting_channels,
    out_gno_coord_dim=saved_args.coord_dim,
    out_gno_radius=saved_args.gno_radius,
    out_gno_channel_mlp_layers=saved_args.out_gno_channel_mlp_layers,
    projection_channel_ratio=saved_args.projection_channel_ratio,
    out_channels=saved_args.out_channels,
).to(device)

# Load weights
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"\nCheckpoint training progress:")
print(f"- Epoch: {checkpoint['epoch'] + 1}")
print(f"- Train loss: {checkpoint['train_losses'][-1]:.4f}")
print(f"- Val loss: {checkpoint['val_losses'][-1]:.4f}")
print(f"\nModel loaded successfully!")

  checkpoint = torch.load(model_path, map_location=device)


Saved model configuration:
- FNO modes: (12, 12, 8)
- FNO layers: 4
- Hidden channels: 64
- GNO radius: 0.15
- Latent dims: (32, 32, 24)
- Target columns: ['mass_concentration', 'head']
- Number of target columns: 2

Checkpoint training progress:
- Epoch: 250
- Train loss: 0.2006
- Val loss: 0.2706

Model loaded successfully!


## Helper Functions

In [7]:
def make_collate_fn(latent_query_dims=(32, 32, 24), coord_dim=3, device='cuda'):
    """Create collate function for batch processing (same as in generate_gino_predictions_multi_col.py)."""
    def collate_fn(batch_samples):
        # Get shared point cloud from first sample
        core_coords = batch_samples[0]['core_coords']
        ghost_coords = batch_samples[0]['ghost_coords']
        
        # Combine core and ghost points
        point_coords = torch.concat([core_coords, ghost_coords], dim=0).float()
        
        # Create latent queries grid over the per-batch bounding box
        coords_min = torch.min(point_coords, dim=0).values
        coords_max = torch.max(point_coords, dim=0).values
        latent_query_arr = [
            torch.linspace(coords_min[i], coords_max[i], latent_query_dims[i], device=device)
            for i in range(coord_dim)
        ]
        latent_queries = torch.stack(torch.meshgrid(*latent_query_arr, indexing='ij'), dim=-1)
        
        # Build batched sequences
        x_list, y_list = [], []
        for sample in batch_samples:
            sample_input = torch.concat([sample['core_in'], sample['ghost_in']], dim=0).float().unsqueeze(0)
            sample_output = torch.concat([sample['core_out'], sample['ghost_out']], dim=0).float().unsqueeze(0)
            x_list.append(sample_input)
            y_list.append(sample_output)
        
        x = torch.cat(x_list, dim=0)  # [B, N_points, input_window_size * n_target_cols]
        y = torch.cat(y_list, dim=0)  # [B, N_points, output_window_size * n_target_cols]
        
        return {
            'point_coords': point_coords,
            'latent_queries': latent_queries,
            'x': x,
            'y': y,
            'core_len': len(core_coords),
            'patch_id': sample['patch_id']
        }
    return collate_fn


def reshape_multi_col_predictions(predictions, output_window_size, n_target_cols):
    """
    Reshape concatenated predictions to separate target columns.
    
    The dataset concatenates data as: [t0_var0, t0_var1, t1_var0, t1_var1, t2_var0, t2_var1, ...]
    This is because _concat_sequence does: seq.reshape(n_points, -1) on [n_points, window_size, n_target_cols]
    which flattens in row-major order, interleaving timesteps and variables.
    
    Args:
        predictions: Array of shape [N_samples, N_points, output_window_size * n_target_cols]
        output_window_size: Number of timesteps
        n_target_cols: Number of target columns
        
    Returns:
        Array of shape [N_samples, N_points, output_window_size, n_target_cols]
    """
    n_samples, n_points, total_size = predictions.shape
    # From [N_samples, N_points, T*C] to [N_samples, N_points, T, C]
    # where T = output_window_size and C = n_target_cols
    if total_size != output_window_size * n_target_cols:
        raise ValueError(f"Expected total size {output_window_size * n_target_cols} (output_window_size={output_window_size} * n_target_cols={n_target_cols}), but got {total_size}")
    
    # The data is stored as [t0_v0, t0_v1, t1_v0, t1_v1, ...] for each point
    # So we reshape to [N_samples, N_points, output_window_size, n_target_cols] directly
    # This naturally deinterleaves the timesteps and variables
    reshaped = predictions.reshape(n_samples, n_points, output_window_size, n_target_cols)
    return reshaped

print("Helper functions defined.")

Helper functions defined.


## Example: Get a single batch for analysis

In [8]:
import time
# Create a sampler and loader
sampler = PatchBatchSampler(
    train_ds, 
    batch_size=batch_size,
    shuffle_within_batches=False,
    shuffle_patches=True,
    seed=int(time.time())
)

# Get latent query dimensions from the saved model config
latent_query_dims = saved_args.latent_query_dims
collate_fn = make_collate_fn(latent_query_dims=latent_query_dims, coord_dim=3, device=device)
val_loader = DataLoader(train_ds, batch_sampler=sampler, collate_fn=collate_fn)

# Get a single batch
batch = next(iter(val_loader))

print(f"Point coords shape: {batch['point_coords'].shape}")
print(f"Latent queries shape: {batch['latent_queries'].shape}")
print(f"Input x shape: {batch['x'].shape}")
print(f"Target y shape: {batch['y'].shape}")
print(f"Core length: {batch['core_len']}")
print(f"Patch ID: {batch['patch_id']}")

Building patch groups (one-time operation)...
Building patch_ids cache...
Cached 13180 patch_ids
Found 20 patches with 13180 total samples
Patch sizes: min=659, max=659, avg=659.0
Pre-built 420 batches
Point coords shape: torch.Size([540, 3])
Latent queries shape: torch.Size([32, 32, 24, 3])
Input x shape: torch.Size([32, 540, 10])
Target y shape: torch.Size([32, 540, 10])
Core length: 453
Patch ID: 20


## Example: Generate predictions for a batch

In [9]:
# Move batch to device
point_coords = batch['point_coords'].to(device).float()
latent_queries = batch['latent_queries'].to(device).float()
x = batch['x'].to(device).float()
y = batch['y'].to(device).float()
core_len = batch['core_len']

# Generate predictions
with torch.no_grad():
    outputs = model(
        input_geom=point_coords,
        latent_queries=latent_queries,
        x=x,
        output_queries=point_coords,
    )

# Extract core points only (exclude ghost points)
pred_obs = outputs[:, :core_len].cpu().numpy()
target_obs = y[:, :core_len].cpu().numpy()
coords = point_coords[:core_len].cpu().numpy()
input_obs = x[:, :core_len].cpu().numpy()

print(f"Predictions shape: {pred_obs.shape}")
print(f"Target shape: {target_obs.shape}")
print(f"Coords shape: {coords.shape}")
print(f"Inputs shape: {input_obs.shape}")


# Reshape to separate target columns
pred_reshaped = reshape_multi_col_predictions(pred_obs, output_window_size, n_target_cols)
target_reshaped = reshape_multi_col_predictions(target_obs, output_window_size, n_target_cols)
input_reshaped = reshape_multi_col_predictions(input_obs, input_window_size, n_target_cols)

print(f"\nPredictions reshaped: {pred_reshaped.shape}")
print(f"Target reshaped: {target_reshaped.shape}")
print(f"Shape format: [batch_size, n_points, output_window_size, n_target_cols]")

Predictions shape: (32, 453, 10)
Target shape: (32, 453, 10)
Coords shape: (453, 3)
Inputs shape: (32, 453, 10)

Predictions reshaped: (32, 453, 5, 2)
Target reshaped: (32, 453, 5, 2)
Shape format: [batch_size, n_points, output_window_size, n_target_cols]


## Your Analysis Code

Add your analysis cells below.

In [10]:
from sklearn.metrics import mean_squared_error

In [11]:
for i in range(output_window_size):
    print(f"At {i+1} ts")
    for j in range(1, n_target_cols):
        # print(f"  for variable {target_cols[j]}")
        pred_at_ts = pred_reshaped[:, :, i, j]
        target_at_ts = target_reshaped[:, :, i, j]
        input_at_last_ts = input_reshaped[:, :, -1, j]

        err_w_target = mean_squared_error(pred_at_ts.reshape(-1, 1), target_at_ts.reshape(-1, 1))
        err_w_input = mean_squared_error(pred_at_ts.reshape(-1, 1), input_at_last_ts.reshape(-1, 1))


        print(f"    Error with target: {err_w_target:.4f}")
        print(f"    Error with input: {err_w_input:.4f}")

At 1 ts
    Error with target: 0.0248
    Error with input: 0.0240
At 2 ts
    Error with target: 0.0604
    Error with input: 0.0252
At 3 ts
    Error with target: 0.0790
    Error with input: 0.0338
At 4 ts
    Error with target: 0.0773
    Error with input: 0.0404
At 5 ts
    Error with target: 0.0789
    Error with input: 0.0426


In [12]:
for i in range(output_window_size):
    print(f"At {i+1} ts")
    for j in range(0, n_target_cols-1):
        # print(f"  for variable {target_cols[j]}")
        pred_at_ts = pred_reshaped[:, :, i, j]
        target_at_ts = target_reshaped[:, :, i, j]
        input_at_last_ts = input_reshaped[:, :, -1, j]

        err_w_target = mean_squared_error(pred_at_ts.reshape(-1, 1), target_at_ts.reshape(-1, 1))
        err_w_input = mean_squared_error(pred_at_ts.reshape(-1, 1), input_at_last_ts.reshape(-1, 1))


        print(f"    Error with target: {err_w_target:.4f}")
        print(f"    Error with input: {err_w_input:.4f}")
        

At 1 ts
    Error with target: 0.0006
    Error with input: 0.0006
At 2 ts
    Error with target: 0.0006
    Error with input: 0.0006
At 3 ts
    Error with target: 0.0006
    Error with input: 0.0006
At 4 ts
    Error with target: 0.0006
    Error with input: 0.0006
At 5 ts
    Error with target: 0.0006
    Error with input: 0.0006
