In [1]:
import gym
from collections import deque
from gym import spaces
import numpy as np

In [2]:
class ConcatObs(gym.Wrapper):
    def __init__(self, env, k):
        gym.Wrapper.__init__(self, env)
        self.k = k
        self.frames = deque([], maxlen = k)
        shp = env.observation_space.shape
        self.observation_space = \
            spaces.Box(low=0, high=255, shape=((k,)+shp), dtype=env.observation_space.dtype)
        
    def reset(self):
        ob = self.env.reset()
        for _ in range(self.k):
            self.frames.append(ob)
        return self._get_ob()
    
    def step(self, action):
        ob, reward, done, info = self.env.step(action)
        self.frames.append(ob)
        return self._get_ob(), reward, done, info
    
    def _get_ob(self):
        return np.array(self.frames)

In [3]:
env = gym.make("BreakoutNoFrameskip-v4")
wrapped_env = ConcatObs(env, 4)

print (f"The new observation space: {wrapped_env.observation_space}.")

The new observation space: Box([[[[0 0 0]
   [0 0 0]
   [0 0 0]
   ...
   [0 0 0]
   [0 0 0]
   [0 0 0]]

  [[0 0 0]
   [0 0 0]
   [0 0 0]
   ...
   [0 0 0]
   [0 0 0]
   [0 0 0]]

  [[0 0 0]
   [0 0 0]
   [0 0 0]
   ...
   [0 0 0]
   [0 0 0]
   [0 0 0]]

  ...

  [[0 0 0]
   [0 0 0]
   [0 0 0]
   ...
   [0 0 0]
   [0 0 0]
   [0 0 0]]

  [[0 0 0]
   [0 0 0]
   [0 0 0]
   ...
   [0 0 0]
   [0 0 0]
   [0 0 0]]

  [[0 0 0]
   [0 0 0]
   [0 0 0]
   ...
   [0 0 0]
   [0 0 0]
   [0 0 0]]]


 [[[0 0 0]
   [0 0 0]
   [0 0 0]
   ...
   [0 0 0]
   [0 0 0]
   [0 0 0]]

  [[0 0 0]
   [0 0 0]
   [0 0 0]
   ...
   [0 0 0]
   [0 0 0]
   [0 0 0]]

  [[0 0 0]
   [0 0 0]
   [0 0 0]
   ...
   [0 0 0]
   [0 0 0]
   [0 0 0]]

  ...

  [[0 0 0]
   [0 0 0]
   [0 0 0]
   ...
   [0 0 0]
   [0 0 0]
   [0 0 0]]

  [[0 0 0]
   [0 0 0]
   [0 0 0]
   ...
   [0 0 0]
   [0 0 0]
   [0 0 0]]

  [[0 0 0]
   [0 0 0]
   [0 0 0]
   ...
   [0 0 0]
   [0 0 0]
   [0 0 0]]]


 [[[0 0 0]
   [0 0 0]
   [0 0 0]
   ...
   [0 0 0]


In [4]:
obs = wrapped_env.reset()
print (f"Initial obs is shape: {obs.shape}")

obs, _, _, _ = wrapped_env.step(2)
print (f"Obs after step is shape: {obs.shape}")

Initial obs is shape: (4, 210, 160, 3)
Obs after step is shape: (4, 210, 160, 3)


In [5]:
import random
import time

In [6]:
class ObservationWrapper(gym.ObservationWrapper):
    def __init__ (self, env):
        super().__init__(env)
        
    def observation(self, obs):
        return obs / 255.0
    
class RewardWrapper(gym.RewardWrapper):
    def __init__ (self, env):
        super().__init__(env)
        
    def reward(self, reward):
        return np.clip(reward, 0, 1)
    
class ActionWrapper(gym.ActionWrapper):
    def __init__ (self, env):
        super().__init__(env)
        
    def action(self, action):
        if action == 3:
            return random.choice([0,1,2])
        else:
            return action

In [7]:
env = gym.make("BreakoutNoFrameskip-v4")
wrapped_env = ObservationWrapper(RewardWrapper(ActionWrapper(env)))

obs = wrapped_env.reset()

for step in range(500):
    action = wrapped_env.action_space.sample()
    obs, reward, done, info = wrapped_env.step(action)
    
    # Raise flag if values not vectorised correctly
    if (obs > 1.0).any() or (obs < 0.0).any():
        print ("Max and Min value of observations out of range.")
        
    # Raise flag if rewards not clipped
    if reward < 0.0 or reward > 1.0:
        assert False, "Reward out of bounds."
        
    # Check rendering if slider moves to the left
    wrapped_env.render()
    
    time.sleep(0.001)
    
wrapped_env.close()

print ("All checks passed.")



All checks passed.


In [8]:
print("Wrapped Env:", wrapped_env)
print("Unwrapped Env", wrapped_env.unwrapped)
print("Getting the meaning of actions", wrapped_env.unwrapped.get_action_meanings())

Wrapped Env: <ObservationWrapper<RewardWrapper<ActionWrapper<TimeLimit<AtariEnv<BreakoutNoFrameskip-v4>>>>>>
Unwrapped Env <AtariEnv<BreakoutNoFrameskip-v4>>
Getting the meaning of actions ['NOOP', 'FIRE', 'RIGHT', 'LEFT']
