In [1]:
import gymnasium as gym
import ray
from ray.rllib.algorithms.a2c import A2CConfig
import numpy as np

import dm_memorytasks

2023-10-10 22:59:15,555	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2023-10-10 22:59:29,491	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [None]:
import gymnasium as gym
from ray.rllib.algorithms.ppo import PPOConfig


# Define your problem using python and Farama-Foundation's gymnasium API:
class ParrotEnv(gym.Env):
    """Environment in which an agent must learn to repeat the seen observations.

    Observations are float numbers indicating the to-be-repeated values,
    e.g. -1.0, 5.1, or 3.2.

    The action space is always the same as the observation space.

    Rewards are r=-abs(observation - action), for all steps.
    """

    def __init__(self, config):
        # Make the space (for actions and observations) configurable.
        self.action_space = config.get(
            "parrot_shriek_range", gym.spaces.Box(-1.0, 1.0, shape=(1, )))
        # Since actions should repeat observations, their spaces must be the
        # same.
        self.observation_space = self.action_space
        self.cur_obs = None
        self.episode_len = 0

    def reset(self, *, seed=None, options=None):
        """Resets the episode and returns the initial observation of the new one.
        """
        # Reset the episode len.
        self.episode_len = 0
        # Sample a random number from our observation space.
        self.cur_obs = self.observation_space.sample()
        # Return initial observation.
        return self.cur_obs, {}

    def step(self, action):
        """Takes a single step in the episode given `action`

        Returns:
            New observation, reward, done-flag, info-dict (empty).
        """
        # Set `truncated` flag after 10 steps.
        self.episode_len += 1
        terminated = False
        truncated = self.episode_len >= 10
        # r = -abs(obs - action)
        reward = -sum(abs(self.cur_obs - action))
        # Set a new observation (random sample).
        self.cur_obs = self.observation_space.sample()
        return self.cur_obs, reward, terminated, truncated, {}


# Create an RLlib Algorithm instance from a PPOConfig to learn how to
# act in the above environment.
config = (
    PPOConfig()
    .environment(
        # Env class to use (here: our gym.Env sub-class from above).
        env=ParrotEnv,
        # Config dict to be passed to our custom env's constructor.
        env_config={
            "parrot_shriek_range": gym.spaces.Box(-5.0, 5.0, (1, ))
        },
    )
    # Parallelize environment rollouts.
    .rollouts(num_rollout_workers=3)
)
# Use the config's `build()` method to construct a PPO object.
algo = config.build()

# Train for n iterations and report results (mean episode rewards).
# Since we have to guess 10 times and the optimal reward is 0.0
# (exact match between observation and action value),
# we can expect to reach an optimal episode reward of 0.0.
for i in range(5):
    results = algo.train()
    print(f"Iter: {i}; avg. reward={results['episode_reward_mean']}")

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2023-10-10 22:59:44,029	INFO worker.py:1642 -- Started a local Ray instance.


In [2]:
# wrapper for psych lab -> gymnasium env
class PsychLab(gym.Env):
    def __init__(self, env_config):

        # Initialize the PsychLab environment with the provided config
        env_settings = dm_memorytasks.EnvironmentSettings(seed=123, level_name='spot_diff_extrapolate')
        
        self.env = dm_memorytasks.load_from_docker(name='gcr.io/deepmind-environments/dm_memorytasks:v1.0.1', settings=env_settings)        
        self.action_spec = self.env.action_spec()
        observation_spec = self.env.observation_spec()

        self.action_space = gym.spaces.Dict({
            'STRAFE_LEFT_RIGHT': gym.spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float32),
            'MOVE_BACK_FORWARD': gym.spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float32),
            'LOOK_LEFT_RIGHT': gym.spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float32),
            'LOOK_DOWN_UP': gym.spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float32)
        })
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=observation_spec['RGB_INTERLEAVED'].shape, dtype=int)
        
    def reset(self, *, seed=None, options=None):
        # Reset the PsychLab environment and return the initial observation
        timestep = self.env.reset()
        return timestep.observation['RGB_INTERLEAVED'], {}
    
    def step(self, action):
        timestep = self.env.step(action)
        print('goblin')
        print(timestep)
        return timestep.observation['RGB_INTERLEAVED'], timestep.reward, False, timestep.last(),{}

In [None]:
# Create an RLlib Algorithm instance
config = A2CConfig()
# config = config.training(lr=0.0003, train_batch_size=52)  
config = config.resources(num_gpus=0)

# Build a Algorithm object from the config and run 1 training iteration.
algo = config.build(env=PsychLab)

# Train for n iterations and report results (mean episode rewards).
for i in range(5):
    results = algo.train()
    print(f"Iter: {i}; avg. reward={results['episode_reward_mean']}")

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2023-10-10 22:50:37,805	INFO worker.py:1642 -- Started a local Ray instance.
