In [2]:
# %load ../../scripts/frozen_lake/actor_critic_ppo.py
import logging

import numpy as np
from gym.envs.toy_text.frozen_lake import FrozenLakeEnv, LEFT, RIGHT, UP, DOWN

from keras_gym.utils import TrainMonitor
from keras_gym.preprocessing import DefaultPreprocessor
from keras_gym.policies import LinearSoftmaxPolicy, ActorCritic
from keras_gym.value_functions import LinearV


logging.basicConfig(level=logging.ERROR)


# env with preprocessing
actions = {LEFT: 'L', RIGHT: 'R', UP: 'U', DOWN: 'D'}
env = FrozenLakeEnv(is_slippery=False)
env = DefaultPreprocessor(env)
env = TrainMonitor(env)


# updateable policy
policy = LinearSoftmaxPolicy(env, lr=0.1, update_strategy='ppo')
V = LinearV(env, lr=0.1, gamma=0.9, bootstrap_n=1)
actor_critic = ActorCritic(policy, V)


# static parameters
target_model_sync_period = 20
num_episodes = 500
num_steps = 30


# train
for ep in range(num_episodes):
    s = env.reset()

    for t in range(num_steps):
        a = policy(s, use_target_model=True)
        s_next, r, done, info = env.step(a)

        # small incentive to keep moving
        if np.array_equal(s_next, s):
            r = -0.1

        actor_critic.update(s, a, r, done)

        if env.T % target_model_sync_period == 0:
            policy.sync_target_model(tau=1.0)

        if done:
            break

        s = s_next


# run env one more time to render
s = env.reset()
env.render()

for t in range(num_steps):

    # print individual action probabilities
    print("  V(s) = {:.3f}".format(V(s)))
    for i, p in enumerate(policy.proba(s)):
        print("  π({:s}|s) = {:.3f}".format(actions[i], p))

    a = policy.greedy(s)
    s, r, done, info = env.step(a)
    env.render()

    if done:
        break



[41mS[0mFFF
FHFH
FFFH
HFFG
  V(s) = -0.022
  π(L|s) = 0.114
  π(D|s) = 0.383
  π(R|s) = 0.401
  π(U|s) = 0.103
  (Right)
S[41mF[0mFF
FHFH
FFFH
HFFG
  V(s) = 0.006
  π(L|s) = 0.128
  π(D|s) = 0.365
  π(R|s) = 0.392
  π(U|s) = 0.115
  (Right)
SF[41mF[0mF
FHFH
FFFH
HFFG
  V(s) = 0.068
  π(L|s) = 0.145
  π(D|s) = 0.435
  π(R|s) = 0.294
  π(U|s) = 0.127
  (Down)
SFFF
FH[41mF[0mH
FFFH
HFFG
  V(s) = 0.117
  π(L|s) = 0.145
  π(D|s) = 0.439
  π(R|s) = 0.286
  π(U|s) = 0.130
  (Down)
SFFF
FHFH
FF[41mF[0mH
HFFG
  V(s) = 0.336
  π(L|s) = 0.144
  π(D|s) = 0.447
  π(R|s) = 0.278
  π(U|s) = 0.130
  (Down)
SFFF
FHFH
FFFH
HF[41mF[0mG
  V(s) = 0.600
  π(L|s) = 0.141
  π(D|s) = 0.274
  π(R|s) = 0.458
  π(U|s) = 0.127
  (Right)
SFFF
FHFH
FFFH
HFF[41mG[0m
