Code based on https://www.baeldung.com/cs/reinforcement-learning-neural-network

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import gym
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MeanSquaredError
from tqdm import tqdm

np.random.seed(42)

# The environment

In [None]:
"""
SFFF       (S: starting point, safe)
FHFH       (F: frozen surface, safe)
FFFH       (H: hole, fall to your doom)
HFFG       (G: goal, where the frisbee is located)

state = row * ncol + col

LEFT = 0
DOWN = 1
RIGHT = 2
UP = 3

The episode ends when you reach the goal or fall in a hole.
You receive a reward of 1 if you reach the goal, and zero otherwise.

https://github.com/openai/gym/blob/master/gym/envs/toy_text/frozen_lake.py
"""

env = gym.make('FrozenLake-v1', is_slippery=False)

In [None]:
def encode_state(state):
    return np.identity(env.observation_space.n)[state:state + 1]


terminal_states = []
for s in env.P.keys():  # for state
    for a in env.P[s].keys():  # for action
        prob, state, reward, terminal = env.P[s][a][0]
        if terminal:
            terminal_states.append(state)
terminal_states = sorted(set(terminal_states))

# Agent setup

In [None]:
discount_factor = 0.95
eps = 0.4
eps_decay_factor = 0.995
num_episodes = 500

In [None]:
adam = Adam(learning_rate=0.01)
loss_fn = MeanSquaredError()

inp = Input(shape=(env.observation_space.n,))
x = Dense(20, activation='relu')(inp)
out = Dense(env.action_space.n, activation='linear')(x)
model = Model(inp, out)
model.summary()

In [None]:
@tf.function
def train_step(inputs, targets):
    with tf.GradientTape() as tape:

        # Run the forward pass of the layer.
        # The operations that the layer applies
        # to its inputs are going to be recorded
        # on the GradientTape.
        logits = model(inputs, training=True)  # Logits for this minibatch

        # Compute the loss value for this minibatch.
        loss_value = loss_fn(targets, logits)

    # Use the gradient tape to automatically retrieve
    # the gradients of the trainable variables with respect to the loss.
    grads = tape.gradient(loss_value, model.trainable_weights)

    # Run one step of gradient descent by updating
    # the value of the variables to minimize the loss.
    adam.apply_gradients(zip(grads, model.trainable_weights))

# Train the Agent

In [None]:
for i in tqdm(range(num_episodes)):
    state = env.reset()
    eps *= eps_decay_factor
    terminal = False
    while not terminal:
        if np.random.random() < eps:
            action = np.random.randint(0, env.action_space.n)
        else:
            action = np.argmax(model.predict(encode_state(state)))
        new_state, reward, terminal, _ = env.step(action)
        if new_state in terminal_states[:-1]:
            # give a bit of negative reward to dying
            reward = -0.05
        if terminal:
            target = reward # + discount_factor * 0.0
            target_vector = np.zeros(shape=(env.action_space.n,))
        else:
            target = reward + discount_factor * np.max(model.predict(encode_state(new_state)))
            target_vector = model.predict(encode_state(state))[0]
        target_vector[action] = target
        train_step(np.identity(env.observation_space.n)[state:state + 1], target_vector.reshape(-1, env.action_space.n))
        state = new_state

# Visualize learned q-values

In [None]:
q_table = model.predict(np.identity(env.observation_space.n))
q_table = pd.DataFrame(data=q_table, columns=['left', 'down', 'right', 'up'])
q_table.index.name = 'state'

In [None]:
# heatmap expected reward non-terminal states
fig, ax = plt.subplots(figsize=(3, 6))

sns.heatmap(q_table.loc[~q_table.index.isin(terminal_states)], annot=q_table.loc[~q_table.index.isin(terminal_states)], cmap='coolwarm');

In [None]:
# heatmap expected reward all states
fig, ax = plt.subplots(figsize=(3, 6))

sns.heatmap(q_table, annot=q_table, cmap='coolwarm');