In [1]:
import sys
import torch
import numpy as np
sys.path.append('..')

from src.models.neuralop.fno import FNOInterpolate
from src.data.patch_dataset_multi_col import GWPatchDatasetMultiCol
from src.data.batch_sampler import PatchBatchSampler
from src.data.data_utils import (
    calculate_coord_transform,
    calculate_obs_transform,
    create_patch_datasets,
    make_collate_fn,
)
from torch.utils.data import DataLoader

print("All imports successful!")

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
All imports successful!


## Configuration

In [8]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Data configuration
base_data_dir = '/srv/scratch/z5370003/projects/data/groundwater/FEFLOW/coastal/variable_density'
raw_data_dir = f'{base_data_dir}/all'
patch_data_dir = f'{base_data_dir}/filter_patch'

# Dataset parameters
target_cols = ['mass_concentration', 'head']
target_col_indices = [0, 1]  # Corresponding to mass_concentration and head
input_window_size = 3
output_window_size = 1
batch_size = 2

# Model hyperparameters (matching GINO training config)
coord_dim = 3
latent_grid_size = (32, 32, 24)  # Matching latent_query_dims from GINO training
n_target_cols = len(target_cols)
in_channels = input_window_size * n_target_cols  # 3 * 2 = 6
out_channels = output_window_size * n_target_cols  # 1 * 2 = 2
latent_feature_channels = None
fno_hidden_channels = 64
fno_n_layers = 4
fno_n_modes = (12, 12, 8)
lifting_channels = 64

print(f"Target columns: {target_cols}")
print(f"Input/Output channels: {in_channels}/{out_channels}")
print(f"Latent grid size: {latent_grid_size}")

Using device: cuda
Target columns: ['mass_concentration', 'head']
Input/Output channels: 6/2
Latent grid size: (32, 32, 24)


## Load Patch Dataset

In [9]:
# Calculate data transforms
print("Calculating coordinate transform...")
coord_transform = calculate_coord_transform(raw_data_dir)

print("Calculating observation transform...")
obs_transform = calculate_obs_transform(
    raw_data_dir,
    target_obs_cols=['mass_concentration', 'head', 'pressure']
)

print("Creating datasets...")
train_ds, val_ds = create_patch_datasets(
    dataset_class=GWPatchDatasetMultiCol,
    patch_data_dir=patch_data_dir,
    coord_transform=coord_transform,
    obs_transform=obs_transform,
    target_col_indices=target_col_indices,
    input_window_size=input_window_size,
    output_window_size=output_window_size,
)

print(f"Dataset sizes - Train: {len(train_ds)}, Val: {len(val_ds)}")

Calculating coordinate transform...
Coordinate mean: [ 3.57225665e+05  6.45774324e+06 -9.27782248e+00]
Coordinate std: [569.1699999  566.35797379  15.26565618]
Calculating observation transform...
Output mean: [1.77942252e+04 3.95881156e-01 9.48469883e+01]
Output std: [1.55859465e+04 2.13080032e-01 1.51226320e+02]
Creating datasets...
Computed variance-aware weights for 20 patches
Dataset variance range: [0.000000, 405223819.275790]
Dataset mean variance: 3213522.908705
Weight range: [0.4446, 3.4922]
Weight std: 0.2615
Computed variance-aware weights for 20 patches
Dataset variance range: [0.000000, 405223819.275790]
Dataset mean variance: 3213522.908705
Weight range: [0.4446, 3.4922]
Weight std: 0.2615
Dataset sizes - Train: 13300, Val: 5680


## Create Data Loader

In [10]:
# Create batch sampler and data loader
val_sampler = PatchBatchSampler(
    val_ds, 
    batch_size=batch_size,
    shuffle_within_batches=False,
    shuffle_patches=False,
    seed=None
)

# Create collate function (simplified args)
class SimpleArgs:
    def __init__(self):
        self.input_window_size = input_window_size
        self.output_window_size = output_window_size
        self.latent_query_dims = latent_grid_size
        self.use_open3d = False
        self.device = device

args = SimpleArgs()
collate_fn = make_collate_fn(args, coord_dim=coord_dim)

val_loader = DataLoader(val_ds, batch_sampler=val_sampler, collate_fn=collate_fn)

print(f"Data loader created with {len(val_loader)} batches")
print("Getting first batch...")

Building patch groups (one-time operation)...
Building patch_ids cache...
Cached 5680 patch_ids
Found 20 patches with 5680 total samples
Patch sizes: min=284, max=284, avg=284.0
Pre-built 2840 batches
Data loader created with 2840 batches
Getting first batch...


## Initialize FNOInterpolate Model

In [11]:
model = FNOInterpolate(
    latent_grid_size=latent_grid_size,
    coord_dim=coord_dim,
    in_channels=in_channels,
    out_channels=out_channels,
    latent_feature_channels=latent_feature_channels,
    fno_n_layers=fno_n_layers,
    fno_n_modes=fno_n_modes,
    fno_hidden_channels=fno_hidden_channels,
    lifting_channels=lifting_channels,
)

model = model.to(device)
model.eval()

print(f"Model initialized on device: {device}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

Model initialized on device: cuda
Total parameters: 23,631,362


## Extract and Inspect Batch Data

In [12]:
# Get first batch from validation loader
batch = next(iter(val_loader))

# First, inspect what keys are in the batch
print("Available batch keys:")
print(batch.keys())
print()

# Extract batch components (using tensors, not dictionaries)
input_geom = batch['point_coords'].to(device)
latent_queries = batch['latent_queries'].to(device)
output_queries = batch['point_coords'].to(device)  # Use same points as input for testing
x_input = batch['x'].to(device)
y_target = batch['y'].to(device)
latent_features = batch.get('latent_features')
if latent_features is not None:
    latent_features = latent_features.to(device)

print("Batch data shapes:")
print(f"  input_geom: {input_geom.shape}")
print(f"  latent_queries: {latent_queries.shape}")
print(f"  x_input: {x_input.shape}")
print(f"  y_target: {y_target.shape}")
print(f"  latent_features: {latent_features.shape if latent_features is not None else None}")
print(f"  output_queries: {output_queries.shape}")

Available batch keys:
dict_keys(['point_coords', 'latent_queries', 'x', 'y', 'core_len', 'weights'])

Batch data shapes:
  input_geom: torch.Size([512, 3])
  latent_queries: torch.Size([32, 32, 24, 3])
  x_input: torch.Size([2, 512, 6])
  y_target: torch.Size([2, 512, 2])
  latent_features: None
  output_queries: torch.Size([512, 3])


## Forward Pass Test

In [13]:
# Run forward pass through FNOInterpolate model
print("Running forward pass...")
with torch.no_grad():
    y_pred = model(
        input_geom=input_geom,
        latent_queries=latent_queries,
        output_queries=output_queries,
        x=x_input,
        latent_features=latent_features
    )

print(f"\nForward pass successful!")
print(f"Prediction output shape: {y_pred.shape}")
print(f"Target output shape: {y_target.shape}")

# Verify shapes match
shapes_match = y_pred.shape == y_target.shape
print(f"\n{'Shapes match: ✓' if shapes_match else 'Shapes match: ✗'}")

if shapes_match:
    print(f"\nBoth tensors have shape: (batch_size={y_pred.shape[0]}, n_points={y_pred.shape[1]}, channels={y_pred.shape[2]})")

Running forward pass...

Forward pass successful!
Prediction output shape: torch.Size([2, 512, 2])
Target output shape: torch.Size([2, 512, 2])

Shapes match: ✓

Both tensors have shape: (batch_size=2, n_points=512, channels=2)


## Prediction Statistics

In [14]:
# Analyze predictions and targets
print(f"Prediction shape: {y_pred.shape}")
print(f"  Batch size: {y_pred.shape[0]}")
print(f"  Num points: {y_pred.shape[1]}")
print(f"  Output channels: {y_pred.shape[2]} (expected {out_channels})")

print(f"\nPrediction statistics:")
print(f"  Mean: {y_pred.mean().item():.6f}")
print(f"  Std: {y_pred.std().item():.6f}")
print(f"  Min: {y_pred.min().item():.6f}")
print(f"  Max: {y_pred.max().item():.6f}")

print(f"\nTarget statistics:")
print(f"  Mean: {y_target.mean().item():.6f}")
print(f"  Std: {y_target.std().item():.6f}")
print(f"  Min: {y_target.min().item():.6f}")
print(f"  Max: {y_target.max().item():.6f}")

# Calculate MSE loss
mse = torch.nn.functional.mse_loss(y_pred, y_target)
print(f"\nMSE Loss: {mse.item():.6f}")

Prediction shape: torch.Size([2, 512, 2])
  Batch size: 2
  Num points: 512
  Output channels: 2 (expected 2)

Prediction statistics:
  Mean: 0.004121
  Std: 0.014030
  Min: -0.016462
  Max: 0.037722

Target statistics:
  Mean: -0.169145
  Std: 0.824083
  Min: -2.297711
  Max: 1.786044

MSE Loss: 0.698116


## Summary

Successfully tested FNOInterpolate model with patch dataset! The model:
- ✓ Loaded real patch dataset with multi-column support
- ✓ Processed batch through forward pass
- ✓ Generated predictions matching target shapes
- ✓ Maintained API compatibility with GINO training pipeline
- ✓ Works with dictionary-based patch outputs

The FNOInterpolate model is ready for training on the groundwater dataset.