diff --git a/gsp_rl/src/actors/actor.py b/gsp_rl/src/actors/actor.py index 4506446..947e5d4 100644 --- a/gsp_rl/src/actors/actor.py +++ b/gsp_rl/src/actors/actor.py @@ -368,6 +368,39 @@ def choose_action(self, observation, networks, test=False): else: raise Exception('[ERROR]: Learning scheme not recognised for action selection ' + networks['learning_scheme']) + def choose_actions_batch(self, observations, networks, test=False): + """Batched action selection for multiple observations in one forward pass. + + Only supports stateless action networks (DQN, DDQN, DDPG, TD3). + Does NOT support RDDPG or attention — those have state/memory concerns. + + Args: + observations: list of observation arrays, one per agent. + networks: network dict (self.networks). + test: if True, greedy (no exploration noise/epsilon). + + Returns: + list of actions, one per observation. + """ + if networks['learning_scheme'] in {'DQN', 'DDQN'}: + if test or np.random.random() > self.epsilon: + return self.DQN_DDQN_choose_action_batch(observations, networks) + else: + return [np.random.choice(self.action_space) for _ in observations] + elif networks['learning_scheme'] == 'DDPG': + actions = self.DDPG_choose_action_batch(observations, networks) + if not test: + actions = actions + T.normal(0.0, self.noise, + size=(len(observations), networks['output_size'])).to(networks['actor'].device) + actions = T.clamp(actions, -self.min_max_action, self.min_max_action) + return actions.cpu().detach().numpy() + elif networks['learning_scheme'] == 'TD3': + return self.TD3_choose_action_batch(observations, networks, self.output_size) + else: + raise NotImplementedError( + f"Batched action selection not supported for {networks['learning_scheme']}. " + f"Use choose_action() for RDDPG/attention networks.") + def learn(self): # TODO Not sure why we have n_agents*batch_size + batch_size if self.networks['replay'].mem_ctr < self.batch_size: # (self.n_agents*self.batch_size + self.batch_size): diff --git a/gsp_rl/src/actors/learning_aids.py b/gsp_rl/src/actors/learning_aids.py index cfee864..aad24e2 100644 --- a/gsp_rl/src/actors/learning_aids.py +++ b/gsp_rl/src/actors/learning_aids.py @@ -162,6 +162,12 @@ def DQN_DDQN_choose_action(self, observation, networks): state = T.tensor(observation, dtype = T.float).to(networks['q_eval'].device) action_values = networks['q_eval'].forward(state) return T.argmax(action_values).item() + + def DQN_DDQN_choose_action_batch(self, observations, networks): + """Batched action selection for DQN/DDQN. Returns list of action indices.""" + states = T.tensor(np.array(observations), dtype=T.float).to(networks['q_eval'].device) + action_values = networks['q_eval'].forward(states) + return T.argmax(action_values, dim=1).cpu().tolist() def DDPG_choose_action(self, observation, networks): if networks['learning_scheme'] == 'RDDPG': @@ -172,6 +178,30 @@ def DDPG_choose_action(self, observation, networks): return networks['actor'].forward(state).unsqueeze(0) + def DDPG_choose_action_batch(self, observations, networks): + """Batched action selection for DDPG. Returns (batch, output_size) numpy array.""" + if networks['learning_scheme'] == 'RDDPG': + # RDDPG uses sequences — cannot batch across robots (stateful LSTM) + raise NotImplementedError("RDDPG cannot be batched — use sequential choose_action") + states = T.tensor(np.array(observations), dtype=T.float).to(networks['actor'].device) + return networks['actor'].forward(states) + + def TD3_choose_action_batch(self, observations, networks, n_actions): + """Batched action selection for TD3. Returns list of (1, output_size) numpy arrays.""" + if self.time_step < self.warmup: + batch_size = len(observations) + mus = T.tensor(np.random.normal(scale=self.noise, size=(batch_size, n_actions)), + dtype=T.float).to(networks['actor'].device) + else: + states = T.tensor(np.array(observations), dtype=T.float).to(networks['actor'].device) + mus = networks['actor'].forward(states).to(networks['actor'].device) + noise = T.tensor(np.random.normal(scale=self.noise, size=mus.shape), + dtype=T.float).to(networks['actor'].device) + mus_prime = T.clamp(mus + noise, -networks['actor'].min_max_action, + networks['actor'].min_max_action) + self.time_step += 1 + return mus_prime.cpu().detach().numpy() + def TD3_choose_action(self, observation, networks, n_actions): if self.time_step < self.warmup: mu = T.tensor(np.random.normal(scale = self.noise, diff --git a/tests/test_actor/test_batch_choose_action.py b/tests/test_actor/test_batch_choose_action.py new file mode 100644 index 0000000..504f701 --- /dev/null +++ b/tests/test_actor/test_batch_choose_action.py @@ -0,0 +1,125 @@ +"""Tests for batched action selection — verify it produces same results as sequential.""" + +import numpy as np +import torch as T +import pytest + +from gsp_rl.src.actors.actor import Actor + + +@pytest.fixture +def config(): + return { + "GAMMA": 0.99, "TAU": 0.005, "ALPHA": 0.001, "BETA": 0.001, + "LR": 0.001, "EPSILON": 0.0, "EPS_MIN": 0.0, "EPS_DEC": 0.0, + "BATCH_SIZE": 8, "MEM_SIZE": 100, "REPLACE_TARGET_COUNTER": 10, + "NOISE": 0.0, "UPDATE_ACTOR_ITER": 1, "WARMUP": 0, + "GSP_LEARNING_FREQUENCY": 100, "GSP_BATCH_SIZE": 8, + } + + +class TestDQNBatch: + def test_batch_matches_sequential(self, config): + """Batched DQN should return identical actions to sequential calls.""" + config["EPSILON"] = 0.0 # greedy — no randomness + actor = Actor(id=1, config=config, network="DQN", + input_size=8, output_size=4, min_max_action=1, meta_param_size=1) + observations = [np.random.randn(8).astype(np.float32) for _ in range(4)] + + sequential = [actor.choose_action(obs, actor.networks, test=True) for obs in observations] + batched = actor.choose_actions_batch(observations, actor.networks, test=True) + + assert sequential == batched + + def test_batch_returns_correct_count(self, config): + actor = Actor(id=1, config=config, network="DQN", + input_size=8, output_size=4, min_max_action=1, meta_param_size=1) + observations = [np.random.randn(8).astype(np.float32) for _ in range(6)] + actions = actor.choose_actions_batch(observations, actor.networks, test=True) + assert len(actions) == 6 + + def test_batch_actions_in_range(self, config): + actor = Actor(id=1, config=config, network="DQN", + input_size=8, output_size=9, min_max_action=1, meta_param_size=1) + observations = [np.random.randn(8).astype(np.float32) for _ in range(4)] + actions = actor.choose_actions_batch(observations, actor.networks, test=True) + for a in actions: + assert 0 <= a < 9 + + +class TestDDQNBatch: + def test_batch_matches_sequential(self, config): + config["EPSILON"] = 0.0 + actor = Actor(id=1, config=config, network="DDQN", + input_size=8, output_size=4, min_max_action=1, meta_param_size=1) + observations = [np.random.randn(8).astype(np.float32) for _ in range(4)] + + sequential = [actor.choose_action(obs, actor.networks, test=True) for obs in observations] + batched = actor.choose_actions_batch(observations, actor.networks, test=True) + + assert sequential == batched + + +class TestDDPGBatch: + def test_batch_matches_sequential(self, config): + config["NOISE"] = 0.0 + actor = Actor(id=1, config=config, network="DDPG", + input_size=8, output_size=2, min_max_action=1.0, meta_param_size=1) + observations = [np.random.randn(8).astype(np.float32) for _ in range(4)] + + sequential = np.array([actor.choose_action(obs, actor.networks, test=True) for obs in observations]) + batched = actor.choose_actions_batch(observations, actor.networks, test=True) + + np.testing.assert_allclose(sequential, batched, atol=1e-5) + + def test_batch_output_shape(self, config): + actor = Actor(id=1, config=config, network="DDPG", + input_size=8, output_size=2, min_max_action=1.0, meta_param_size=1) + observations = [np.random.randn(8).astype(np.float32) for _ in range(4)] + actions = actor.choose_actions_batch(observations, actor.networks, test=True) + assert actions.shape == (4, 2) + + +class TestTD3Batch: + def test_batch_matches_sequential_after_warmup(self, config): + config["NOISE"] = 0.0 + config["WARMUP"] = 0 + actor = Actor(id=1, config=config, network="TD3", + input_size=8, output_size=2, min_max_action=1.0, meta_param_size=1) + # Advance past warmup + actor.time_step = 100 + observations = [np.random.randn(8).astype(np.float32) for _ in range(4)] + + # Sequential calls each increment time_step, so save/restore + saved_ts = actor.time_step + sequential = [] + for obs in observations: + actor.time_step = saved_ts # reset to same point + sequential.append(actor.choose_action(obs, actor.networks, test=True)) + sequential = np.array(sequential) + + actor.time_step = saved_ts + batched = actor.choose_actions_batch(observations, actor.networks, test=True) + + np.testing.assert_allclose(sequential, batched, atol=1e-5) + + def test_batch_output_shape(self, config): + config["WARMUP"] = 0 + actor = Actor(id=1, config=config, network="TD3", + input_size=8, output_size=2, min_max_action=1.0, meta_param_size=1) + actor.time_step = 100 + observations = [np.random.randn(8).astype(np.float32) for _ in range(4)] + actions = actor.choose_actions_batch(observations, actor.networks, test=True) + assert actions.shape == (4, 2) + + +class TestBatchNotSupported: + def test_rddpg_raises(self, config): + """RDDPG should explicitly refuse batching.""" + actor = Actor(id=1, config=config, network="DDPG", + input_size=8, output_size=2, min_max_action=1.0, meta_param_size=1, + gsp=True, recurrent_gsp=True, gsp_input_size=6, gsp_output_size=1) + observations = [np.random.randn(8).astype(np.float32) for _ in range(4)] + # The main networks are DDPG (batchable), but gsp_networks are RDDPG + with pytest.raises(NotImplementedError): + actor.choose_actions_batch(observations, actor.gsp_networks, test=True)