In [53]:
import numpy as np

class GridWorld():
    def __init__(self):
        self.agent_pos = {'y' : 0, 'x' : 0} # start point
        self.goal_pos = {'y' : 6, 'x' : 4} # end point
        self.home_pos = {'y' : 2, 'x' : 4}

        self.y_min, self.x_min, self.y_max, self.x_max = 0, 0, 6, 4

        self.state = np.zeros([7, 5]) # 7x5 grid 생성
        self.state[self.agent_pos['y'], self.agent_pos['x']] = 1

        self.wall_pos = [{'y' : 1,'x' : 1}, {'y' : 1, 'x' : 2}, {'y' : 2, 'x' : 1}, {'y' : 2, 'x' : 2}]
        self.trap_pos = [{'y' : 4, 'x' : 1}, {'y' : 4, 'x' : 2}]
        self.ice_pos = [{'y' : 6, 'x' : 1}, {'y' : 6, 'x' : 2}, {'y' : 6, 'x' : 3}]
        self.ice_ns = {'y' : 6, 'x' : 0} # ice의 next state

        self.state_space = []
        for y in range(7):
            for x in range(5):
                state = np.zeros([7, 5])
                state[y, x] = 1
                self.state_space.append(state)

        self.action_space = [0, 1, 2, 3] # Up, Down, Left, Right
        self.gamma = 0.9
        self.epsilon = 0.3

    def reset(self):
        self.agent_pos = {'y' : 0, 'x' : 0}
        self.state = np.zeros([7, 5])
        self.state[self.agent_pos['y'], self.agent_pos['x']] = 1

        return self.state

    def step(self, action):
        prev_state = self.state
        prev_pos = self.agent_pos

        if (action == 0):
            self.agent_pos['y'] = max(self.agent_pos['y'] - 1, self.y_min)
        elif (action == 1):
            self.agent_pos['y'] = min(self.agent_pos['y'] + 1, self.y_max)
        elif (action == 2):
            self.agent_pos['x'] = max(self.agent_pos['x'] - 1, self.x_min)
        elif (action == 3):
            self.agent_pos['x'] = min(self.agent_pos['x'] + 1, self.x_max)
        else:
            assert False, "Invalid action value"

        if any(self.agent_pos['y'] == wall['y'] and self.agent_pos['x'] == wall['x'] for wall in self.wall_pos):
            self.agent_pos = prev_pos 

        if any(self.agent_pos['y'] == ice['y'] and self.agent_pos['x'] == ice['x'] for ice in self.ice_pos):
            self.agent_pos = self.ice_ns
        
        self.state = np.zeros([7, 5])
        self.state[self.agent_pos['y'], self.agent_pos['x']] = 1

        done = False
        if (self.agent_pos == self.goal_pos):
            done = True

        reward = self.reward(prev_state, action, self.state)

        return reward, self.state, done

    def reward(self, s, a, s_next):
        reward = -0.5
        y, x = np.where(s == 1)
        y_next, x_next = np.where(s_next == 1)
        if ((y_next == self.goal_pos['y'] and x_next == self.goal_pos['x']) and (y != self.goal_pos['y'] or x != self.goal_pos['x'])):
            reward = 10
        elif ((y_next == self.home_pos['y'] and x_next == self.home_pos['x']) and (y != self.home_pos['y'] or x != self.home_pos['x'])):
            reward = 5
        elif any(y_next == trap['y'] and x_next == trap['x'] for trap in self.trap_pos):
            reward = -1.0

        return reward

    def get_state_index(self, state_space, state):
        for i_s, s in enumerate(state_space):
            if (s == state).all():
                return i_s
        assert False, "Couldn't find the state from the state space"

    def exploring_start(self):
        while (True):
            y_random = np.random.randint(7)
            x_random = np.random.randint(5)
            self.agent_pos = {'y' : y_random, 'x' : x_random}

            # 시작 점이 끝 점이면 한번 더 랜덤
            if (self.agent_pos != self.goal_pos):
                break
        
        self.state = np.zeros([7, 5])
        self.state[self.agent_pos['y'], self.agent_pos['x']] = 1
        
        return self.state

In [54]:
def calc_return(gamma, rewards):
    n = len(rewards)
    rewards = np.array(rewards)
    gammas = gamma * np.ones([n])
    powers = np.arange(n)
    power_of_gammas = np.power(gammas, powers)
    discounted_rewards = rewards * power_of_gammas
    g = np.sum(discounted_rewards)
    return g

In [55]:
def mc_control(env, policy):

    epsilon = 0.3

    action_value_matrix = np.zeros([len(env.state_space), len(env.action_space)]) # 35 x 4
    returns = [[{'n' : 0, 'avg' : 0} for a in env.action_space] for s in env.state_space]
    
    for loop_count in range(5000):
        episode = {
            'states' : list(),
            'actions' : list(),
            'rewards' : list(),
        }
        done = False
        step_count = 0

        s = env.exploring_start()   # set start point

        # Generate an episode
        while not done:
            s_inx = env.get_state_index(env.state_space, s)
            # pi_s = policy[s_inx]
            a = np.random.choice(env.action_space, p = epsilon_policy(policy, s_inx, epsilon))
            r, s_next, done = env.step(a)

            episode['states'].append(s)
            episode['actions'].append(a)
            episode['rewards'].append(r)

            step_count += 1
            s = s_next

            is_dead_lock = False
            if (step_count > 1000):
                is_dead_lock = True
                break
        
        if (is_dead_lock):
            continue

        episode['states'].append(s)

        # state evaluation
        for t in range(step_count):
            s_t = episode['states'][t] # 7x5
            a_t = episode['actions'][t] # action : n
            i_s_t = env.get_state_index(env.state_space, s_t)
            i_a_t = env.action_space.index(a_t) # a_t = i_a_t
            g_t = calc_return(env.gamma, episode['rewards'][t:]) # g_t : n

            # n_prev, avg_prev = returns[i_s_t][i_a_t]['n'], returns[i_s_t][i_a_t]['avg']
            # returns[i_s_t][i_a_t]['avg'] = (avg_prev * n_prev + g_t) / (n_prev + 1)
            returns[i_s_t][i_a_t]['avg'] = returns[i_s_t][i_a_t]['avg'] + 1.0 * (g_t - returns[i_s_t][i_a_t]['avg'])

            # returns[i_s_t][i_a_t]['n'] = n_prev + 1
            action_value_matrix[i_s_t][i_a_t] = returns[i_s_t][i_a_t]['avg']

        # update policy
        for t in range(step_count):
            s_t = episode['states'][t]
            i_s_t = env.get_state_index(env.state_space, s_t)

            a_max = action_value_matrix[i_s_t].argmax()
            policy[i_s_t][:] = 0
            policy[i_s_t][a_max] = 1

        if ((loop_count + 1) % 100 == 0):
            print(f"[{loop_count}] action value matrix : \n{action_value_matrix}")

    return policy, action_value_matrix

In [56]:
def epsilon_policy(policy, state_index, epsilon):
    pi_s = policy[state_index].copy()
    pi_s = pi_s * (1 - epsilon)
    pi_s += epsilon / len(env.action_space)
    return pi_s

In [57]:
env = GridWorld()
policy = []
for i_s, s in enumerate(env.state_space):
    pi = np.array([0.25, 0.25, 0.25, 0.25])
    policy.append(pi)
policy = np.array(policy)

policy, action_value_matrix = mc_control(env, policy)

value_vector = np.sum(policy * action_value_matrix, axis = -1)
value_table = value_vector.reshape(7, 5)

[99] action value matrix : 
[[-4.42771936e+00 -2.50777503e-01 -3.17635018e+00 -3.92315302e+00]
 [-3.53527876e+00 -4.99550464e+00 -3.80350335e+00 -4.96149298e+00]
 [-3.44663755e+00  2.60711500e+00 -2.58245596e+00 -4.80045808e+00]
 [-4.21497855e+00  5.45586835e-01  1.84640350e+00 -2.97372242e+00]
 [-4.99999911e+00 -4.55764826e+00 -8.97184850e-03 -4.26539633e+00]
 [-7.25699753e-01 -2.74858047e+00  7.62386142e-01  2.76913885e-01]
 [-4.79184487e+00  8.63237650e-01 -1.15312978e+00 -4.99500516e+00]
 [-1.71056664e+00  3.45235000e+00 -1.93763175e+00  8.11307335e-01]
 [ 1.16176315e+00  1.16176315e+00 -4.89395524e+00 -1.72538643e+00]
 [-5.08074664e-01  8.22474416e-01 -4.32578610e+00 -9.57267197e-01]
 [-4.24728396e+00 -2.22046972e+00 -2.49842275e+00  1.45700815e+00]
 [-2.26224856e+00 -6.02881611e+00 -4.32255556e+00  3.45235000e+00]
 [-9.85431566e-01  3.40735000e+00 -4.39030001e+00 -4.59997583e+00]
 [ 5.31846982e-01  1.84640350e+00  3.45235000e+00  1.72262321e+01]
 [-4.64169509e+00  5.93500000e+00 

In [58]:
for i in range(len(policy)):
    print(f"{i+1} {policy[i]}")

1 [0. 1. 0. 0.]
2 [0. 1. 0. 0.]
3 [0. 0. 1. 0.]
4 [0. 0. 0. 1.]
5 [0. 1. 0. 0.]
6 [0. 1. 0. 0.]
7 [0. 1. 0. 0.]
8 [0. 0. 1. 0.]
9 [0. 1. 0. 0.]
10 [0. 1. 0. 0.]
11 [0. 0. 0. 1.]
12 [0. 1. 0. 0.]
13 [0. 0. 0. 1.]
14 [1. 0. 0. 0.]
15 [0. 0. 0. 1.]
16 [0. 1. 0. 0.]
17 [0. 0. 0. 1.]
18 [0. 0. 0. 1.]
19 [1. 0. 0. 0.]
20 [0. 1. 0. 0.]
21 [0. 0. 1. 0.]
22 [1. 0. 0. 0.]
23 [0. 1. 0. 0.]
24 [0. 1. 0. 0.]
25 [1. 0. 0. 0.]
26 [0. 1. 0. 0.]
27 [0. 1. 0. 0.]
28 [0. 1. 0. 0.]
29 [0. 1. 0. 0.]
30 [0. 1. 0. 0.]
31 [0. 0. 0. 1.]
32 [0. 1. 0. 0.]
33 [0. 1. 0. 0.]
34 [0. 1. 0. 0.]
35 [0.25 0.25 0.25 0.25]


In [59]:
value_table

array([[ 1.12895815,  1.8792085 ,  1.19128765,  0.74999066,  7.1244535 ],
       [ 2.1744535 ,  2.971615  , -1.72538643,  8.09094258,  8.471615  ],
       [ 2.971615  ,  3.85735   ,  4.18644305,  6.78184832,  4.8415    ],
       [ 4.8415    ,  4.8415    ,  5.935     ,  6.40700815,  7.15      ],
       [ 5.935     ,  5.485     ,  7.15      ,  8.5       ,  5.935     ],
       [ 8.5       , 10.        , 10.        , 10.        , 10.        ],
       [10.        , 10.        , 10.        , 10.        ,  0.        ]])