Skip to content

Commit

Permalink
Allow option to provide list of seeds in BatchEnv (#290)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcCote committed Jan 9, 2022
1 parent 753faea commit 5b6a9b3
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 6 deletions.
18 changes: 12 additions & 6 deletions textworld/envs/batch/batch_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,12 @@ def load(self, game_files: List[str]) -> None:
env.result()

def seed(self, seed=None):
# Use a different seed for each env to decorrelate batch examples.
rng = np.random.RandomState(seed)
seeds = list(rng.randint(65635, size=self.batch_size))
seeds = seed
if seeds is None or isinstance(seeds, int):
# Use a different seed for each env to decorrelate batch examples.
rng = np.random.RandomState(seeds)
seeds = list(rng.randint(65635, size=self.batch_size))

for env, seed in zip(self.envs, seeds):
env.call_sync("seed", seed)

Expand Down Expand Up @@ -209,9 +212,12 @@ def load(self, game_files: List[str]) -> None:
env.load(game_file)

def seed(self, seed=None):
# Use a different seed for each env to decorrelate batch examples.
rng = np.random.RandomState(seed)
seeds = list(rng.randint(65635, size=self.batch_size))
seeds = seed
if seeds is None or isinstance(seeds, int):
# Use a different seed for each env to decorrelate batch examples.
rng = np.random.RandomState(seeds)
seeds = list(rng.randint(65635, size=self.batch_size))

for env, seed in zip(self.envs, seeds):
env.seed(seed)

Expand Down
30 changes: 30 additions & 0 deletions textworld/envs/batch/tests/test_batch_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import textworld.gym
from textworld import EnvInfos
from textworld.utils import make_temp_directory
from textworld.envs import JerichoEnv
from textworld.envs.batch.batch_env import AsyncBatchEnv, SyncBatchEnv


def test_batch_env():
Expand Down Expand Up @@ -33,3 +35,31 @@ def test_batch_env():
# env.close()
del env
print("OKAY")


def test_seed():
batch_size = 4
env_options = EnvInfos(inventory=True, description=True, admissible_commands=True)
env_fns = [lambda: JerichoEnv(env_options) for _ in range(batch_size)]

env = SyncBatchEnv(env_fns)
seeds = env.seed(1234)
for seed, env_ in zip(seeds, env.envs):
assert seed == env_._seed

env.seed(range(batch_size))
for seed, env_ in zip(range(batch_size), env.envs):
assert seed == env_._seed

env.close()

env = AsyncBatchEnv(env_fns)
seeds = env.seed(1234)
for seed, env_ in zip(seeds, env.envs):
assert seed == env_.get_sync("_seed")

env.seed(range(batch_size))
for seed, env_ in zip(range(batch_size), env.envs):
assert seed == env_.get_sync("_seed")

env.close()

0 comments on commit 5b6a9b3

Please sign in to comment.