In [12]:
import logging

import numpy as np
import keras_gym as km
from tensorflow import keras
from gym.envs.toy_text.frozen_lake import FrozenLakeEnv, LEFT, RIGHT, UP, DOWN


logging.basicConfig(level=logging.INFO)


# env with preprocessing
actions = {LEFT: 'L', RIGHT: 'R', UP: 'U', DOWN: 'D'}
env = FrozenLakeEnv(is_slippery=False)
env = km.utils.TrainMonitor(env)


class LinearFunctionApproximator(km.FunctionApproximator):
    def body(self, S, variable_scope):
        one_hot = keras.layers.Lambda(lambda x: K.one_hot(x, 16))
        return one_hot(S)


# define function approximators
mlp = LinearFunctionApproximator(env, lr=0.01)
v = km.V(mlp, gamma=0.9, bootstrap_n=1)
pi = km.SoftmaxPolicy(mlp, update_strategy='ppo')

# combine into one actor-critic
actor_critic = km.ActorCritic(pi, v)


# static parameters
target_model_sync_period = 20
num_episodes = 250
num_steps = 30


# train
for ep in range(num_episodes):
    s = env.reset()

    for t in range(num_steps):
        a = pi(s, use_target_model=True)
        s_next, r, done, info = env.step(a)

        # small incentive to keep moving
        if np.array_equal(s_next, s):
            r = -0.1

        actor_critic.update(s, a, r, done)

        if env.T % target_model_sync_period == 0:
            pi.sync_target_model(tau=1.0)

        if done:
            break

        s = s_next


# run env one more time to render
s = env.reset()
env.render()

for t in range(num_steps):

    # print individual action probabilities
    print("  V(s) = {:.3f}".format(v(s)))
    for i, p in enumerate(pi.proba(s)):
        print("  π({:s}|s) = {:.3f}".format(actions[i], p))

    a = pi.greedy(s)
    s, r, done, info = env.step(a)
    env.render()

    if done:
        break


INFO:TrainMonitor:ep: 1, T: 7, G: 0, avg(r): 0.000, t: 6, dt: 985.397ms
INFO:TrainMonitor:ep: 2, T: 14, G: 0, avg(r): 0.000, t: 6, dt: 7.441ms
INFO:TrainMonitor:ep: 3, T: 19, G: 0, avg(r): 0.000, t: 4, dt: 8.277ms
INFO:TrainMonitor:ep: 4, T: 29, G: 0, avg(r): 0.000, t: 9, dt: 7.619ms
INFO:TrainMonitor:ep: 5, T: 44, G: 0, avg(r): 0.000, t: 14, dt: 88.140ms
INFO:TrainMonitor:ep: 6, T: 58, G: 0, avg(r): 0.000, t: 13, dt: 8.627ms
INFO:TrainMonitor:ep: 7, T: 64, G: 0, avg(r): 0.000, t: 5, dt: 5.647ms
INFO:TrainMonitor:ep: 8, T: 71, G: 0, avg(r): 0.000, t: 6, dt: 5.502ms
INFO:TrainMonitor:ep: 9, T: 74, G: 0, avg(r): 0.000, t: 2, dt: 2.344ms
INFO:TrainMonitor:ep: 10, T: 77, G: 0, avg(r): 0.000, t: 2, dt: 2.441ms
INFO:TrainMonitor:ep: 11, T: 80, G: 0, avg(r): 0.000, t: 2, dt: 1.724ms
INFO:TrainMonitor:ep: 12, T: 86, G: 0, avg(r): 0.000, t: 5, dt: 6.369ms
INFO:TrainMonitor:ep: 13, T: 89, G: 0, avg(r): 0.000, t: 2, dt: 1.492ms
INFO:TrainMonitor:ep: 14, T: 97, G: 0, avg(r): 0.000, t: 7, dt: 6.177

INFO:TrainMonitor:ep: 114, T: 567, G: 0, avg(r): 0.000, t: 2, dt: 1.922ms
INFO:TrainMonitor:ep: 115, T: 575, G: 0, avg(r): 0.000, t: 7, dt: 9.578ms
INFO:TrainMonitor:ep: 116, T: 582, G: 1, avg(r): 0.167, t: 6, dt: 6.883ms
INFO:TrainMonitor:ep: 117, T: 586, G: 0, avg(r): 0.000, t: 3, dt: 4.040ms
INFO:TrainMonitor:ep: 118, T: 589, G: 0, avg(r): 0.000, t: 2, dt: 1.447ms
INFO:TrainMonitor:ep: 119, T: 592, G: 0, avg(r): 0.000, t: 2, dt: 1.093ms
INFO:TrainMonitor:ep: 120, T: 595, G: 0, avg(r): 0.000, t: 2, dt: 1.018ms
INFO:TrainMonitor:ep: 121, T: 598, G: 0, avg(r): 0.000, t: 2, dt: 3.878ms
INFO:TrainMonitor:ep: 122, T: 602, G: 0, avg(r): 0.000, t: 3, dt: 5.701ms
INFO:TrainMonitor:ep: 123, T: 609, G: 1, avg(r): 0.167, t: 6, dt: 6.412ms
INFO:TrainMonitor:ep: 124, T: 612, G: 0, avg(r): 0.000, t: 2, dt: 1.159ms
INFO:TrainMonitor:ep: 125, T: 615, G: 0, avg(r): 0.000, t: 2, dt: 1.146ms
INFO:TrainMonitor:ep: 126, T: 619, G: 0, avg(r): 0.000, t: 3, dt: 7.127ms
INFO:TrainMonitor:ep: 127, T: 622, G: 

INFO:TrainMonitor:ep: 225, T: 1,122, G: 1, avg(r): 0.167, t: 6, dt: 11.607ms
INFO:TrainMonitor:ep: 226, T: 1,129, G: 1, avg(r): 0.167, t: 6, dt: 12.305ms
INFO:TrainMonitor:ep: 227, T: 1,136, G: 1, avg(r): 0.167, t: 6, dt: 8.932ms
INFO:TrainMonitor:ep: 228, T: 1,139, G: 0, avg(r): 0.000, t: 2, dt: 1.235ms
INFO:TrainMonitor:ep: 229, T: 1,146, G: 1, avg(r): 0.167, t: 6, dt: 8.150ms
INFO:TrainMonitor:ep: 230, T: 1,153, G: 1, avg(r): 0.167, t: 6, dt: 8.891ms
INFO:TrainMonitor:ep: 231, T: 1,160, G: 1, avg(r): 0.167, t: 6, dt: 10.066ms
INFO:TrainMonitor:ep: 232, T: 1,167, G: 1, avg(r): 0.167, t: 6, dt: 9.875ms
INFO:TrainMonitor:ep: 233, T: 1,174, G: 1, avg(r): 0.167, t: 6, dt: 13.845ms
INFO:TrainMonitor:ep: 234, T: 1,181, G: 1, avg(r): 0.167, t: 6, dt: 7.969ms
INFO:TrainMonitor:ep: 235, T: 1,188, G: 1, avg(r): 0.167, t: 6, dt: 7.692ms
INFO:TrainMonitor:ep: 236, T: 1,195, G: 1, avg(r): 0.167, t: 6, dt: 9.675ms
INFO:TrainMonitor:ep: 237, T: 1,202, G: 1, avg(r): 0.167, t: 6, dt: 6.136ms
INFO:Tra


[41mS[0mFFF
FHFH
FFFH
HFFG
  V(s) = 0.400


INFO:TrainMonitor:ep: 251, T: 1,294, G: 1, avg(r): 0.167, t: 6, dt: 99.422ms


  π(L|s) = 0.003
  π(D|s) = 0.057
  π(R|s) = 0.935
  π(U|s) = 0.005
  (Right)
S[41mF[0mFF
FHFH
FFFH
HFFG
  V(s) = 0.477
  π(L|s) = 0.008
  π(D|s) = 0.038
  π(R|s) = 0.947
  π(U|s) = 0.006
  (Right)
SF[41mF[0mF
FHFH
FFFH
HFFG
  V(s) = 0.607
  π(L|s) = 0.012
  π(D|s) = 0.949
  π(R|s) = 0.027
  π(U|s) = 0.011
  (Down)
SFFF
FH[41mF[0mH
FFFH
HFFG
  V(s) = 0.723
  π(L|s) = 0.008
  π(D|s) = 0.915
  π(R|s) = 0.065
  π(U|s) = 0.012
  (Down)
SFFF
FHFH
FF[41mF[0mH
HFFG
  V(s) = 0.753
  π(L|s) = 0.011
  π(D|s) = 0.937
  π(R|s) = 0.040
  π(U|s) = 0.013
  (Down)
SFFF
FHFH
FFFH
HF[41mF[0mG
  V(s) = 0.920
  π(L|s) = 0.011
  π(D|s) = 0.139
  π(R|s) = 0.824
  π(U|s) = 0.027
  (Right)
SFFF
FHFH
FFFH
HFF[41mG[0m


In [6]:
policy.train_model.loss

<keras_gym.losses.policy_based.ClippedSurrogateLoss at 0x7f3d9c9ffb38>