Skip to content

Add batched action selection for stateless networks#19

Merged
jdbloom merged 1 commit intomainfrom
feat/batch-action-forward
Apr 8, 2026
Merged

Add batched action selection for stateless networks#19
jdbloom merged 1 commit intomainfrom
feat/batch-action-forward

Conversation

@jdbloom
Copy link
Copy Markdown
Contributor

@jdbloom jdbloom commented Apr 8, 2026

Summary

Add choose_actions_batch() method that processes multiple observations in a single forward pass. Reduces action selection from N sequential network calls to 1 batched call.

Supported: DQN, DDQN, DDPG, TD3 (stateless — no memory between calls)
Not supported: RDDPG (LSTM state concerns — explicitly raises NotImplementedError)

Motivation

Profiling showed choose_action accounts for 19.8% of Python time in RL-CT (called 4× per step, once per robot). Batching reduces this to ~1 call.

Why not batch R-GSP-N/A-GSP-N: The LSTM hidden state bug (discards h_t/c_t) means these are accidentally stateless today, but we should not optimize around a bug. When the LSTM is fixed to maintain state, batching would change behavior. See LSTM hidden state bug in TODO.

Tests

9 new tests verify batched output matches sequential for all 4 algorithms:

  • DQN: exact match
  • DDQN: exact match
  • DDPG: allclose (float tolerance)
  • TD3: allclose after warmup
  • RDDPG: correctly raises NotImplementedError

205/205 tests pass.

Downstream implications

  • RL-CT Agent can now call choose_actions_batch([obs_0, obs_1, obs_2, obs_3]) instead of 4× choose_agent_action(obs_i)
  • Scaling: Batching becomes more impactful with larger swarms (8, 16, 32 robots)
  • Distributed execution: When agents move to separate processes (TODO item N), each process calls choose_action individually — batching is a single-process optimization that won't apply in the fully distributed case, but helps in the base/threaded fidelity layers

🤖 Generated with Claude Code

New methods:
- DQN_DDQN_choose_action_batch: single forward pass for N observations
- DDPG_choose_action_batch: batched continuous actions
- TD3_choose_action_batch: batched with noise
- Actor.choose_actions_batch: dispatch with epsilon/noise handling

Only supports stateless networks (DQN, DDQN, DDPG, TD3).
RDDPG explicitly raises NotImplementedError — has LSTM state concerns.

9 tests verify batched output matches sequential for all algorithms.

Profiling showed choose_action is 19.8% of Python time (4 sequential calls).
Batching reduces this to ~1 call, expected ~60% reduction in that category.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@jdbloom jdbloom merged commit 92ced9c into main Apr 8, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant