Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions gsp_rl/src/actors/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,39 @@ def choose_action(self, observation, networks, test=False):
else:
raise Exception('[ERROR]: Learning scheme not recognised for action selection ' + networks['learning_scheme'])

def choose_actions_batch(self, observations, networks, test=False):
"""Batched action selection for multiple observations in one forward pass.

Only supports stateless action networks (DQN, DDQN, DDPG, TD3).
Does NOT support RDDPG or attention — those have state/memory concerns.

Args:
observations: list of observation arrays, one per agent.
networks: network dict (self.networks).
test: if True, greedy (no exploration noise/epsilon).

Returns:
list of actions, one per observation.
"""
if networks['learning_scheme'] in {'DQN', 'DDQN'}:
if test or np.random.random() > self.epsilon:
return self.DQN_DDQN_choose_action_batch(observations, networks)
else:
return [np.random.choice(self.action_space) for _ in observations]
elif networks['learning_scheme'] == 'DDPG':
actions = self.DDPG_choose_action_batch(observations, networks)
if not test:
actions = actions + T.normal(0.0, self.noise,
size=(len(observations), networks['output_size'])).to(networks['actor'].device)
actions = T.clamp(actions, -self.min_max_action, self.min_max_action)
return actions.cpu().detach().numpy()
elif networks['learning_scheme'] == 'TD3':
return self.TD3_choose_action_batch(observations, networks, self.output_size)
else:
raise NotImplementedError(
f"Batched action selection not supported for {networks['learning_scheme']}. "
f"Use choose_action() for RDDPG/attention networks.")

def learn(self):
# TODO Not sure why we have n_agents*batch_size + batch_size
if self.networks['replay'].mem_ctr < self.batch_size: # (self.n_agents*self.batch_size + self.batch_size):
Expand Down
30 changes: 30 additions & 0 deletions gsp_rl/src/actors/learning_aids.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ def DQN_DDQN_choose_action(self, observation, networks):
state = T.tensor(observation, dtype = T.float).to(networks['q_eval'].device)
action_values = networks['q_eval'].forward(state)
return T.argmax(action_values).item()

def DQN_DDQN_choose_action_batch(self, observations, networks):
"""Batched action selection for DQN/DDQN. Returns list of action indices."""
states = T.tensor(np.array(observations), dtype=T.float).to(networks['q_eval'].device)
action_values = networks['q_eval'].forward(states)
return T.argmax(action_values, dim=1).cpu().tolist()

def DDPG_choose_action(self, observation, networks):
if networks['learning_scheme'] == 'RDDPG':
Expand All @@ -172,6 +178,30 @@ def DDPG_choose_action(self, observation, networks):
return networks['actor'].forward(state).unsqueeze(0)


def DDPG_choose_action_batch(self, observations, networks):
"""Batched action selection for DDPG. Returns (batch, output_size) numpy array."""
if networks['learning_scheme'] == 'RDDPG':
# RDDPG uses sequences — cannot batch across robots (stateful LSTM)
raise NotImplementedError("RDDPG cannot be batched — use sequential choose_action")
states = T.tensor(np.array(observations), dtype=T.float).to(networks['actor'].device)
return networks['actor'].forward(states)

def TD3_choose_action_batch(self, observations, networks, n_actions):
"""Batched action selection for TD3. Returns list of (1, output_size) numpy arrays."""
if self.time_step < self.warmup:
batch_size = len(observations)
mus = T.tensor(np.random.normal(scale=self.noise, size=(batch_size, n_actions)),
dtype=T.float).to(networks['actor'].device)
else:
states = T.tensor(np.array(observations), dtype=T.float).to(networks['actor'].device)
mus = networks['actor'].forward(states).to(networks['actor'].device)
noise = T.tensor(np.random.normal(scale=self.noise, size=mus.shape),
dtype=T.float).to(networks['actor'].device)
mus_prime = T.clamp(mus + noise, -networks['actor'].min_max_action,
networks['actor'].min_max_action)
self.time_step += 1
return mus_prime.cpu().detach().numpy()

def TD3_choose_action(self, observation, networks, n_actions):
if self.time_step < self.warmup:
mu = T.tensor(np.random.normal(scale = self.noise,
Expand Down
125 changes: 125 additions & 0 deletions tests/test_actor/test_batch_choose_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Tests for batched action selection — verify it produces same results as sequential."""

import numpy as np
import torch as T
import pytest

from gsp_rl.src.actors.actor import Actor


@pytest.fixture
def config():
return {
"GAMMA": 0.99, "TAU": 0.005, "ALPHA": 0.001, "BETA": 0.001,
"LR": 0.001, "EPSILON": 0.0, "EPS_MIN": 0.0, "EPS_DEC": 0.0,
"BATCH_SIZE": 8, "MEM_SIZE": 100, "REPLACE_TARGET_COUNTER": 10,
"NOISE": 0.0, "UPDATE_ACTOR_ITER": 1, "WARMUP": 0,
"GSP_LEARNING_FREQUENCY": 100, "GSP_BATCH_SIZE": 8,
}


class TestDQNBatch:
def test_batch_matches_sequential(self, config):
"""Batched DQN should return identical actions to sequential calls."""
config["EPSILON"] = 0.0 # greedy — no randomness
actor = Actor(id=1, config=config, network="DQN",
input_size=8, output_size=4, min_max_action=1, meta_param_size=1)
observations = [np.random.randn(8).astype(np.float32) for _ in range(4)]

sequential = [actor.choose_action(obs, actor.networks, test=True) for obs in observations]
batched = actor.choose_actions_batch(observations, actor.networks, test=True)

assert sequential == batched

def test_batch_returns_correct_count(self, config):
actor = Actor(id=1, config=config, network="DQN",
input_size=8, output_size=4, min_max_action=1, meta_param_size=1)
observations = [np.random.randn(8).astype(np.float32) for _ in range(6)]
actions = actor.choose_actions_batch(observations, actor.networks, test=True)
assert len(actions) == 6

def test_batch_actions_in_range(self, config):
actor = Actor(id=1, config=config, network="DQN",
input_size=8, output_size=9, min_max_action=1, meta_param_size=1)
observations = [np.random.randn(8).astype(np.float32) for _ in range(4)]
actions = actor.choose_actions_batch(observations, actor.networks, test=True)
for a in actions:
assert 0 <= a < 9


class TestDDQNBatch:
def test_batch_matches_sequential(self, config):
config["EPSILON"] = 0.0
actor = Actor(id=1, config=config, network="DDQN",
input_size=8, output_size=4, min_max_action=1, meta_param_size=1)
observations = [np.random.randn(8).astype(np.float32) for _ in range(4)]

sequential = [actor.choose_action(obs, actor.networks, test=True) for obs in observations]
batched = actor.choose_actions_batch(observations, actor.networks, test=True)

assert sequential == batched


class TestDDPGBatch:
def test_batch_matches_sequential(self, config):
config["NOISE"] = 0.0
actor = Actor(id=1, config=config, network="DDPG",
input_size=8, output_size=2, min_max_action=1.0, meta_param_size=1)
observations = [np.random.randn(8).astype(np.float32) for _ in range(4)]

sequential = np.array([actor.choose_action(obs, actor.networks, test=True) for obs in observations])
batched = actor.choose_actions_batch(observations, actor.networks, test=True)

np.testing.assert_allclose(sequential, batched, atol=1e-5)

def test_batch_output_shape(self, config):
actor = Actor(id=1, config=config, network="DDPG",
input_size=8, output_size=2, min_max_action=1.0, meta_param_size=1)
observations = [np.random.randn(8).astype(np.float32) for _ in range(4)]
actions = actor.choose_actions_batch(observations, actor.networks, test=True)
assert actions.shape == (4, 2)


class TestTD3Batch:
def test_batch_matches_sequential_after_warmup(self, config):
config["NOISE"] = 0.0
config["WARMUP"] = 0
actor = Actor(id=1, config=config, network="TD3",
input_size=8, output_size=2, min_max_action=1.0, meta_param_size=1)
# Advance past warmup
actor.time_step = 100
observations = [np.random.randn(8).astype(np.float32) for _ in range(4)]

# Sequential calls each increment time_step, so save/restore
saved_ts = actor.time_step
sequential = []
for obs in observations:
actor.time_step = saved_ts # reset to same point
sequential.append(actor.choose_action(obs, actor.networks, test=True))
sequential = np.array(sequential)

actor.time_step = saved_ts
batched = actor.choose_actions_batch(observations, actor.networks, test=True)

np.testing.assert_allclose(sequential, batched, atol=1e-5)

def test_batch_output_shape(self, config):
config["WARMUP"] = 0
actor = Actor(id=1, config=config, network="TD3",
input_size=8, output_size=2, min_max_action=1.0, meta_param_size=1)
actor.time_step = 100
observations = [np.random.randn(8).astype(np.float32) for _ in range(4)]
actions = actor.choose_actions_batch(observations, actor.networks, test=True)
assert actions.shape == (4, 2)


class TestBatchNotSupported:
def test_rddpg_raises(self, config):
"""RDDPG should explicitly refuse batching."""
actor = Actor(id=1, config=config, network="DDPG",
input_size=8, output_size=2, min_max_action=1.0, meta_param_size=1,
gsp=True, recurrent_gsp=True, gsp_input_size=6, gsp_output_size=1)
observations = [np.random.randn(8).astype(np.float32) for _ in range(4)]
# The main networks are DDPG (batchable), but gsp_networks are RDDPG
with pytest.raises(NotImplementedError):
actor.choose_actions_batch(observations, actor.gsp_networks, test=True)
Loading