# Import

In [7]:
# import os
from IPython.display import clear_output
import numpy as np
import time
import random
from collections import defaultdict

# Config

In [8]:
NUM_EPISODE = 1000

N_HEIGHT = 7 # grid height
N_WIDTH = 7 # grid width 

# Environment

In [9]:
class GridEnv(object):
    def __init__(self):
        self.height = N_HEIGHT
        self.width = N_WIDTH

        self.action_dict = {"up":0, "right": 1, "down": 2, "left": 3}
        self.action_coords = np.array([[-1,0], [0,1], [1,0], [0,-1]], dtype=np.int)
        self.num_actions = len(self.action_dict.keys())

        self.state_dim = (self.height, self.width)
        self.action_dim = (self.num_actions,)
        self.state_action_dim = self.state_dim + self.action_dim

        self.obstacles = []
        self.add_obstacle(6, 5)

        self.reset()

    def add_obstacle(self, h, w):
        self.obstacles.append([h, w])

    def get_valid_actions(self):
        actions = []
        h = self.current_state[0]
        w = self.current_state[1]
        if (h > 0): actions.append(self.action_dict["up"])
        if (h < self.height-1): actions.append(self.action_dict["down"])
        if (w > 0): actions.append(self.action_dict["left"])
        if (w < self.width-1): actions.append(self.action_dict["right"])
        actions = np.array(actions, dtype=np.int)
        return actions

    def reset(self):
        self.current_state = np.array([0, 0], dtype=np.int)
        return self.current_state

    def step(self, action):
        self.current_state = np.add(self.current_state, self.action_coords[action])
        if np.array_equal(self.current_state, [self.height-1, self.width-1]):
            reward = 100
            done = True
        elif list(self.current_state) in self.obstacles:
            reward = -10
            done = False
        else:
            reward = -1
            done = False
        return self.current_state, reward, done

# Agent

In [10]:
class Agent(object):
    def __init__(self, env):
        self.env = env

        self.learning_rate = 0.1
        self.discount_factor = 0.99 
        self.epsilon = 0.9

        self.q_table = np.zeros(env.state_action_dim, dtype=np.float)

    def update_table(self, state, action, reward, next_state):
        q_prev = self.q_table[state[0], state[1], action]
        q_target = reward + self.discount_factor * max(self.q_table[next_state[0], next_state[1]])
        self.q_table[state[0], state[1], action] += self.learning_rate * (q_target - q_prev)

    def get_action(self, state, greedy=False):
        if greedy:
            epsilon = 0
        else:
            epsilon = self.epsilon

        valid_actions = self.env.get_valid_actions()
        if random.random() < epsilon:
            action = random.choice(valid_actions)
        else:
            Q_s = self.q_table[state[0], state[1], valid_actions]
            action = random.choice(valid_actions[np.flatnonzero(Q_s == np.max(Q_s))])
        return action

# Train

In [12]:
env = GridEnv()
agent = Agent(env)

for n_episode in range(NUM_EPISODE):
    state = env.reset()

    while True:
        action = agent.get_action(state)
        next_state, reward, done = env.step(action)

        agent.update_table(state, action, reward, next_state)
        state = next_state

        if done:
            break

    debug_str = ""
    for h in range(env.height):
        for w in range(env.width):
            debug_str += '****************'
        debug_str += "*\n"
        for w in range(env.width):
            debug_str += '# up:' + str('%.2f ' % (agent.q_table[h, w, 0])).rjust(11)
        debug_str += "#\n"
        for w in range(env.width):
            debug_str += '# right:' + str('%.2f ' % (agent.q_table[h, w, 1])).rjust(8)
        debug_str += "#\n"
        for w in range(env.width):
            debug_str += '# down:' + str('%.2f ' % (agent.q_table[h, w, 2])).rjust(9)
        debug_str += "#\n"
        for w in range(env.width):
            debug_str += '# left:' + str('%.2f ' % (agent.q_table[h, w, 3])).rjust(9)
        debug_str += "#\n"
    for c in range(env.width):
        debug_str += '****************'
    debug_str += "*\n"
    debug_str += "num_episode=%d" % n_episode

#     os.system("clear")
    clear_output()
    print(debug_str)

# save table
np.save("q_table.npy", agent.q_table)



*****************************************************************************************************************
# up:      0.00 # up:      0.00 # up:      0.00 # up:      0.00 # up:      0.00 # up:      0.00 # up:      0.00 #
# right:  79.07 # right:  80.88 # right:  82.70 # right:  84.55 # right:  86.41 # right:  88.30 # right:   0.00 #
# down:   79.07 # down:   80.88 # down:   82.70 # down:   84.55 # down:   86.41 # down:   88.30 # down:   90.20 #
# left:    0.00 # left:   77.28 # left:   79.07 # left:   80.88 # left:   82.70 # left:   84.55 # left:   86.41 #
*****************************************************************************************************************
# up:     77.28 # up:     79.07 # up:     80.88 # up:     82.70 # up:     84.55 # up:     86.41 # up:     88.30 #
# right:  80.88 # right:  82.70 # right:  84.55 # right:  86.41 # right:  88.30 # right:  90.20 # right:   0.00 #
# down:   80.88 # down:   82.70 # down:   84.55 # down:   86.41 # down:   88.30 # down: 