In [5]:
import numpy as np

class GridWorld():
    def __init__(self):
        self.agent_pos = {'y' : 0, 'x' : 0}
        self.goal_pos = {'y' : 3, 'x' : 3}
        self.y_min, self.x_min, self.y_max, self.x_max = 0, 0, 3, 3

        self.state = np.zeros([4, 4]) # grid 생성
        self.state[self.agent_pos['y'], self.agent_pos['x']] = 1

        self.state_space = []
        for y in range(4):
            for x in range(4):
                state = np.zeros([4, 4])
                state[y, x] = 1
                self.state_space.append(state)

        self.action_space = [0, 1, 2, 3] # Up, Down, Left, Right
        self.gamma = 0.9

    def reset(self):
        self.agent_pos = {'y' : 0, 'x' : 0}
        self.state = np.zeros([4, 4])
        self.state[self.agent_pos['y'], self.agent_pos['x']] = 1

        return self.state

    def step(self, action):
        if (action == 0):
            self.agent_pos['y'] = max(self.agent_pos['y'] - 1, self.y_min)
        elif (action == 1):
            self.agent_pos['y'] = min(self.agent_pos['y'] + 1, self.y_max)
        elif (action == 2):
            self.agent_pos['x'] = max(self.agent_pos['x'] - 1, self.x_min)
        elif (action == 3):
            self.agent_pos['x'] = min(self.agent_pos['x'] + 1, self.x_max)
        else:
            assert False, "Invalid action value"

        prev_state = self.state
        self.state = np.zeros([4, 4])
        self.state[self.agent_pos['y'], self.agent_pos['x']] = 1

        done = False
        if (self.agent_pos == self.goal_pos):
            done = True

        reward = self.reward(prev_state, action, self.state)

        return reward, self.state, done

    def reward(self, s, a, s_next):
        reward = 0
        y, x = np.where(s == 1)
        y_next, x_next = np.where(s_next == 1)
        if ((y_next == self.goal_pos['y'] and x_next == self.goal_pos['x']) and (y != self.goal_pos['y'] or x != self.goal_pos['x'])):
            reward = 10

        return reward

    def get_state_index(self, state_space, state):
        for i_s, s in enumerate(state_space):
            if (s == state).all():
                return i_s
        assert False, "Couldn't find the state from the state space"

    def exploring_start(self):
        while (True):
            y_random = np.random.randint(4)
            x_random = np.random.randint(4)
            self.agent_pos = {'y' : y_random, 'x' : x_random}
            if (self.agent_pos != self.goal_pos):
                break
        
        self.state = np.zeros([4, 4])
        self.state[self.agent_pos['y'], self.agent_pos['x']] = 1
        
        return self.state

In [10]:
def td_value_prediction(env, policy):
    gamma = 0.9
    alpha = 5e-3

    value_vector = np.zeros([len(env.state_space)])

    # Repeat policy 
    for loop_count in range(10000):
        done = False
        step_count = 0
        s = env.reset()

        # Generate an episode
        while (not done):
            i_s = env.get_state_index(env.state_space, s)
            pi_s = policy[i_s]
            a = np.random.choice(env.action_space, p = pi_s)
            r, s_next, done = env.step(a)

            i_s_next = env.get_state_index(env.state_space, s_next)
            td = r + gamma * value_vector[i_s_next]
            value_vector[i_s] += alpha * (td - value_vector[i_s])

            if done:
                value_vector[i_s_next] = 0

            step_count += 1
            s = s_next

        if (loop_count % 100 == 0):
            print(f"{loop_count} value vector : \n{value_vector}")

    return value_vector

In [26]:
env = GridWorld()
policy = []
for i_s, s in enumerate(env.state_space):
    pi = np.array([0.25, 0.25, 0.25, 0.25])
    policy.append(pi)
policy = np.array(policy)

value_vector = td_value_prediction(env, policy).reshape(4, 4)

0 value vector : 
[0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.05 0.   0.
 0.   0.  ]
100 value vector : 
[6.66019436e-04 2.88881569e-03 1.46634301e-02 5.48391040e-02
 2.36361771e-03 1.46886998e-02 8.01268882e-02 2.68843156e-01
 1.11110058e-02 7.77948509e-02 4.56647153e-01 1.93053664e+00
 3.20842055e-02 2.40267654e-01 1.56861330e+00 0.00000000e+00]
200 value vector : 
[0.01529569 0.03697194 0.08075867 0.19032608 0.03356417 0.09888367
 0.30004072 0.78146736 0.07923297 0.2772504  0.87612916 2.85818261
 0.16623232 0.64657375 2.39748374 0.        ]
300 value vector : 
[0.05251759 0.09756063 0.19042059 0.34992417 0.10406079 0.21544155
 0.52562048 1.00548021 0.21976166 0.52147178 1.38349085 3.20525929
 0.36164064 0.95199305 3.1401062  0.        ]
400 value vector : 
[0.11675999 0.17429663 0.30757323 0.46772077 0.18911598 0.36518562
 0.71514093 1.27317756 0.3414731  0.70879422 1.71745799 3.83591638
 0.50355418 1.28487162 3.41474851 0.        ]
500 value vector : 
[0.18891928 0.27

In [27]:
value_vector

array([[0.70664317, 0.8667051 , 1.16575196, 1.40876578],
       [0.88832708, 1.13017912, 1.72566126, 2.31996894],
       [1.11853846, 1.60770293, 2.94963703, 4.85186042],
       [1.31927106, 2.07075277, 4.71476074, 0.        ]])