Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,6 @@ To understand how TorchSim works, start with the [comprehensive tutorials](https

TorchSim's structure is summarized in the [API reference](https://radical-ai.github.io/torch-sim/reference/index.html) documentation.

> `torch-sim` module graph. Each node represents a Python module. Arrows indicate imports between modules. Node color indicates connectedness: blue nodes have fewer dependents, red nodes have more (up to 16). The number in parentheses is the number of lines of code in the module.

## License

TorchSim is released under an [MIT license](LICENSE).
Expand Down
8 changes: 7 additions & 1 deletion docs/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
API reference
=============

Overview of the torch_sim API.
Overview of the TorchSim API.

.. currentmodule:: torch_sim

Expand All @@ -28,6 +28,12 @@ Overview of the torch_sim API.
transforms
units


TorchSim module graph. Each node represents a Python module. Arrows indicate
imports between modules. Node color indicates connectedness: blue nodes have fewer
dependents, red nodes have more (up to 16). The number in parentheses is the number of
lines of code in the module. Click on nodes to navigate to the file.

.. image:: /_static/torch-sim-module-graph.svg
:alt: torch-sim Module Graph
:width: 100%
Expand Down
10 changes: 5 additions & 5 deletions examples/scripts/4_High_level_api/4.2_auto_batching_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from mace.calculators.foundations_models import mace_mp

from torch_sim.autobatching import (
ChunkingAutoBatcher,
HotSwappingAutoBatcher,
BinningAutoBatcher,
InFlightAutoBatcher,
calculate_memory_scaler,
)
from torch_sim.integrators import nvt_langevin
Expand Down Expand Up @@ -65,7 +65,7 @@
# %% TODO: add max steps
converge_max_force = generate_force_convergence_fn(force_tol=1e-1)
single_system_memory = calculate_memory_scaler(fire_states[0])
batcher = HotSwappingAutoBatcher(
batcher = InFlightAutoBatcher(
model=mace_model,
memory_scales_with="n_atoms_x_density",
max_memory_scaler=single_system_memory * 2.5 if os.getenv("CI") else None,
Expand All @@ -86,7 +86,7 @@
print("Total number of completed states", len(all_completed_states))


# %% run chunking autobatcher
# %% run binning autobatcher
nvt_init, nvt_update = nvt_langevin(
model=mace_model, dt=0.001, kT=300 * MetalUnits.temperature
)
Expand All @@ -105,7 +105,7 @@


single_system_memory = calculate_memory_scaler(fire_states[0])
batcher = ChunkingAutoBatcher(
batcher = BinningAutoBatcher(
model=mace_model,
memory_scales_with="n_atoms_x_density",
max_memory_scaler=single_system_memory * 2.5 if os.getenv("CI") else None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
ts.io.atoms_to_state(atoms=ase_atoms_list, device=device, dtype=dtype)
)

batcher = ts.autobatching.HotSwappingAutoBatcher(
batcher = ts.autobatching.InFlightAutoBatcher(
model=mace_model,
memory_scales_with="n_atoms_x_density",
max_memory_scaler=1000 if os.getenv("CI") else None,
Expand Down
28 changes: 14 additions & 14 deletions examples/tutorials/autobatching_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
atoms exceeds available GPU memory. The `torch_sim.autobatching` module solves this by:

1. Automatically determining optimal batch sizes based on GPU memory constraints
2. Providing two complementary strategies: chunking and hot-swapping
2. Providing two complementary strategies: binning and in-flight
3. Efficiently managing memory resources during large-scale simulations

Let's explore how to use these powerful features!
Expand Down Expand Up @@ -120,9 +120,9 @@ def mock_determine_max_batch_size(*args, **kwargs):
This is a verbose way to determine the max memory metric, we'll see a simpler way
shortly.

## ChunkingAutoBatcher: Fixed Batching Strategy
## BinningAutoBatcher: Fixed Batching Strategy

Now on to the exciting part, autobatching! The `ChunkingAutoBatcher` groups states into
Now on to the exciting part, autobatching! The `BinningAutoBatcher` groups states into
batches with a binpacking algorithm, ensuring that we minimize the total number of
batches while maximizing the GPU utilization of each batch. This approach is ideal for
scenarios where all states need to be processed the same number of times, such as
Expand All @@ -132,7 +132,7 @@ def mock_determine_max_batch_size(*args, **kwargs):
"""

# %% Initialize the batcher, the max memory scaler will be computed automatically
batcher = ts.ChunkingAutoBatcher(
batcher = ts.BinningAutoBatcher(
model=mace_model,
memory_scales_with="n_atoms",
)
Expand Down Expand Up @@ -167,11 +167,11 @@ def process_batch(batch):
maximum safe batch size through test runs on your GPU. However, the max memory scaler
is typically fixed for a given model and simulation setup. To avoid calculating it
every time, which is a bit slow, you can calculate it once and then include it in the
`ChunkingAutoBatcher` constructor.
`BinningAutoBatcher` constructor.
"""

# %%
batcher = ts.ChunkingAutoBatcher(
batcher = ts.BinningAutoBatcher(
model=mace_model,
memory_scales_with="n_atoms",
max_memory_scaler=max_memory_scaler,
Expand All @@ -192,7 +192,7 @@ def process_batch(batch):
nvt_state = nvt_init(state)

# Initialize the batcher
batcher = ts.ChunkingAutoBatcher(
batcher = ts.BinningAutoBatcher(
model=mace_model,
memory_scales_with="n_atoms",
)
Expand All @@ -217,13 +217,13 @@ def process_batch(batch):

# %% [markdown]
"""
## HotSwappingAutoBatcher: Dynamic Batching Strategy
## InFlightAutoBatcher: Dynamic Batching Strategy

The `HotSwappingAutoBatcher` optimizes GPU utilization by dynamically removing
The `InFlightAutoBatcher` optimizes GPU utilization by dynamically removing
converged states and adding new ones. This is ideal for processes like geometry
optimization where different states may converge at different rates.

The `HotSwappingAutoBatcher` is more complex than the `ChunkingAutoBatcher` because
The `InFlightAutoBatcher` is more complex than the `BinningAutoBatcher` because
it requires the batch to be dynamically updated. The swapping logic is handled internally,
but the user must regularly provide a convergence tensor indicating which batches in
the state have converged.
Expand All @@ -236,7 +236,7 @@ def process_batch(batch):
fire_state = fire_init(state)

# Initialize the batcher
batcher = ts.HotSwappingAutoBatcher(
batcher = ts.InFlightAutoBatcher(
model=mace_model,
memory_scales_with="n_atoms",
max_memory_scaler=1000,
Expand Down Expand Up @@ -296,7 +296,7 @@ def process_batch(batch):
"""

# %% Initialize with return_indices=True
batcher = ts.ChunkingAutoBatcher(
batcher = ts.BinningAutoBatcher(
model=mace_model,
memory_scales_with="n_atoms",
max_memory_scaler=80,
Expand All @@ -317,8 +317,8 @@ def process_batch(batch):
TorchSim's autobatching provides powerful tools for GPU-efficient simulation of
multiple systems:

1. Use `ChunkingAutoBatcher` for simpler workflows with fixed iteration counts
2. Use `HotSwappingAutoBatcher` for optimization problems with varying convergence
1. Use `BinningAutoBatcher` for simpler workflows with fixed iteration counts
2. Use `InFlightAutoBatcher` for optimization problems with varying convergence
rates
3. Let the library handle memory management automatically, or specify limits manually

Expand Down
Loading