# Test GFNO Model on 2D Plane Sequences

This notebook loads pre-generated sequence data and tests the GFNO model with proper dataset, sampler, and dataloader setup.

**CUDA Compatible**: Automatically detects and uses GPU if available.

In [1]:
import os
import sys
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Add project root to path
project_root = '/Users/arpitkapoor/Library/CloudStorage/OneDrive-UNSW/Shared/Projects/10_Katana/04_groundwater/GW_SciML/'
if project_root not in sys.path:
    sys.path.append(project_root)

from src.data.plane_dataset import GWPlaneDatasetFromFiles
from src.data.batch_sampler import PatchBatchSampler
from src.models import GFNO

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU device: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.7.1
CUDA available: False


## Configuration

In [2]:
# Data path
data_dir = '/Users/arpitkapoor/data/GW/2d_plane_sequences'

# Device configuration - automatically use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Model parameters
coord_dim = 3  # 3D: S, Z, T
n_target_cols = 2  # head and mass_concentration

# GNO parameters
gno_radius = 0.15
gno_out_channels = n_target_cols
gno_channel_mlp_layers = [16, 32, 16]

# FNO parameters
fno_input_channels = coord_dim + n_target_cols + 2  # coordinates + target variables
fno_n_layers = 4
fno_n_modes = (6, 8, 8)  # 3D modes (S, Z, T)
fno_hidden_channels = 64
lifting_channels = 64
out_channels = n_target_cols

# Training parameters
batch_size = 64  # Adjust based on GPU memory
# num_workers = None  # For parallel data loading

Using device: cpu


## Load Dataset

In [3]:
# Create dataset from saved files
print(f"Loading dataset from {data_dir}...")
dataset = GWPlaneDatasetFromFiles(
    data_dir=data_dir,
    fill_nan_value=-999.0
)

print(f"\nDataset loaded successfully!")
print(f"Total sequences: {len(dataset)}")

# Check a sample
sample = dataset[0]
print(f"\nSample data shapes:")
for key, value in sample.items():
    if isinstance(value, torch.Tensor):
        print(f"  {key}: {value.shape} (dtype: {value.dtype})")
    else:
        print(f"  {key}: {value}")

Loading dataset from /Users/arpitkapoor/data/GW/2d_plane_sequences...
Initialized GWPlaneDatasetFromFiles with 3712 sequences across 32 planes

Dataset loaded successfully!
Total sequences: 3712

Sample data shapes:
  plane_id: 0
  input_geom: torch.Size([5216, 3]) (dtype: torch.float32)
  input_data: torch.Size([5216, 2]) (dtype: torch.float32)
  latent_geom: torch.Size([16, 32, 32, 3]) (dtype: torch.float32)
  latent_features: torch.Size([16, 32, 32, 7]) (dtype: torch.float32)
  output_latent_geom: torch.Size([16, 32, 32, 3]) (dtype: torch.float32)
  output_latent_features: torch.Size([16, 32, 32, 2]) (dtype: torch.float32)


## Create Batch Sampler and DataLoader

In [4]:
# Create batch sampler (ensures samples from same plane are batched together)
print("Creating batch sampler...")
batch_sampler = PatchBatchSampler(
    dataset=dataset,
    batch_size=batch_size,
    shuffle_within_batches=True,
    shuffle_patches=True,
    seed=42
)

print(f"Batch sampler created:")
print(f"  Total batches: {len(batch_sampler)}")
print(f"  Batch size: {batch_size}")

Creating batch sampler...
Building patch groups (one-time operation)...
Building plane_ids cache...
Cached 3712 plane_ids
Found 32 patches with 3712 total samples
Patch sizes: min=116, max=116, avg=116.0
Pre-built 64 batches
Batch sampler created:
  Total batches: 64
  Batch size: 64


In [5]:
# Create DataLoader
print("\nCreating DataLoader...")
dataloader = DataLoader(
    dataset,
    batch_sampler=batch_sampler,
    # num_workers=num_workers,
    pin_memory=True if device.type == 'cuda' else False  # Faster GPU transfer
)

print(f"DataLoader created with {len(dataloader)} batches")
print(f"Pin memory: {dataloader.pin_memory} (enabled for faster GPU transfer)")


Creating DataLoader...
DataLoader created with 64 batches
Pin memory: False (enabled for faster GPU transfer)


## Test DataLoader

In [6]:
# Fetch and inspect first batch
print("Testing DataLoader - fetching first batch...\n")

batch = next(iter(dataloader))

print(f"Batch details:")
print(f"  Batch size: {len(batch['plane_id'])}")
print(f"  Plane IDs: {batch['plane_id'].numpy()}")
print(f"  Unique planes: {torch.unique(batch['plane_id']).numpy()} (should be single plane)")
print(f"\nBatch tensor shapes:")
print(f"  input_geom: {batch['input_geom'].shape}")
print(f"  input_data: {batch['input_data'].shape}")
print(f"  latent_geom: {batch['latent_geom'].shape}")
print(f"  latent_features: {batch['latent_features'].shape}")
print(f"  output_latent_geom: {batch['output_latent_geom'].shape}")
print(f"  output_latent_features: {batch['output_latent_features'].shape}")

Testing DataLoader - fetching first batch...

Batch details:
  Batch size: 64
  Plane IDs: [26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26
 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26
 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26]
  Unique planes: [26] (should be single plane)

Batch tensor shapes:
  input_geom: torch.Size([64, 320, 3])
  input_data: torch.Size([64, 320, 2])
  latent_geom: torch.Size([64, 16, 32, 32, 3])
  latent_features: torch.Size([64, 16, 32, 32, 7])
  output_latent_geom: torch.Size([64, 16, 32, 32, 3])
  output_latent_features: torch.Size([64, 16, 32, 32, 2])


## Initialize GFNO Model

In [7]:
# Create GFNO model
print("Initializing GFNO model...")

model = GFNO(
    gno_coord_dim=coord_dim,
    gno_radius=gno_radius,
    gno_out_channels=gno_out_channels,
    gno_channel_mlp_layers=gno_channel_mlp_layers,
    latent_feature_channels=fno_input_channels,  # X, Y, head, mass_conc
    fno_n_layers=fno_n_layers,
    fno_n_modes=fno_n_modes,
    fno_hidden_channels=fno_hidden_channels,
    lifting_channels=lifting_channels,
    out_channels=out_channels
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel created successfully!")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model device: {next(model.parameters()).device}")

Initializing GFNO model...

Model created successfully!
Total parameters: 7,910,180
Trainable parameters: 7,910,180
Model device: cpu


## Test Forward Pass on Single Batch

In [8]:
# Run forward pass on a single batch
print("Running forward pass on a single batch...\n")

model.eval()  # Set to evaluation mode

with torch.no_grad():
    # Get first batch
    batch = next(iter(dataloader))
    
    # Move batch to device
    input_geom = batch['input_geom'].to(device)
    input_data = batch['input_data'].to(device)
    latent_geom = batch['latent_geom'].to(device)
    latent_features = batch['latent_features'].to(device)
    output_latent_features = batch['output_latent_features'].to(device)
    
    print(f"Input shapes (on {device}):")
    print(f"  input_geom: {input_geom.shape}")
    print(f"  input_data: {input_data.shape}")
    print(f"  latent_geom: {latent_geom.shape}")
    print(f"  latent_features: {latent_features.shape}")
    
    # Forward pass
    print(f"\nRunning forward pass...")
    import time
    start_time = time.time()
    
    predictions = model(
        input_geom=input_geom,
        x=input_data,
        latent_queries=latent_geom,
        latent_features=latent_features
    )
    
    if device.type == 'cuda':
        torch.cuda.synchronize()  # Wait for GPU to finish
    
    elapsed_time = time.time() - start_time
    
    print(f"\nForward pass complete!")
    print(f"Time taken: {elapsed_time:.4f} seconds")
    print(f"\nOutput shape: {predictions.shape}")
    print(f"Expected shape: {output_latent_features.shape}")
    print(f"\nOutput statistics:")
    print(f"  Min: {predictions.min().item():.6f}")
    print(f"  Max: {predictions.max().item():.6f}")
    print(f"  Mean: {predictions.mean().item():.6f}")
    print(f"  Std: {predictions.std().item():.6f}")

Running forward pass on a single batch...

Input shapes (on cpu):
  input_geom: torch.Size([64, 320, 3])
  input_data: torch.Size([64, 320, 2])
  latent_geom: torch.Size([64, 16, 32, 32, 3])
  latent_features: torch.Size([64, 16, 32, 32, 7])

Running forward pass...

Forward pass complete!
Time taken: 13.2065 seconds

Output shape: torch.Size([64, 16, 32, 32, 2])
Expected shape: torch.Size([64, 16, 32, 32, 2])

Output statistics:
  Min: -270.095184
  Max: 338.677338
  Mean: -21.460239
  Std: 124.285980


## Test Forward Pass on Multiple Batches

In [9]:
# Test forward pass on multiple batches
print("Running forward pass on multiple batches...\n")

model.eval()
num_test_batches = 5

total_time = 0
batch_times = []

with torch.no_grad():
    for i, batch in enumerate(dataloader):
        if i >= num_test_batches:
            break
        
        # Move to device
        input_geom = batch['input_geom'].to(device)
        input_data = batch['input_data'].to(device)
        latent_geom = batch['latent_geom'].to(device)
        latent_features = batch['latent_features'].to(device)
        
        # Time forward pass
        start_time = time.time()
        
        predictions = model(
            input_geom=input_geom,
            x=input_data,
            latent_queries=latent_geom,
            latent_features=latent_features
        )
        
        if device.type == 'cuda':
            torch.cuda.synchronize()
        
        elapsed = time.time() - start_time
        batch_times.append(elapsed)
        total_time += elapsed
        
        print(f"Batch {i+1}/{num_test_batches}: {elapsed:.4f}s, "
              f"plane_id={torch.unique(batch['plane_id']).item()}, "
              f"output_shape={predictions.shape}")

print(f"\nSummary:")
print(f"  Total time: {total_time:.4f}s")
print(f"  Average time per batch: {np.mean(batch_times):.4f}s")
print(f"  Min time: {np.min(batch_times):.4f}s")
print(f"  Max time: {np.max(batch_times):.4f}s")

Running forward pass on multiple batches...

Batch 1/5: 12.6202s, plane_id=26, output_shape=torch.Size([64, 16, 32, 32, 2])
Batch 2/5: 12.9709s, plane_id=29, output_shape=torch.Size([64, 16, 32, 32, 2])
Batch 3/5: 21.5430s, plane_id=0, output_shape=torch.Size([64, 16, 32, 32, 2])
Batch 4/5: 12.4956s, plane_id=22, output_shape=torch.Size([64, 16, 32, 32, 2])
Batch 5/5: 10.7889s, plane_id=2, output_shape=torch.Size([52, 16, 32, 32, 2])

Summary:
  Total time: 70.4186s
  Average time per batch: 14.0837s
  Min time: 10.7889s
  Max time: 21.5430s


## Memory Usage (GPU)

In [10]:
# Check GPU memory usage
if device.type == 'cuda':
    print("GPU Memory Usage:")
    print(f"  Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    print(f"  Cached: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")
    print(f"  Max allocated: {torch.cuda.max_memory_allocated(0) / 1024**3:.2f} GB")
else:
    print("CPU mode - no GPU memory tracking")

CPU mode - no GPU memory tracking


## Model Summary

In [11]:
# Print model architecture
print("Model Architecture:")
print("=" * 80)
print(model)
print("=" * 80)

# Print parameter count by layer
print("\nParameter count by component:")
print("-" * 80)
for name, module in model.named_children():
    num_params = sum(p.numel() for p in module.parameters())
    print(f"{name:30s}: {num_params:>15,} parameters")
print("-" * 80)
print(f"{'Total':30s}: {total_params:>15,} parameters")

Model Architecture:
GFNO(
  (gno): GNOBlock(
    (pos_embedding): SinusoidalEmbedding()
    (neighbor_search): NeighborSearch()
    (integral_transform): IntegralTransform(
      (channel_mlp): LinearChannelMLP(
        (fcs): ModuleList(
          (0): Linear(in_features=384, out_features=16, bias=True)
          (1): Linear(in_features=16, out_features=32, bias=True)
          (2): Linear(in_features=32, out_features=16, bias=True)
          (3): Linear(in_features=16, out_features=2, bias=True)
        )
      )
    )
  )
  (fno_blocks): FNOBlocks(
    (fno_skips): ModuleList(
      (0-3): 4 x Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
    )
    (convs): ModuleList(
      (0-3): 4 x SpectralConv(
        (weight): ModuleList(
          (0): ComplexDenseTensor(shape=torch.Size([64, 64, 6, 8, 5]), rank=None)
        )
      )
    )
  )
  (lifting): ChannelMLP(
    (fcs): ModuleList(
      (0): Conv1d(9, 64, kernel_size=(1,), stride=(1,))
      (1): Conv1d(64, 

## Success! âœ…

The model successfully performed forward passes on batched data. You can now:

1. **Train the model** by adding a training loop with optimizer and loss function
2. **Adjust batch size** based on available GPU memory
3. **Experiment with model parameters** to optimize performance
4. **Add validation loop** to monitor model performance

Next steps:
- Create a training script using this notebook as a template
- Implement proper train/validation split
- Add checkpointing and logging
- Monitor metrics and visualize predictions