In [3]:
import numpy as np

In [4]:
class Env:
    def __init__(self):
        self.agent_pos = {'y' : 0, 'x' : 0}
        self.home_pos = {'y' : 3, 'x' : 3}

        self.workplace_pos = {'y' : 0, 'x' : 3}
        self.park_area = {'y' : [1, 2], 'x' : [1, 2]}

        self.y_min, self.x_min, self.y_max, self.x_max = 0, 0, 3, 3

        # set up state
        self.state = self.set_state(self.agent_pos['y'], self.agent_pos['x'])

        self.state_space = []
        for y in range(4):
            for x in range(4):
                state = self.set_state(y, x)
                self.state_space.append(state)

        self.action_space = [0, 1, 2, 3] # Up, Down, Left, Right

    def set_state(self, y_agent, x_agent):
        state = np.zeros([4, 4])
        state[self.workplace_pos['y'], self.workplace_pos['x']] = 2
        state[self.home_pos['y'], self.home_pos['x']] = 3
        state[self.park_area['y'], self.park_area['x']] = -1
        state[y_agent, x_agent] = 1

        return state

    def reset(self):
        self.agent_pos = {'y' : 0, 'x' : 0}
        self.state = self.set_state(self.agent_pos['y'], self.agent_pos['x'])
        
        return self.state

    def step(self, action):
        # Update environmental variables
        is_random_action = np.random.choice([0, 1], p = [0.7, 0.3])

        if (is_random_action):
            random_action_set = list(self.action_space)
            random_action_set.remove(action)
            action = np.random.choice(random_action_set)

        if (action == 0):
            self.agent_pos['y'] = max(self.agent_pos['y'] - 1, self.y_min)
        elif (action == 1):
            self.agent_pos['y'] = min(self.agent_pos['y'] + 1, self.y_max)
        elif (action == 2):
            self.agent_pos['x'] = max(self.agent_pos['x'] - 1, self.x_min)
        elif (action == 3):
            self.agent_pos['x'] = min(self.agent_pos['x'] + 1, self.x_max)
        else:
            assert False, "Invalid action value"

        prev_state = self.state
        self.state = self.set_state(self.agent_pos['y'], self.agent_pos['x'])

        done = False
        if (self.agent_pos == self.workplace_pos) or (self.agent_pos == self.home_pos):
            done = True

        reward = self.reward(prev_state, action, self.state)

        return reward, self.state, done

    def reward(self, s, a, s_next):
        reward = -0.5
        y, x = np.where(s == 1)
        y_next, x_next = np.where(s_next == 1)

        was_at_workplace = (y == self.workplace_pos['y'] and x == self.workplace_pos['x'])
        is_at_workplace = (y_next == self.workplace_pos['y'] and x_next == self.workplace_pos['x'])

        was_at_home = (y == self.home_pos['y'] and x == self.home_pos['x'])
        is_at_home = (y_next == self.home_pos['y'] and x_next == self.home_pos['x'])

        is_in_park = (y_next in self.park_area['y'] and x_next in self.park_area['x'])

        if was_at_workplace and is_at_workplace:
            reward = 0
        elif (not was_at_workplace and is_at_workplace):
            reward = 5
        
        if (was_at_home and is_at_home):
            reward = 0
        
        if (not was_at_home and is_at_home):
            reward = 10

        if (is_in_park):
            reward = -1.0

        return reward

In [25]:
gamma = 0.95
k_alpha = 1e-3
k_eps = 5e-4

def get_state_index(state_space, state):
    for i_s, s in enumerate(state_space):
        if (s == state).all():
            return i_s

    assert False, "Couldn't find the state"

def sarsa(env):
    action_value_matrix = np.zeros([len(env.state_space), len(env.action_space)])

    def sample_action(eps, action_value):
        a_max = action_value.argmax()
        pi = np.zeros([len(env.action_space)])
        pi[:] = eps / len(env.action_space) # pi = [0, 0, 0, 0] -> eps=0.2 // pi = [0.05, 0.05, 0.05, 0.05]
        pi[a_max] = pi[a_max] + 1 - eps
        a = np.random.choice(env.action_space, p = pi)
        return a

    def get_eps(total_step_count):
        return 1 / (1 + k_eps * total_step_count)

    # Repeat sarsa loop
    total_step_count = 0
    for loop_count in range(10000):
        done = False
        step_count = 0

        s = env.reset()
        i_s = get_state_index(env.state_space, s)
        action_value = action_value_matrix[i_s]
        eps = get_eps(total_step_count)
        a = sample_action(eps, action_value)

        # Generate on episode
        while not done:
            r, s_next, done = env.step(a)
            i_s_next = get_state_index(env.state_space, s_next)
            action_value_next = action_value_matrix[i_s_next]
            eps = get_eps(total_step_count)
            a_next = sample_action(eps, action_value_next)

            alpha = 1 / (1 + k_alpha * loop_count)
            td = r + gamma * action_value_matrix[i_s_next][a_next] - action_value_matrix[i_s][a]
            action_value_matrix[i_s][a] = action_value_matrix[i_s][a] + alpha * td

            if done:
                action_value_matrix[i_s_next] = 0

            step_count += 1
            total_step_count += 1

            s = s_next
            i_s = i_s_next
            a = a_next
        
        if (loop_count + 1) % 100 == 0:
            print(f"[{loop_count}] action_value_matrix : \n{action_value_matrix}\n" 
            + f"eps : {get_eps(total_step_count):.4f}"
            + f"\talpha : {alpha:.4f}")

    policy = np.zeros([len(env.state_space), len(env.action_space)])
    state_indexes = np.arange(len(env.state_space))
    argmax_actions = action_value_matrix.argmax(axis = -1)
    policy[state_indexes, argmax_actions] = 1

    return action_value_matrix, policy

In [28]:
np.set_printoptions(formatter = {'float' :' {:0.3f}'.format})

env = Env()
action_value_matrix, policy = sarsa(env)

argmax_actions = action_value_matrix.argmax(axis = -1)
value_vector = np.sum(policy * action_value_matrix, axis = -1)

value_table = value_vector.reshape(4, 4)
argmax_actions_table = argmax_actions.reshape(4, 4)
print(f"value_table : \n{value_table}\n"
    + f"argmax_actions : \n{argmax_actions_table}")

[99] action_value_matrix : 
[[ -4.946  -3.448  -2.109  1.819]
 [ -6.030  -6.206  -5.366  3.668]
 [ -2.642  2.981  -3.214  4.997]
 [ 0.000  0.000  0.000  0.000]
 [ -4.575  -3.951  -0.371  -6.505]
 [ -2.662  -5.175  -4.146  -3.399]
 [ -0.097  0.808  -1.937  4.182]
 [ 5.000  5.117  4.687  0.599]
 [ -4.674  -3.814  -4.329  -2.770]
 [ -5.713  -5.870  -3.421  -4.435]
 [ 1.759  -5.444  -5.309  3.905]
 [ 6.710  3.418  5.682  6.344]
 [ -0.834  -6.911  -4.795  -6.690]
 [ 1.136  -6.856  -6.714  -4.795]
 [ -2.807  -3.382  1.979  9.385]
 [ 0.000  0.000  0.000  0.000]]
eps : 0.4949	alpha : 0.9099
[199] action_value_matrix : 
[[ -5.124  -3.440  -4.753  -0.914]
 [ -4.393  -2.113  -4.465  2.908]
 [ -3.889  -4.915  -3.363  4.881]
 [ 0.000  0.000  0.000  0.000]
 [ -5.205  -4.521  -5.225  -2.889]
 [ -0.659  -3.511  -0.327  4.103]
 [ 0.364  -0.159  -7.197  4.084]
 [ 5.065  0.726  1.271  0.723]
 [ -5.432  -6.654  -5.670  1.518]
 [ 0.309  -1.937  -0.769  4.598]
 [ 4.579  8.556  7.059  8.916]
 [ 2.047  9.998 

In [29]:
policy

array([[ 0.000,  0.000,  0.000,  1.000],
       [ 0.000,  0.000,  0.000,  1.000],
       [ 0.000,  0.000,  0.000,  1.000],
       [ 1.000,  0.000,  0.000,  0.000],
       [ 0.000,  0.000,  0.000,  1.000],
       [ 0.000,  0.000,  0.000,  1.000],
       [ 0.000,  0.000,  0.000,  1.000],
       [ 0.000,  1.000,  0.000,  0.000],
       [ 0.000,  1.000,  0.000,  0.000],
       [ 0.000,  1.000,  0.000,  0.000],
       [ 0.000,  0.000,  0.000,  1.000],
       [ 0.000,  1.000,  0.000,  0.000],
       [ 0.000,  0.000,  0.000,  1.000],
       [ 0.000,  0.000,  0.000,  1.000],
       [ 0.000,  0.000,  0.000,  1.000],
       [ 1.000,  0.000,  0.000,  0.000]])