In [2]:
import numpy as np

In [78]:
class Environment:
    def __init__(self):
        self.value_table = np.zeros((5, 5))
        
    def update(self):
        #copy
        new_value_table = np.zeros((5, 5))
        value_table = self.value_table.copy()
        for row in range(5):
            for col in range(5):
                for action in range(4):
                    r, c, reward = self.step((row, col), action)
                    new_value_table[row, col] += 0.25*(reward + 0.9*value_table[r, c])
        self.value_table = new_value_table
                
    def step(self, state:tuple[int, int], action:int)->tuple[int, int, int]:
        if state == (0, 1):
            row, col = (4, 1)
            reward = 10
            return (row, col, reward)
        elif state == (0, 3):
            row, col = (2, 3)
            reward = 5
            return (row, col, reward)
        else:
            try:
                if action == 0:
                    row, col = state[0]-1, state[1]
                elif action == 1:
                    row, col = state[0]+1, state[1]
                elif action == 2:
                    row, col = state[0], state[1]-1
                else:
                    row, col = state[0], state[1]+1
                value = self.value_table[row, col]
                reward = 0
            except IndexError:
                row, col = state[0], state[1]
                reward = -1 
            return (row, col, reward)
        
    def get_policy(self):
        dir_table = {0: "↑", 1: "↓", 2: "←", 3: "→"}
        direction_list = []
        for row in range(5):
            direction_list_row = []
            for col in range(5):
                if row == 0 and col == 1 or row == 0 and col == 3:
                    direction_list_row.append("*")
                    continue
                action_value = []
                for action in range(4):
                    if action == 0:
                        try:
                            action_value.append(self.value_table[row-1, col])
                        except IndexError:
                            action_value.append(-999)
                    elif action == 1:
                        try:
                            action_value.append(self.value_table[row+1, col])
                        except IndexError:
                            action_value.append(-999)
                    elif action == 2:
                        try:
                            action_value.append(self.value_table[row, col-1])
                        except IndexError:
                            action_value.append(-999)
                    else:
                        try:
                            action_value.append(self.value_table[row, col+1])
                        except IndexError:
                            action_value.append(-999)
                direction_list_row.append(dir_table[np.argmax(action_value)])
            direction_list.append(direction_list_row)
        return direction_list

In [80]:
env = Environment()
print("초기 상태 가치 함수")
print(env.value_table)

for i in range(100):
    env.update()
    print(env.value_table)

print("\n무작위 행동으로 수렴된 상태 가치 함수 (100번 반복)")
print(env.value_table)

print("\n상태 가치 함수로 계산한 정책")
print(np.array(env.get_policy()))

초기 상태 가치 함수
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
[[ 0.   10.    0.    5.   -0.25]
 [ 0.    0.    0.    0.   -0.25]
 [ 0.    0.    0.    0.   -0.25]
 [ 0.    0.    0.    0.   -0.25]
 [-0.25 -0.25 -0.25 -0.25 -0.5 ]]
[[ 2.1375   9.775    3.31875  5.       0.65   ]
 [-0.05625  2.25     0.       1.06875 -0.41875]
 [-0.05625  0.       0.      -0.05625 -0.41875]
 [-0.1125  -0.05625 -0.05625 -0.1125  -0.475  ]
 [-0.475   -0.41875 -0.41875 -0.475   -0.8375 ]]
[[ 2.22609375  9.623125    3.23015625  4.949375    0.73859375]
 [ 0.8803125   2.18671875  1.4934375   1.018125   -0.05171875]
 [-0.1321875   0.4809375  -0.0253125   0.1209375  -0.55796875]
 [-0.2390625  -0.1321875  -0.1321875  -0.2390625  -0.66484375]
 [-0.66484375 -0.55796875 -0.55796875 -0.66484375 -1.090625  ]]
[[ 2.37986719e+00  9.49782812e+00  3.48929297e+00  5.10884375e+00
   7.72765625e-01]
 [ 9.51503906e-01  2.80750781e+00  1.44217969e+00  1.46520703e+00
   8.08203125e-03]
 [ 1