In [1]:
from gym.envs.toy_text.frozen_lake import FrozenLakeEnv, UP, DOWN, LEFT, RIGHT

from tensorflow import keras
from tensorflow.keras import backend as K

from keras_gym.value_functions import LinearQ
from keras_gym.algorithms import MonteCarloQ, Reinforce
from keras_gym.policies import ValuePolicy, SoftmaxPolicy
from keras_gym.metrics import SoftmaxPolicyLossWithLogits
from keras_gym.utils import feature_vector


env = FrozenLakeEnv(is_slippery=False)


# behavior policy
Q = LinearQ(env, lr=0.1)
behavior_policy = ValuePolicy(Q)
behavior_algo = MonteCarloQ(Q)


# need this to create function approximator
num_features = feature_vector(
    env.observation_space.sample(),
    env.observation_space).size
num_actions = env.action_space.n


# function approximator for our policy
def create_model():
    # inputs
    X = keras.Input(shape=[num_features])
    advantages = keras.Input(shape=[1])
    
    # computation graph
    dense = keras.layers.Dense(num_actions, kernel_initializer='zeros')
    logits = dense(X)
    
    # loss
    loss_function = SoftmaxPolicyLossWithLogits(advantages)
    
    # the final model
    model = keras.Model(inputs=[X, advantages], outputs=logits)
    model.compile(
        loss=loss_function,
        optimizer=keras.optimizers.SGD(lr=0.1))
    
    return model


# this is the algo we'll develop
model = create_model()
policy = SoftmaxPolicy(env, model)
algo = Reinforce(policy)



def display_proba(behavior_policy, policy, s):
    actions = dict([(UP, 'up'), (DOWN, 'down'), (LEFT, 'left'), (RIGHT, 'right')])
    
    proba = behavior_policy.proba(s).p
    pmax = np.max(proba)
    print('\nb(a|s={}):'.format(s))
    print('\n'.join("{2} {1:.3f} - {0}".format(actions[a], p, '*' if p == pmax else ' ')
                    for a, p in enumerate(proba)))

    proba = policy.proba(s).p
    pmax = np.max(proba)
    print('\npi(a|s={}):'.format(s))
    print('\n'.join("{2} {1:.3f} - {0}".format(actions[a], p, '*' if p == pmax else ' ')
                    for a, p in enumerate(proba)))
    print()


def run_episode(use_new_policy=True, epsilon=0.1, update=False, render=False):
    s = env.reset()
    done = False
    while not done:
        if render:
            env.render()
            display_proba(behavior_policy, policy, s)
        
        if use_new_policy:
            if update:
                a = policy.thompson(s)
            else:
                a = policy.greedy(s)
                
        else:
            if update:
                a = behavior_policy.epsilon_greedy(s, epsilon)
            else:
                a = behavior_policy.greedy(s)
                
        s_next, r, done, info = env.step(a)
        if update:
            behavior_algo.update(s, a, r, s_next, done)
            algo.update(s, a, r, s_next, done)
        s = s_next
    if render:
        env.render()


for _ in range(200):
    run_episode(update=True, use_new_policy=True)

run_episode(render=True, use_new_policy=True)


[41mS[0mFFF
FHFH
FFFH
HFFG

b(a|s=0):
  0.251 - left
  0.247 - down
  0.250 - right
* 0.252 - up

pi(a|s=0):
  0.246 - left
  0.254 - down
* 0.255 - right
  0.245 - up

  (Right)
S[41mF[0mFF
FHFH
FFFH
HFFG

b(a|s=1):
  0.252 - left
  0.247 - down
* 0.252 - right
  0.249 - up

pi(a|s=1):
  0.248 - left
  0.252 - down
* 0.256 - right
  0.245 - up

  (Right)
SF[41mF[0mF
FHFH
FFFH
HFFG

b(a|s=2):
  0.248 - left
  0.251 - down
  0.247 - right
* 0.254 - up

pi(a|s=2):
  0.245 - left
* 0.256 - down
  0.251 - right
  0.248 - up

  (Down)
SFFF
FH[41mF[0mH
FFFH
HFFG

b(a|s=6):
  0.248 - left
* 0.255 - down
  0.248 - right
  0.249 - up

pi(a|s=6):
  0.245 - left
* 0.257 - down
  0.252 - right
  0.245 - up

  (Down)
SFFF
FHFH
FF[41mF[0mH
HFFG

b(a|s=10):
  0.248 - left
* 0.255 - down
  0.248 - right
  0.248 - up

pi(a|s=10):
  0.245 - left
* 0.257 - down
  0.252 - right
  0.245 - up

  (Down)
SFFF
FHFH
FFFH
HF[41mF[0mG

b(a|s=14):
  0.251 - left
  0.247 - down
* 0.256 - right
  0.247 