Skip to content

Commit

Permalink
task labels from envs objects; pep8 fix
Browse files Browse the repository at this point in the history
  • Loading branch information
NickLucche committed Apr 21, 2022
1 parent 848c779 commit 4380ebc
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 23 deletions.
66 changes: 44 additions & 22 deletions avalanche/benchmarks/scenarios/rl_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
# Website: avalanche.continualai.org #
################################################################################
"""Reinforcement Learning scenario definitions."""
from avalanche.benchmarks.scenarios import CLExperience, ExperienceAttribute, CLScenario, EagerCLStream
from avalanche.benchmarks.scenarios import (
CLExperience, ExperienceAttribute,
CLScenario, EagerCLStream)
from typing import Callable, List, Union, Dict
import numpy as np
import torch
Expand All @@ -30,11 +32,13 @@ class RLExperience(CLExperience):
"""Experience for Continual Reinforcement Learning purposes.
The experience provides access to a `gym.Env` environment.
Such environment can also be created lazily by providing a function.
Such environment can also be created lazily by
providing a function.
"""

def __init__(self, env: Union[Env, Callable[[], Env]], n_envs: int = 1, task_label: int = None, current_experience: int = None, origin_stream=None):
# current experience and origin stream are set when iterating a CLStream by default
def __init__(self, env: Union[Env, Callable[[], Env]], n_envs: int = 1,
task_label: int = None, current_experience: int = None,
origin_stream=None):
super().__init__(current_experience, origin_stream)
self.env = env
self.n_envs = n_envs
Expand All @@ -44,7 +48,7 @@ def __init__(self, env: Union[Env, Callable[[], Env]], n_envs: int = 1, task_lab

@property
def environment(self) -> Env:
# supports dynamic creation of environment, useful to instantiate envs for evaluation
# support dynamic/lazy environment creation
if not isinstance(self.env, Env):
return self.env()
return self.env
Expand Down Expand Up @@ -72,34 +76,52 @@ def __init__(self, envs: List[Env],
"""Init.
Args:
:param envs: list of gym environments to be used for training the agent.
Each environment will be wrapped within a RLExperience.
:param envs: list of gym environments to be used for training the
agent.Each environment will be wrapped within a RLExperience.
:param n_parallel_envs: number of parallel agent-environment
interactions to run for each experience. If an int is provided, the same
degree of parallelism will be used for every environment.
interactions to run for each experience. If an int is provided,
the same degree of parallelism will be used for every env.
:param eval_envs: list of gym environments
to be used for evaluating the agent. Each environment will be wrapped
within a RLExperience.
:param wrappers_generators: list of `gym.Wrapper` generator functions
which are applied to some environment, each identified by its id.
It represents behavior added as post-processing steps (e.g. reward scaling).
:param task_labels: whether to add task labels to RLExperience. A task label
is assigned to each different environment, in the order they're provided in `envs`.
:param shuffle: whether to randomly shuffle `envs`. Defaults to False.
to be used for evaluating the agent. Each environment will
be wrapped within a RLExperience.
:param wrappers_generators: dict mapping env ids to a list of
`gym.Wrapper` generator. Wrappers represent behavior
added as post-processing steps (e.g. reward scaling).
:param task_labels: whether to add task labels to RLExperience.
A task label is assigned to each different environment,
in the order they're provided in `envs`.
:param shuffle: whether to randomly shuffle `envs`.
Defaults to False.
"""

n_experiences = len(envs)
if type(n_parallel_envs) is int:
n_parallel_envs = [n_parallel_envs] * n_experiences
assert len(n_parallel_envs) == len(envs)
# this is so that we can infer the task labels
assert all([isinstance(e, Env) for e in envs]), "Lazy instantation of\
training environments is not supported"
assert all([n > 0 for n in n_parallel_envs]
), "Number of parallel environments must be a positive integer"
), "Number of parallel environments\
must be a positive integer"
tr_envs = envs
eval_envs = eval_envs or []
self._num_original_envs = len(tr_envs)
self.n_envs = n_parallel_envs
# this shouldn't contain duplicate envs, but it's difficult to ensure if scenario isn't created through a benchmark generator
tr_task_labels = list(range(len(envs)))
# this can contain shallow copies of envs to have multiple
# experiences from the same task
tr_task_labels = []
env_occ = {}
j = 0
# assign task label by checking whether the same instance of env is
# provided multiple times (shallow copy only)
for e in envs:
if e in env_occ:
tr_task_labels.append(env_occ[e])
else:
tr_task_labels.append(j)
env_occ[e] = j
j += 1

# eval_task_labels = list(range(len(eval_envs)))
self._wrappers_generators = wrappers_generators
Expand All @@ -113,8 +135,8 @@ def __init__(self, envs: List[Env],
tr_task_labels = tr_task_labels if task_labels else [
None] * len(tr_envs)

tr_exps = [RLExperience(
tr_envs[i], n_parallel_envs[i], tr_task_labels[i]) for i in range(len(tr_envs))]
tr_exps = [RLExperience(tr_envs[i], n_parallel_envs[i],
tr_task_labels[i]) for i in range(len(tr_envs))]
tstream = EagerCLStream("train", tr_exps)
# we're only supporting single process envs in evaluation atm
eval_exps = [RLExperience(e, 1) for e in eval_envs]
Expand Down
25 changes: 24 additions & 1 deletion tests/benchmarks/scenarios/test_rl_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
@unittest.skipIf(skip, reason="Need gym to run these tests")
def test_simple_scenario():
n_envs = 3
envs = [gym.make('CartPole-v1') for _ in range(n_envs)]
envs = [gym.make('CartPole-v1')]*n_envs
rl_scenario = RLScenario(envs, n_parallel_envs=1,
task_labels=True, eval_envs=[])
tr_stream = rl_scenario.train_stream
Expand All @@ -21,6 +21,29 @@ def test_simple_scenario():
for i, exp in enumerate(tr_stream):
assert exp.current_experience == i
env = exp.environment
# same envs
assert exp.task_label == 0
assert isinstance(env, gym.Env)
obs = env.reset()
assert isinstance(obs, np.ndarray)


@unittest.skipIf(skip, reason="Need gym to run these tests")
def test_multiple_envs():
envs = [gym.make('CartPole-v0'), gym.make('CartPole-v1'),
gym.make('Acrobot-v1')]
rl_scenario = RLScenario(envs, n_parallel_envs=1,
task_labels=True, eval_envs=[])
tr_stream = rl_scenario.train_stream
assert len(tr_stream) == 3

for i, exp in enumerate(tr_stream):
assert exp.current_experience == i == exp.task_label

# deep copies of the same env are considered as different tasks
envs = [gym.make('CartPole-v1') for _ in range(3)]
rl_scenario = RLScenario(envs, n_parallel_envs=1,
task_labels=True, eval_envs=[])
for i, exp in enumerate(rl_scenario.train_stream):
assert exp.task_label == i

0 comments on commit 4380ebc

Please sign in to comment.