In [1]:
import os
import numpy as np
import pandas as pd

from tqdm import tqdm

In [2]:
data_path = '/Users/arpitkapoor/Library/CloudStorage/OneDrive-UNSW/Shared/Data/FEFLOW/processed/2d_plane_data'

sea_level_csv = '/Users/arpitkapoor/Library/CloudStorage/OneDrive-UNSW/Shared/Data/FEFLOW/simulation_files/SeaLevelDataPeaksHL.csv'
# sea_level_csv = '/Users/arpitkapoor/Downloads/SeaLevelDataPeaksHL.csv'


skip_factor = 2  # Process every 2nd timestep for faster processing

times = pd.read_csv(sea_level_csv, header=None, names=['time', 'sealevel'])['time'].values[::skip_factor]
times.shape

(955,)

In [3]:
sorted_data_dirs = sorted(os.listdir(data_path))
sorted_data_dirs = [d for d in sorted_data_dirs if d.startswith('timestep_')]
sorted_data_dirs  = sorted_data_dirs[::skip_factor]  # Skip timesteps for faster processing
len(sorted_data_dirs)

955

In [4]:
N_planes = 32

plane_data_all_timesteps = {p: [] for p in range(N_planes)}
bc_data_all_timesteps = {p: [] for p in range(N_planes)}

bc_nodes_for_plane = {}


for t  in tqdm(range(len(sorted_data_dirs))):
    d = sorted_data_dirs[t]
    ts_dir = os.path.join(data_path, d)
    sorted_planes_dir = sorted(os.listdir(ts_dir))
    sorted_planes_dir = [os.path.join(ts_dir, f) for f in sorted_planes_dir]

    for p, f in enumerate(sorted_planes_dir):
        
        plane_data = np.load(os.path.join(f, 'plane_data.npz'))
        S = plane_data['S']
        Z = plane_data['Z']
        X = plane_data['X']
        Y = plane_data['Y']
        T = np.ones_like(X) * times[t]
        head_m = plane_data['head_m']
        mass_conc_mg_l = plane_data['mass_conc_mg_l']
        # print(f" Processing plane {p} at timestep {d.split('_')[-1]} with shapes: X-{X.shape}, Y-{Y.shape}, Z-{Z.shape}, T-{T.shape}, head_m-{head_m.shape}, mass_conc_mg_l-{mass_conc_mg_l.shape}")
        stacked_plane = np.stack([S, Z, T, X, Y, head_m, mass_conc_mg_l], axis=-1)

        plane_data_all_timesteps[p].append(stacked_plane)

        bc_data = np.load(os.path.join(f, 'bc_data.npz'))
        bc_node_idx = bc_data['node_idx']
        bc_S = bc_data['S']
        bc_Z = bc_data['Z']
        bc_T = np.ones_like(bc_node_idx) * times[t]
        bc_head = bc_data['head']
        bc_mass_conc = bc_data['mass_conc']
        
        # Assertion to ensure bc arrays have the same length
        if not len(bc_S) == len(bc_Z) == len(bc_head) == len(bc_mass_conc):
            print(f" Plane {p} failed assertion! Boundary condition arrays have mismatched lengths")
            continue

        if t == 0:
            bc_nodes_for_plane[p] = bc_node_idx.astype(np.int32)
        else:
            bc_nodes_for_plane[p] = np.intersect1d(bc_nodes_for_plane[p], bc_node_idx.astype(np.int32))

        stacked_bc = np.stack([bc_node_idx, bc_S, bc_Z, bc_T, 
                               bc_head, bc_mass_conc], axis=-1)
        bc_data_all_timesteps[p].append(stacked_bc)


100%|██████████| 955/955 [00:18<00:00, 52.49it/s]


In [5]:
for p in tqdm(range(N_planes)):
    plane_data_p = plane_data_all_timesteps[p]
    plane_data_all_timesteps[p] = np.stack(plane_data_p, axis=0)
    # print(f" Plane {p} data shape across all timesteps: {plane_data_all_timesteps[p].shape}")

100%|██████████| 32/32 [00:00<00:00, 63.64it/s] 


In [6]:


for p in tqdm(range(N_planes)):
    bc_list = bc_data_all_timesteps[p]
    n_bc = len(bc_list)

    bc_data_for_selected_nodes = []

    for t in range(n_bc):
        bc_data_for_p_at_t = bc_data_all_timesteps[p][t]
        bc_nodes_for_p = bc_nodes_for_plane[p]

        # Filter bc_data_for_p_at_t to only include rows where node_idx is in bc_nodes_for_p
        mask = np.isin(bc_data_for_p_at_t[:, 0], bc_nodes_for_p)
        bc_data_for_selected_nodes.append(bc_data_for_p_at_t[mask])
    
    bc_data_all_timesteps[p] = np.stack(bc_data_for_selected_nodes, axis=0)
    # print(f" Plane {p}: Cleaned BC data shape across all timesteps: {bc_data_all_timesteps[p].shape}")
    


100%|██████████| 32/32 [00:00<00:00, 66.97it/s]


Write cleaned up data to file.

In [7]:
# out_dir
outdir = '/Users/arpitkapoor/Library/CloudStorage/OneDrive-UNSW/Shared/Data/FEFLOW/processed/2d_plane_data_cleaned'
os.makedirs(outdir, exist_ok=True)


In [8]:
for p in range(N_planes):
    plane_dir = os.path.join(outdir, f'plane_{p:02d}')
    os.makedirs(plane_dir, exist_ok=True)

    # BC data
    # Input Geom - S, Z, T
    input_geom = bc_data_all_timesteps[p][0,..., 1:4]
    input_data = bc_data_all_timesteps[p][..., 4:]
    lantent_geom = plane_data_all_timesteps[p][0,..., :3]

    # Latent inputs
    latent_inputs = plane_data_all_timesteps[p][..., 3:]

    print(f"\nWriting data for plane {p} to {plane_dir} ...")
    print(f"  Input Geom shape: {input_geom.shape}, Input data shape: {input_data.shape}")
    print(f"  Latent Geom shape: {lantent_geom.shape}, Latent inputs shape: {latent_inputs.shape}")
    break




Writing data for plane 0 to /Users/arpitkapoor/Library/CloudStorage/OneDrive-UNSW/Shared/Data/FEFLOW/processed/2d_plane_data_cleaned/plane_00 ...
  Input Geom shape: (326, 3), Input data shape: (955, 326, 2)
  Latent Geom shape: (32, 32, 3), Latent inputs shape: (955, 32, 32, 4)


In [9]:
# Generate sequences for all planes
input_sequences = {p: {'input_geom': [], 'input_data': [], 'latent_geom': [], 'latent_features': []} for p in range(N_planes)}
output_sequences = {p: {'latent_geom': [], 'latent_features': []} for p in range(N_planes)}

alpha = 16 # sequence length for input and output

# Generate time sequences
n_timesteps = len(sorted_data_dirs)
ts = [(i, i + alpha, i + 2 * alpha) for i in range(0, n_timesteps - 2 * alpha + 1, alpha//2)]
print(f"Generated {len(ts)} time sequences with alpha={alpha}")
print(f"\n\nFirst few sequences: {ts[:3]} \n times: {times[[ts[i] for i in range(min(3, len(ts)))]]}")
print(f"\n\nLast few sequences: {ts[-3:]} \n times: {times[[ts[i] for i in range(max(0, len(ts)-3), len(ts))]]}")

for p in tqdm(range(N_planes)):
    # BC data for this plane
    bc_data = bc_data_all_timesteps[p]  # shape: (n_timesteps, n_bc_nodes, 6)
    # Plane data for this plane
    plane_data = plane_data_all_timesteps[p]  # shape: (n_timesteps, n_nodes, 7)
    
    for t_start, t_mid, t_end in ts:
        # Input sequence: [t_start:t_mid]
        input_geom_seq = bc_data[t_start:t_mid, ..., 1:4].reshape(-1, 3)  # (alpha, n_bc_nodes, 3) - S, Z, T
        input_data_seq = bc_data[t_start:t_mid, ..., 4:].reshape(-1, 2)  # (alpha, n_bc_nodes, 2) - head, mass_conc
        
        latent_geom_seq = plane_data[t_start:t_mid, ..., :3]  # (alpha, n_nodes, 3) - S, Z, T
        latent_features_seq = plane_data[t_start:t_mid, ..., 3:]  # (alpha, n_nodes, 4) - X, Y, head, mass_conc
        
        # Output sequence: [t_mid:t_end]
        output_latent_geom_seq = plane_data[t_mid:t_end, ..., :3]  # (alpha, n_nodes, 3)
        output_latent_features_seq = plane_data[t_mid:t_end, ..., 3:]  # (alpha, n_nodes, 4)
        
        # Store sequences
        input_sequences[p]['input_geom'].append(input_geom_seq)
        input_sequences[p]['input_data'].append(input_data_seq)
        input_sequences[p]['latent_geom'].append(latent_geom_seq)
        input_sequences[p]['latent_features'].append(latent_features_seq)
        
        output_sequences[p]['latent_geom'].append(output_latent_geom_seq)
        output_sequences[p]['latent_features'].append(output_latent_features_seq)

# Convert lists to numpy arrays for each plane and field
for p in range(N_planes):
    for key in input_sequences[p]:
        input_sequences[p][key] = np.array(input_sequences[p][key])
    for key in output_sequences[p]:
        output_sequences[p][key] = np.array(output_sequences[p][key])

print(f"\nTotal sequences per plane: {len(ts)}")
print(f"Total planes: {N_planes}")


Generated 116 time sequences with alpha=16


First few sequences: [(0, 16, 32), (8, 24, 40), (16, 32, 48)] 
 times: [[ 0.         13.91666667 26.91666667]
 [ 7.91666667 20.75       34.75      ]
 [13.91666667 26.91666667 42.41666667]]


Last few sequences: [(904, 920, 936), (912, 928, 944), (920, 936, 952)] 
 times: [[859.5416667 875.5833333 891.4583333]
 [867.4166667 883.5       899.4583333]
 [875.5833333 891.4583333 906.3333333]]


100%|██████████| 32/32 [00:00<00:00, 5649.85it/s]



Total sequences per plane: 116
Total planes: 32


In [10]:
input_sequences[0]['input_geom'].shape, input_sequences[0]['input_data'].shape, output_sequences[0]['latent_geom'].shape, output_sequences[0]['latent_features'].shape

((116, 5216, 3), (116, 5216, 2), (116, 16, 32, 32, 3), (116, 16, 32, 32, 4))

In [11]:
import torch
import torch.nn as nn
import sys
sys.path.append('/Users/arpitkapoor/Library/CloudStorage/OneDrive-UNSW/Shared/Projects/10_Katana/04_groundwater/GW_SciML/')
from src.models.neuralop.gno import GNOBlock
from src.models.neuralop.fno import FNOBlocks
from src.models.neuralop.channel_mlp import ChannelMLP

# Model configuration parameters
coord_dim = 3  # 3D: S, Z, T (time as dimension)
n_target_cols = 2  # head and mass_concentration

# GNO and FNO parameters
gno_radius = 0.15
gno_out_channels = n_target_cols  # 10 * 2 = 20
gno_channel_mlp_layers = [16, 32, 16]

fno_n_layers = 4
fno_n_modes = (6, 8, 8)  # 3D modes (S, Z, T)
fno_hidden_channels = 64
lifting_channels = 64
out_channels = n_target_cols  # 10 * 2 = 20

latent_query_dims = (alpha, 32, 32)  # 3D latent grid with time

# Device configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cpu


In [12]:
import torch.nn.functional as F

class GFNO(nn.Module):

    def __init__(
        self,
        # Input GNO
        gno_coord_dim=3,
        gno_radius=0.03,
        gno_pos_embed_type='transformer',
        gno_pos_embed_channels=32,
        gno_pos_embed_max_positions=10000,
        gno_channel_mlp_layers=[64, 64, 64],
        gno_channel_mlp_non_linearity=F.gelu,
        gno_out_channels=3,
        # Latent features
        latent_feature_channels=None,
        # FNO
        fno_n_layers=4,
        fno_n_modes=(16, 16, 16),
        fno_hidden_channels=128,
        fno_skip_fno_bias=False,
        fno_fft_norm="forward",
        fno_rank=1.0,
        fno_max_n_modes=None,
        fno_non_linearity=F.gelu,
        # Lifting
        lifting_channels=128,
        # Projection
        projection_channel_ratio=4,
        out_channels=1,
        # Neighbor search settings
        use_open3d_neighbor_search=None,
    ):
        super(GFNO, self).__init__()

        # Determine whether to use open3d neighbor search based on coordinate dimension
        if use_open3d_neighbor_search is None:
            # Only use open3d for 3D coordinates
            use_open3d_neighbor_search = (gno_coord_dim == 3)

        self.gno_coord_dim = gno_coord_dim
        self.gno_radius = gno_radius
        self.gno_pos_embed_type = gno_pos_embed_type
        self.gno_pos_embed_channels = gno_pos_embed_channels
        self.gno_pos_embed_max_positions = gno_pos_embed_max_positions
        self.gno_channel_mlp_layers = gno_channel_mlp_layers
        self.gno_channel_mlp_non_linearity = gno_channel_mlp_non_linearity
        self.gno_out_channels = gno_out_channels
        self.latent_feature_channels = latent_feature_channels

        self.gno = GNOBlock(
            in_channels=0,
            out_channels=gno_out_channels,
            coord_dim=gno_coord_dim,
            radius=gno_radius,
            pos_embedding_type=gno_pos_embed_type,
            pos_embedding_channels=gno_pos_embed_channels,
            pos_embedding_max_positions=gno_pos_embed_max_positions,
            reduction='mean',
            weighting_fn=None,
            channel_mlp_layers=gno_channel_mlp_layers,
            channel_mlp_non_linearity=gno_channel_mlp_non_linearity,
            transform_type='linear',
            use_open3d_neighbor_search=use_open3d_neighbor_search,
            use_torch_scatter_reduce=False
        )
        
        # Store additional attributes needed for forward pass
        self.fno_hidden_channels = fno_hidden_channels
        self.in_coord_dim_reverse_order = list(range(2, gno_coord_dim + 2))  # For permute operation
        self.adain_pos_embed = None  # Placeholder for adaptive instance norm embedding
        self.fno_norm = None  # Placeholder for FNO normalization
        self.out_gno_tanh = None  # Placeholder for output GNO tanh activation

        if latent_feature_channels is not None:
            self.fno_in_channels = gno_out_channels + latent_feature_channels
        else:
            self.fno_in_channels = gno_out_channels

        # Define FNO blocks
        self.fno_blocks = FNOBlocks(
            n_layers=fno_n_layers,
            n_modes=fno_n_modes,
            hidden_channels=fno_hidden_channels,
            skip_fno_bias=fno_skip_fno_bias,
            fft_norm=fno_fft_norm,
            rank=fno_rank,
            max_n_modes=fno_max_n_modes,
            non_linearity=fno_non_linearity,
        )

        # Define lifting layer
        self.lifting_channels = lifting_channels
        self.lifting = ChannelMLP(
            in_channels=self.fno_in_channels,
            hidden_channels=self.lifting_channels,
            out_channels=fno_hidden_channels,
            n_layers=2
        )


        # Define projection layer
        self.projection_channel_ratio = projection_channel_ratio
        self.projection_channels = projection_channel_ratio * fno_hidden_channels
        self.out_channels = out_channels
        self.projection = ChannelMLP(
            in_channels=fno_hidden_channels,
            out_channels=self.out_channels,
            hidden_channels=self.projection_channels,
            n_layers=2,
            # n_dim=1,
            non_linearity=fno_non_linearity
        )

    def latent_embedding(self, in_p, ada_in=None):

        # in_p : (batch, n_1 , ... , n_k, in_channels + k)
        # ada_in : (fno_ada_in_dim, )

        # permute (b, n_1, ..., n_k, c) -> (b,c, n_1,...n_k)
        in_p = in_p.permute(0, len(in_p.shape)-1, *list(range(1,len(in_p.shape)-1)))
        #Update Ada IN embedding    
        if ada_in is not None:
            if ada_in.ndim == 2:
                ada_in = ada_in.squeeze(0)
            if self.adain_pos_embed is not None:
                ada_in_embed = self.adain_pos_embed(ada_in.unsqueeze(0)).squeeze(0)
            else:
                ada_in_embed = ada_in
            if self.fno_norm == "ada_in":
                self.fno_blocks.set_ada_in_embeddings(ada_in_embed)

        #Apply FNO blocks
        in_p = self.lifting(in_p)

        # for idx in range(self.fno_blocks.n_layers):
        in_p = self.fno_blocks(in_p)

        return in_p 

    def forward(self, input_geom, latent_queries, x=None, latent_features=None, ada_in=None, **kwargs):

        if x is None:
            batch_size = 1
        else:
            batch_size = x.shape[0]
        
        if latent_features is not None:
            assert self.latent_feature_channels is not None,\
                  "if passing latent features, latent_feature_channels must be set."
            assert latent_features.shape[-1] == self.latent_feature_channels

            # batch, n_gridpts_1, .... n_gridpts_n, gno_coord_dim
            assert latent_features.ndim == self.gno_coord_dim + 2,\
                f"Latent features must be of shape (batch, n_gridpts_1, ...n_gridpts_n, gno_coord_dim), got {latent_features.shape}"
            # latent features must have the same shape (except channels) as latent_queries 
            if latent_features.shape[0] != batch_size:
                if latent_features.shape[0] == 1:
                    latent_features = latent_features.repeat(batch_size, *[1]*(latent_features.ndim-1))


        # Squeeze batch dim if batch_size == 1
        if (input_geom.shape[0] == 1 and input_geom.ndim == 3) or input_geom.ndim == 2:
            input_geom = input_geom.squeeze(0) 
            latent_queries = latent_queries.squeeze(0)

            # Pass through input GNOBlock 
            in_p = self.gno(y=input_geom,
                            x=latent_queries.view((-1, 
                                                   latent_queries.shape[-1])),
                            f_y=x)
        elif input_geom.shape[0] == batch_size:
            in_p_list = []
            for b in range(batch_size):
                in_p_b = self.gno(y=input_geom[b],
                                 x=latent_queries[b].view((-1, 
                                                           latent_queries.shape[-1])),
                                 f_y=x[b] if x is not None else None)
                in_p_list.append(in_p_b)
            in_p = torch.stack(in_p_list, dim=0)
        else:
            raise ValueError(f"input_geom batch size {input_geom.shape[0]} does not match x batch size {batch_size}")
        
        print(in_p.shape)
        
        grid_shape = latent_queries.shape[:-1] # disregard positional encoding dim
        
        # shape (batch_size, grid1, ...gridn, -1)
        in_p = in_p.view((batch_size, *grid_shape, -1))
        
        if latent_features is not None:
            in_p = torch.cat((in_p, latent_features), dim=-1)
        # take apply fno in latent space
        latent_embed = self.latent_embedding(in_p=in_p, 
                                             ada_in=ada_in)

        # Integrate latent space to output queries
        #latent_embed shape (b, c, n_1, n_2, ..., n_k)
        batch_size = latent_embed.shape[0]
        # permute to (b, n_1, n_2, ...n_k, c)
        # then reshape to (b, n_1 * n_2 * ...n_k, out_channels)
        latent_embed = latent_embed.permute(0, *self.in_coord_dim_reverse_order, 1)
        
        if self.out_gno_tanh in ['latent_embed', 'both']:
            latent_embed = torch.tanh(latent_embed)

        # Project pointwise to out channels
        latent_embed = latent_embed.permute(0, 4, 1, 2, 3)
        out = self.projection(latent_embed)
        out = out.permute(0, *self.in_coord_dim_reverse_order, 1)
        
        return out

In [13]:
gfno = GFNO(
    gno_coord_dim=coord_dim,
    gno_radius=gno_radius,
    gno_out_channels=gno_out_channels,
    gno_channel_mlp_layers=gno_channel_mlp_layers,
    latent_feature_channels=4,  # X, Y, head, mass_conc
    fno_n_layers=fno_n_layers,
    fno_n_modes=fno_n_modes,
    fno_hidden_channels=fno_hidden_channels,
    lifting_channels=lifting_channels,
    out_channels=out_channels
).to(device)

### Use GFNO from Module

Now you can import GFNO directly from the models module instead of defining it in the notebook.

In [14]:
# Import GFNO from the models module
from src.models import GFNO

# Create GFNO instance using the imported class
gfno_from_module = GFNO(
    gno_coord_dim=coord_dim,
    gno_radius=gno_radius,
    gno_out_channels=gno_out_channels,
    gno_channel_mlp_layers=gno_channel_mlp_layers,
    latent_feature_channels=4,  # X, Y, head, mass_conc
    fno_n_layers=fno_n_layers,
    fno_n_modes=fno_n_modes,
    fno_hidden_channels=fno_hidden_channels,
    lifting_channels=lifting_channels,
    out_channels=out_channels
).to(device)

print(f"GFNO model created with {sum(p.numel() for p in gfno_from_module.parameters())} parameters")

GFNO model created with 7909988 parameters


In [15]:
input_geom_sample = torch.tensor(input_sequences[0]['input_geom'][:64], dtype=torch.float32)
input_data_sample = torch.tensor(input_sequences[0]['input_data'][:64], dtype=torch.float32)
latent_geom_sample = torch.tensor(input_sequences[0]['latent_geom'][:64], dtype=torch.float32)
latent_features_sample = torch.tensor(input_sequences[0]['latent_features'][:64], dtype=torch.float32)

# Fill nan values with -999
input_geom_sample = torch.nan_to_num(input_geom_sample, nan=-999.0)
input_data_sample = torch.nan_to_num(input_data_sample, nan=-999.0)
latent_geom_sample = torch.nan_to_num(latent_geom_sample, nan=-999.0)
latent_features_sample = torch.nan_to_num(latent_features_sample, nan=-999.0)

print(f"Input Geom sample shape: {input_geom_sample.shape}")  # (alpha, n_bc_nodes, 3)
print(f"Input Data sample shape: {input_data_sample.shape}")  # (alpha, n_bc_nodes, 2)
print(f"Latent Geom sample shape: {latent_geom_sample.shape}")  # (alpha, n_nodes, 3)
print(f"Latent Features sample shape: {latent_features_sample.shape}")  # (alpha, n_nodes, 4)

Input Geom sample shape: torch.Size([64, 5216, 3])
Input Data sample shape: torch.Size([64, 5216, 2])
Latent Geom sample shape: torch.Size([64, 16, 32, 32, 3])
Latent Features sample shape: torch.Size([64, 16, 32, 32, 4])


In [16]:
%%time
gfno(input_geom=input_geom_sample.to(device)[0],
     x = input_data_sample.to(device),
     latent_queries=latent_geom_sample.to(device)[0],
     latent_features=latent_features_sample.to(device)).shape

torch.Size([64, 16384, 2])
CPU times: user 20.5 s, sys: 4.09 s, total: 24.6 s
Wall time: 19.1 s


torch.Size([64, 16, 32, 32, 2])

## Create Dataset and DataLoader

Now let's create a dataset and dataloader using the plane data we've prepared.

In [152]:
from torch.utils.data import DataLoader
from src.data.plane_dataset import GWPlaneDataset
from src.data.batch_sampler import PatchBatchSampler  # Works for both patches and planes!

# Create the dataset
dataset = GWPlaneDataset(
    input_sequences=input_sequences,
    output_sequences=output_sequences,
    fill_nan_value=-999.0
)

print(f"\nDataset created with {len(dataset)} total sequences")
print(f"Sample shapes:")
sample = dataset[0]
for key, value in sample.items():
    if isinstance(value, torch.Tensor):
        print(f"  {key}: {value.shape}")
    else:
        print(f"  {key}: {value}")

Initialized GWPlaneDataset with 3712 sequences across 32 planes

Dataset created with 3712 total sequences
Sample shapes:
  plane_id: 0
  input_geom: torch.Size([5216, 3])
  input_data: torch.Size([5216, 2])
  latent_geom: torch.Size([16, 32, 32, 3])
  latent_features: torch.Size([16, 32, 32, 4])
  output_latent_geom: torch.Size([16, 32, 32, 3])
  output_latent_features: torch.Size([16, 32, 32, 4])


In [154]:
# Create the batch sampler
# This ensures all samples from the same plane are in the same batch
batch_size = 64
batch_sampler = PatchBatchSampler(
    dataset=dataset,
    batch_size=batch_size,
    shuffle_within_batches=True,  # Shuffle sequences within each batch
    shuffle_patches=True,  # Shuffle the order of planes
    seed=42
)

print(f"\nBatch sampler created:")
print(f"  Total batches: {len(batch_sampler)}")
print(f"  Batch size: {batch_size}")

# Create the DataLoader
dataloader = DataLoader(
    dataset,
    batch_sampler=batch_sampler,
    num_workers=0  # Set to > 0 for parallel data loading
)

print(f"\nDataLoader created with {len(dataloader)} batches")

Building patch groups (one-time operation)...
Found 32 patches with 3712 total samples
Patch sizes: min=116, max=116, avg=116.0
Pre-built 64 batches

Batch sampler created:
  Total batches: 64
  Batch size: 64

DataLoader created with 64 batches


In [155]:
# Test the dataloader by fetching a few batches
print("Testing dataloader - fetching first 3 batches:\n")

for i, batch in enumerate(dataloader):
    if i >= 3:
        break
    
    print(f"Batch {i+1}:")
    print(f"  Batch size (actual): {len(batch['plane_id'])}")
    print(f"  Plane IDs in batch: {batch['plane_id'].numpy()}")
    print(f"  Unique planes: {torch.unique(batch['plane_id']).numpy()} (should be single plane)")
    print(f"  Input geom shape: {batch['input_geom'].shape}")
    print(f"  Input data shape: {batch['input_data'].shape}")
    print(f"  Latent geom shape: {batch['latent_geom'].shape}")
    print(f"  Latent features shape: {batch['latent_features'].shape}")
    print(f"  Output latent geom shape: {batch['output_latent_geom'].shape}")
    print(f"  Output latent features shape: {batch['output_latent_features'].shape}")
    print()

Testing dataloader - fetching first 3 batches:

Batch 1:
  Batch size (actual): 64
  Plane IDs in batch: [26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26
 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26
 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26]
  Unique planes: [26] (should be single plane)
  Input geom shape: torch.Size([64, 320, 3])
  Input data shape: torch.Size([64, 320, 2])
  Latent geom shape: torch.Size([64, 16, 32, 32, 3])
  Latent features shape: torch.Size([64, 16, 32, 32, 4])
  Output latent geom shape: torch.Size([64, 16, 32, 32, 3])
  Output latent features shape: torch.Size([64, 16, 32, 32, 4])

Batch 1:
  Batch size (actual): 64
  Plane IDs in batch: [26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26
 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26
 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26]
  Unique planes: [26] (should be single plane)
  Input geom shape: torch.S

## Save Data to Disk (Optional)

If you want to use `GWPlaneDatasetFromFiles` for on-disk loading, save the data in the required format.

In [None]:
# Save data to disk in format expected by GWPlaneDatasetFromFiles
save_to_disk = False  # Set to True to save

if save_to_disk:
    output_data_dir = '/Users/arpitkapoor/Library/CloudStorage/OneDrive-UNSW/Shared/Data/FEFLOW/processed/2d_plane_sequences'
    os.makedirs(output_data_dir, exist_ok=True)
    
    print(f"Saving plane data to {output_data_dir}...")
    
    for p in tqdm(range(N_planes)):
        plane_dir = os.path.join(output_data_dir, f'plane_{p:02d}')
        os.makedirs(plane_dir, exist_ok=True)
        
        # Save all data for this plane
        np.save(os.path.join(plane_dir, 'input_geom.npy'), input_sequences[p]['input_geom'])
        np.save(os.path.join(plane_dir, 'input_data.npy'), input_sequences[p]['input_data'])
        np.save(os.path.join(plane_dir, 'latent_geom.npy'), input_sequences[p]['latent_geom'])
        np.save(os.path.join(plane_dir, 'latent_features.npy'), input_sequences[p]['latent_features'])
        np.save(os.path.join(plane_dir, 'output_latent_geom.npy'), output_sequences[p]['latent_geom'])
        np.save(os.path.join(plane_dir, 'output_latent_features.npy'), output_sequences[p]['latent_features'])
    
    print(f"Data saved successfully!")
    
    # Now you can load using GWPlaneDatasetFromFiles
    # from src.data.plane_dataset import GWPlaneDatasetFromFiles
    # dataset_from_disk = GWPlaneDatasetFromFiles(output_data_dir)
else:
    print("Skipping save to disk (set save_to_disk=True to save)")