diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 9cd55673..a691de36 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -952,6 +952,8 @@ def _get_first_batch(self) -> SimState: scale_factor=self.memory_scaling_factor, ) self.max_memory_scaler = self.max_memory_scaler * self.max_memory_padding + newer_states = self._get_next_states() + states = [*states, *newer_states] return concatenate_states([first_state, *states]) def next_batch( # noqa: C901