Add n_edges memory metric to auto-batching#481
Conversation
Adds a third memory scaling metric, "n_edges", that computes the actual neighbor list edge count via torchsim_nl. This is more accurate than n_atoms or n_atoms_x_density for molecular systems where neighbor counts vary significantly across structures. A cutoff parameter (default 7.0 Å) is threaded through both BinningAutoBatcher and InFlightAutoBatcher into calculate_memory_scalers.
- Move torchsim_nl import from inside _n_edges_scalers to module level - Cast bincount result to float so _n_edges_scalers returns list[float] consistently with the declared return type - Add tests for _n_edges_scalers covering periodic, non-periodic, and batched states
|
nice work! glad to see this! when doing relaxations, are we building neighbor lists twice for each system when setting
maybe worth thinking of a clean API that allows to reuse the external NL? also would be great to allow swapping |
|
Yes, we are building the neighbor lists one extra time. For relaxation/integrations I don't see this being a problem but for I agree it'd be a nice-to-have but it'd be a more significant amount of work than I wanted to include at first pass. Especially since all of the models currently handle the neighbor list differently. |
|
|
@janosh I thought the same when I saw it but then it's probably a headache at the moment because you need to make sure that the built neighbor list is provided at the right location in each MLIP graph builder. By the way the main bottleneck with graph building is when your inference is sequential, like MD, while you can parallelize for trivial inference: you build on CPU in the dataloader while the GPU is running the inference. |
would this still be faster than doing everything with, say, alchemi on GPU? sounds like a benchmark worth running. is what you describe offloading NL building to CPU ahead of time easy to achieve with
@orionarcher most models use 6 Å, maybe a safer default? by occurrence count 6 > 5 > 10 > 12 > 7 i think? Allegro is the only 7 Å model that comes to mind. even Allegro more often has 10 now AFAIK |
|
@janosh Good remark, it's not clear what is the best approach, a profiler is needed.
Draft for the implementation:
The idea is that you want to be GPU bounded, so that running neighbor list construction on GPU makes you loose time, so you try to run as much as possible while the GPU is busy. The headache is that graph construction is in TorchSim Models.forward function, which makes sense for MD. If you want to apply the previous workflows, you need to remove it from forward. So then you need to make the difference between these two forward, one taking a ts.State and the other the normal input of the model. This approach is valuable only when you know structures ahead of time yes, like the workflows you mentionned |
Wait how is this possible? Doesn't the MLIP inference require an already constructed neighbor list?
This worries me most. In practice, I'm not sure what the limit is but it's not ideal that practice have a maximum size.
I can update it in the next PR I put it, or we can sneak it in elsewhere. |
|
You can construct the graphs of a set of systems while the GPU is running on another set of systems |
Summary
"n_edges"as a third option formemory_scales_within bothBinningAutoBatcherandInFlightAutoBatchertorchsim_nlon the full batched state and counts edges per system viasystem_mapping.bincount()— giving the actual neighbor list size rather than a density-based approximationcutoffparameter (default7.0Å) is added to both batcher__init__methods and threaded through tocalculate_memory_scalers; it has no effect when using"n_atoms"or"n_atoms_x_density"n_edgeslogic is extracted into a private_n_edges_scalershelper to keepcalculate_memory_scalerswithin ruff's complexity limitMotivation
Neither
n_atomsnorn_atoms_x_densitycorrelates well with memory usage for molecular systems, where neighbor counts vary significantly across structures regardless of density.n_edgesis the ground-truth metric: it reflects exactly what the model's neighbor list will contain. Identified in #412 and #416.While perhaps
n_edgesshould be the default in the future, for now I am just adding it as a third option.