Skip to content

Add HealDA dataloader protocols and init recipe#1555

Merged
pzharrington merged 14 commits intoNVIDIA:mainfrom
pzharrington:healda-data
Apr 30, 2026
Merged

Add HealDA dataloader protocols and init recipe#1555
pzharrington merged 14 commits intoNVIDIA:mainfrom
pzharrington:healda-data

Conversation

@pzharrington
Copy link
Copy Markdown
Collaborator

@pzharrington pzharrington commented Apr 9, 2026

PhysicsNeMo Pull Request

Description

Adds the HealDA data loader system to physicsnemo.experimental.datapipes. Focused initially on reproducibility, preserving performance features NVR developed, and establishing clear interfaces for users interested in extending with custom data.

Also brings in the unit testing for these components, currently living in the recipe folder.

Documentation is mostly in the recipe readme, some copied here for reference:

The physicsnemo.experimental.datapipes.healda package provides a composable data loading pipeline with clear extension points. The architecture separates components into loaders, transforms, datasets, and sampling infrastructure.

Architecture

ObsERA5Dataset(era5_data, obs_loader, transform)
  |  Temporal windowing via FrameIndexGenerator
  |  __getitems__ -> get() per index -> transform.transform()
  v
RestartableDistributedSampler (stateful distributed sampling with checkpointing)
  |
DataLoader (pin_memory, persistent_workers)
  |
prefetch_map(loader, transform.device_transform)
  |
Training loop (GPU-ready batch)

Key Protocols

Custom data sources and transforms plug in via these protocols
(see physicsnemo.experimental.datapipes.healda.protocols):

ObsLoader — the observation loading interface:

class MyObsLoader:
    async def sel_time(self, times):
        """Return {"obs": [pa.Table, ...]}"""
        ...

Transform / DeviceTransform — two-stage batch
processing:

class MyTransform:
    def transform(self, times, frames):
        """CPU-side: normalize, encode obs, time features."""
        ...

    def device_transform(self, batch, device):
        """GPU-side: move to device, compute obs features."""
        ...

Provided Implementations

Component Module Description
ObsERA5Dataset dataset ERA5 state + observations
UFSUnifiedLoader loaders.ufs_obs Parquet obs loader
ERA5Loader loaders.era5 Async ERA5 zarr loader
ERA5ObsTransform transforms.era5_obs Two-stage transform
RestartableDistributedSampler samplers Stateful distributed sampler
prefetch_map prefetch CUDA stream prefetching

All modules above are under
physicsnemo.experimental.datapipes.healda.

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@pzharrington pzharrington self-assigned this Apr 9, 2026
@pzharrington
Copy link
Copy Markdown
Collaborator Author

pzharrington commented Apr 9, 2026

Testing script comparing against reference loader:

Click to expand code
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compare outputs between the ported data loader and the reference implementation.

Modes:
    smoketest  — 3 sample indices, fast sanity check (~1 min)
    full       — 50 indices spread across train split (~10 min)

Usage:
    # From examples/weather/healda/
    python scripts/compare_loaders.py smoketest
    python scripts/compare_loaders.py full
    python scripts/compare_loaders.py full --indices 0 100 500 1000

Requires:
    - The reference codebase importable (healda-reference/src on PYTHONPATH, or
      the healda package installed).
    - Environment variables from .env (ERA5_74VAR, UFS_OBS_PATH, etc.).
"""

from __future__ import annotations

import argparse
import os
import sys
import time

import numpy as np
import pandas as pd
import torch

# ---------------------------------------------------------------------------
# Path setup
# ---------------------------------------------------------------------------
RECIPE_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
REFERENCE_ROOT = os.path.join(os.path.dirname(RECIPE_ROOT), "healda-reference")

# Load .env
from dotenv import load_dotenv

load_dotenv(os.path.join(RECIPE_ROOT, ".env"))


# ============================================================================
# Ported loader construction
# ============================================================================


def build_ported_dataset(split="train", sensors=None):
    """Construct ObsERA5Dataset from the ported data/ package."""
    from physicsnemo.experimental.datapipes.healda.configs.variable_configs import VARIABLE_CONFIGS
    from physicsnemo.experimental.datapipes.healda.dataset import ObsERA5Dataset
    from physicsnemo.experimental.datapipes.healda.loaders.ufs_obs import UFSUnifiedLoader
    from physicsnemo.experimental.datapipes.healda.transforms.era5_obs import ERA5ObsTransform

    variable_config = VARIABLE_CONFIGS["era5"]

    if sensors is None:
        sensors = ["atms", "mhs", "amsua", "amsub"]

    obs_path = os.environ["UFS_OBS_PATH"]
    obs_loader = UFSUnifiedLoader(
        data_path=obs_path,
        sensors=sensors,
        normalization="zscore",
        obs_context_hours=(-21, 3),
    )

    transform = ERA5ObsTransform(variable_config=variable_config, sensors=sensors)

    import xarray

    era5_path = os.environ["ERA5_74VAR"]
    era5_ds = xarray.open_zarr(era5_path, chunks=None)
    era5_data = era5_ds["data"]

    dataset = ObsERA5Dataset(
        era5_data=era5_data,
        obs_loader=obs_loader,
        transform=transform,
        variable_config=variable_config,
        split=split,
    )
    return dataset


# ============================================================================
# Reference loader construction
# ============================================================================


def build_reference_dataset(split="train", sensors=None):
    """Construct ObsERA5Dataset from the reference healda-reference codebase.

    Requires healda-reference/src and healda-reference/ on sys.path.
    """
    # Add reference paths
    ref_src = os.path.join(REFERENCE_ROOT, "src")
    ref_private = REFERENCE_ROOT
    for p in [ref_src, ref_private]:
        if p not in sys.path:
            sys.path.insert(0, p)

    import dotenv as _dotenv

    _dotenv.load_dotenv(os.path.join(RECIPE_ROOT, ".env"))

    from healda.config.models import ObsConfig
    from private.fcn3_dataset import ObsERA5Dataset as RefObsERA5Dataset

    # Build obs_config matching default sensors
    use_conv = sensors is not None and "conv" in sensors
    obs_config = ObsConfig(
        use_obs=True,
        innovation_type="none",
        context_start=-21,
        context_end=3,
        use_conv=use_conv,
    )

    dataset = RefObsERA5Dataset(
        split=split,
        time_length=1,
        frame_step=1,
        model_rank=0,
        model_world_size=1,
        obs_config=obs_config,
    )
    return dataset


# ============================================================================
# Comparison logic
# ============================================================================


def compare_single_sample(ported_ds, ref_ds, idx: int, verbose: bool = True):
    """Compare a single sample between ported and reference datasets.

    Returns a dict of comparison results.
    """
    results = {"idx": idx, "pass": True, "errors": []}

    # --- Raw data comparison (before transform) ---
    try:
        t0 = time.time()
        ported_times, ported_objs = ported_ds.get(idx)
        ported_elapsed = time.time() - t0

        t0 = time.time()
        ref_times, ref_objs = ref_ds.get(idx)
        ref_elapsed = time.time() - t0

        results["ported_time_s"] = ported_elapsed
        results["ref_time_s"] = ref_elapsed

    except Exception as e:
        results["pass"] = False
        results["errors"].append(f"Loading failed: {e}")
        return results

    # Compare timestamps
    for i, (pt, rt) in enumerate(zip(ported_times, ref_times)):
        if str(pt) != str(rt):
            results["pass"] = False
            results["errors"].append(f"Time mismatch at frame {i}: {pt} vs {rt}")

    # Compare state arrays
    for i, (po, ro) in enumerate(zip(ported_objs, ref_objs)):
        p_state = po["state"]
        r_state = ro["state"]

        if p_state.shape != r_state.shape:
            results["pass"] = False
            results["errors"].append(
                f"State shape mismatch at frame {i}: {p_state.shape} vs {r_state.shape}"
            )
            continue

        max_diff = np.max(np.abs(p_state - r_state))
        results[f"state_frame{i}_maxdiff"] = float(max_diff)

        if max_diff > 1e-6:
            results["pass"] = False
            results["errors"].append(
                f"State value mismatch at frame {i}: max_diff={max_diff:.2e}"
            )

    # Compare observation tables
    # Note: ported and reference may produce rows in different order (due to
    # platform grouping within parquet row-groups).  This is benign — the
    # downstream transform processes all obs in a window together.  We sort
    # both tables by a canonical key before value comparison.
    import pyarrow.compute as pc

    def _sort_obs_table(table):
        """Sort by (Global_Channel_ID, Latitude, Longitude, Absolute_Obs_Time)
        to produce a deterministic row order for comparison."""
        sort_keys = [
            ("Global_Channel_ID", "ascending"),
            ("Latitude", "ascending"),
            ("Longitude", "ascending"),
            ("Absolute_Obs_Time", "ascending"),
        ]
        indices = pc.sort_indices(table, sort_keys=sort_keys)
        return table.take(indices)

    for i, (po, ro) in enumerate(zip(ported_objs, ref_objs)):
        p_obs = po.get("obs")
        r_obs = ro.get("obs") or ro.get("obs_v2")  # reference uses legacy key

        if p_obs is None and r_obs is None:
            continue
        if (p_obs is None) != (r_obs is None):
            results["pass"] = False
            results["errors"].append(f"Obs presence mismatch at frame {i}")
            continue

        p_nrows = p_obs.num_rows
        r_nrows = r_obs.num_rows
        results[f"obs_frame{i}_nrows_ported"] = p_nrows
        results[f"obs_frame{i}_nrows_ref"] = r_nrows

        if p_nrows != r_nrows:
            results["pass"] = False
            results["errors"].append(
                f"Obs row count mismatch at frame {i}: {p_nrows} vs {r_nrows}"
            )
            continue

        if p_nrows > 0:
            # Compare schemas
            p_cols = set(p_obs.schema.names)
            r_cols = set(r_obs.schema.names)
            if p_cols != r_cols:
                results["pass"] = False
                results["errors"].append(
                    f"Obs schema mismatch at frame {i}: "
                    f"ported_only={p_cols - r_cols}, ref_only={r_cols - p_cols}"
                )
                continue

            # Sort both tables to canonical order before comparison
            p_sorted = _sort_obs_table(p_obs)
            r_sorted = _sort_obs_table(r_obs)

            # Compare observation values
            p_vals = p_sorted["Observation"].to_numpy()
            r_vals = r_sorted["Observation"].to_numpy()
            obs_max_diff = np.nanmax(np.abs(p_vals - r_vals))
            results[f"obs_frame{i}_val_maxdiff"] = float(obs_max_diff)
            if obs_max_diff > 1e-5:
                results["pass"] = False
                results["errors"].append(
                    f"Obs value mismatch at frame {i}: max_diff={obs_max_diff:.2e}"
                )

            # Also verify Global_Channel_ID sets match
            p_gcids = set(p_obs["Global_Channel_ID"].to_pylist())
            r_gcids = set(r_obs["Global_Channel_ID"].to_pylist())
            if p_gcids != r_gcids:
                results["pass"] = False
                results["errors"].append(
                    f"Obs GCID set mismatch at frame {i}: "
                    f"ported_only={p_gcids - r_gcids}, ref_only={r_gcids - p_gcids}"
                )

    if verbose:
        status = "PASS" if results["pass"] else "FAIL"
        timing = (
            f"ported={results.get('ported_time_s', 0):.2f}s "
            f"ref={results.get('ref_time_s', 0):.2f}s"
        )
        print(f"  [{status}] idx={idx:6d}  {timing}")
        for err in results["errors"]:
            print(f"         {err}")

    return results


def compare_transformed_sample(ported_ds, ref_ds, idx: int, verbose: bool = True):
    """Compare transformed (batched) output between ported and reference.

    Uses __getitems__ to exercise the full transform pipeline.
    """
    results = {"idx": idx, "pass": True, "errors": []}

    try:
        t0 = time.time()
        ported_batch = ported_ds.__getitems__([idx])
        ported_elapsed = time.time() - t0

        t0 = time.time()
        ref_batch = ref_ds.__getitems__([idx])
        ref_elapsed = time.time() - t0

        results["ported_transform_s"] = ported_elapsed
        results["ref_transform_s"] = ref_elapsed

    except Exception as e:
        results["pass"] = False
        results["errors"].append(f"Transform failed: {e}")
        if verbose:
            print(f"  [FAIL] idx={idx:6d} Transform error: {e}")
        return results

    # Compare batch dict keys
    p_keys = set(ported_batch.keys())
    r_keys = set(ref_batch.keys())
    if p_keys != r_keys:
        results["errors"].append(
            f"Batch key mismatch: ported_only={p_keys - r_keys}, ref_only={r_keys - p_keys}"
        )
        # Don't fail — extra/missing keys may be intentional

    # Compare tensor fields
    for key in sorted(p_keys & r_keys):
        pv = ported_batch[key]
        rv = ref_batch[key]

        if isinstance(pv, torch.Tensor) and isinstance(rv, torch.Tensor):
            if pv.shape != rv.shape:
                results["pass"] = False
                results["errors"].append(
                    f"Shape mismatch for '{key}': {pv.shape} vs {rv.shape}"
                )
                continue

            if pv.numel() == 0:
                continue
            max_diff = (pv.float() - rv.float()).abs().max().item()
            results[f"{key}_maxdiff"] = max_diff

            # Use loose tolerance for float transforms
            tol = 1e-4 if pv.is_floating_point() else 0
            if max_diff > tol:
                results["pass"] = False
                results["errors"].append(
                    f"Value mismatch for '{key}': max_diff={max_diff:.2e}"
                )

        elif isinstance(pv, tuple) and isinstance(rv, tuple):
            # unified_obs is a tuple (obs_tensors, lengths_3d)
            # Row ordering may differ between ported and reference (benign —
            # within each sensor group, platforms can appear in different order
            # depending on parquet row-group layout).  We sort both by
            # (global_channel_id, latitude, longitude) before comparing values.
            if len(pv) != len(rv):
                results["pass"] = False
                results["errors"].append(
                    f"Tuple length mismatch for '{key}': {len(pv)} vs {len(rv)}"
                )
                continue

            if isinstance(pv[0], dict) and isinstance(rv[0], dict):
                p_obs_keys = set(pv[0].keys())
                r_obs_keys = set(rv[0].keys())
                if p_obs_keys != r_obs_keys:
                    results["errors"].append(
                        f"Obs tensor key mismatch: "
                        f"ported_only={p_obs_keys - r_obs_keys}, "
                        f"ref_only={r_obs_keys - p_obs_keys}"
                    )

                # Build a stable sort index using torch.lexsort-style
                # multi-key sorting: (gcid, abs_time, lat, lon, observation)
                def _sort_idx(obs_dict):
                    gcid = obs_dict.get("global_channel_id")
                    lat = obs_dict.get("latitude")
                    lon = obs_dict.get("longitude")
                    obs_time = obs_dict.get("absolute_obs_time")
                    obs_val = obs_dict.get("observation")
                    if gcid is None or gcid.numel() == 0:
                        return None
                    # Stack columns as (N, K) float64 for lexicographic sort.
                    # torch.lexsort isn't available, so we use numpy.
                    cols = [gcid.double().cpu().numpy()]
                    if obs_time is not None:
                        cols.append(obs_time.double().cpu().numpy())
                    if lat is not None:
                        cols.append(lat.double().cpu().numpy())
                    if lon is not None:
                        cols.append(lon.double().cpu().numpy())
                    if obs_val is not None:
                        cols.append(obs_val.double().cpu().numpy())
                    # np.lexsort sorts by last key first, so reverse
                    order = np.lexsort(cols[::-1])
                    return torch.from_numpy(order).long()

                p_order = _sort_idx(pv[0])
                r_order = _sort_idx(rv[0])

                for obs_key in sorted(p_obs_keys & r_obs_keys):
                    pt = pv[0][obs_key]
                    rt = rv[0][obs_key]
                    if pt.shape != rt.shape:
                        results["pass"] = False
                        results["errors"].append(
                            f"Obs tensor shape mismatch for '{obs_key}': "
                            f"{pt.shape} vs {rt.shape}"
                        )
                    elif pt.numel() > 0:
                        # Apply sort order before comparison
                        ps = pt[p_order] if p_order is not None else pt
                        rs = rt[r_order] if r_order is not None else rt
                        d = (ps.float() - rs.float()).abs().max().item()
                        results[f"obs_{obs_key}_maxdiff"] = d
                        if d > 1e-4:
                            results["pass"] = False
                            results["errors"].append(
                                f"Obs tensor mismatch for '{obs_key}': "
                                f"max_diff={d:.2e}"
                            )

            # Compare lengths_3d (sensor, batch, time) — these count obs per
            # sensor/window and are order-independent as long as sensor_id
            # mapping matches.
            for ti, name in [(1, "lengths")]:
                if ti < len(pv) and ti < len(rv):
                    pt, rt = pv[ti], rv[ti]
                    if isinstance(pt, torch.Tensor) and isinstance(rt, torch.Tensor):
                        if pt.shape != rt.shape:
                            results["pass"] = False
                            results["errors"].append(
                                f"{name} shape mismatch: {pt.shape} vs {rt.shape}"
                            )
                        elif not torch.equal(pt, rt):
                            results["pass"] = False
                            results["errors"].append(f"{name} value mismatch")

    if verbose:
        status = "PASS" if results["pass"] else "FAIL"
        timing = (
            f"ported={results.get('ported_transform_s', 0):.2f}s "
            f"ref={results.get('ref_transform_s', 0):.2f}s"
        )
        print(f"  [{status}] idx={idx:6d}  {timing}")
        for err in results["errors"]:
            print(f"         {err}")

    return results


# ============================================================================
# Main driver
# ============================================================================


def get_indices(mode: str, ds_len: int, custom_indices=None):
    """Return sample indices based on mode."""
    if custom_indices:
        return [i for i in custom_indices if i < ds_len]

    if mode == "smoketest":
        # 3 indices: start, middle, near end
        return [0, ds_len // 2, ds_len - 1]

    elif mode == "full":
        # 50 indices spread across the dataset
        n = min(50, ds_len)
        step = max(1, ds_len // n)
        return list(range(0, ds_len, step))[:n]

    else:
        raise ValueError(f"Unknown mode: {mode}")


def main():
    parser = argparse.ArgumentParser(
        description="Compare ported vs reference data loader outputs."
    )
    parser.add_argument(
        "mode",
        choices=["smoketest", "full"],
        help="smoketest: 3 indices, fast. full: 50 indices.",
    )
    parser.add_argument(
        "--indices",
        type=int,
        nargs="*",
        default=None,
        help="Override indices to compare.",
    )
    parser.add_argument(
        "--split", default="train", help="Dataset split (default: train)."
    )
    parser.add_argument(
        "--sensors",
        nargs="*",
        default=None,
        help="Sensor list (default: atms mhs amsua amsub).",
    )
    parser.add_argument(
        "--transform",
        action="store_true",
        help="Also compare transformed (__getitems__) output.",
    )
    parser.add_argument(
        "--no-raw",
        action="store_true",
        help="Skip raw (get) comparison, only do transform.",
    )
    args = parser.parse_args()

    print("=" * 70)
    print(f"Loader comparison — mode={args.mode}, split={args.split}")
    print("=" * 70)

    # Build datasets
    print("\nBuilding ported dataset...")
    t0 = time.time()
    ported_ds = build_ported_dataset(split=args.split, sensors=args.sensors)
    print(f"  Done in {time.time() - t0:.1f}s  (len={len(ported_ds)})")

    print("Building reference dataset...")
    t0 = time.time()
    ref_ds = build_reference_dataset(split=args.split, sensors=args.sensors)
    print(f"  Done in {time.time() - t0:.1f}s  (len={len(ref_ds)})")

    # Verify lengths match
    if len(ported_ds) != len(ref_ds):
        print(
            f"\nWARNING: Dataset lengths differ! "
            f"ported={len(ported_ds)} vs ref={len(ref_ds)}"
        )

    indices = get_indices(
        args.mode, min(len(ported_ds), len(ref_ds)), args.indices
    )
    print(f"\nComparing {len(indices)} samples: {indices[:10]}{'...' if len(indices) > 10 else ''}")

    # --- Raw comparison ---
    if not args.no_raw:
        print(f"\n--- Raw comparison (get) ---")
        raw_results = []
        for idx in indices:
            r = compare_single_sample(ported_ds, ref_ds, idx)
            raw_results.append(r)

        n_pass = sum(1 for r in raw_results if r["pass"])
        n_fail = len(raw_results) - n_pass
        print(f"\nRaw: {n_pass}/{len(raw_results)} passed, {n_fail} failed")

    # --- Transform comparison ---
    if args.transform:
        print(f"\n--- Transform comparison (__getitems__) ---")
        xform_results = []
        for idx in indices:
            r = compare_transformed_sample(ported_ds, ref_ds, idx)
            xform_results.append(r)

        n_pass = sum(1 for r in xform_results if r["pass"])
        n_fail = len(xform_results) - n_pass
        print(f"\nTransform: {n_pass}/{len(xform_results)} passed, {n_fail} failed")

    # --- Summary ---
    print("\n" + "=" * 70)
    all_results = []
    if not args.no_raw:
        all_results.extend(raw_results)
    if args.transform:
        all_results.extend(xform_results)

    n_total = len(all_results)
    n_pass = sum(1 for r in all_results if r["pass"])
    if n_total == n_pass:
        print(f"ALL {n_total} CHECKS PASSED")
    else:
        print(f"{n_total - n_pass}/{n_total} CHECKS FAILED")
        sys.exit(1)


if __name__ == "__main__":
    main()

Comment thread physicsnemo/experimental/datapipes/healda/prefetch.py
Comment thread physicsnemo/experimental/datapipes/healda/configs/variable_configs.py Outdated
Comment thread examples/weather/healda/requirements.txt Outdated
Copy link
Copy Markdown
Contributor

@aayushg55 aayushg55 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took a final pass. Looks good, I should be able to extend the obs transform easily. Just need to update the requirements and rename the platform field.

Comment thread physicsnemo/experimental/datapipes/healda/types.py Outdated
@pzharrington pzharrington marked this pull request as ready for review April 17, 2026 00:12
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 17, 2026

Greptile Summary

This PR introduces the physicsnemo.experimental.datapipes.healda package — a composable ERA5 + satellite/conventional observation data pipeline for HealDA training, including async zarr/parquet loaders, two-stage CPU/GPU transforms, a stateful distributed sampler, and CUDA-stream prefetching.

  • P1 — split_array_contiguous crashes on single-element arrays (indexing.py:44): d = x[1] - x[0] raises IndexError whenever the time array has exactly one element; only size == 0 is guarded.
  • P1 — Multi-window row-group data loss in _iterate_parquet_da_windows (ufs_obs.py:207–229): when a parquet row group spans multiple DA windows, this_window is overwritten for each match and only the last one is yielded; observations from earlier windows end up stored under the wrong time key or silently dropped.

Important Files Changed

Filename Overview
physicsnemo/experimental/datapipes/healda/indexing.py New temporal indexing module. Two bugs: split_array_contiguous crashes with IndexError on single-element arrays; _map_logical_to_physical uses total_samples instead of valid_length for the bounds check.
physicsnemo/experimental/datapipes/healda/loaders/ufs_obs.py New observation loader. Three issues: (1) _iterate_parquet_da_windows silently drops/misattributes observations when a parquet row group spans multiple DA windows; (2) channel_table local_channel_id computation uses Python wrap-around at i=0; (3) fixed_range normalization hardcodes [0, 400] instead of per-channel min_valid/max_valid.
physicsnemo/experimental/datapipes/healda/samplers.py New stateful distributed sampler. Rank-specific RNG seeds produce independent (not partitioned) permutations per rank, so uniform dataset coverage per epoch is not guaranteed; this is intentional but undocumented.
physicsnemo/experimental/datapipes/healda/dataset.py New map-style dataset combining ERA5 state with observations; clean implementation with good docstrings and test coverage.
physicsnemo/experimental/datapipes/healda/prefetch.py New background CUDA prefetch iterator; well-structured with proper error propagation. Minor: _stop() may leave the worker thread stuck on queue.put() when the queue is full and the consumer is gone, but daemon=True limits the blast radius.
physicsnemo/experimental/datapipes/healda/protocols.py Defines ObsLoader, Transform, and DeviceTransform protocols; well-documented and correctly uses runtime_checkable.
physicsnemo/experimental/datapipes/healda/loaders/era5.py ERA5 zarr loader with normalization stats; adds renamed keys to data without removing originals (harmless since _collect_fields uses index-based lookup), but slightly wasteful.
physicsnemo/experimental/datapipes/healda/transforms/era5_obs.py Two-stage CPU+GPU transform; correctly separates DataLoader-worker CPU processing from CUDA-stream GPU featurization.

Reviews (1): Last reviewed commit: "Revert precommit change" | Re-trigger Greptile

Comment thread physicsnemo/experimental/datapipes/healda/indexing.py
Comment thread physicsnemo/experimental/datapipes/healda/loaders/ufs_obs.py Outdated
Comment thread physicsnemo/experimental/datapipes/healda/loaders/ufs_obs.py
Comment thread physicsnemo/experimental/datapipes/healda/loaders/ufs_obs.py Outdated
Comment thread physicsnemo/experimental/datapipes/healda/samplers.py
Comment thread physicsnemo/experimental/datapipes/healda/indexing.py
@pzharrington
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

Copy link
Copy Markdown
Collaborator

@ktangsali ktangsali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes to the TOML file and the Codeowners file look good.

Copy link
Copy Markdown
Collaborator

@coreyjadams coreyjadams left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have two main concerns with this PR.

First, the updates to pyproject.toml are model or example specific, and maybe they should target a specific examples requirements.txt and not pyproject.toml.

Second, and this is bigger, the functionality built here in the experimental datapipes is replicating significant functionality built for physicsnemo datapipes in v2.0 and we should not merge it as it stands: we will be supporting the same thing, effectively, twice, at twice the engineering burden.

Can we see if the HealDA datapipes can align with the physicsnemo datapipes and not introduce tech debt?

Comment thread pyproject.toml Outdated
Comment thread physicsnemo/experimental/datapipes/healda/protocols.py
Comment thread physicsnemo/experimental/datapipes/healda/prefetch.py
@pzharrington
Copy link
Copy Markdown
Collaborator Author

Addressing your comments @coreyjadams:

  1. I can migrate the deps to be in the requirements.txt of the recipe for now. I had put them in the toml to start paving the way for full integration but it's not necessary at the moment, recipe-specific is fine for now.
  2. I agree the integration into physicsnemo.datapipes should not happen until these pieces are aligned with that API (and/or extensions are made to the datapipes interface to support some of the patterns here if need be). Until then, we deemed physicsnemo.experimental to be a good compromise in getting a known working configuration in place upon which this integration/refactoring will be easier.

To list concretely some of the reasons behind the departure from datapipes conventions:

  • Loading ragged obs data from PyArrow tables (messy, non-uniform in time/across sensors per batch) doesn't map clearly to TensorDict. Among other things requires custom indexing and collation logic
  • Some aspects of these data transforms are per-batch and CPU-only; datapipes currently assumes transforms are per-sample and GPU-only
  • Slightly difference parallelization design, using the PyTorch CPU-side workers for doing pyarrow.compute stuff rather than the ThreadPoolExecutor of datapipes.

These are not insurmountable but I think will be easier to attack once the initial pieces that are present in this PR, along with the test suite for sanity, are in place. Basically asking to not let perfection be the enemy of the good here 🙂

Copy link
Copy Markdown
Collaborator

@coreyjadams coreyjadams left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving to unblock this for release, but the plan of record was:

  • requirements move out of pyproject.toml and into model example requirements, so there will not be updates to the package requirements.
  • If possible, the normalization csvs will move to the example so that they do not pollute the pip wheel, though I acknowledge they are not overly large. I don't think we want to set this precedent in general.
  • Before any merge of this datapipe from experimental into the core code base, it will get heavily refactored into alignment with physicsnemo datapipes to remove code duplication, etc.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 28, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@pzharrington
Copy link
Copy Markdown
Collaborator Author

/ok to test 1883017

Copy link
Copy Markdown
Collaborator

@NickGeneva NickGeneva left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving for code owners.

@pzharrington pzharrington enabled auto-merge April 28, 2026 23:46
@pzharrington
Copy link
Copy Markdown
Collaborator Author

/ok to test 0368d4d

@pzharrington
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@pzharrington pzharrington added this pull request to the merge queue Apr 30, 2026
Merged via the queue into NVIDIA:main with commit 845906f Apr 30, 2026
6 checks passed
peterdsharpe added a commit to peterdsharpe/physicsnemo that referenced this pull request May 4, 2026
commit 91a942b
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 21:12:32 2026 -0400

    Adds Greptile minor fixes

commit b24f9b6
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 20:56:24 2026 -0400

    Back-merges dataset interrogate fix

commit 6ddfb5a
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 20:43:01 2026 -0400

    Removes accidentally-commited benchmarks; these will come later

commit 9fa0b5d
Merge: 3e67057 4c52a45
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 20:37:57 2026 -0400

    Merge branch 'main' into psharpe/add-GLOBE-drivaerml-standalone

commit 3e67057
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 20:35:32 2026 -0400

    Partial merge from add-GLOBE-3D-BarnesHut

commit 4c52a45
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 16:38:28 2026 -0400

    Synchronizes `GLOBE` model progress for 26.05 (NVIDIA#1595)

    * Migrate cached_dataset.py

    * verified model arch new features (self_regularization_beta)

    * minor formatting syncs

    * Adds nonregression testing

    * Adds compile_logging utilities and prefetching utilities

    * Adds self to pade.py codeowners

    * Syncs AirFRANS updates

    * corrects a docstring

    * Strips out broken ram caching

    * Adds helpful error messages

    * Adds helpful error messages

    * docs

    * Refactor compile logging in training script

    - Removed the CompileDiagnosticsCollector and replaced it with a new utility function, silence_compile_logs_on_non_zero_ranks, to suppress non-error logs from torch.compile on all ranks except rank 0.
    - Updated the training script to call this new function, improving log clarity during distributed training.
    - Adjusted logging levels for the globe logger to ensure proper diagnostics are captured only during the first launch.

    * Enhance DataLoader worker configuration for distributed training

    - Updated the logic for auto-computing `num_workers` in the `AirFRANSDataSet` class to consider CPU affinity and local world size, improving efficiency in distributed environments.
    - Adjusted logging to provide detailed information about the computed `num_workers`, including CPU count and GPU visibility.
    - Modified the run script comments to reflect the new method of calculating `num_workers`, ensuring clarity on process-level parallelism.

    * Partial merge from add-GLOBE-3D-BarnesHut

commit 4cb586a
Merge: 645701f ed855da
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 15:23:50 2026 -0400

    Merge branch 'main' into psharpe/add-GLOBE-model-progress

commit 645701f
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 15:06:55 2026 -0400

    Partial merge from add-GLOBE-3D-BarnesHut

commit 15d7913
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 13:45:06 2026 -0400

    Enhance DataLoader worker configuration for distributed training

    - Updated the logic for auto-computing `num_workers` in the `AirFRANSDataSet` class to consider CPU affinity and local world size, improving efficiency in distributed environments.
    - Adjusted logging to provide detailed information about the computed `num_workers`, including CPU count and GPU visibility.
    - Modified the run script comments to reflect the new method of calculating `num_workers`, ensuring clarity on process-level parallelism.

commit 65675a4
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri May 1 17:53:47 2026 -0400

    Refactor compile logging in training script

    - Removed the CompileDiagnosticsCollector and replaced it with a new utility function, silence_compile_logs_on_non_zero_ranks, to suppress non-error logs from torch.compile on all ranks except rank 0.
    - Updated the training script to call this new function, improving log clarity during distributed training.
    - Adjusted logging levels for the globe logger to ensure proper diagnostics are captured only during the first launch.

commit 948da86
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri May 1 15:40:31 2026 -0400

    docs

commit ed855da
Author: Charlelie Laurent <84199758+CharlelieLrt@users.noreply.github.com>
Date:   Thu Apr 30 20:44:37 2026 -0700

    Implements Predictor specialization for multi-diffusion (NVIDIA#1573)

    * Implements Predictor specialization for multi-diffusion

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Compile denoiser in multi-diffusion sampling compile tests

    Compiling the predictor instance directly was producing divergent results
    under torch 2.10 in the sample() loop (euler cases only). Follow the same
    pattern as test_samplers.py::TestSampleCompile and compile the denoiser
    closure instead — tracing through it still verifies that the predictor's
    __call__ path is compile-compatible.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Avoid fullgraph compile in multi-diffusion sampling test

    torch 2.10 Dynamo crashes with Fatal Python error: Aborted when tracing
    the nested MultiDiffusionPredictor -> MultiDiffusionModel2D call chain
    inside sample() with fullgraph=True. Allow graph breaks here; the
    predictor compile contract is still tested in isolation by
    test_multi_diffusion_predictor.py::TestCompile.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Flatten MultiDiffusionPredictor hot path for torch.compile

    Dispatch on pos_embd presence and model_kwargs is now resolved once at
    __init__ into a specialized closure, so __call__ is branch-free and the
    no-kwargs path avoids ** expansion. This keeps fullgraph=True compile
    cleanly traceable under torch 2.10 (which was hitting a Dynamo abort on
    the nested MultiDiffusionPredictor -> MultiDiffusionModel2D call chain
    when the denoiser closure was compiled in the sample() loop).

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Loosen torch.compile euler check in multi-diffusion sampling tests

    Reverts the two earlier CI-fix attempts (compile-denoiser switch, predictor
    hot-path flatten) since neither actually fixed the divergence. The
    underlying issue is an upstream torch>=2.10 Dynamo bug: euler + compiled
    MultiDiffusionPredictor produces numerically divergent results. Heun works,
    predictor compiles correctly in isolation. For euler we now assert only
    shape + isfinite until the upstream bug is resolved.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Force contiguous t_cur/t_next in Euler solvers

    sample() passes t.expand(B) (a stride-0 non-contiguous tensor) into
    solver.step(). HeunSolver already forces .contiguous() on both tensors to
    prevent torch.compile from specializing on the stride pattern of the first
    call and then either mis-firing guards or silently recompiling on
    subsequent calls with different underlying storage.

    EulerSolver and EDMStochasticEulerSolver had no such guard, which was a
    latent bug exposed by torch 2.10 (stricter stride tracking) in the
    multi-diffusion compiled sample loop — producing 90%+ element divergence
    vs eager on the first call and a Dynamo abort on the second call. Apply
    the same fix uniformly across all four solver steps.

    Also revert the temporary loosened euler assertion in
    test_multi_diffusion_sampling.py now that the real fix is in place.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Drop dead is_compiling guard and inherit from Predictor in MultiDiffusionPredictor

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Narrow _patching type and tighten multi-diffusion tests

    Move the _patching None check out of the is_compiling guard in
    MultiDiffusionModel2D so the type checker narrows self._patching
    to RandomPatching2D | GridPatching2D for the rest of each method,
    and route fuse/reset_patch_indices through isinstance.

    Streamline TestConstructor to only exercise the public contract
    (.fuse, .model, setter round-trip) and drop assertions on private
    caches. Compile the denoiser instead of the predictor in
    TestMultiDiffusionSampleCompile and add TestMultiDiffusionFullSamplerCompile
    mirroring test_samplers.py.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Force contiguous pos_embd before patching

    pos_embd.unsqueeze(0).expand(B, -1, -1, -1) produces a stride-0 view
    (all B copies share storage). Passing this through nn.ReflectionPad2d
    and F.unfold inside image_batching triggers a glibc heap corruption
    on torch 2.10 (CI, not locally on torch 2.8) when the first non-regression
    posembd_sin test runs. Same class of fix as the earlier euler solver.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Use functional F.pad in image_batching

    Instantiating torch.nn.ReflectionPad2d inside image_batching on every
    call creates a fresh nn.Module each time, which torch.compile / AOT
    autograd struggles to trace cleanly under fullgraph=True on torch 2.10.
    Switch to torch.nn.functional.pad which is a plain functional call and
    traces without allocating a module. Same result semantically.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Replace einops.rearrange with native torch reshape+permute

    einops.rearrange goes through a pattern-matched lowering path that
    torch.compile / inductor on torch 2.10 handles fragilely in the
    image_batching / image_fuse hot paths. The underlying transform is a
    plain view + permute + view, so express it directly: this gives inductor
    a straightforward sequence of ops to trace, and drops the einops
    dependency from this module.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Materialise returned tensors in multi-diffusion fuse path

    Under torch.compile / inductor on torch 2.10, a compiled sample() call
    through MultiDiffusionPredictor was returning a tensor whose metadata
    was valid but whose data pointer was dangling (use-after-free) — the
    caller SIGABRTed on the first read of the tensor data. Add .contiguous()
    at the two boundaries that returned a view: image_fuse returns
    x_folded[...] / overlap_count[...], and MultiDiffusionModel2D.forward
    returns the (possibly fused) inner-model output. Forcing fresh storage
    on each boundary prevents the returned tensor from aliasing a buffer
    whose lifetime ends with the compiled frame.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Use clone instead of contiguous at fuse boundary

    The second torch.compile call of a fused MultiDiffusionPredictor was
    segfaulting (SIGSEGV) while the first succeeded. .contiguous() is a
    no-op when the tensor is already contiguous, so inductor could still
    see the returned tensor as aliasing an internal buffer across calls.
    .clone() always allocates fresh storage, so successive compiled calls
    get independent outputs. Also drop the redundant .contiguous() added
    earlier in MultiDiffusionModel2D.forward now that image_fuse owns that
    boundary.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Revert speculative fuse-boundary copies and xfail full-sampler compile on torch>=2.10

    Revert commits 3dfcdb5, 746518f and a007c46 (native-torch rearrange
    in image_batching/image_fuse, .contiguous() on returned tensors, .clone()
    at fuse boundary) since they did not resolve the torch 2.10 inductor
    codegen segfault in TestMultiDiffusionFullSamplerCompile. Keep commits
    7e1db11 (pos_embd .contiguous() for the glibc heap corruption in
    posembd_sin non-regression tests) and feb0d9e (ReflectionPad2d → F.pad).

    Gate TestMultiDiffusionFullSamplerCompile with xfail(run=False) when
    torch>=2.10 so the SIGSEGV does not bring down the pytest process.
    TestMultiDiffusionSampleCompile (per-step denoiser compile) still runs.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Minor updates to predictor.py

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Drop redundant _patching_type and add test-time-only docstring warning

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    ---------

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

commit 16a336f
Author: Peter Harrington <48932392+pzharrington@users.noreply.github.com>
Date:   Thu Apr 30 00:03:24 2026 -0700

    FSDP optimizer state channels last fix (NVIDIA#1597)

    * Fix channels last FSDP optimizer state load bug

    * lint

    * Catch use_orig_params=True case

commit 845906f
Author: Peter Harrington <48932392+pzharrington@users.noreply.github.com>
Date:   Wed Apr 29 23:59:29 2026 -0700

    Add HealDA dataloader protocols and init recipe (NVIDIA#1555)

    * Add healda protocols and loaders to experimental

    * Cleanup and address imports

    * Update precommit for examples tests

    * integrate restartable sampler, other updates, migrate tests

    * move imports, cleanup

    * ruff check fix

    * skip prefetch on CPU

    * Rename to local_platform

    * Revert precommit change

    * greptile feedback

    * Migrate CSVs and deps to example

    * lockfile fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants