Skip to content

Commit

Permalink
Fixed vectorized core
Browse files Browse the repository at this point in the history
- Improved type conversions mechanism
- Added test for vectorized environments
- Fixed remaining issues in vectorized core logic
- Fixed one issue in ListDataset
  • Loading branch information
boris-il-forte committed Nov 24, 2023
1 parent 1257ad0 commit cb99e5c
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 17 deletions.
25 changes: 20 additions & 5 deletions mushroom_rl/core/_impl/type_conversions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy
import numpy as np
import torch

Expand Down Expand Up @@ -52,6 +51,10 @@ def zeros(*dims, dtype):
def ones(*dims, dtype):
raise NotImplementedError

@staticmethod
def copy(array):
raise NotImplementedError


class NumpyConversion(DataConversion):
@staticmethod
Expand All @@ -60,7 +63,7 @@ def to_numpy(array):

@staticmethod
def to_torch(array):
return torch.from_numpy(array).to(TorchUtils.get_device())
return None if array is None else torch.from_numpy(array).to(TorchUtils.get_device())

@staticmethod
def to_backend_array(cls, array):
Expand All @@ -74,11 +77,15 @@ def zeros(*dims, dtype=float):
def ones(*dims, dtype=float):
return np.ones(dims, dtype=dtype)

@staticmethod
def copy(array):
return array.copy()


class TorchConversion(DataConversion):
@staticmethod
def to_numpy(array):
return array.detach().cpu().numpy()
return None if array is None else array.detach().cpu().numpy()

@staticmethod
def to_torch(array):
Expand All @@ -96,15 +103,19 @@ def zeros(*dims, dtype=torch.float32):
def ones(*dims, dtype=torch.float32):
return torch.ones(*dims, dtype=dtype, device=TorchUtils.get_device())

@staticmethod
def copy(array):
return array.clone()


class ListConversion(DataConversion):
@staticmethod
def to_numpy(array):
return numpy.array(array)
return np.array(array)

@staticmethod
def to_torch(array):
return torch.as_tensor(array, device=TorchUtils.get_device())
return None if array is None else torch.as_tensor(array, device=TorchUtils.get_device())

@staticmethod
def to_backend_array(cls, array):
Expand All @@ -118,6 +129,10 @@ def zeros(*dims, dtype=float):
def ones(*dims, dtype=float):
return np.ones(dims, dtype=float)

@staticmethod
def copy(array):
return array.copy()




12 changes: 8 additions & 4 deletions mushroom_rl/core/_impl/vectorized_core_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_mask(self, last):
new_mask[max_runs:] = False
mask[last] = new_mask

self._running_envs = mask.copy()
self._running_envs = self._converter.copy(mask)

return mask

Expand All @@ -46,15 +46,15 @@ def get_initial_state(self, initial_states):
return initial_state

def after_step(self, last):
n_active_envs = self._running_envs.sum()
n_active_envs = self._running_envs.sum().item()
self._total_steps_counter += n_active_envs
self._current_steps_counter += n_active_envs
self._steps_progress_bar.update(n_active_envs)

completed = last.sum()
completed = last.sum().item()
self._total_episodes_counter += completed
self._current_episodes_counter += completed
self._episodes_progress_bar.update(last.sum())
self._episodes_progress_bar.update(completed)

def after_fit(self):
super().after_fit()
Expand All @@ -64,3 +64,7 @@ def after_fit(self):
def _reset_counters(self):
super()._reset_counters()
self._running_envs = self._converter.zeros(self._n_envs, dtype=bool)

@property
def converter(self):
return self._converter
2 changes: 1 addition & 1 deletion mushroom_rl/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, mdp_info, agent_info, n_steps=None, n_episodes=None):
self._data = TorchDataset(state_type, state_shape, action_type, action_shape, reward_shape,
policy_state_shape)
else:
self._data = ListDataset()
self._data = ListDataset(policy_state_shape is not None)

self._converter = DataConversion.get_converter(mdp_info.backend)

Expand Down
11 changes: 4 additions & 7 deletions mushroom_rl/core/vectorized_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,13 @@ def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, render=Fa
def _run(self, datasets, n_steps, n_episodes, render, quiet, record, initial_states=None):
self._core_logic.initialize_run(n_steps, n_episodes, initial_states, quiet)


converter = datasets[0].converter

last = converter.ones(self.env.number, dtype=bool)
last = self._core_logic.converter.ones(self.env.number, dtype=bool)
mask = None

while self._core_logic.move_required():
if last.any():
mask = self._core_logic.get_mask(last)
self._reset(converter, initial_states, last, mask)
self._reset(initial_states, last, mask)

samples, step_infos = self._step(render, record, mask)

Expand Down Expand Up @@ -183,7 +180,7 @@ def _step(self, render, record, mask):

return (state, action, rewards, next_state, absorbing, last, policy_state, policy_next_state), step_info

def _reset(self, converter, initial_states, last, mask):
def _reset(self, initial_states, last, mask):
"""
Reset the states of the agent.
Expand All @@ -197,7 +194,7 @@ def _reset(self, converter, initial_states, last, mask):
self.agent.next_action = None

if self._episode_steps is None:
self._episode_steps = converter.zeros(self.env.number, dtype=int)
self._episode_steps = self._core_logic.converter.zeros(self.env.number, dtype=int)
else:
self._episode_steps[last] = 0

Expand Down
96 changes: 96 additions & 0 deletions tests/environments/test_vectorized_envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import numpy as np
import torch

from mushroom_rl.core import Agent, VectorCore, VectorizedEnvironment, MDPInfo
from mushroom_rl.rl_utils import Box
from mushroom_rl.policy import Policy


class DummyPolicy(Policy):
def __init__(self, action_shape, backend):
self._dim = action_shape[0]
self._backend = backend
super().__init__()

def draw_action(self, state, policy_state):
if self._backend == 'torch':
return torch.randn(state.shape[0], self._dim), None
elif self._backend == 'numpy':
return np.random.randn(state.shape[0], self._dim), None
else:
raise NotImplementedError


class DummyAgent(Agent):
def __init__(self, mdp_info, backend):
policy = DummyPolicy(mdp_info.action_space.shape, backend)
super().__init__(mdp_info, policy, backend=backend)

def fit(self, dataset):

assert len(dataset.episodes_length) == 20


class DummyVecEnv(VectorizedEnvironment):
def __init__(self, backend):
n_envs = 10
state_dim = 3

horizon = 100
gamma = 0.99

observation_space = Box(0, 200, shape=(3,))
action_space = Box(0, 200, shape=(2,))

mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, backend=backend)

if backend == 'torch':
self._state = torch.empty(n_envs, state_dim)
elif backend == 'numpy':
self._state = np.empty((n_envs, state_dim))
else:
raise NotImplementedError

super().__init__(mdp_info, n_envs)

def reset_all(self, env_mask, state=None):
self._state[env_mask] = torch.randint(size=(env_mask.sum(), self._state.shape[1]), low=2, high=200).float()
return self._state, [{}]*self._n_envs

def step_all(self, env_mask, action):
self._state[env_mask] -= 1

if self.info.backend == 'torch':
reward = torch.zeros(self._state.shape[0])
elif self.info.backend == 'numpy':
reward = torch.zeros(self._state.shape[0])
else:
raise NotImplementedError

done = (self._state == 0).any(1)

return self._state, reward, done & env_mask, [{}] * self._n_envs


def run_exp(env_backend, agent_backend):
torch.random.manual_seed(42)

env = DummyVecEnv(env_backend)
agent = DummyAgent(env.info, agent_backend)

core = VectorCore(agent, env)

dataset = core.evaluate(n_steps=2000)
assert len(dataset) == 2000

dataset = core.evaluate(n_episodes=20)
assert len(dataset.episodes_length) == 20

core.learn(n_steps=10000, n_episodes_per_fit=20)


def test_vectorized_env_():
run_exp(env_backend='torch', agent_backend='torch')
run_exp(env_backend='torch', agent_backend='numpy')
run_exp(env_backend='numpy', agent_backend='torch')
run_exp(env_backend='numpy', agent_backend='numpy')

0 comments on commit cb99e5c

Please sign in to comment.