In [None]:
%matplotlib auto

In [None]:
! pip install numpy
! pip install tensorflow
! pip install tqdm
! pip install matplotlib

# Q-learning playing catch

This notebook shows a very simple implementation of the parts involved in Deep Q-learning by learning a neural network to play the game of Catch. Every time a piece of "fruit" is initiated at a random position on top of the grid. The player has the ability to move the "basket" (a line at the bottom of the screen) left, right or to keep it in position. 

In [1]:
from collections import deque
from enum import Enum
from random import sample

import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, HTML
from matplotlib import animation
from tqdm import tqdm
from tensorflow import keras

# The game

In [2]:
class Action(Enum):
    LEFT = 0
    STAY = 1
    RIGHT = 2

class Catch(object):
    def __init__(self, grid_size=10, basket_size=3):
        self.grid_size = grid_size
        self.basket_size = basket_size
        self.reset()
        
    def reset(self):
        # Fruit starts at a random column at the top row
        fruit = (
            0,
            np.random.randint(self.grid_size)
        )
        
        # The basket starts at a random column at the bottom row
        basket_left = (
            self.grid_size - 1,
            np.random.randint(self.grid_size - self.basket_size)            
        )
        
        self.terminated = False
        self._state = [fruit, basket_left]
        
    @property
    def state(self):
        fruit, basket_left = self._state
        
        basket_right = (
            basket_left[0], basket_left[1] + self.basket_size
        )
        
        return fruit, basket_left, basket_right
        
    def observe(self):
        grid = np.zeros((self.grid_size, self.grid_size))
        fruit, basket_left, basket_right = self.state
        
        grid[fruit] = 1
        
        grid[basket_left[0], basket_left[1]:basket_right[1]] = 1
        
        return grid
    
    def _is_permitted_basket_position(self, basket_left):
        return ((basket_left[1] >= 0) and
                (basket_left[1] + self.basket_size < self.grid_size))
    
    def _get_desired_basket_position(self, action):
        current_basket_position = self._state[1]
        
        if action == Action.LEFT:
            desired_basket_position = (current_basket_position[0], current_basket_position[1] - 1)
        elif action == Action.STAY:
            desired_basket_position = current_basket_position
        elif action == Action.RIGHT:
            desired_basket_position = (current_basket_position[0], current_basket_position[1] + 1)
            
        return desired_basket_position
            
    def _fruit_in_basket(self):
        fruit, basket_left, basket_right = self.state
        
        return ((fruit[0] == basket_left[0]) and
                (basket_left[1] <= fruit[1]) and
                (basket_right[1] >= fruit[1]))
    
    def _fruit_on_last_row(self):
        fruit, _, _ = self.state
        return fruit[0] == self.grid_size - 1
    
    def act(self, action):
        assert not self.terminated, "Cannot play a terminated game. Call .reset() to reset."
        
        
        # Change the basket position if it is an allowed position
        desired_basket_position = self._get_desired_basket_position(action)
        
        if self._is_permitted_basket_position(desired_basket_position):
            self._state[1] = desired_basket_position
        
        # Move the fruit and check if it can still be catched
        self._state[0] = (self._state[0][0] + 1, self._state[0][1])
        
        if self._fruit_on_last_row() and self._fruit_in_basket():
            # Game has been won. Fruit was catched
            self.terminated = True
            return (1, self.terminated)
        elif self._fruit_on_last_row() and not self._fruit_in_basket():
            # Game has been lost since fruit was not catched
            self.terminated = True
            return (-1, self.terminated)
        else:
            # Continue playing
            return (0, self.terminated)

# Experience replay
The classes and functions below serve as a very minimalistic implementation for performing experience replay.

In [3]:
class ExperienceReplayBuffer(object):
    def __init__(self, max_memory=100):
        self._buffer = deque(maxlen=max_memory)
        
    @property
    def max_memory(self):
        return self._buffer.maxlen
    
    def append(self, curr_state, action, reward, new_state, terminated):
        self._buffer.append((curr_state, action, reward, new_state, terminated))
        
    def sample(self, k):
        return sample(self._buffer, min(k, len(self._buffer)))
    

def construct_training_batch_from_replay_buffer(replay_buffer, model, batch_size):
    experience = replay_buffer.sample(batch_size)
    current_state = np.vstack([
        curr_state[None, :] for curr_state, _, _, _, _ in experience
    ])

    new_state = np.vstack([
        new_state[None, :] for _, _, _, new_state, _ in experience
    ])

    rewards = np.array([reward for _, _, reward, _, _ in experience])

    curr_state_Q = model.predict(current_state)
    new_state_Q = model.predict(new_state)

    target = curr_state_Q

    for i, (_, action, reward, _, terminated) in enumerate(experience):
        if not terminated:
            target[i, action.value] = reward + 0.9 * np.max(new_state_Q[i])
        else:
            target[i, action.value] = reward

    return current_state, target

In [8]:
HIDDEN_UNITS = 128
N_FILTERS = 30
N_GAMES = 3000
EPSILON = 0.10
BATCH_SIZE = 50

In [13]:
model = keras.models.Sequential()
model.add(keras.layers.Lambda(lambda x: x[..., None]))
model.add(keras.layers.Conv2D(N_FILTERS, (3, 3), activation='relu'))
model.add(keras.layers.Conv2D(N_FILTERS, (3, 3), activation='relu'))
# model.add(keras.layers.Dense(HIDDEN_UNITS, activation='relu'))
# model.add(keras.layers.Dense(HIDDEN_UNITS, activation='relu'))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(len(Action)))

model.compile(optimizer=keras.optimizers.SGD(0.2), loss='mse')

In [15]:
env = Catch(grid_size=10, basket_size=3)
replay_buffer = ExperienceReplayBuffer(max_memory=250)

rewards = []
games_iterator = tqdm(range(N_GAMES))
for game in games_iterator:
    env.reset()
    total_reward = 0
    
    while not env.terminated:
        current_state = env.observe()
        Q = model.predict(current_state[None, :])[0]
        
        if np.random.rand() < EPSILON:
            action = Action(
                np.random.randint(len(Action))
            )
        else:
            action = Action(
                np.argmax(Q)
            )
        
        reward, terminated = env.act(action)
        total_reward += reward
        
        new_state = env.observe()
        
        replay_buffer.append(current_state, action, reward, new_state, env.terminated)

    X, y = construct_training_batch_from_replay_buffer(replay_buffer, model, 50)
    model.train_on_batch(X, y)
    
    rewards.append(total_reward)
    games_iterator.set_description("Average reward %.2f" % np.array(rewards).mean())

Average reward 0.76: 100%|██████████| 3000/3000 [15:58<00:00,  3.13it/s]


# Results
The video below shows a few games played by the network.

In [18]:
def plot_image_sequence(images):
    x_pixels, y_pixels = images.shape[-2:]
    fig, ax = plt.subplots(figsize=(y_pixels, x_pixels))
    im = ax.imshow(images[0])

    def animate(i):
        im.set_array(images[i])
        return (im,)

    anim = animation.FuncAnimation(fig, animate, frames=len(images), interval=500, repeat_delay=1, repeat=True)
    display(HTML(anim.to_html5_video()))
    

states = []
for game in range(10):
    env.reset()
    states.append(env.observe()[None, :])
    while not env.terminated:
        Q = model.predict(env.observe()[None, :])[0]

        action = Action(
            np.argmax(Q)
        )

        env.act(action)
        
        states.append(env.observe()[None, :])
    
images = np.vstack(states)

plot_image_sequence(images)