In [2]:
import numpy as np

class GridWorld():
    def __init__(self):
        self.agent_pos = {'y' : 0, 'x' : 0}
        self.goal_pos = {'y' : 3, 'x' : 3}
        self.y_min, self.x_min, self.y_max, self.x_max = 0, 0, 3, 3

        self.state = np.zeros([4, 4]) # grid 생성
        self.state[self.agent_pos['y'], self.agent_pos['x']] = 1

        self.state_space = []
        for y in range(4):
            for x in range(4):
                state = np.zeros([4, 4])
                state[y, x] = 1
                self.state_space.append(state)

        self.action_space = [0, 1, 2, 3] # Up, Down, Left, Right
        self.gamma = 0.9

    def reset(self):
        self.agent_pos = {'y' : 0, 'x' : 0}
        self.state = np.zeros([4, 4])
        self.state[self.agent_pos['y'], self.agent_pos['x']] = 1

        return self.state

    def step(self, action):
        if (action == 0):
            self.agent_pos['y'] = max(self.agent_pos['y'] - 1, self.y_min)
        elif (action == 1):
            self.agent_pos['y'] = min(self.agent_pos['y'] + 1, self.y_max)
        elif (action == 2):
            self.agent_pos['x'] = max(self.agent_pos['x'] - 1, self.x_min)
        elif (action == 3):
            self.agent_pos['x'] = min(self.agent_pos['x'] + 1, self.x_max)
        else:
            assert False, "Invalid action value"

        prev_state = self.state
        self.state = np.zeros([4, 4])
        self.state[self.agent_pos['y'], self.agent_pos['x']] = 1

        done = False
        if (self.agent_pos == self.goal_pos):
            done = True

        reward = self.reward(prev_state, action, self.state)

        return reward, self.state, done

    def reward(self, s, a, s_next):
        reward = 0
        y, x = np.where(s == 1)
        y_next, x_next = np.where(s_next == 1)
        if ((y_next == self.goal_pos['y'] and x_next == self.goal_pos['x']) and (y != self.goal_pos['y'] or x != self.goal_pos['x'])):
            reward = 10

        return reward

    def get_state_index(self, state_space, state):
        for i_s, s in enumerate(state_space):
            if (s == state).all():
                return i_s
        assert False, "Couldn't find the state from the state space"

    def exploring_start(self):
        while (True):
            y_random = np.random.randint(4)
            x_random = np.random.randint(4)
            self.agent_pos = {'y' : y_random, 'x' : x_random}
            if (self.agent_pos != self.goal_pos):
                break
        
        self.state = np.zeros([4, 4])
        self.state[self.agent_pos['y'], self.agent_pos['x']] = 1
        
        return self.state

In [3]:
def calc_return(gamma, rewards):
    n = len(rewards)
    rewards = np.array(rewards)
    gammas = gamma * np.ones([n])
    powers = np.arange(n)

    power_of_gammas = np.power(gammas, powers)
    discounted_rewards = rewards * power_of_gammas
    g = np.sum(discounted_rewards)

    return g

In [4]:
def mc_control(env, policy):
    action_value_matrix = np.zeros([len(env.state_space), len(env.action_space)]) # 16 x 4
    returns = [[{'n' : 0, 'avg' : 0} for a in env.action_space] for s in env.state_space]
    
    for loop_count in range(5000):
        episode = {
            'states' : list(),
            'actions' : list(),
            'rewards' : list(),
        }
        done = False
        step_count = 0
        s = env.exploring_start()   # set start point

        # Generate an episode
        while not done:
            s_inx = env.get_state_index(env.state_space, s)
            pi_s = policy[s_inx]
            a = np.random.choice(env.action_space, p = pi_s)
            r, s_next, done = env.step(a)
            print(s)
            print(s_inx)
            print(pi_s)
            print(a)

            episode['states'].append(s)
            episode['actions'].append(a)
            episode['rewards'].append(r)

            step_count += 1
            s = s_next

            is_dead_lock = False
            if (step_count > 1000):
                is_dead_lock = True
                break
        
        if (is_dead_lock):
            continue

        episode['states'].append(s)

        # state evaluation
        for t in range(step_count):
            s_t = episode['states'][t] # 4x4
            a_t = episode['actions'][t] # action : n
            i_s_t = env.get_state_index(env.state_space, s_t)
            i_a_t = env.action_space.index(a_t) # a_t = i_a_t
            g_t = calc_return(env.gamma, episode['rewards'][t:]) # g_t : n

            n_prev, avg_prev = returns[i_s_t][i_a_t]['n'], returns[i_s_t][i_a_t]['avg']
            returns[i_s_t][i_a_t]['avg'] = (avg_prev * n_prev + g_t) / (n_prev + 1)
            returns[i_s_t][i_a_t]['avg'] = returns[i_s_t][i_a_t]['avg'] + 1.0 * (g_t - returns[i_s_t][i_a_t]['avg']) / (n_prev + 1)

            returns[i_s_t][i_a_t]['n'] = n_prev + 1
            action_value_matrix[i_s_t][i_a_t] = returns[i_s_t][i_a_t]['avg']

        # update policy
        for t in range(step_count):
            s_t = episode['states'][t]
            i_s_t = env.get_state_index(env.state_space, s_t)

            a_max = action_value_matrix[i_s_t].argmax()
            policy[i_s_t][:] = 0
            policy[i_s_t][a_max] = 1

        if ((loop_count + 1) % 100 == 0):
            print(f"[{loop_count}] action value matrix : \n{action_value_matrix}")

    return policy, action_value_matrix

In [6]:
env = GridWorld()
policy = []
for i_s, s in enumerate(env.state_space):
    pi = np.array([0.25, 0.25, 0.25, 0.25])
    policy.append(pi)
policy = np.array(policy)

policy, action_value_matrix = mc_control(env, policy)

value_vector = np.sum(policy * action_value_matrix, axis = -1)
value_table = value_vector.reshape(4, 4)

[[1. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
0
[0.25 0.25 0.25 0.25]
2
[[1. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
0
[0.25 0.25 0.25 0.25]
3
[[0. 1. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
1
[0.25 0.25 0.25 0.25]
2
[[1. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
0
[0.25 0.25 0.25 0.25]
0
[[1. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
0
[0.25 0.25 0.25 0.25]
1
[[0. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
4
[0.25 0.25 0.25 0.25]
0
[[1. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
0
[0.25 0.25 0.25 0.25]
0
[[1. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
0
[0.25 0.25 0.25 0.25]
2
[[1. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
0
[0.25 0.25 0.25 0.25]
1
[[0. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
4
[0.25 0.25 0.25 0.25]
1
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 0. 0.]]
8
[0.25 0.25 0.25 0.25]
3
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 1. 0. 0.

In [8]:
policy

array([[0.  , 1.  , 0.  , 0.  ],
       [0.  , 0.  , 1.  , 0.  ],
       [0.  , 1.  , 0.  , 0.  ],
       [0.  , 1.  , 0.  , 0.  ],
       [0.  , 1.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 1.  ],
       [0.  , 1.  , 0.  , 0.  ],
       [0.  , 1.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 1.  ],
       [0.  , 0.  , 0.  , 1.  ],
       [0.  , 1.  , 0.  , 0.  ],
       [0.  , 0.  , 1.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 1.  ],
       [0.  , 0.  , 0.  , 1.  ],
       [0.25, 0.25, 0.25, 0.25]])

In [19]:
print("value_table : \n", value_table)
print("action_value_matrix : \n", action_value_matrix)

value_table : 
 [[ 5.9049      5.31441     7.29        6.561     ]
 [ 6.56099793  7.29        8.09999972  8.99999976]
 [ 5.90487607  8.1         9.         10.        ]
 [ 6.56099286  7.28999621  8.09999766  0.        ]]
action_value_matrix : 
 [[ 4.3046721   5.9049      0.          4.782969  ]
 [ 0.          0.          5.31441     0.        ]
 [ 0.          7.29        0.          0.        ]
 [ 0.          0.          6.561       5.9049    ]
 [ 0.          0.          5.50484303  6.56099793]
 [ 0.          0.          5.31441     7.29      ]
 [ 0.          5.31441     0.          8.09999972]
 [ 0.          8.99999976  0.          0.        ]
 [ 0.          5.90487607  0.          0.        ]
 [ 0.          0.          0.          8.1       ]
 [ 5.62441725  0.          0.          9.        ]
 [ 0.         10.          8.1         0.        ]
 [ 0.          0.          0.          6.56099286]
 [ 0.          0.          0.          7.28999621]
 [ 8.09999766  0.          0.          0.