From 544615e9c38221d799569a81666952f23e6285ae Mon Sep 17 00:00:00 2001 From: orionarcher Date: Fri, 18 Apr 2025 08:20:42 -0700 Subject: [PATCH 1/5] throw error if autobatcher type is wrong --- torch_sim/runners.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 148cafa1..be43147d 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -89,7 +89,7 @@ def _configure_batches_iterator( elif autobatcher is False: batches = [(state, [])] else: - raise ValueError( + raise TypeError( f"Invalid autobatcher type: {type(autobatcher).__name__}, " "must be bool or BinningAutoBatcher." ) @@ -206,7 +206,7 @@ def _configure_in_flight_autobatcher( if isinstance(autobatcher, InFlightAutoBatcher): autobatcher.return_indices = True autobatcher.max_attempts = max_attempts - else: + elif isinstance(autobatcher, bool): if autobatcher: memory_scales_with = model.memory_scales_with max_memory_scaler = None @@ -221,6 +221,11 @@ def _configure_in_flight_autobatcher( max_iterations=max_attempts, max_memory_padding=0.9, ) + else: + raise TypeError( + f"Invalid autobatcher type: {type(autobatcher).__name__}, " + "must be bool or InFlightAutoBatcher." + ) return autobatcher From 3443ccaa0adcec3c4cc79b8fdd7b72f6b1036fc7 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Fri, 18 Apr 2025 08:30:27 -0700 Subject: [PATCH 2/5] make InFlight.load_stes return max memory scaler --- torch_sim/autobatching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index e5e75c3d..2ac4c526 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -853,6 +853,7 @@ def load_states( self.first_batch_returned = False self._first_batch = self._get_first_batch() + return self.max_memory_scaler def _get_next_states(self) -> list[SimState]: """Add states from the iterator until max_memory_scaler is reached. @@ -953,7 +954,6 @@ def _get_first_batch(self) -> SimState: scale_factor=self.memory_scaling_factor, ) self.max_memory_scaler = self.max_memory_scaler * self.max_memory_padding - print(f"Max metric calculated: {self.max_memory_scaler}") return concatenate_states([first_state, *states]) def next_batch( From 89962224b8a7429404ea9513730b36a7b1a147d8 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Fri, 18 Apr 2025 08:35:02 -0700 Subject: [PATCH 3/5] refactor _get_first_batch to save memory estimation evaluation --- torch_sim/autobatching.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 2ac4c526..991230a4 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -929,19 +929,17 @@ def _get_first_batch(self) -> SimState: self.current_idx += [0] self.swap_attempts.append(0) # Initialize attempt counter for first state self.iterator_idx += 1 - # self.total_metric += first_metric # if max_metric is not set, estimate it has_max_metric = bool(self.max_memory_scaler) if not has_max_metric: - self.max_memory_scaler = estimate_max_memory_scaler( + n_batches = determine_max_batch_size( + first_state, self.model, - [first_state], - [first_metric], max_atoms=self.max_atoms_to_try, scale_factor=self.memory_scaling_factor, ) - self.max_memory_scaler = self.max_memory_scaler * 0.8 + self.max_memory_scaler = n_batches * first_metric * 0.8 states = self._get_next_states() From 30de238366fea877d56122c4ef74ccc85df553a6 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Fri, 18 Apr 2025 08:41:14 -0700 Subject: [PATCH 4/5] fix testing --- tests/test_runners.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_runners.py b/tests/test_runners.py index 2c5e7d4f..59564351 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -515,10 +515,10 @@ def test_optimize_with_default_autobatcher( """Test optimize with autobatcher.""" def mock_estimate(*args, **kwargs) -> float: # noqa: ARG001 - return 10_000.0 + return 200 monkeypatch.setattr( - "torch_sim.autobatching.estimate_max_memory_scaler", mock_estimate + "torch_sim.autobatching.determine_max_batch_size", mock_estimate ) states = [ar_supercell_sim_state, fe_supercell_sim_state, ar_supercell_sim_state] From 26ac3adb07eb113d9d718ca4648abd0023291087 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Fri, 18 Apr 2025 08:49:40 -0700 Subject: [PATCH 5/5] lint --- tests/test_runners.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_runners.py b/tests/test_runners.py index 59564351..0b440c88 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -517,9 +517,7 @@ def test_optimize_with_default_autobatcher( def mock_estimate(*args, **kwargs) -> float: # noqa: ARG001 return 200 - monkeypatch.setattr( - "torch_sim.autobatching.determine_max_batch_size", mock_estimate - ) + monkeypatch.setattr("torch_sim.autobatching.determine_max_batch_size", mock_estimate) states = [ar_supercell_sim_state, fe_supercell_sim_state, ar_supercell_sim_state] triple_state = initialize_state(