<a href="https://colab.research.google.com/github/450586509/reinforcement-learning-practice/blob/master/07_A2C.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# # in google colab uncomment this

import os

os.system('apt-get install -y xvfb')
os.system('wget https://raw.githubusercontent.com/yandexdataschool/Practical_DL/fall18/xvfb -O ../xvfb')
os.system('apt-get install -y python-opengl ffmpeg')
os.system('pip install pyglet==1.2.4')

os.system('python -m pip install -U pygame --user')

print('setup complete')

# XVFB will be launched if you run on a server
import os
if type(os.environ.get("DISPLAY")) is not str or len(os.environ.get("DISPLAY")) == 0:
    !bash ../xvfb start
    os.environ['DISPLAY'] = ':1'

setup complete
Starting virtual X frame buffer: Xvfb.


### 实现A2C(Adavantage-Actor Critic)算法

利用多个Atari 2600环境并行训练agent

atari_wrappers.py 提供了observations的预处理操作,包括：resize,grayscal,take max between frames, skip frames and stack frames。

env_batch.py 提供了并行运行多个环境的功能。



### 并行运行多个环境

In [0]:
# pylint: skip-file
from multiprocessing import Process, Pipe

from gym import Env, Wrapper, Space
import numpy as np


class SpaceBatch(Space):
    def __init__(self, spaces):
        first_type = type(spaces[0])
        first_shape = spaces[0].shape
        first_dtype = spaces[0].dtype
        for space in spaces:
            if not isinstance(space, first_type):
                raise TypeError("spaces have different types: {}, {}"
                                .format(first_type, type(space)))
            if first_shape != space.shape:
                raise ValueError("spaces have different shapes: {}, {}"
                                 .format(first_shape, space.shape))
            if first_dtype != space.dtype:
                raise ValueError("spaces have different data types: {}, {}"
                                 .format(first_dtype, space.dtype))

        self.spaces = spaces
        super(SpaceBatch, self).__init__(shape=self.spaces[0].shape,
                                         dtype=self.spaces[0].dtype)

    def sample(self):
        return np.stack([space.sample() for space in self.spaces])

    def __getattr__(self, attr):
        return getattr(self.spaces[0], attr)


class EnvBatch(Env):
    def __init__(self, make_env, nenvs=None):
        make_env_functions = self._get_make_env_functions(make_env, nenvs)
        self._envs = [make_env() for make_env in make_env_functions]
        self._nenvs = len(self.envs)
        # self.observation_space = SpaceBatch([env.observation_space
        #                                      for env in self._envs])
        self.action_space = SpaceBatch([env.action_space
                                        for env in self._envs])

    def _get_make_env_functions(self, make_env, nenvs):
        if nenvs is None and not isinstance(make_env, list):
            raise ValueError("When nenvs is None make_env"
                             " must be a list of callables")
        if nenvs is not None and not callable(make_env):
            raise ValueError(
                "When nenvs is not None make_env must be callable")

        if nenvs is not None:
            make_env = [make_env for _ in range(nenvs)]
        return make_env

    @property
    def nenvs(self):
        return self._nenvs

    @property
    def envs(self):
        return self._envs

    def _check_actions(self, actions):
        if not len(actions) == self.nenvs:
            raise ValueError(
                "number of actions is not equal to number of envs: "
                "len(actions) = {}, nenvs = {}"
                .format(len(actions), self.nenvs))

    def step(self, actions):
        self._check_actions(actions)
        obs, rews, resets, infos = [], [], [], []
        for env, action in zip(self._envs, actions):
            ob, rew, done, info = env.step(action)
            if done:
                ob = env.reset()
            obs.append(ob)
            rews.append(rew)
            resets.append(done)
            infos.append(info)
        return np.stack(obs), np.stack(rews), np.stack(resets), infos

    def reset(self):
        return np.stack([env.reset() for env in self.envs])


class SingleEnvBatch(Wrapper, EnvBatch):
    def __init__(self, env):
        super(SingleEnvBatch, self).__init__(env)
        self.observation_space = SpaceBatch([self.env.observation_space])
        self.action_space = SpaceBatch([self.env.action_space])

    @property
    def nenvs(self):
        return 1

    @property
    def envs(self):
        return [self.env]

    def step(self, actions):
        self._check_actions(actions)
        ob, rew, done, info = self.env.step(actions[0])
        if done:
            ob = self.env.reset()
        return (
            ob[None],
            np.expand_dims(rew, 0),
            np.expand_dims(done, 0),
            [info],
        )

    def reset(self):
        return self.env.reset()[None]


def worker(parent_connection, worker_connection, make_env_function,
           send_spaces=True):
    # Adapted from SubprocVecEnv github.com/openai/baselines
    parent_connection.close()
    env = make_env_function()
    if send_spaces:
        worker_connection.send((env.observation_space, env.action_space))
    while True:
        cmd, action = worker_connection.recv()
        if cmd == "step":
            ob, rew, done, info = env.step(action)
            if done:
                ob = env.reset()
            worker_connection.send((ob, rew, done, info))
        elif cmd == "reset":
            ob = env.reset()
            worker_connection.send(ob)
        elif cmd == "close":
            env.close()
            worker_connection.close()
            break
        else:
            raise NotImplementedError("Unknown command %s" % cmd)


class ParallelEnvBatch(EnvBatch):
    """
    An abstract batch of environments.
    """

    def __init__(self, make_env, nenvs=None):
        make_env_functions = self._get_make_env_functions(make_env, nenvs)
        self._nenvs = len(make_env_functions)
        self._parent_connections, self._worker_connections = zip(*[
            Pipe() for _ in range(self._nenvs)
        ])
        self._processes = [
            Process(
                target=worker,
                args=(parent_connection, worker_connection, make_env),
                daemon=True
            )
            for i, (parent_connection, worker_connection, make_env)
            in enumerate(zip(self._parent_connections,
                             self._worker_connections,
                             make_env_functions))
        ]
        for p in self._processes:
            p.start()
        self._closed = False

        for conn in self._worker_connections:
            conn.close()

        observation_spaces, action_spaces = [], []
        for conn in self._parent_connections:
            ob_space, ac_space = conn.recv()
            observation_spaces.append(ob_space)
            action_spaces.append(ac_space)
        self.observation_space = SpaceBatch(observation_spaces)
        self.action_space = SpaceBatch(action_spaces)

    @property
    def nenvs(self):
        return self._nenvs

    def step(self, actions):
        self._check_actions(actions)
        for conn, a in zip(self._parent_connections, actions):
            conn.send(("step", a))
        results = [conn.recv() for conn in self._parent_connections]
        obs, rews, dones, infos = zip(*results)
        return np.stack(obs), np.stack(rews), np.stack(dones), infos

    def reset(self):
        for conn in self._parent_connections:
            conn.send(("reset", None))
        return np.stack([conn.recv() for conn in self._parent_connections])

    def close(self):
        if self._closed:
            return
        for conn in self._parent_connections:
            conn.send(("close", None))
        for p in self._processes:
            p.join()
        self._closed = True

    def render(self):
        raise ValueError("render not defined for %s" % self)

### 预处理observations

In [0]:
""" Environment wrappers. """
from collections import deque

import cv2
import gym
import gym.spaces as spaces
from gym.envs import atari
import numpy as np
import tensorflow as tf


cv2.ocl.setUseOpenCL(False)


class EpisodicLife(gym.Wrapper):
    """ Sets done flag to true when agent dies. """

    def __init__(self, env):
        super(EpisodicLife, self).__init__(env)
        self.lives = 0
        self.real_done = True

    def step(self, action):
        obs, rew, done, info = self.env.step(action)
        self.real_done = done
        info["real_done"] = done
        lives = self.env.unwrapped.ale.lives()
        if 0 < lives < self.lives:
            done = True
        self.lives = lives
        return obs, rew, done, info

    def reset(self, **kwargs):
        if self.real_done:
            obs = self.env.reset(**kwargs)
        else:
            obs, _, _, _ = self.env.step(0)
        self.lives = self.env.unwrapped.ale.lives()
        return obs


class FireReset(gym.Wrapper):
    """ Makes fire action when reseting environment.
    Some environments are fixed until the agent makes the fire action,
    this wrapper makes this action so that the epsiode starts automatically.
    """

    def __init__(self, env):
        super(FireReset, self).__init__(env)
        action_meanings = env.unwrapped.get_action_meanings()
        if len(action_meanings) < 3:
            raise ValueError(
                "env.unwrapped.get_action_meanings() must be of length >= 3"
                f"but is of length {len(action_meanings)}")
        if env.unwrapped.get_action_meanings()[1] != "FIRE":
            raise ValueError(
                "env.unwrapped.get_action_meanings() must have 'FIRE' "
                f"under index 1, but is {action_meanings}")

    def step(self, action):
        return self.env.step(action)

    def reset(self, **kwargs):
        self.env.reset(**kwargs)
        obs, _, done, _ = self.env.step(1)
        if done:
            self.env.reset(**kwargs)
        obs, _, done, _ = self.env.step(2)
        if done:
            self.env.reset(**kwargs)
        return obs


class StartWithRandomActions(gym.Wrapper):
    """ Makes random number of random actions at the beginning of each
    episode. """

    def __init__(self, env, max_random_actions=30):
        super(StartWithRandomActions, self).__init__(env)
        self.max_random_actions = max_random_actions
        self.real_done = True

    def step(self, action):
        obs, rew, done, info = self.env.step(action)
        self.real_done = info.get("real_done", True)
        return obs, rew, done, info

    def reset(self, **kwargs):
        obs = self.env.reset()
        if self.real_done:
            num_random_actions = np.random.randint(self.max_random_actions + 1)
            for _ in range(num_random_actions):
                obs, _, _, _ = self.env.step(self.env.action_space.sample())
            self.real_done = False
        return obs


class ImagePreprocessing(gym.ObservationWrapper):
    """ Preprocesses image-observations by possibly grayscaling and resizing. """

    def __init__(self, env, width=84, height=84, grayscale=True):
        super(ImagePreprocessing, self).__init__(env)
        self.width = width
        self.height = height
        self.grayscale = grayscale
        ospace = self.env.observation_space
        low, high, dtype = ospace.low.min(), ospace.high.max(), ospace.dtype
        if self.grayscale:
            self.observation_space = spaces.Box(
                low=low,
                high=high,
                shape=(width, height),
                dtype=dtype,
            )
        else:
            obs_shape = (width, height) + self.observation_space.shape[2:]
            self.observation_space = spaces.Box(low=low, high=high,
                                                shape=obs_shape, dtype=dtype)

    def observation(self, observation):
        """ Performs image preprocessing. """
        if self.grayscale:
            observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
        observation = cv2.resize(observation, (self.width, self.height),
                                 cv2.INTER_AREA)
        return observation


class MaxBetweenFrames(gym.ObservationWrapper):
    """ Takes maximum between two subsequent frames. """

    def __init__(self, env):
        if (isinstance(env.unwrapped, atari.AtariEnv) and
                "NoFrameskip" not in env.spec.id):
            raise ValueError(
                "MaxBetweenFrames requires NoFrameskip in atari env id")
        super(MaxBetweenFrames, self).__init__(env)
        self.last_obs = None

    def observation(self, observation):
        obs = np.maximum(observation, self.last_obs)
        self.last_obs = observation
        return obs

    def reset(self, **kwargs):
        self.last_obs = self.env.reset()
        return self.last_obs


class QueueFrames(gym.ObservationWrapper):
    """ Queues specified number of frames together along new dimension. """

    def __init__(self, env, nframes, concat=False):
        super(QueueFrames, self).__init__(env)
        self.obs_queue = deque([], maxlen=nframes)
        self.concat = concat
        ospace = self.observation_space
        if self.concat:
            oshape = ospace.shape[:-1] + (ospace.shape[-1] * nframes,)
        else:
            oshape = ospace.shape + (nframes,)
        self.observation_space = spaces.Box(
            ospace.low.min(), ospace.high.max(), oshape, ospace.dtype)

    def observation(self, observation):
        self.obs_queue.append(observation)
        return (np.concatenate(self.obs_queue, -1) if self.concat
                else np.dstack(self.obs_queue))

    def reset(self, **kwargs):
        obs = self.env.reset()
        for _ in range(self.obs_queue.maxlen - 1):
            self.obs_queue.append(obs)
        return self.observation(obs)


class SkipFrames(gym.Wrapper):
    """ Performs the same action for several steps and returns the final result.
    """

    def __init__(self, env, nskip=4):
        super(SkipFrames, self).__init__(env)
        if (isinstance(env.unwrapped, atari.AtariEnv) and
                "NoFrameskip" not in env.spec.id):
            raise ValueError("SkipFrames requires NoFrameskip in atari env id")
        self.nskip = nskip

    def step(self, action):
        total_reward = 0.0
        for _ in range(self.nskip):
            obs, rew, done, info = self.env.step(action)
            total_reward += rew
            if done:
                break
        return obs, total_reward, done, info

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)


class ClipReward(gym.RewardWrapper):
    """ Modifes reward to be in {-1, 0, 1} by taking sign of it. """

    def reward(self, reward):
        return np.sign(reward)


class TFSummaries(gym.Wrapper):
    """ Writes env summaries."""

    def __init__(self, env, prefix=None, running_mean_size=100, step_var=None):
        super(TFSummaries, self).__init__(env)
        self.episode_counter = 0
        self.prefix = prefix or self.env.spec.id
        self.step_var = (step_var if step_var is not None
                         else tf.train.get_global_step())

        nenvs = getattr(self.env.unwrapped, "nenvs", 1)
        self.rewards = np.zeros(nenvs)
        self.had_ended_episodes = np.zeros(nenvs, dtype=np.bool)
        self.episode_lengths = np.zeros(nenvs)
        self.reward_queues = [deque([], maxlen=running_mean_size)
                              for _ in range(nenvs)]

    def should_write_summaries(self):
        """ Returns true if it's time to write summaries. """
        return np.all(self.had_ended_episodes)

    def add_summaries(self):
        """ Writes summaries. """
        tf.contrib.summary.scalar(
            f"{self.prefix}/total_reward",
            tf.reduce_mean([q[-1] for q in self.reward_queues]),
            step=self.step_var)
        tf.contrib.summary.scalar(
            f"{self.prefix}/reward_mean_{self.reward_queues[0].maxlen}",
            tf.reduce_mean([np.mean(q) for q in self.reward_queues]),
            step=self.step_var)
        tf.contrib.summary.scalar(
            f"{self.prefix}/episode_length",
            tf.reduce_mean(self.episode_lengths),
            step=self.step_var)
        if self.had_ended_episodes.size > 1:
            tf.contrib.summary.scalar(
                f"{self.prefix}/min_reward",
                min(q[-1] for q in self.reward_queues),
                step=self.step_var)
            tf.contrib.summary.scalar(
                f"{self.prefix}/max_reward",
                max(q[-1] for q in self.reward_queues),
                step=self.step_var)
        self.episode_lengths.fill(0)
        self.had_ended_episodes.fill(False)

    def step(self, action):
        obs, rew, done, info = self.env.step(action)
        self.rewards += rew
        self.episode_lengths[~self.had_ended_episodes] += 1

        info_collection = [info] if isinstance(info, dict) else info
        done_collection = [done] if isinstance(done, bool) else done
        done_indices = [i for i, info in enumerate(info_collection)
                        if info.get("real_done", done_collection[i])]
        for i in done_indices:
            if not self.had_ended_episodes[i]:
                self.had_ended_episodes[i] = True
            self.reward_queues[i].append(self.rewards[i])
            self.rewards[i] = 0

        if self.should_write_summaries():
            self.add_summaries()
        return obs, rew, done, info

    def reset(self, **kwargs):
        self.rewards.fill(0)
        self.episode_lengths.fill(0)
        self.had_ended_episodes.fill(False)
        return self.env.reset(**kwargs)


def nature_dqn_env(env_id, nenvs=None, seed=None,
                   summaries=True, clip_reward=True):
    """ Wraps env as in Nature DQN paper. """
    if "NoFrameskip" not in env_id:
        raise ValueError(f"env_id must have 'NoFrameskip' but is {env_id}")
    if nenvs is not None:
        if seed is None:
            seed = list(range(nenvs))
        if isinstance(seed, int):
            seed = [seed] * nenvs
        if len(seed) != nenvs:
            raise ValueError(f"seed has length {len(seed)} but must have "
                             f"length equal to nenvs which is {nenvs}")

        env = ParallelEnvBatch([
            lambda i=i, env_seed=env_seed: nature_dqn_env(
                env_id, seed=env_seed, summaries=False, clip_reward=False)
            for i, env_seed in enumerate(seed)
        ])
        if summaries:
            env = TFSummaries(env, prefix=env_id)
        if clip_reward:
            env = ClipReward(env)
        return env

    env = gym.make(env_id)
    env.seed(seed)
    if summaries:
        env = TFSummaries(env)
    env = EpisodicLife(env)
    if "FIRE" in env.unwrapped.get_action_meanings():
        env = FireReset(env)
    env = StartWithRandomActions(env, max_random_actions=30)
    env = MaxBetweenFrames(env)
    env = SkipFrames(env, 4)
    env = ImagePreprocessing(env, width=84, height=84, grayscale=True)
    env = QueueFrames(env, 4)
    if clip_reward:
        env = ClipReward(env)
    return env

In [0]:
import numpy as np
#from atari_wrappers import nature_dqn_env


env = nature_dqn_env("SpaceInvadersNoFrameskip-v4", nenvs=8)
obs = env.reset()
assert obs.shape == (8, 84, 84, 4)
assert obs.dtype == np.uint8