<a href="https://colab.research.google.com/github/klinime/SAC/blob/master/sac_bipedalwalker.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install box2d-py > /dev/null 2>&1
!pip install gym pyvirtualdisplay > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1

In [0]:
%tensorflow_version 1.x

import gym
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_probability as tfp
import time
from datetime import datetime, timezone, timedelta
import pickle
from pyvirtualdisplay import Display

import logging
logging.getLogger('pyvirtualdisplay').setLevel(level=logging.ERROR)
from tensorflow.python.util import deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False

from google.colab import drive
drive.mount('/content/gdrive')
PATH = '/content/gdrive/My Drive/Colab Notebooks/sac_checkpoints/'

In [0]:
class ReplayPool:
    def __init__(self, max_size, fields):
        max_size = int(max_size)
        self._max_size = max_size

        self.fields = {}
        self.field_names = []
        self.add_fields(fields)

        self._pointer = 0
        self._size = 0

    @property
    def size(self):
        return self._size

    def add_fields(self, fields):
        self.fields.update(fields)
        self.field_names += list(fields.keys())

        for field_name, field_attrs in fields.items():
            field_shape = [self._max_size] + list(field_attrs['shape'])
            initializer = field_attrs.get('initializer', np.zeros)
            setattr(self, field_name, initializer(field_shape))

    def _advance(self, count=1):
        self._pointer = (self._pointer + count) % self._max_size
        self._size = min(self._size + count, self._max_size)

    def add_sample(self, **kwargs):
        self.add_samples(1, **kwargs)

    def add_samples(self, num_samples=1, **kwargs):
        for field_name in self.field_names:
            idx = np.arange(self._pointer,
                            self._pointer + num_samples) % self._max_size
            getattr(self, field_name)[idx] = kwargs.pop(field_name)

        self._advance(num_samples)

    def random_indices(self, batch_size):
        if self._size == 0: return []
        return np.random.randint(0, self._size, batch_size)

    def random_batch(self, batch_size, field_name_filter=None):
        random_indices = self.random_indices(batch_size)
        return self.batch_by_indices(random_indices, field_name_filter)

    def batch_by_indices(self, indices, field_name_filter=None):
        field_names = self.field_names
        if field_name_filter is not None:
            field_names = [
                field_name for field_name in field_names
                if field_name_filter(field_name)
            ]

        return {
            field_name: getattr(self, field_name)[indices]
            for field_name in field_names
        }
    
    def save(self):
        with open(PATH + 'replay_pool.pkl', 'wb') as file:
            pickle.dump(self, file, -1)
            print('ReplayPool saved.')
    
    def load(self):
        with open(PATH + 'replay_pool.pkl', 'rb') as file:
            self = pickle.load(file)
            print('ReplayPool loaded.')


class SimpleReplayPool(ReplayPool):
    def __init__(self, observation_shape, action_shape, *args, **kwargs):
        self._observation_shape = observation_shape
        self._action_shape = action_shape

        fields = {
            'observations': {
                'shape': self._observation_shape,
                'dtype': 'float32'
            },
            # It's a bit memory inefficient to save the observations twice,
            # but it makes the code *much* easier since you no longer have
            # to worry about termination conditions.
            'next_observations': {
                'shape': self._observation_shape,
                'dtype': 'float32'
            },
            'actions': {
                'shape': self._action_shape,
                'dtype': 'float32'
            },
            'rewards': {
                'shape': [],
                'dtype': 'float32'
            },
            # self.terminals[i] = a terminal was received at time i
            'terminals': {
                'shape': [],
                'dtype': 'bool'
            },
        }

        super(SimpleReplayPool, self).__init__(*args, fields=fields, **kwargs)


class Sampler():
    def __init__(self, max_episode_length, prefill_steps=10000):
        self._max_episode_length = max_episode_length
        self._prefill_steps = prefill_steps

        self.env = None
        self.policy = None
        self.pool = None

    def initialize(self, env, policy, pool):
        self.env = env
        self.policy = policy
        self.pool = pool

        class UniformPolicy:
            def __init__(self, action_dim):
                self._action_dim = action_dim

            def eval(self, _):
                return np.random.uniform(-1, 1, self._action_dim)

        uniform_exploration_policy = UniformPolicy(env.action_space.shape[0])
        for _ in range(self._prefill_steps):
            self.sample(uniform_exploration_policy)

    def set_policy(self, policy):
        self.policy = policy

    def sample(self):
        raise NotImplementedError

    def random_batch(self, batch_size):
        return self.pool.random_batch(batch_size)

    def terminate(self):
        self.env.terminate()


class SimpleSampler(Sampler):
    def __init__(self, **kwargs):
        super(SimpleSampler, self).__init__(**kwargs)

        self._episode_length = 0
        self._episode_return = 0
        self._last_episode_return = 0
        self._max_episode_return = -np.inf
        self._n_episodes = 0
        self._current_observation = None
        self._total_samples = 0

    def sample(self, policy=None):
        policy = self.policy if policy is None else policy
        if self._current_observation is None:
            self._current_observation = self.env.reset()

        action = policy.eval(self._current_observation)
        next_observation, reward, terminal, info = self.env.step(action)
        self._episode_length += 1
        self._episode_return += reward
        self._total_samples += 1

        self.pool.add_sample(
            observations=self._current_observation,
            actions=action,
            rewards=reward,
            terminals=terminal,
            next_observations=next_observation)

        if terminal or self._episode_length >= self._max_episode_length:
            self._current_observation = self.env.reset()
            self._episode_length = 0
            self._max_episode_return = max(self._max_episode_return,
                                           self._episode_return)
            self._last_episode_return = self._episode_return

            self._episode_return = 0
            self._n_episodes += 1

        else:
            self._current_observation = next_observation

In [0]:
def build_mlp_model(name, input_size, output_size, n_layers, size, activation='relu', output_activation=None):
  model = keras.Sequential(name=name)
  model.add(keras.layers.Dense(size, input_shape=(input_size,), activation=activation))
  for _ in range(n_layers-1):
    model.add(keras.layers.Dense(size, activation=activation))
  model.add(keras.layers.Dense(output_size, activation=output_activation))
  return model

class GaussianPolicy(keras.Model):
  def __init__(self, name, input_size, output_size, n_layers, size, activation='relu', output_activation=None):
    super(GaussianPolicy, self).__init__(name=name)
    self._f = None
    self.mlp = build_mlp_model(name + '_mlp', input_size, output_size, n_layers, size, activation, output_activation)
      
  def call(self, obs_no):
    mean, log_std = tf.split(self.mlp(obs_no), num_or_size_splits=2, axis=1)
    log_std = tf.clip_by_value(log_std, -20., 2.)
    distribution = tfp.distributions.MultivariateNormalDiag(loc=mean, scale_diag=tf.exp(log_std))
    acs_na = distribution.sample()
    logp_n = distribution.log_prob(acs_na)
    logp_n -= self.squash_correction(acs_na)
    return tf.tanh(acs_na), logp_n
  
  def squash_correction(self, acs_na):
    return 2 * tf.reduce_sum(np.log(2) + acs_na - tf.nn.softplus(2 * acs_na), axis=1)
      
  def eval(self, observation):
    assert self.built and observation.ndim == 1
    if self._f is None:
      self._f = keras.backend.function(self.inputs, [self.outputs[0]])
    action, = self._f([observation[None]])
    return action.flatten()

In [0]:
class SAC_Trainer():
  def __init__(self, sess, params):
    self.sess = sess
    self.env = params['env']
    self.ob_dim = self.env.observation_space.shape[0]
    self.ac_dim = self.env.action_space.shape[0]

    self.alpha = params['alpha']
    self.gamma = params['gamma']
    self.tau = params['tau']
    self.start_iter = params['start_iter']
    self.n_iter = params['n_iter']
    self.batch_size = params['batch_size']
    self.ep_len = params['ep_len']

    self.build()
    self.log = None
    self.log_freq = params['log_freq']
    self.saver = tf.train.Saver(max_to_keep=self.n_iter//self.log_freq+1)
  
  def build(self):
    models = self.define_models()
    self.define_placeholders()
    losses = self.define_losses()
    optimizer = tf.train.AdamOptimizer(learning_rate=params['learning_rate'], name='optimizer')
    self.training_ops = [optimizer.minimize(loss=loss, var_list=model.trainable_variables) for loss, model in zip(losses, models)]
    self.target_update_ops = self.define_target_update()
  
  def define_models(self):
    self.policy = GaussianPolicy('policy', self.ob_dim, self.ac_dim * 2, params['n_layers'], params['size'])
    self.v_func = build_mlp_model('v_func', self.ob_dim, 1, params['n_layers'], params['size'])
    self.q_func1 = build_mlp_model('q_func1', self.ob_dim + self.ac_dim, 1, params['n_layers'], params['size'])
    self.q_func2 = build_mlp_model('q_func2', self.ob_dim + self.ac_dim, 1, params['n_layers'], params['size'])
    self.target_v_func = build_mlp_model('target_v_func', self.ob_dim, 1, params['n_layers'], params['size'])
    return self.policy, self.v_func, self.q_func1, self.q_func2, self.target_v_func
  
  def define_placeholders(self):
    self.obs_no_ph = tf.placeholder(tf.float32, shape=(None, self.ob_dim), name='obs_no')
    self.acs_na_ph = tf.placeholder(tf.float32, shape=(None, self.ac_dim), name='acs_na')
    self.rews_n_ph = tf.placeholder(tf.float32, shape=(None, ), name='rews_n')
    self.next_obs_no_ph = tf.placeholder(tf.float32, shape=(None, self.ob_dim), name='next_obs_no')
    self.terminals_n_ph = tf.placeholder(tf.float32, shape=(None, ), name='terminals_n')
  
  def define_losses(self):
    acs_na, logp_n = self.policy(self.obs_no_ph)
    q_input = tf.concat([self.obs_no_ph, self.acs_na_ph], axis=1)
    q1_n = tf.squeeze(self.q_func1(q_input), axis=1)
    q2_n = tf.squeeze(self.q_func2(q_input), axis=1)
    q_n = tf.minimum(q1_n, q2_n)
    v_n = tf.squeeze(self.v_func(self.obs_no_ph), axis=1)
    target_v_n = tf.stop_gradient(q_n - self.alpha * logp_n)
    target_q_n = tf.stop_gradient(self.rews_n_ph + \
        self.gamma * tf.squeeze(self.target_v_func(self.next_obs_no_ph), axis=1) * (1 - self.terminals_n_ph))

    self.policy_loss = tf.reduce_mean(self.alpha * logp_n - q_n)
    self.v_func_loss = tf.losses.mean_squared_error(target_v_n, v_n)
    self.q_func1_loss = tf.losses.mean_squared_error(target_q_n, q1_n)
    self.q_func2_loss = tf.losses.mean_squared_error(target_q_n, q2_n)
    return self.policy_loss, self.v_func_loss, self.q_func1_loss, self.q_func2_loss
  
  def define_target_update(self):
    return [tf.assign(target, (1 - self.tau) * target + self.tau * source)
            for target, source in zip(self.target_v_func.trainable_variables, self.v_func.trainable_variables)]
  
  def eval_reward(self, sampler):
    return np.sum(sampler.random_batch(self.batch_size)['rewards'])
  
  def save(self, sampler, i=None):
    print('\nSaving sesson...')
    filename = 'iter_{:04d}_'.format(i) if i is not None else ''
    pst_tz = timezone(timedelta(hours=-8), name='PST')
    filename = filename + datetime.now(tz=pst_tz).strftime('%Y_%m_%d_%H_%M_%S') + '.ckpt'
    self.saver.save(self.sess, PATH + filename)
    np.save(PATH + 'log.npy', self.log)
    print('Session saved.')
    sampler.pool.save()
    self.log_video(i)
  
  def load(self, filename):
    self.saver.restore(self.sess, PATH + filename)
    self.log = np.load(PATH + 'log.npy')
    print('Session restored.')
  
  def log_video(self, i):
    print('\nLogging video...')
    display = Display(visible=0, size=(1400, 900))
    display.start()
    rewards = []
    env = gym.wrappers.Monitor(self.env, PATH + 'video/{:04d}'.format(i), video_callable=lambda episode_id: True, force=True)
    ob = env.reset()
    for t in range(self.ep_len):
      env.render()
      ac = self.policy.eval(ob)
      ob, rew, done, _ = env.step(ac)
      rewards.append(rew)
      if done:
        print('Episode finished in {} steps.'.format(t+1))
        break
    print('Total reward: {:.3f}'.format(np.sum(rewards)))
    env.close()
    print('Logging complete.')
  
  def run_training_loop(self, sampler):
    start_time = time.time()
    for itr in range(self.start_iter, self.start_iter + self.n_iter):
      if itr % self.log_freq == 0:
        print('\n====================Iter {}/{}===================='.format(itr+1, self.n_iter))
      for t in range(self.ep_len):
        sampler.sample()
        batch = sampler.random_batch(self.batch_size)
        feed_dict = {
            self.obs_no_ph: batch['observations'],
            self.acs_na_ph: batch['actions'],
            self.rews_n_ph: batch['rewards'],
            self.next_obs_no_ph: batch['next_observations'],
            self.terminals_n_ph: batch['terminals'],
        }
        self.sess.run(self.training_ops, feed_dict=feed_dict)
        self.sess.run(self.target_update_ops)
      
      if itr % self.log_freq == 0:
        print('\nBeginning evaluation...')
        reward = self.eval_reward(sampler)
        if self.log is None:
          self.log = np.array([[reward, time.time() - start_time]])
        else:
          self.log = np.concatenate([self.log, [[reward, time.time() - start_time]]])
        print('EvalReward: {:.3f}'.format(self.log[-1][0]))
        print('TimeElapsed: {:.3f}s'.format(self.log[-1][1]))
        self.save(sampler, itr)
    self.save(sampler, self.start_iter + self.n_iter)
    print('Loop completed. Total time: {}'.format(time.time() - start_time))

In [0]:
with tf.Session() as sess:
  params = {'env_name': 'BipedalWalker-v2',
            'seed': 0,
            'start_iter': 0,
            'n_iter': 1000,
            'batch_size': 256,
            'ep_len': 1600,
            'alpha': 0.2,
            'gamma': 0.99,
            'tau': 0.01,
            'learning_rate': 5e-4,
            'n_layers': 2,
            'size': 256,
            'log_freq': 50,
            'load_checkpoint': None}
  
  seed = params['seed']
  tf.set_random_seed(seed)
  np.random.seed(seed)
  env = gym.make(params['env_name'])
  env.seed(seed)
  params['env'] = env

  trainer = SAC_Trainer(sess, params)
  replay_pool = SimpleReplayPool(
      observation_shape=env.observation_space.shape,
      action_shape=env.action_space.shape,
      max_size=1000000)
  if params['load_checkpoint']:
    trainer.load(params['load_checkpoint'])
    replay_pool.load()
    sampler = SimpleSampler(max_episode_length=params['ep_len'], prefill_steps=0)
  else:
    sess.run(tf.global_variables_initializer())
    sampler = SimpleSampler(max_episode_length=params['ep_len'])
  sampler.initialize(env, trainer.policy, replay_pool)
  trainer.run_training_loop(sampler)

In [0]:
import io
import base64
from IPython.display import HTML
from IPython import display as ipythondisplay

def render_video(filename):
  video = io.open(PATH + filename, 'r+b').read()
  encoded = base64.b64encode(video)
  ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
              loop controls style="height: 400px;">
              <source src="data:video/mp4;base64,{0}" type="video/mp4" />
            </video>'''.format(encoded.decode('ascii'))))

In [0]:
# render_video('video/0/openaigym.video.0.5301.video000000.mp4')