In [212]:
from gym import Env
from gym.spaces import Discrete, Box
import numpy as np
import random
%matplotlib notebook
import matplotlib.pyplot as plt
from tqdm import tqdm

from env_control import Env_control

## Create Environment class

In [213]:
class SnakeEnv(Env):
    def __init__(self, grid_size = [20,20]):
        self.grid_size = grid_size.copy()
        self.viewer = None
        self.action_space = Discrete(3)
        
    def step(self, action):
        self.last_state, reward, done, info = self.env_control.step(action)
        return self.last_state, reward, done, info

    def render(self, mode='human', close=False, frame_speed=.1):
        if self.viewer is None:
            self.fig = plt.figure(figsize=(10,10))
            self.viewer = self.fig.add_subplot(111)
            plt.ion()
            self.fig.show()
        else:
            self.viewer.clear()
            self.viewer.imshow(self.env_control.grid.grid[:,:,0])
            plt.pause(frame_speed)
        self.fig.canvas.draw()
    
    def reset(self):
        self.env_control = Env_control(self.grid_size)
        self.last_state = self.env_control.reset()
        return self.last_state
    
    def seed(self, x):
        pass

## Create Agent class
TD(0) + QLearning

In [214]:
class QLearnAgent():
    def __init__(self, num_actions, num_states, epsilon=0.02, step_size=0.5, discount=1):
        self.num_actions = num_actions
        self.num_states = num_states
        self.epsilon = epsilon
        self.step_size = step_size
        self.discount = discount
        self.state_dict = {}
        
        self.q = np.zeros((self.num_states, self.num_actions))
        
    def agent_start(self, state):
        #Add state to dict if new + get index
        if state not in self.state_dict.keys():
            self.state_dict[state] = len(self.state_dict)
        
        state_idx = self.state_dict[state]
        
        # chose action based on e-greedy
        current_q = self.q[state_idx,:]
        if np.random.rand() < self.epsilon:
            action = np.random.randint(self.num_actions)
        else:
            action = self.argmax(current_q)
        
        self.prev_state_idx = self.state_dict[state]
        self.prev_action = action
        return action

    def agent_step(self, reward, state):
        #Add state to dict if new + get index
        if state not in self.state_dict.keys():
            self.state_dict[state] = len(self.state_dict)
        
        state_idx = self.state_dict[state]
        
        # chose action based on e-greedy
        current_q = self.q[state_idx,:]
        if np.random.rand() < self.epsilon:
            action = np.random.randint(self.num_actions)
        else:
            action = self.argmax(current_q)
            
        # perform update
        self.q[self.prev_state_idx][self.prev_action] = (1-self.step_size) * self.q[self.prev_state_idx][self.prev_action] + \
                                                     self.step_size * (reward + self.discount * self.q[state_idx][self.argmax(current_q)])

        self.prev_state_idx = self.state_dict[state]
        self.prev_action = action
        return action
    
    def agent_end(self, reward):
        # only perform update
        self.q[self.prev_state_idx][self.prev_action] = (1-self.step_size) * self.q[self.prev_state_idx][self.prev_action] + \
                                                    self.step_size * reward

    def agent_cleanup(self):
        self.last_state_idx = None
        self.last_state = None
        self.last_action = None
        
    def agent_message(self, message):
        pass

    def argmax(self, q_values):
        # returns index of max value from list, breaks ties randomly
        top = float('-inf')
        ties = []
        for i in range(len(q_values)):
            if q_values[i] > top:
                top = q_values[i]
                ties = []
            if q_values[i] == top:
                ties.append(i)
        return np.random.choice(ties)

## Train agent

In [232]:
# STATE DESCRIPTION (moving direction: [0, 0, 0, 0]; UP, DOWN, LEFT, RIGHT  | 4 possibilities
#                    reward direction: [0, 0, 0, 0]; UP, DOWN, LEFT, RIGHT  | 4 + 2 + 2 = 8 possibilities
#                    danger: [0, 0, 0]; LEFT, STRAIGHT, RIGHT)              | 8 possibilities
# overall a 12-element tuple eg (0,0,1,0, 0,1,1,0, 0,0,1)

In [257]:
num_episodes = 100000

env = SnakeEnv([20,20])
agent = QLearnAgent(num_actions=3, num_states=512, epsilon=0.2, step_size=0.1, discount=0.95)
scores = list()
all_rewards = list()

for i in tqdm(range(num_episodes)):
    episode = []
    sum_rewards = []
    state = env.reset()
    
    # First step
    action = agent.agent_start(state)
    
    # All other steps in episode
    while True:
        action = agent.agent_step(reward, state)
        next_state, reward, done, info = env.step(action)
        state = next_state
        sum_rewards.append(reward)

        if done: 
            agent.agent_end(reward) #update q values last time
            break
    scores.append(len(env.env_control.snake.body))
    all_rewards.append(np.sum(sum_rewards))

100%|██████████| 100000/100000 [04:53<00:00, 340.91it/s]


In [259]:
plt.figure(figsize=(5,5))
plt.plot(all_rewards)
plt.xlabel('episode', )
plt.ylabel('score')
plt.show()

<IPython.core.display.Javascript object>

## Visualise the agent's performance

In [263]:
state = env.reset()
env.viewer = None

# First step
action = agent.agent_start(state)
    
# All other steps in episode
while True:
    action = agent.agent_step(reward, state)
    next_state, reward, done, info = env.step(action)
    episode.append((state, action, reward))
    state = next_state
    env.render()
    if done: 
        agent.agent_end(reward) #update q values last time
        break

<IPython.core.display.Javascript object>