diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index f0da41c0..17d80567 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -361,12 +361,11 @@ def mock_measure(*_args: Any, **_kwargs: Any) -> float: ) # Test with a small max_atoms value to limit the sequence - max_size = determine_max_batch_size(si_sim_state, lj_model, max_atoms=10) - + max_size = determine_max_batch_size(si_sim_state, lj_model, max_atoms=16) # The Fibonacci sequence up to 10 is [1, 2, 3, 5, 8, 13] - # Since we're not triggering OOM errors with our mock, it should - # return the largest value < max_atoms - assert max_size == 8 + # Since we're not triggering OOM errors with our mock, it should return the + # largest value that fits within max_atoms (simstate has 8 atoms, so 2 batches) + assert max_size == 2 @pytest.mark.parametrize("scale_factor", [1.1, 1.4]) @@ -388,7 +387,9 @@ def test_determine_max_batch_size_small_scale_factor_no_infinite_loop( # Verify sequence is strictly increasing (prevents infinite loop) sizes = [1] - while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1)) < 20: + while ( + next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1) + ) * si_sim_state.n_atoms <= 20: sizes.append(next_size) assert all(sizes[idx] > sizes[idx - 1] for idx in range(1, len(sizes))) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 3a45b267..b0040f95 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -287,7 +287,9 @@ def determine_max_batch_size( """ # Create a geometric sequence of batch sizes sizes = [start_size] - while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1)) < max_atoms: + while ( + next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1) + ) * state.n_atoms <= max_atoms: sizes.append(next_size) for sys_idx in range(len(sizes)): diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 220fb12a..1353432b 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -449,6 +449,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 init_kwargs=dict(**init_kwargs or {}), max_memory_scaler=autobatcher.max_memory_scaler, memory_scales_with=autobatcher.memory_scales_with, + max_atoms_to_try=autobatcher.max_atoms_to_try, ) autobatcher.load_states(state) if trajectory_reporter is not None: