In [28]:
import matplotlib.pyplot as plt
import operator

In [71]:
class model_RL():
    
    # init 
    def __init__(self, size, treasure_pos, gamma = 0.9):
        self.size = size
        self.states = range(size * size)
        self.treasure_pos = treasure_pos
        self.actions = ['n', 's', 'w', 'e']
        self.gamma = gamma
        self.rewards = dict()
        self.policy = ['n' for i in range(size * size)]
        for i in range(self.size):
            self.policy[i] = 's'
        self.value = [0 for i in range(size * size)]
        
        # fill the dictionary

        for i in range(self.size):
            for j in range(self.size):
                for action in self.actions:
                    if i * size + j == self.treasure_pos:
                        self.rewards[(i * size + j, action)] = 0
                    else:
                        self.rewards[(i * size + j, action)] = -1
    
    def reset_value(self):
        self.value = [0 for i in range(self.size * self.size)]
    
    def reset(self):
        self.reset_value()
        self.policy = ['n' for i in range(self.size * self.size)]
        for i in range(self.size):
            self.policy[i] = 's'
    
    def update_policy(self):
        for state in range(self.size * self.size):
            row, col = state // self.size, state % self.size
            maxm = -10000000
            if row > 0:
                cur_state = (row - 1) * self.size + col
                if self.value[cur_state] > maxm:
                    self.policy[state] = 'n'
                    maxm = self.value[cur_state]
            if row < self.size - 1:
                cur_state = (row + 1) * self.size + col
                if self.value[cur_state] > maxm:
                    self.policy[state] = 's'
                    maxm = self.value[cur_state]
            if col > 0:
                cur_state = state - 1
                if self.value[cur_state] > maxm:
                    self.policy[state] = 'w'
                    maxm = self.value[cur_state]
            if col < self.size - 1:
                cur_state = state + 1
                if self.value[cur_state] > maxm:
                    self.policy[state] = 'e'
                    maxm = self.value[cur_state]
        return self.policy
    
    # given current state and action, get the next state
    def get_next_state(self, state, action):
        if action == 'n':
            return state - self.size
        elif action == 's':
            return state + self.size
        elif action == 'w':
            return state - 1
        else:
            return state + 1
    
    # given current state and action, get the reward
    def get_reward(self, state, action):
        return self.rewards[(state, action)]
    
    def update_to_expected_value(self):
        pre = self.value.copy()
        bias = 0.01
        cur_bias = 1
        cur = [0 for i in range(self.size * self.size)]
        while cur_bias > bias:
            for state in self.states:
                action = self.policy[state]
                
                next_state = self.get_next_state(state, action)
                cur[state] = self.rewards[(state, action)] + self.gamma * pre[next_state]
            cur[self.treasure_pos] = self.rewards[(self.treasure_pos, 'n')]
            cur_bias = abs(sum(cur) - sum(pre))
            pre = cur.copy()
        self.value = pre
    
    def get_max_value_action(self, state):
        action = 'n'
        row, col = state // self.size, state % self.size
        maxm = -10000000
        if row > 0:
            cur_state = (row - 1) * self.size + col
            if self.value[cur_state] > maxm:
                action = 'n'
                maxm = self.value[cur_state]
        if row < self.size - 1:
            cur_state = (row + 1) * self.size + col
            if self.value[cur_state] > maxm:
                action = 's'
                maxm = self.value[cur_state]
        if col > 0:
            cur_state = state - 1
            if self.value[cur_state] > maxm:
                action = 'w'
                maxm = self.value[cur_state]
        if col < self.size - 1:
            cur_state = state + 1
            if self.value[cur_state] > maxm:
                action = 'e'
                maxm = self.value[cur_state]
        return action
    
    # policy iteration
    def policy_iteration(self):
        pre = self.policy.copy()
        
        while True:
            self.update_to_expected_value()
            cur = self.update_policy()
            if operator.eq(cur, pre):
                break
            pre = cur.copy()
            self.reset_value()
    
    def update_value(self):
        cur = self.value.copy()
        for state in range(self.size * self.size):
            if state == self.treasure_pos:
                continue
            action = self.get_max_value_action(state)
            next_state = self.get_next_state(state, action)
            cur[state] = self.rewards[(state, action)] + self.gamma * self.value[next_state]
        self.value = cur
    # value iteration
    def value_iteration(self):
        bias = 0.01
        cur_bias = 1
        while cur_bias > bias:
            pre = self.value.copy()
            self.update_value()
            cur_bias = abs(sum(pre) - sum(self.value))
        self.update_policy()


In [72]:
RL1 = model_RL(4, 9)

In [73]:
RL1.policy_iteration()
print(RL1.policy)
RL1.reset()
RL1.value_iteration()
print(RL1.policy)

['s', 's', 's', 's', 's', 's', 's', 's', 'e', 'n', 'w', 'w', 'n', 'n', 'n', 'n']
['s', 's', 's', 's', 's', 's', 's', 's', 'e', 'n', 'w', 'w', 'n', 'n', 'n', 'n']
