In [None]:
import time
import numpy as np
from torch_geometric.loader import DataLoader
from torch.utils.data import Subset
from euler import EulerPeriodicDataset

h5_path = "/work/imos/datasets/euler_multi_quadrants_periodicBC/data/train/euler_multi_quadrants_periodicBC_gamma_1.76_Ar_-180.hdf5"
stats_path = "/work/imos/datasets/euler_multi_quadrants_periodicBC/stats.yaml"

### 1) Full-grid dataset checks

In [2]:
print("=== Full-grid dataset ===")
ds_full = EulerPeriodicDataset(h5_path, stats_path=stats_path, time_window=1, patch_size=None, normalize=True)
print("Dataset created:", len(ds_full))

# basic metadata checks
print("len(dataset) =", len(ds_full))
print("n_sims, n_t, H, W =", ds_full.n_sims, ds_full.n_t, ds_full.H, ds_full.W)
print("n_per_sim =", ds_full.n_per_sim, "total_samples =", ds_full.total_samples)

# static cache checks
print("static cache keys:", list(ds_full._static_cache.keys()))
print("x_periodic, y_periodic:", ds_full._static_cache.get("x_periodic"), ds_full._static_cache.get("y_periodic"))
print("gamma:", ds_full._static_cache.get("gamma"))
print("pos_template shape:", ds_full._static_cache.get("pos_template").shape)

# fast I/O test: load arrays only (no edge build)
arrs = ds_full._load_time_window(sim_idx=0, t_idx=0)
print("density shape:", arrs["density"].shape)   # expect (time_window, H, W)
print("momentum shape:", arrs["momentum"].shape) # expect (time_window, H, W, 2)

# build the graph for a single sample (this triggers edge construction, may take a few seconds and memory)
t0 = time.time()
data = ds_full[0]   # full-grid graph
t1 = time.time()
print("\nBuilt Data object in {:.1f}s".format(t1-t0))
print("data.x shape (N, C):", data.x.shape)
print("data.pos shape:", data.pos.shape)
print("edge_index shape:", data.edge_index.shape)
print("edge_attr shape:", data.edge_attr.shape)
print("global features u:", data.u)

# simple assert to ensure original behaviour returns full-grid-size nodes
N_full = ds_full.H * ds_full.W
assert data.x.shape[0] == N_full, "Full-grid Data node count mismatch"

=== Full-grid dataset ===
EulerPeriodicDataset: using full-grid samples (512x512), this is large (262144 nodes). Consider patching.
Dataset created: 40000
len(dataset) = 40000
n_sims, n_t, H, W = 400 101 512 512
n_per_sim = 100 total_samples = 40000
static cache keys: ['pos_template', 'x_coords', 'y_coords', 'gamma', 'x_periodic_mask', 'y_periodic_mask', 'x_periodic', 'y_periodic']
x_periodic, y_periodic: True True
gamma: 1.7599999904632568
pos_template shape: (262144, 2)
density shape: (1, 512, 512)
momentum shape: (1, 512, 512, 2)

Built Data object in 5.1s
data.x shape (N, C): torch.Size([262144, 5])
data.pos shape: torch.Size([262144, 2])
edge_index shape: torch.Size([2, 1048576])
edge_attr shape: torch.Size([1048576, 4])
global features u: tensor([[1.7600, 0.0000]])


### 2) Patched dataset checks

In [6]:
print("\n=== Patch-enabled dataset: mapping & decode tests ===")
# choose patch size 64 (cells)
patch_size = (64, 64)
ds_patch = EulerPeriodicDataset(
    h5_path,
    stats_path=stats_path,
    time_window=1,
    patch_size=patch_size,
    patch_stride=32,
    normalize=True,
)

# dataset should have computed patches_per_row/col and patches_per_timestep (patches_per_sim)
print("patch_size:", ds_patch.patch_size)
print("patch_stride (h,w):", ds_patch.patch_stride, ds_patch.patch_stride_h, ds_patch.patch_stride_w)
print("patches_per_row, patches_per_col:", ds_patch.patches_per_row, ds_patch.patches_per_col)
print("patches_per_timestep:", ds_patch.patches_per_timestep)
print("n_sims, n_per_sim:", ds_patch.n_sims, ds_patch.n_per_sim)

# expected counts
expected_patches_per_row = (ds_patch.W - patch_size[1]) // ds_patch.patch_stride_w + 1
expected_patches_per_col = (ds_patch.H - patch_size[0]) // ds_patch.patch_stride_h + 1
expected_patches_per_timestep = expected_patches_per_row * expected_patches_per_col
expected_total_samples = ds_patch.n_sims * ds_patch.n_per_sim * expected_patches_per_timestep

print("expected patches_per_row:", expected_patches_per_row, "col:", expected_patches_per_col)
print("expected patches_per_timestep:", expected_patches_per_timestep)
print("dataset.total_samples:", ds_patch.total_samples, "expected:", expected_total_samples)
print("len(dataset):", len(ds_patch))

# assertions
assert ds_patch.patches_per_row == expected_patches_per_row
assert ds_patch.patches_per_col == expected_patches_per_col
assert ds_patch.patches_per_timestep == expected_patches_per_timestep
assert ds_patch.total_samples == expected_total_samples
assert len(ds_patch) == expected_total_samples

# test _decode_index for several important indices
per_timestep = ds_patch.patches_per_timestep
per_sim = ds_patch.n_per_sim * per_timestep
print("per_timestep (patches per timestep):", per_timestep, "per_sim:", per_sim)

# first index
sim0, t0, i0, j0 = ds_patch._decode_index(0)
print("idx 0 ->", sim0, t0, i0, j0)
assert sim0 == 0 and t0 == 0

# last index in first sim
idx_last_first_sim = per_sim - 1
sim_l, t_l, i_l, j_l = ds_patch._decode_index(idx_last_first_sim)
print(f"idx {idx_last_first_sim} ->", sim_l, t_l, i_l, j_l)
assert sim_l == 0 and (0 <= t_l < ds_patch.n_per_sim)

# first index of second sim
idx_per_sim = per_sim
sim2, t2, i2, j2 = ds_patch._decode_index(idx_per_sim)
print(f"idx {idx_per_sim} ->", sim2, t2, i2, j2)
assert sim2 == 1 and t2 == 0

# last global index
idx_last = ds_patch.total_samples - 1
sim_last, t_last, i_last, j_last = ds_patch._decode_index(idx_last)
print(f"last idx {idx_last} -> sim {sim_last}, t {t_last}, i0 {i_last}, j0 {j_last}")
assert sim_last == ds_patch.n_sims - 1
assert 0 <= t_last < ds_patch.n_per_sim

print("\n✅ Mapping & decode tests passed.\n")

# -------------------------------------------------------------------------
# A) CACHE TESTS
# -------------------------------------------------------------------------
print("=== Cache behaviour test ===")

p_h, p_w = ds_patch.patch_size
x_periodic = int(bool(ds_patch._static_cache.get("x_periodic", False)))
y_periodic = int(bool(ds_patch._static_cache.get("y_periodic", False)))
edge_cache_key = f"edge_patch_{p_h}_{p_w}_{x_periodic}_{y_periodic}"

print("Expected patch cache key:", edge_cache_key)
print("initially in static_cache:", [k for k in ds_patch._static_cache.keys() if 'edge' in k])

# Trigger one patch build (calls cached builder)
t0 = time.time()
data0 = ds_patch[0]
t1 = time.time()
print(f"Built first patch in {t1-t0:.3f}s")

print("after building one patch, cached keys:", [k for k in ds_patch._static_cache.keys() if 'edge' in k])
assert edge_cache_key in ds_patch._static_cache, "edge cache key not found after building a patch"

cached = ds_patch._static_cache[edge_cache_key]
edge_index_cached = cached["edge_index"]
edge_attr_cached = cached["edge_attr"]

# call builder again with same cache key -> should return same objects
edge_index2, edge_attr2 = ds_patch._build_grid_edges(
    p_h, p_w,
    data0.pos.reshape(p_h, p_w, 2),
    x_periodic=bool(x_periodic),
    y_periodic=bool(y_periodic),
    cache_key=edge_cache_key,
)
print("edge_index same object:", edge_index2 is edge_index_cached)
print("edge_attr  same object:", edge_attr2 is edge_attr_cached)
assert edge_index2 is edge_index_cached and edge_attr2 is edge_attr_cached

# rebuild without cache key to confirm difference
t0 = time.time()
edge_index_new, edge_attr_new = ds_patch._build_grid_edges(
    p_h, p_w,
    data0.pos.reshape(p_h, p_w, 2),
    x_periodic=bool(x_periodic),
    y_periodic=bool(y_periodic),
    cache_key=None,
)
t1 = time.time()
print(f"Rebuild (no cache key) took {t1-t0:.3f}s")
print("rebuild returned different edge_index object:", edge_index_new is not edge_index_cached)

print("\n✅ Cache tests passed.\n")

# -------------------------------------------------------------------------
# B) SIMPLE DATALOADER TEST
# -------------------------------------------------------------------------
print("=== DataLoader batching test ===")

num_check = 8
idxs = list(range(min(num_check, len(ds_patch))))
subset = Subset(ds_patch, idxs)

loader = DataLoader(subset, batch_size=2, shuffle=False, num_workers=0)

batch = next(iter(loader))
print("Batch.x shape:", batch.x.shape)
print("Batch.pos shape:", batch.pos.shape)
print("Batch.batch shape:", batch.batch.shape)

expected_nodes = 2 * (p_h * p_w)
assert batch.x.shape[0] == expected_nodes, f"expected {expected_nodes} nodes in batch, saw {batch.x.shape[0]}"

print("\n✅ DataLoader batching test passed.\n")
print("All patch-related tests completed successfully.")



=== Patch-enabled dataset: mapping & decode tests ===
patch_size: (64, 64)
patch_stride (h,w): 32 32 32
patches_per_row, patches_per_col: 15 15
patches_per_timestep: 225
n_sims, n_per_sim: 400 100
expected patches_per_row: 15 col: 15
expected patches_per_timestep: 225
dataset.total_samples: 9000000 expected: 9000000
len(dataset): 9000000
per_timestep (patches per timestep): 225 per_sim: 22500
idx 0 -> 0 0 0 0
idx 22499 -> 0 99 448 448
idx 22500 -> 1 0 0 0
last idx 8999999 -> sim 399, t 99, i0 448, j0 448

✅ Mapping & decode tests passed.

=== Cache behaviour test ===
Expected patch cache key: edge_patch_64_64_1_1
initially in static_cache: []
Built first patch in 0.069s
after building one patch, cached keys: ['edge_patch_64_64_1_1']
edge_index same object: True
edge_attr  same object: True
Rebuild (no cache key) took 0.052s
rebuild returned different edge_index object: True

✅ Cache tests passed.

=== DataLoader batching test ===
Batch.x shape: torch.Size([8192, 5])
Batch.pos shape: t

### 3) Mean fields checks

In [4]:
def compute_mean_fields(ds):
    """
    Computes the mean fields values over the entire dataset (all simulations, all timesteps).
    """
    mean_density = []
    mean_pressure = []
    mean_energy = []
    mean_momentum = []  # magnitude

    for i in range(ds.n_sims):
        # load all time possible steps for simulation i
        arrs = ds._load_time_window(sim_idx=i, t_idx=0)
        mean_density.append(arrs["density"].mean())
        mean_pressure.append(arrs["pressure"].mean())
        mean_energy.append(arrs["energy"].mean())

        # for momentum, compute magnitude first
        arrs["momentum"] = np.sqrt((arrs["momentum"]**2).sum(axis=-1))
        mean_momentum.append(arrs["momentum"].mean())

    # now compute global means
    mean_density = sum(mean_density) / ds.n_sims
    mean_pressure = sum(mean_pressure) / ds.n_sims
    mean_energy = sum(mean_energy) / ds.n_sims
    mean_momentum = sum(mean_momentum) / ds.n_sims

    return mean_density, mean_pressure, mean_energy, mean_momentum

In [5]:
train_path = "/work/imos/datasets/euler_multi_quadrants_periodicBC/data/train/euler_multi_quadrants_periodicBC_gamma_1.4_Dry_air_20.hdf5"
valid_path = "/work/imos/datasets/euler_multi_quadrants_periodicBC/data/valid/euler_multi_quadrants_periodicBC_gamma_1.4_Dry_air_20.hdf5"
test_path = "/work/imos/datasets/euler_multi_quadrants_periodicBC/data/test/euler_multi_quadrants_periodicBC_gamma_1.4_Dry_air_20.hdf5"

train_ds = EulerPeriodicDataset(train_path, stats_path=stats_path, time_window=1, patch_size=None, normalize=True)
valid_ds = EulerPeriodicDataset(valid_path, stats_path=stats_path, time_window=1, patch_size=None, normalize=True)
test_ds = EulerPeriodicDataset(test_path, stats_path=stats_path, time_window=1, patch_size=None, normalize=True)

# compute mean fields for each split
train_means = compute_mean_fields(train_ds)
valid_means = compute_mean_fields(valid_ds)
test_means = compute_mean_fields(test_ds)

print("Train means (density, pressure, energy, momentum):", train_means)
print("Valid means (density, pressure, energy, momentum):", valid_means)
print("Test means (density, pressure, energy, momentum):", test_means)

EulerPeriodicDataset: using full-grid samples (512x512), this is large (262144 nodes). Consider patching.
EulerPeriodicDataset: using full-grid samples (512x512), this is large (262144 nodes). Consider patching.
EulerPeriodicDataset: using full-grid samples (512x512), this is large (262144 nodes). Consider patching.
Train means (density, pressure, energy, momentum): (0.9299198387563229, 0.9002689383178949, 2.459686482846737, 0.5267369353398681)
Valid means (density, pressure, energy, momentum): (0.9133071875572205, 0.8708313792943955, 2.3710775804519653, 0.4901703608036041)
Test means (density, pressure, energy, momentum): (0.8470769155025483, 0.7551606976985932, 2.0858439445495605, 0.48677064090967176)
