Skip to content

Add n_edges memory metric to auto-batching#481

Merged
CompRhys merged 4 commits intomainfrom
feat/n-edges-memory-metric
Mar 2, 2026
Merged

Add n_edges memory metric to auto-batching#481
CompRhys merged 4 commits intomainfrom
feat/n-edges-memory-metric

Conversation

@orionarcher
Copy link
Copy Markdown
Collaborator

@orionarcher orionarcher commented Mar 1, 2026

Summary

  • Adds "n_edges" as a third option for memory_scales_with in both BinningAutoBatcher and InFlightAutoBatcher
  • The metric calls torchsim_nl on the full batched state and counts edges per system via system_mapping.bincount() — giving the actual neighbor list size rather than a density-based approximation
  • A cutoff parameter (default 7.0 Å) is added to both batcher __init__ methods and threaded through to calculate_memory_scalers; it has no effect when using "n_atoms" or "n_atoms_x_density"
  • The n_edges logic is extracted into a private _n_edges_scalers helper to keep calculate_memory_scalers within ruff's complexity limit

Motivation

Neither n_atoms nor n_atoms_x_density correlates well with memory usage for molecular systems, where neighbor counts vary significantly across structures regardless of density. n_edges is the ground-truth metric: it reflects exactly what the model's neighbor list will contain. Identified in #412 and #416.

While perhaps n_edges should be the default in the future, for now I am just adding it as a third option.

Adds a third memory scaling metric, "n_edges", that computes the actual
neighbor list edge count via torchsim_nl. This is more accurate than
n_atoms or n_atoms_x_density for molecular systems where neighbor counts
vary significantly across structures.

A cutoff parameter (default 7.0 Å) is threaded through both
BinningAutoBatcher and InFlightAutoBatcher into calculate_memory_scalers.
- Move torchsim_nl import from inside _n_edges_scalers to module level
- Cast bincount result to float so _n_edges_scalers returns list[float]
  consistently with the declared return type
- Add tests for _n_edges_scalers covering periodic, non-periodic, and
  batched states
@orionarcher orionarcher marked this pull request as ready for review March 1, 2026 18:25
@orionarcher orionarcher requested a review from CompRhys March 1, 2026 18:27
Copy link
Copy Markdown
Member

@CompRhys CompRhys left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@CompRhys CompRhys merged commit 87ff47e into main Mar 2, 2026
68 checks passed
@CompRhys CompRhys deleted the feat/n-edges-memory-metric branch March 2, 2026 11:54
@janosh
Copy link
Copy Markdown
Collaborator

janosh commented Mar 2, 2026

nice work! glad to see this!

when doing relaxations, are we building neighbor lists twice for each system when setting memory_scales_with="n_edges"?

  1. once for batching metric
  2. again inside the model/inference path during relax steps (model-specific NL/graph build)

maybe worth thinking of a clean API that allows to reuse the external NL? also would be great to allow swapping torchsim_nl for alchemi NL when on GPU

@orionarcher
Copy link
Copy Markdown
Collaborator Author

orionarcher commented Mar 2, 2026

Yes, we are building the neighbor lists one extra time. For relaxation/integrations I don't see this being a problem but for static it's a legitimate overhead.

I agree it'd be a nice-to-have but it'd be a more significant amount of work than I wanted to include at first pass. Especially since all of the models currently handle the neighbor list differently.

@CompRhys
Copy link
Copy Markdown
Member

CompRhys commented Mar 2, 2026

would be great to allow swapping torchsim_nl for alchemi NL when on GPU

torchsim_nl uses alchemi first if available then vesin then that native version (torch_nl).

@thomasloux
Copy link
Copy Markdown
Collaborator

@janosh I thought the same when I saw it but then it's probably a headache at the moment because you need to make sure that the built neighbor list is provided at the right location in each MLIP graph builder.

By the way the main bottleneck with graph building is when your inference is sequential, like MD, while you can parallelize for trivial inference: you build on CPU in the dataloader while the GPU is running the inference.

@janosh
Copy link
Copy Markdown
Collaborator

janosh commented Mar 3, 2026

By the way the main bottleneck with graph building is when your inference is sequential, like MD, while you can parallelize for trivial inference: you build on CPU in the dataloader while the GPU is running the inference.

would this still be faster than doing everything with, say, alchemi on GPU? sounds like a benchmark worth running. is what you describe offloading NL building to CPU ahead of time easy to achieve with torch-sim? relaxation is also sequential so the main workflows that would benefit from this are elastic, frozen phonons, EOS?

A cutoff parameter (default 7.0 Å)

@orionarcher most models use 6 Å, maybe a safer default? by occurrence count 6 > 5 > 10 > 12 > 7 i think? Allegro is the only 7 Å model that comes to mind. even Allegro more often has 10 now AFAIK

@thomasloux
Copy link
Copy Markdown
Collaborator

thomasloux commented Mar 3, 2026

@janosh Good remark, it's not clear what is the best approach, a profiler is needed.
Multiple strategies:

  • Run NL on GPU
    pros: easy and no CPU<>GPU communication
    cons: technically with many systems you could have VRAM issue.
  • Run NL on many CPU (multithreading/process)
    pros: technically it's possible to parallelize GPU MLIP inference and CPU Neighbor list construction.
    cons: as you said it's probably not super easy to write.

Draft for the implementation:

  • Have a buffer of graph (state -> graph using MLIP graph builder). The size can probably be estimated by an heuristic on the number of atoms.
  • Extract the NL to evaluate number of edges
  • On the fly batching
  • Move to GPU and MLIP inference
  • In the mean time, have a dataloader fetching the needed number of systems to fill up the buffer
  • Continue until the dataset is empty
    Note: this would naturally extend to huge dataset that don't even fit in RAM. Technically that is something lacking in torch-sim: large scale inference from a dataset. Now you need to provide a state with all systems already loaded in RAM.

The idea is that you want to be GPU bounded, so that running neighbor list construction on GPU makes you loose time, so you try to run as much as possible while the GPU is busy.

The headache is that graph construction is in TorchSim Models.forward function, which makes sense for MD. If you want to apply the previous workflows, you need to remove it from forward. So then you need to make the difference between these two forward, one taking a ts.State and the other the normal input of the model.

This approach is valuable only when you know structures ahead of time yes, like the workflows you mentionned

@orionarcher
Copy link
Copy Markdown
Collaborator Author

pros: technically it's possible to parallelize GPU MLIP inference and CPU Neighbor list construction.

Wait how is this possible? Doesn't the MLIP inference require an already constructed neighbor list?

cons: technically with many systems you could have VRAM issue.

This worries me most. In practice, I'm not sure what the limit is but it's not ideal that practice have a maximum size.

@orionarcher most models use 6 Å, maybe a safer default? by occurrence count 6 > 5 > 10 > 12 > 7 i think? Allegro is the only 7 Å model that comes to mind. even Allegro more often has 10 now AFAIK

I can update it in the next PR I put it, or we can sneak it in elsewhere.

@thomasloux
Copy link
Copy Markdown
Collaborator

thomasloux commented Mar 3, 2026

You can construct the graphs of a set of systems while the GPU is running on another set of systems

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants