In [None]:
%matplotlib auto

In [None]:
! pip install numpy
! pip install tensorflow
! pip install tqdm
! pip install matplotlib

# Policy gradients play Catch
This notebook uses the same Catch environment from the [Q learning plays Catch example](../Catch/Q%20learning%20plays%20Catch.ipynb), but now uses policy gradients (specifically the REINFORCE) algorithm to come up with a winning policy.

Below some nice referenes that offer some intuition on policy gradients:
- https://towardsdatascience.com/policy-gradients-in-a-nutshell-8b72f9743c5d
- https://medium.com/@jonathan_hui/rl-policy-gradients-explained-9b13b688b146

In [1]:
from collections import deque
from enum import Enum
from random import sample

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, HTML
from matplotlib import animation
from tqdm import tqdm
from tensorflow import keras

# Environment
We use the more or less the same `Catch` environment as we did for the DQN implementation. However, this time there will be a catch (pun intended): we'll sometimes randomly move the piece of fruit left or right to make the game a bit more exciting to play.

In [33]:
class Action(Enum):
    LEFT = 0
    STAY = 1
    RIGHT = 2

class Catch(object):
    def __init__(self, grid_size=10, basket_size=3):
        self.grid_size = grid_size
        self.basket_size = basket_size
        self.reset()
        
    def reset(self):
        # Fruit starts at a random column at the top row
        fruit = (
            0,
            np.random.randint(self.grid_size)
        )
        
        # The basket starts at a random column at the bottom row
        basket_left = (
            self.grid_size - 1,
            np.random.randint(self.grid_size - self.basket_size)            
        )
        
        self.terminated = False
        self._state = [fruit, basket_left]
        
    @property
    def state(self):
        fruit, basket_left = self._state
        
        basket_right = (
            basket_left[0], basket_left[1] + self.basket_size
        )
        
        return fruit, basket_left, basket_right
        
    def observe(self):
        grid = np.zeros((self.grid_size, self.grid_size))
        fruit, basket_left, basket_right = self.state
        
        grid[fruit] = 1
        
        grid[basket_left[0], basket_left[1]:basket_right[1]] = 1
        
        return grid
    
    def _is_permitted_basket_position(self, basket_left):
        return ((basket_left[1] >= 0) and
                (basket_left[1] + self.basket_size < self.grid_size))
    
    def _get_desired_basket_position(self, action):
        current_basket_position = self._state[1]
        
        if action == Action.LEFT:
            desired_basket_position = (current_basket_position[0], current_basket_position[1] - 1)
        elif action == Action.STAY:
            desired_basket_position = current_basket_position
        elif action == Action.RIGHT:
            desired_basket_position = (current_basket_position[0], current_basket_position[1] + 1)
            
        return desired_basket_position
            
    def _fruit_in_basket(self):
        fruit, basket_left, basket_right = self.state
        
        return ((fruit[0] == basket_left[0]) and
                (basket_left[1] <= fruit[1]) and
                (basket_right[1] >= fruit[1]))
    
    def _fruit_on_last_row(self):
        fruit, _, _ = self.state
        return fruit[0] == self.grid_size - 1
    
    def act(self, action):
        assert not self.terminated, "Cannot play a terminated game. Call .reset() to reset."
        
        
        # Change the basket position if it is an allowed position
        desired_basket_position = self._get_desired_basket_position(action)
        
        if self._is_permitted_basket_position(desired_basket_position):
            self._state[1] = desired_basket_position
        
        # Move the fruit and check if it can still be catched
        fruit_offset = np.random.choice([-1, 0, 1], p=[0.1, 0.8, 0.1])
        self._state[0] = (
            self._state[0][0] + 1, 
            (self._state[0][1] + fruit_offset) % self.grid_size
        )
        
        if self._fruit_on_last_row() and self._fruit_in_basket():
            # Game has been won. Fruit was catched
            self.terminated = True
            return (1., self.terminated)
        elif self._fruit_on_last_row() and not self._fruit_in_basket():
            # Game has been lost since fruit was not catched
            self.terminated = True
            return (-1., self.terminated)
        else:
            # Continue playing
            return (0., self.terminated)

# REINFORCE
The basic form of the REINFORCE algorithm is, implementation wise, by far less involved than for example a DQN network. For example, we don't need a replay buffer or anything to be able to process updates. All one needs for implementing the basic form of REINFORCE is the update equation for the parameters of the policy network:

$$
    \theta_{k+1} \leftarrow \theta_k + \alpha \sum_t \nabla \log \pi_{\theta_k}(a_t|s_t) G_t.
$$
where $\pi_{\theta_k}$ is the current policy, $G_t$ the discounted future reward and $\alpha$ the learning rate.

In [34]:
MAX_EPISODES = 5000
N_FILTERS = 30

In [37]:
model = keras.models.Sequential()
model.add(keras.layers.Lambda(lambda x: x[..., None]))
model.add(keras.layers.Conv2D(N_FILTERS, (3, 3), activation='relu'))
model.add(keras.layers.Conv2D(N_FILTERS, (3, 3), activation='relu'))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(len(Action), activation='softmax'))

In [42]:
env = Catch()
episodes_iterator = tqdm(range(MAX_EPISODES))
optimizer = tf.keras.optimizers.Adam(0.0001)

total_rewards = []
for episode in episodes_iterator:
    env.reset()
    rewards = []
    log_probs = []
    with tf.GradientTape() as tape:
        while not env.terminated:
            # Use the network to obtain a policy for the current state and sample an action
            # from this policy.
            log_policy_probs = tf.math.log(
                model(env.observe()[None, ...])
            )
            sampled_action = Action(
                tf.random.categorical(logits=log_policy_probs, num_samples=1)[0, 0]
            )

            # Act using the sampled policy
            reward, terminated = env.act(sampled_action)

            # Store the reward and the log probability tensor related to this action
            rewards.append(reward)
            log_probs.append(log_policy_probs[0, sampled_action.value])
            
        # Use the update formula described above this cell to derive a loss that,
        # when differentiated, will lead to the right update for the parameters of th
        # policy network
        loss = -1. * tf.reduce_sum(
            tf.cumsum(tf.stack(rewards)[::-1]) * tf.stack(log_probs)
        )
        
        # Apply the gradients
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
            
    total_rewards.append(sum(rewards))
    episodes_iterator.set_description(
        "Average reward (last 50 games / total): %.2f/%.2f" % (np.mean(total_rewards[-50:]), np.mean(total_rewards))
    )

Average reward (last 50 games / total): 0.60/0.72: 100%|██████████| 5000/5000 [02:19<00:00, 35.86it/s]


# Results
We plot results in the same was as we do in the [Q learning plays Catch example](../Catch/Q%20learning%20plays%20Catch.ipynb) notebook.

In [45]:
def plot_image_sequence(images):
    x_pixels, y_pixels = images.shape[-2:]
    fig, ax = plt.subplots(figsize=(y_pixels, x_pixels))
    im = ax.imshow(images[0])

    def animate(i):
        im.set_array(images[i])
        return (im,)

    anim = animation.FuncAnimation(fig, animate, frames=len(images), interval=500, repeat_delay=1, repeat=True)
    display(HTML(anim.to_html5_video()))
    

states = []
for game in range(10):
    env.reset()
    states.append(env.observe()[None, :])
    while not env.terminated:
        probs = model.predict(env.observe()[None, :])[0]

        action = Action(
            np.random.choice(len(Action), p=probs)
        )

        env.act(action)
        
        states.append(env.observe()[None, :])
    
images = np.vstack(states)

plot_image_sequence(images)