Skip to content

feat(crash): lazy Zarr loading for multi-GPU DDP training#1703

Open
Thabhelo wants to merge 3 commits into
NVIDIA:mainfrom
Thabhelo:feat/crash-zarr-lazy-ddp-1550
Open

feat(crash): lazy Zarr loading for multi-GPU DDP training#1703
Thabhelo wants to merge 3 commits into
NVIDIA:mainfrom
Thabhelo:feat/crash-zarr-lazy-ddp-1550

Conversation

@Thabhelo
Copy link
Copy Markdown

@Thabhelo Thabhelo commented Jun 6, 2026

Summary

  • Add lazy Zarr loading (lazy_load: true by default) so mesh trajectories and point features materialize on first sample access instead of at dataset construction.
  • Update the crash datapipe to stream stats and defer per-sample tensors/graphs when lazy records are returned.
  • Document DDP memory behavior (DistributedSampler shards indices, not host RAM) in the crash README and log a hint in train.py when multi-GPU + lazy Zarr are active.

closes #1550

Test plan

  • pytest examples/structural_mechanics/crash/tests/test_zarr_reader.py examples/structural_mechanics/crash/tests/test_datapipe_lazy.py
  • Eager-mode smoke test with Reader(lazy_load=False) (all samples materialized at init)
  • pre-commit on changed files

Notes

Defer Zarr mesh and point-data materialization until sample access so each
DDP rank no longer loads the full dataset at construction time. Document
DistributedSampler memory behavior and enable lazy_load by default.

closes NVIDIA#1550

Signed-off-by: Thabhelo <50872400+Thabhelo@users.noreply.github.com>
@Thabhelo Thabhelo requested a review from coreyjadams as a code owner June 6, 2026 04:11
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Jun 6, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Thabhelo
Copy link
Copy Markdown
Author

Thabhelo commented Jun 6, 2026

Follow-up: interrogate hook false positives (separate PR)

What this PR is for: lazy Zarr loading and DDP memory documentation for #1550.

While getting this through pre-commit, the interrogate check failed even though the flagged symbols (SimSample.to, CrashGraphDataset.create_graph, Trainer.train, etc.) were already listed in test/ci_tests/interrogate_baseline.txt.

Root cause: test/ci_tests/check_docstring_coverage.py parses interrogate section headers with a regex that only matches lines surrounded by 4+ equals signs (={4,}). Interrogate 1.7 (pinned in pre-commit) prints headers like:

= Coverage for /abs/path/examples/structural_mechanics/crash/ =

Because that header is not matched, current_dir stays empty and file paths are resolved incorrectly (e.g. ../../../datapipe.py instead of examples/structural_mechanics/crash/datapipe.py). Baseline entries then look like "new" undocumented items whenever those example files are edited.

Temporary workaround in this PR: added small docstrings on those symbols so they are no longer reported as MISSED and the hook passes.

Actual fix: #1704 broadens the header parser so baseline matching works correctly — this means no docstring workarounds needed going forward even for the next person who touches an example file with baseline gaps.

@Thabhelo
Copy link
Copy Markdown
Author

Thabhelo commented Jun 6, 2026

Follow-up parser fix opened in #1704.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 6, 2026

Greptile Summary

This PR adds lazy Zarr loading (lazy_load: true by default) to the crash example so that mesh trajectories and point features are materialized on first __getitem__ access instead of at dataset construction — reducing per-rank host RAM usage in multi-GPU DDP training. It also documents the DDP memory model in the README and adds a log hint in train.py.

  • zarr_reader.py gains validate_zarr_store, load_zarr_edges, and materialize_zarr_record helpers; Reader now always returns a 4-tuple and accepts a lazy_load flag.
  • datapipe.py splits the per-sample build/normalize pipeline into lazy and eager branches, with new lazy stat-computation methods (_compute_autoreg_node_stats_lazy, _compute_feature_stats_lazy, _compute_edge_stats_lazy) and on-demand materialization helpers (_ensure_sample_loaded, _ensure_graph_loaded).
  • Two new test files cover lazy reader round-trips and deferred-materialization behaviour.

Important Files Changed

Filename Overview
examples/structural_mechanics/crash/datapipe.py Adds lazy loading support with two correctness bugs: biased running-mean in _compute_feature_stats_lazy/_compute_edge_stats_lazy, and mismatched normalization state between edge stat computation (unnormalized) and actual graph construction (normalized).
examples/structural_mechanics/crash/zarr_reader.py Cleanly adds lazy-load support: new validate_zarr_store, load_zarr_edges, and materialize_zarr_record helpers; Reader gains lazy_load flag and now always returns a 4-tuple with an empty global_features list.
examples/structural_mechanics/crash/tests/test_datapipe_lazy.py New test verifies deferred materialization behavior; does not exercise multi-sample stat computation, so the biased-mean bugs are not caught.
examples/structural_mechanics/crash/tests/test_zarr_reader.py Updates create_dataset to create_array (zarr v3 API), adds lazy-load reader tests and materialize_zarr_record round-trip test; changes are clean.
examples/structural_mechanics/crash/train.py Minor: adds an info log for multi-GPU + lazy-load path and docstrings; no logic changes.
examples/structural_mechanics/crash/conf/reader/zarr.yaml Adds lazy_load: true to the Zarr reader config; straightforward change.
examples/structural_mechanics/crash/README.md Adds accurate documentation of DDP + Zarr memory behavior and lazy_load configuration; no issues.

Comments Outside Diff (1)

  1. examples/structural_mechanics/crash/datapipe.py, line 646-660 (link)

    P1 Edge stats computed on unnormalized positions, but graphs are built from normalized positions

    _compute_edge_stats_lazy calls _load_sample_tensors(i, retain=False, normalize=False), so the temporary graphs it creates have edge features (displacements and distances) derived from raw, unnormalized node positions. However, _ensure_graph_loaded calls _ensure_sample_loaded first — which stores normalized positions in self.mesh_pos_seq — then uses self.mesh_pos_seq[batch_idx][0] (the normalized position) to build the final graph. The resulting edge features are therefore on a completely different scale from the statistics used to normalize them. In the eager path, graphs are built from already-normalized positions (because _normalize_loaded_samples runs inside the base-class __init__ before graph construction), so its stats and features are consistent. The lazy path breaks that consistency and will produce garbage edge features at training time.

Reviews (1): Last reviewed commit: "feat(crash): add lazy Zarr loading for m..." | Re-trigger Greptile

Comment thread examples/structural_mechanics/crash/datapipe.py Outdated
Comment thread examples/structural_mechanics/crash/datapipe.py Outdated
Use uniform m/num_samples accumulation for lazy feature stats, compute
lazy edge stats from normalized positions to match eager graph construction,
and skip re-normalization when loading eager samples.

Signed-off-by: Thabhelo <50872400+Thabhelo@users.noreply.github.com>
@Thabhelo
Copy link
Copy Markdown
Author

Thabhelo commented Jun 6, 2026

Reviewed PR comments: triage:

  • Greptile P1 (lazy feature/edge running mean): Valid. Fixed in 0b5e96a — lazy stat helpers now use the same zeros + m/num_samples accumulation as the eager path (first sample was previously overweighted).
  • Greptile P1 (edge stats on raw vs normalized positions): Valid for graph + lazy path. Fixed in 0b5e96a_compute_edge_stats_lazy now materializes normalized positions before building temporary graphs, matching eager construction order.
  • Greptile P2 (eager double-normalize landmine): Valid defensive fix. Addressed in 0b5e96a — non-lazy _load_sample_tensors returns already-normalized tensors without re-applying normalization.

Thanks @greptile-apps for the catch.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 6, 2026

Thanks for the detailed triage and for addressing all three issues. The lazy stat accumulation, normalized-position edge stats, and eager double-normalize fixes in 0b5e96a all look correct.

One remaining issue worth flagging: the eager _compute_edge_stats still has the original first-sample overweighting bug that was fixed in the lazy path. The sentinel None pattern leaves m_0 with full weight while m_1…m_{N-1} are each divided by N:

# datapipe.py – _compute_edge_stats (eager)
edge_mean = m if edge_mean is None else edge_mean + m / self.num_samples
#           ↑ full weight for i=0; all others get 1/N weight

This should use the same zeros + m/N pattern as _compute_edge_stats_lazy:

def _compute_edge_stats(self):
    edge_mean = None
    edge_meansqr = None
    edge_dim = None
    for i in range(self.num_samples):
        x_e = self.graphs[i].edge_attr.to(torch.float32)  # [E,De]
        if edge_dim is None:
            edge_dim = x_e.shape[1]
            edge_mean = torch.zeros(edge_dim, dtype=torch.float32)
            edge_meansqr = torch.zeros(edge_dim, dtype=torch.float32)
        m = torch.mean(x_e, dim=0)
        msq = torch.mean(x_e * x_e, dim=0)
        edge_mean += m / self.num_samples
        edge_meansqr += msq / self.num_samples

The same sentinel pattern also appears in _compute_feature_stats, but that one is fine — feat_mean is initialized as zeros(fdim) before the loop and accumulates uniformly. Only _compute_edge_stats has the stale None-sentinel issue.

Tip: You can customize Greptile's behavior for this repo with .greptile/rules.md and .greptile/config.json.

Apply the same zeros + m/num_samples pattern to _compute_edge_stats so
the first graph no longer receives full weight in the running mean.

Signed-off-by: Thabhelo <50872400+Thabhelo@users.noreply.github.com>
@Thabhelo
Copy link
Copy Markdown
Author

Thabhelo commented Jun 6, 2026

Good catch, fixed in 349312a. _compute_edge_stats now uses the same zeros + m/num_samples accumulation as _compute_edge_stats_lazy.

@Thabhelo
Copy link
Copy Markdown
Author

Thabhelo commented Jun 6, 2026

Thanks for the detailed triage and for addressing all three issues. The lazy stat accumulation, normalized-position edge stats, and eager double-normalize fixes in 0b5e96a all look correct.

One remaining issue worth flagging: the eager _compute_edge_stats still has the original first-sample overweighting bug that was fixed in the lazy path. The sentinel None pattern leaves m_0 with full weight while m_1…m_{N-1} are each divided by N:

# datapipe.py – _compute_edge_stats (eager)
edge_mean = m if edge_mean is None else edge_mean + m / self.num_samples
#           ↑ full weight for i=0; all others get 1/N weight

This should use the same zeros + m/N pattern as _compute_edge_stats_lazy:

def _compute_edge_stats(self):
    edge_mean = None
    edge_meansqr = None
    edge_dim = None
    for i in range(self.num_samples):
        x_e = self.graphs[i].edge_attr.to(torch.float32)  # [E,De]
        if edge_dim is None:
            edge_dim = x_e.shape[1]
            edge_mean = torch.zeros(edge_dim, dtype=torch.float32)
            edge_meansqr = torch.zeros(edge_dim, dtype=torch.float32)
        m = torch.mean(x_e, dim=0)
        msq = torch.mean(x_e * x_e, dim=0)
        edge_mean += m / self.num_samples
        edge_meansqr += msq / self.num_samples

The same sentinel pattern also appears in _compute_feature_stats, but that one is fine — feat_mean is initialized as zeros(fdim) before the loop and accumulates uniformly. Only _compute_edge_stats has the stale None-sentinel issue.

Tip: You can customize Greptile's behavior for this repo with .greptile/rules.md and .greptile/config.json.

This is a pre-existing bug (not introduced by lazy loading), but you are right @greptile-apps that it’s still in the file we’re editing.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 6, 2026

Looking at the current PR branch, _compute_edge_stats has already been fixed — it initializes to zeros on first sample, just like the lazy path:

def _compute_edge_stats(self):
    edge_mean = None
    edge_meansqr = None
    edge_dim = None
    for i in range(self.num_samples):
        x_e = self.graphs[i].edge_attr.to(torch.float32)
        if edge_dim is None:
            edge_dim = x_e.shape[1]
            edge_mean = torch.zeros(edge_dim, dtype=torch.float32)   # ✓ zeros, not m
            edge_meansqr = torch.zeros(edge_dim, dtype=torch.float32)
        m = torch.mean(x_e, dim=0)
        msq = torch.mean(x_e * x_e, dim=0)
        edge_mean += m / self.num_samples                             # ✓ uniform weight
        edge_meansqr += msq / self.num_samples

The None-sentinel pattern (edge_mean = m if edge_mean is None else edge_mean + m / self.num_samples) is gone — this was corrected in the same commit you referenced. The eager and lazy paths are now consistent. No further changes needed here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🚀[FEA]: Clarify DDP dataset memory behavior and recommend lazy loading for large Zarr datasets

1 participant