In [1]:
import es_distributed.tf_util as U
import tensorflow as tf

from es_distributed.policies import catcher, CatchPolicy
from es_distributed.es import *
from es_distributed import policies


exp = {
  "config": {
    "calc_obstat_prob": 0.0,
    "episodes_per_batch": 10000,
    "eval_prob": 0.03,
    "l2coeff": 0.005,
    "noise_stdev": 0.02,
    "snapshot_freq": 5,
    "timesteps_per_batch": 100000,
    "return_proc_mode": "centered_rank",
    "episode_cutoff_mode": "env_default"
  },
  "env_id": "catcher",
  "exp_prefix": "humanoid",
  "optimizer": {
    "args": {
      "stepsize": 0.01
    },
    "type": "adam"
  },
  "policy": {
    "args": {
      "connection_type": "ff",
      "hidden_dims": [
        100,
        100
      ],
      "nonlin_type": "tanh"
    },
    "type": "CatchPolicy"
  }
}


config, env, sess, policy = setup(exp, single_threaded=False)
policy.initialize_from('/Users/xxx/snapshot.h5')
policy.rollout(env)

Instructions for updating:
Please use tf.global_variables instead.
Instructions for updating:
Use `tf.variables_initializer` instead.


(array([    0.,     0.,     0.,     0.,     0.,     0.,     0.,  1500.], dtype=float32),
 8)

for i in range(100):
    print(policy.rollout(env)[0][-1])

In [2]:
nextQ = tf.placeholder(shape=policy.a.shape,dtype=tf.float32)
loss = tf.reduce_mean(tf.square(nextQ - policy.a))
trainer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
updateModel = trainer.minimize(loss)


import json
import numpy as np

class Catch(object):
    def __init__(self, grid_size=10):
        self.grid_size = grid_size
        self.reset()

    def _update_state(self, action):
        """
        Input: action and states
        Ouput: new states and reward
        """
        state = self.state
        if action == 0:  # left
            action = -1
        elif action == 1:  # stay
            action = 0
        else:
            action = 1  # right
        f0, f1, basket = state[0]
        new_basket = min(max(1, basket + action), self.grid_size-1)
        f0 += 1
        out = np.asarray([f0, f1, new_basket])
        out = out[np.newaxis]

        assert len(out.shape) == 2
        self.state = out

    def _draw_state(self):
        im_size = (self.grid_size,)*2
        state = self.state[0]
        canvas = np.zeros(im_size)
        canvas[state[0], state[1]] = 1  # draw fruit
        canvas[-1, state[2]-1:state[2] + 2] = 1  # draw basket
        return canvas

    def _get_reward(self):
        fruit_row, fruit_col, basket = self.state[0]
        if fruit_row == self.grid_size-1:
            if abs(fruit_col - basket) <= 1:
                return 1000
            else:
                return -1000
        else:
            return 0

    def _is_over(self):
        if self.state[0, 0] == self.grid_size-1:
            return True
        else:
            return False

    def observe(self):
        canvas = self._draw_state()
        return canvas.reshape((1, -1))

    def act(self, action):
        self._update_state(action)
        reward = self._get_reward()
        game_over = self._is_over()
        return self.observe(), reward, game_over

    def reset(self):
        n = np.random.randint(0, self.grid_size-1, size=1)
        m = np.random.randint(1, self.grid_size-2, size=1)
        self.state = np.asarray([0, n, m])[np.newaxis]


class ExperienceReplay(object):
    def __init__(self, max_memory=100, discount=.9):
        self.max_memory = max_memory
        self.memory = list()
        self.discount = discount

    def remember(self, states, game_over):
        # memory[i] = [[state_t, action_t, reward_t, state_t+1], game_over?]
        self.memory.append([states, game_over])
        if len(self.memory) > self.max_memory:
            del self.memory[0]

    def get_batch(self, policy, batch_size=10):
        len_memory = len(self.memory)
        num_actions = int(policy.a.shape[-1])
        env_dim = self.memory[0][0][0].shape[1]
        inputs = np.zeros((min(len_memory, batch_size), env_dim))
        targets = np.zeros((inputs.shape[0], num_actions))
        for i, idx in enumerate(np.random.randint(0, len_memory,
                                                  size=inputs.shape[0])):
            state_t, action_t, reward_t, state_tp1 = self.memory[idx][0]
            game_over = self.memory[idx][1]

            inputs[i:i+1] = state_t
            # There should be no target values for actions not taken.
            # Thou shalt not correct actions not taken #deep
            targets[i] = policy.act(state_t)[0]
            Q_sa = np.max(policy.act(state_tp1)[0])
            if game_over:  # if game_over is True
                targets[i, action_t] = reward_t
            else:
                # reward_t + gamma * max_a' Q(s', a')
                targets[i, action_t] = reward_t + self.discount * Q_sa
        return inputs, targets


epsilon = .1  # exploration
num_actions = 3  # [move_left, stay, move_right]
epoch = 1000
max_memory = 500
hidden_size = 100
batch_size = 1
grid_size = 10

#model = Sequential()
#model.add(Dense(hidden_size, input_shape=(grid_size**2,), activation='relu'))
#model.add(Dense(hidden_size, activation='relu'))
#model.add(Dense(num_actions))
#model.compile(sgd(lr=.2), "mse")

# If you want to continue training from a previous model, just uncomment the line bellow
# model.load_weights("model.h5")

# Define environment/game
env = Catch(grid_size)

# Initialize experience replay object
exp_replay = ExperienceReplay(max_memory=max_memory)

# Train
win_cnt = 0
for e in range(epoch):
    loss = 0.
    env.reset()
    game_over = False
    # get initial input
    input_t = env.observe()

    while not game_over:
        input_tm1 = input_t
        # get next action
        if np.random.rand() <= epsilon:
            action = np.random.randint(0, num_actions, size=1)
        else:
            q = policy.act(input_tm1)
            action = np.argmax(q[0])

        # apply action, get rewards and new state
        input_t, reward, game_over = env.act(action)
        if reward == 1:
            win_cnt += 1

        # store experience
        exp_replay.remember([input_tm1, action, reward, input_t], game_over)

    # adapt model
    inputs, targets = exp_replay.get_batch(policy, batch_size=batch_size)

    print(policy.act([catcher().reset()])[0])
    #loss += model.train_on_batch(inputs, targets)[0]
    U.eval(updateModel, {policy.o:inputs,nextQ:targets})
    print(policy.act([catcher().reset()])[0])
    print('\n')

    print("Epoch {:03d}/999 | Win count {}".format(e, win_cnt))



[ 0.07838496 -0.08790758 -0.00774639]
[-0.00416617 -0.08159027  0.07238878]


Epoch 000/999 | Win count 0
[ 0.04434976 -0.06134495  0.00959942]
[-0.00434834 -0.05501063  0.07305783]


Epoch 001/999 | Win count 0
[ 0.06748191 -0.088275    0.0208478 ]
[-0.01653543 -0.07426012  0.04711032]


Epoch 002/999 | Win count 0
[ 0.10566166 -0.0513109   0.00590358]
[-0.01355105 -0.05813994  0.03740803]


Epoch 003/999 | Win count 0
[ 0.06690668 -0.10682404 -0.02405518]
[ 0.1034755  -0.05523152 -0.0032712 ]


Epoch 004/999 | Win count 0
[ 0.07494357 -0.06401411 -0.01290841]
[ 0.04430131 -0.11005063  0.00364378]


Epoch 005/999 | Win count 0
[ 0.08484156 -0.04840381  0.00943503]
[ 0.0662381  -0.08355934  0.00513072]


Epoch 006/999 | Win count 0
[ 0.0252526  -0.07277299  0.05147734]
[  1.12523193e+02   5.95594868e-02  -1.99146223e+00]


Epoch 007/999 | Win count 0
[  1.12646606e+02   6.65038377e-02  -1.98713255e+00]
[  2.50107495e+03  -2.52926916e-01  -7.34086752e-01]


Epoch 008/999 | Win count 0
[

[  9.86529881e+18   5.97593366e+19   6.86859966e+18]
[  9.86529881e+18   1.95212902e+19   6.86859966e+18]


Epoch 073/999 | Win count 0
[  9.86529881e+18   1.95212902e+19   6.86859966e+18]
[  6.17379605e+19   1.95212902e+19   6.86859966e+18]


Epoch 074/999 | Win count 0
[  6.17379605e+19   1.95212902e+19   6.86859966e+18]
[  2.01677173e+19   1.95212902e+19   6.86859966e+18]


Epoch 075/999 | Win count 0
[  2.01677173e+19   1.95212902e+19   6.86859966e+18]
[  2.01677173e+19   1.95212902e+19   8.28363792e+19]


Epoch 076/999 | Win count 0
[  2.01677173e+19   1.95212902e+19   8.28363792e+19]
[  2.01677173e+19   1.95212902e+19  -4.74928189e+20]


Epoch 077/999 | Win count 0
[  2.01677173e+19   1.95212902e+19  -4.74928189e+20]
[  6.58813843e+18   1.95212902e+19  -4.74928189e+20]


Epoch 078/999 | Win count 0
[  6.58813843e+18   1.95212902e+19  -4.74928189e+20]
[  6.58813843e+18   6.37696953e+18  -4.74928189e+20]


Epoch 079/999 | Win count 0
[  6.58813843e+18   6.37696953e+18  -4.74928189e

[ nan  nan  nan]


Epoch 202/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 203/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 204/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 205/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 206/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 207/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 208/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 209/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 210/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 211/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 212/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 213/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 214/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 215/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 216/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 217/999 | Win c

[ nan  nan  nan]


Epoch 335/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 336/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 337/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 338/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 339/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 340/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 341/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 342/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 343/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 344/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 345/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 346/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 347/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 348/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 349/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 350/999 | Win c

[ nan  nan  nan]


Epoch 465/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 466/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 467/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 468/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 469/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 470/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 471/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 472/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 473/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 474/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 475/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 476/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 477/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 478/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 479/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 480/999 | Win c

[ nan  nan  nan]
[ nan  nan  nan]


Epoch 600/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 601/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 602/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 603/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 604/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 605/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 606/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 607/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 608/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 609/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 610/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 611/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 612/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 613/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 614/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoc

[ nan  nan  nan]
[ nan  nan  nan]


Epoch 731/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 732/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 733/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 734/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 735/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 736/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 737/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 738/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 739/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 740/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 741/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 742/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 743/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 744/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 745/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoc

[ nan  nan  nan]
[ nan  nan  nan]


Epoch 866/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 867/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 868/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 869/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 870/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 871/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 872/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 873/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 874/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 875/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 876/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 877/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 878/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 879/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoch 880/999 | Win count 0
[ nan  nan  nan]
[ nan  nan  nan]


Epoc

In [3]:
policy.a.shape

TensorShape([Dimension(None), Dimension(3)])

In [4]:
env = catcher()
policy.act([env.reset()])[0]

array([ nan,  nan,  nan], dtype=float32)

In [5]:
targets

array([[ nan,  nan,  nan]])

NameError: name 'inputs' is not defined

In [None]:
_,W1 = sess.run([updateModel,W],feed_dict={inputs1:inputs,nextQ:targets})