In [21]:
import matplotlib; matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd

from gym.wrappers.monitor import load_results
from copy import deepcopy

#REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))


class Config:
    def __init__(self, **kwargs):
        # read parameters from parents, and children can override the values.
        parents = []
        queue = [self.__class__]
        while queue:
            parent = queue.pop()
            if issubclass(parent, Config) and parent is not Config:
                parents.append(parent)
                for p in reversed(parent.__bases__):
                    queue.append(p)

        params = {}
        for cfg in reversed(parents):
            params.update(cfg.__dict__)

        # Set all instance variable based on kwargs and default class variables
        for key, value in params.items():
            if key.startswith('__'):
                continue

            if key in kwargs:
                # override default with provided parameter
                value = kwargs[key]
            else:
                # Need to make copies of class variables so that they aren't changed by instances
                value = deepcopy(value)

            self.__dict__[key] = value

    def __setattr__(self, name, value):
        if name not in self.__dict__:
            raise AttributeError(f"{self.__class__.__name__} does not have attribute {name}")
        self.__dict__[name] = value

    def __getattr__(self, name):
        # Raise error on assignment of missing variable
        if name not in self.__dict__:
            raise AttributeError(f"{self.__class__.__name__} does not have attribute {name}")
        return self.__dict__[name]

    def as_dict(self):
        return deepcopy(self.__dict__)

    def copy(self):
        return self.__class__(**self.as_dict())

    def get(self, name, default):
        return self.as_dict().get(name, default)

    def __repr__(self):
        return super().__repr__() + "\n" + self.dumps()


def plot_learning_curve(filename, value_dict, xlabel='step'):
    # Plot step vs the mean(last 50 episodes' rewards)
    fig = plt.figure(figsize=(12, 4 * len(value_dict)))

    for i, (key, values) in enumerate(value_dict.items()):
        ax = fig.add_subplot(len(value_dict), 1, i + 1)
        ax.plot(range(len(values)), values)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(key)
        ax.grid('k--', alpha=0.6)

    plt.tight_layout()
    os.makedirs(os.path.join(REPO_ROOT, 'figs'), exist_ok=True)
    plt.savefig(os.path.join(REPO_ROOT, 'figs', filename))


def plot_from_monitor_results(monitor_dir, window=10):
    assert os.path.exists(monitor_dir)
    if monitor_dir.endswith('/'):
        monitor_dir = monitor_dir[:-1]

    data = load_results(monitor_dir)
    n_episodes = len(data['episode_lengths'])
    assert n_episodes > 0

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), tight_layout=True, sharex=True)

    ax1.plot(range(n_episodes), pd.rolling_mean(np.array(data['episode_lengths']), window))
    ax1.set_xlabel('episode')
    ax1.set_ylabel('episode length')
    ax1.grid('k--', alpha=0.6)

    ax2.plot(range(n_episodes), pd.rolling_mean(np.array(data['episode_rewards']), window))
    ax2.set_xlabel('episode')
    ax2.set_ylabel('episode reward')
    ax2.grid('k--', alpha=0.6)

    os.makedirs(os.path.join(REPO_ROOT, 'figs'), exist_ok=True)
    plt.savefig(os.path.join(REPO_ROOT, 'figs', os.path.basename(monitor_dir) + '-monitor'))

In [22]:
import os

import numpy as np
import tensorflow as tf
from gym.spaces import Box, Discrete
from gym.utils import colorize

#from playground.utils.misc import Config
#from playground.utils.misc import REPO_ROOT


class TrainConfig(Config):
    lr = 0.001
    n_steps = 10000
    warmup_steps = 5000
    batch_size = 64
    log_every_step = 1000

    # give an extra bonus if done; only needed for certain tasks.
    done_reward = None


class Policy:
    def __init__(self, env, name, training=True, gamma=0.99, deterministic=False):
        self.env = env
        self.gamma = gamma
        self.training = training
        self.name = name

        if deterministic:
            np.random.seed(1)
            tf.set_random_seed(1)

    @property
    def act_size(self):
        # number of options of an action; this only makes sense for discrete actions.
        if isinstance(self.env.action_space, Discrete):
            return self.env.action_space.n
        else:
            return None

    @property
    def act_dim(self):
        # dimension of an action; this only makes sense for continuous actions.
        if isinstance(self.env.action_space, Box):
            return list(self.env.action_space.shape)
        else:
            return []

    @property
    def state_dim(self):
        # dimension of a state.
        return list(self.env.observation_space.shape)

    def obs_to_inputs(self, ob):
        return ob.flatten()

    def act(self, state, **kwargs):
        pass

    def build(self):
        pass

    def train(self, *args, **kwargs):
        pass

    def evaluate(self, n_episodes):
        reward_history = []
        reward = 0.

        for i in range(n_episodes):
            ob = self.env.reset()
            done = False
            while not done:
                a = self.act(ob)
                new_ob, r, done, _ = self.env.step(a)
                self.env.render()
                reward += r
                ob = new_ob

            reward_history.append(reward)
            reward = 0.

        print("Avg. reward over {} episodes: {:.4f}".format(n_episodes, np.mean(reward_history)))


class BaseModelMixin:
    """Abstract object representing an tensorflow model that can be easily saved/loaded.
    Modified based on https://github.com/devsisters/DQN-tensorflow/blob/master/dqn/base.py
    """

    def __init__(self, model_name, tf_sess_config=None):
        self._saver = None
        self._writer = None
        self._model_name = model_name
        self._sess = None

        if tf_sess_config is None:
            tf_sess_config = {
                'allow_soft_placement': True,
                'intra_op_parallelism_threads': 8,
                'inter_op_parallelism_threads': 4,
            }
        self.tf_sess_config = tf_sess_config

    def scope_vars(self, scope, only_trainable=True):
        collection = tf.GraphKeys.TRAINABLE_VARIABLES if only_trainable else tf.GraphKeys.VARIABLES
        variables = tf.get_collection(collection, scope=scope)
        assert len(variables) > 0
        print(f"Variables in scope '{scope}':")
        for v in variables:
            print("\t" + str(v))
        return variables

    def get_variable_values(self):
        t_vars = tf.trainable_variables()
        vals = self.sess.run(t_vars)
        return {v.name: value for v, value in zip(t_vars, vals)}

    def save_checkpoint(self, step=None):
        print(colorize(" [*] Saving checkpoints...", "green"))
        ckpt_file = os.path.join(self.checkpoint_dir, self.model_name)
        self.saver.save(self.sess, ckpt_file, global_step=step)

    def load_checkpoint(self):
        print(colorize(" [*] Loading checkpoints...", "green"))
        ckpt_path = tf.train.latest_checkpoint(self.checkpoint_dir)
        print(self.checkpoint_dir)
        print("ckpt_path:", ckpt_path)

        if ckpt_path:
            # self._saver = tf.train.import_meta_graph(ckpt_path + '.meta')
            self.saver.restore(self.sess, ckpt_path)
            print(colorize(" [*] Load SUCCESS: %s" % ckpt_path, "green"))
            return True
        else:
            print(colorize(" [!] Load FAILED: %s" % self.checkpoint_dir, "red"))
            return False

    def _get_dir(self, dir_name):
        path = os.path.join(REPO_ROOT, dir_name, self.model_name)
        os.makedirs(path, exist_ok=True)
        return path

    @property
    def log_dir(self):
        return self._get_dir('logs')

    @property
    def checkpoint_dir(self):
        return self._get_dir('checkpoints')

    @property
    def model_dir(self):
        return self._get_dir('models')

    @property
    def tb_dir(self):
        # tensorboard
        return self._get_dir('tb')

    @property
    def model_name(self):
        assert self._model_name, "Not a valid model name."
        return self._model_name

    @property
    def saver(self):
        if self._saver is None:
            self._saver = tf.train.Saver(max_to_keep=5)
        return self._saver

    @property
    def writer(self):
        if self._writer is None:
            self._writer = tf.summary.FileWriter(self.tb_dir, self.sess.graph)
        return self._writer

    @property
    def sess(self):
        if self._sess is None:
            config = tf.ConfigProto(**self.tf_sess_config)
            self._sess = tf.Session(config=config)

        return self._sess

In [7]:
# The prediction by the primary Q network for the actual actions.
action_one_hot = tf.one_hot(actions, env.action_space.n, 1.0, 0.0, name='action_one_hot')
pred = tf.reduce_sum(q * action_one_hot, reduction_indices=-1, name='q_acted')

# The optimization target defined by the Bellman equation and the target network.
max_q_next_by_target = tf.reduce_max(q_target, axis=-1)
y = rewards + (1. - done_flags) * gamma * max_q_next_by_target

# The loss measures the mean squared error between prediction and target.
loss = tf.reduce_mean(tf.square(pred - tf.stop_gradient(y)), name="loss_mse_train")
optimizer = tf.train.AdamOptimizer(0.001).minimize(loss, name="adam_optim")

In [8]:
# Get all the variables in the Q primary network.
q_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="Q_primary")
# Get all the variables in the Q target network.
q_target_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="Q_target")
assert len(q_vars) == len(q_target_vars)

def update_target_q_net_hard():
    # Hard update
    sess.run([v_t.assign(v) for v_t, v in zip(q_target_vars, q_vars)])

def update_target_q_net_soft(tau=0.05):
    # Soft update: polyak averaging.
    sess.run([v_t.assign(v_t * (1. - tau) + v * tau) for v_t, v in zip(q_target_vars, q_vars)])

AssertionError: 