Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions tests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,11 +515,9 @@ def test_optimize_with_default_autobatcher(
"""Test optimize with autobatcher."""

def mock_estimate(*args, **kwargs) -> float: # noqa: ARG001
return 10_000.0
return 200

monkeypatch.setattr(
"torch_sim.autobatching.estimate_max_memory_scaler", mock_estimate
)
monkeypatch.setattr("torch_sim.autobatching.determine_max_batch_size", mock_estimate)

states = [ar_supercell_sim_state, fe_supercell_sim_state, ar_supercell_sim_state]
triple_state = initialize_state(
Expand Down
10 changes: 4 additions & 6 deletions torch_sim/autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,7 @@ def load_states(

self.first_batch_returned = False
self._first_batch = self._get_first_batch()
return self.max_memory_scaler

def _get_next_states(self) -> list[SimState]:
"""Add states from the iterator until max_memory_scaler is reached.
Expand Down Expand Up @@ -928,19 +929,17 @@ def _get_first_batch(self) -> SimState:
self.current_idx += [0]
self.swap_attempts.append(0) # Initialize attempt counter for first state
self.iterator_idx += 1
# self.total_metric += first_metric

# if max_metric is not set, estimate it
has_max_metric = bool(self.max_memory_scaler)
if not has_max_metric:
self.max_memory_scaler = estimate_max_memory_scaler(
n_batches = determine_max_batch_size(
first_state,
self.model,
[first_state],
[first_metric],
max_atoms=self.max_atoms_to_try,
scale_factor=self.memory_scaling_factor,
)
self.max_memory_scaler = self.max_memory_scaler * 0.8
self.max_memory_scaler = n_batches * first_metric * 0.8

states = self._get_next_states()

Expand All @@ -953,7 +952,6 @@ def _get_first_batch(self) -> SimState:
scale_factor=self.memory_scaling_factor,
)
self.max_memory_scaler = self.max_memory_scaler * self.max_memory_padding
print(f"Max metric calculated: {self.max_memory_scaler}")
return concatenate_states([first_state, *states])

def next_batch(
Expand Down
9 changes: 7 additions & 2 deletions torch_sim/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _configure_batches_iterator(
elif autobatcher is False:
batches = [(state, [])]
else:
raise ValueError(
raise TypeError(
f"Invalid autobatcher type: {type(autobatcher).__name__}, "
"must be bool or BinningAutoBatcher."
)
Expand Down Expand Up @@ -206,7 +206,7 @@ def _configure_in_flight_autobatcher(
if isinstance(autobatcher, InFlightAutoBatcher):
autobatcher.return_indices = True
autobatcher.max_attempts = max_attempts
else:
elif isinstance(autobatcher, bool):
if autobatcher:
memory_scales_with = model.memory_scales_with
max_memory_scaler = None
Expand All @@ -221,6 +221,11 @@ def _configure_in_flight_autobatcher(
max_iterations=max_attempts,
max_memory_padding=0.9,
)
else:
raise TypeError(
f"Invalid autobatcher type: {type(autobatcher).__name__}, "
"must be bool or InFlightAutoBatcher."
)
return autobatcher


Expand Down
Loading