In [None]:
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

In [None]:
def make_table(width, height, value=[0.00]):
    return [value * width for _ in range(height)] 

In [None]:
# Notice! 'from __future__ import division'
def make_uniform_dist(num):
    return [1/num] * num

In [None]:
import numpy as np
from environment_policy  import GraphicDisplay, Env

In [None]:
class PolicyIteration:
    def __init__(self, env):
        self.env = env
        
        self.value_table = make_table(env.width, env.height)
        
        prob_actions = make_uniform_dist(4)
        self.policy_table = make_table(env.width, env.height, [prob_actions])
        
        self.discount_factor = 0.9
        
        self.policy_table[2][2] = [] 
    
    def is_final_state(self, state):
        return state == [2, 2]
    
    def get_policy(self, state):
        if self.is_final_state(state):
            return 0.0
        return self.policy_table[state[0]][state[1]]
    
    def get_action(self, state):
        state_policy = self.get_policy(state)
        action_idx = np.random.choice([0,1,2,3], 1, p=state_policy)[0]
        return action_idx
        
    def get_value(self, state):
        return round(self.value_table[state[0]][state[1]], 2)
        
    def interact_env(self, state, action):
        next_state = self.env.state_after_action(state, action)
        next_value = self.get_value(next_state)
        reward = self.env.get_reward(state, action)
        return next_state, next_value, reward
    
    def policy_evaluation(self):
        env = self.env
        next_value_table = make_table(env.width, env.height)
        
        def update_next_value(state, value):
            next_value_table[state[0]][state[1]] = value
        
        for state in env.get_all_states():
            if self.is_final_state(state):
                update_next_value(state, 0)
                continue

            value = 0.0
            for action in env.possible_actions:
                next_state, next_value, reward = self.interact_env(state, action)
                prob_action = self.get_policy(state)[action]
                value += prob_action * (reward + self.discount_factor*next_value)
            update_next_value(state, round(value, 2))
        
        self.value_table = next_value_table
    
    def policy_improvement(self):
        env = self.env
        
        def update_next_policy(state, value):
            self.policy_table[state[0]][state[1]] = value
            
        for state in env.get_all_states():
            if self.is_final_state(state):
                continue

            best_value = -999999
            best_actions = []  
            next_state_policy = [0.0, 0.0, 0.0, 0.0]
            for idx, action in enumerate(env.possible_actions):
                next_state, next_value, reward = self.interact_env(state, action)
                action_value = reward + self.discount_factor * next_value
                
                if action_value == best_value:
                    best_actions.append(idx)
                elif action_value > best_value:
                    best_value = action_value
                    best_actions.clear()
                    best_actions.append(idx)
            
            prob_best_action = 1 / len(best_actions)
            for idx in best_actions:
                next_state_policy[idx] = prob_best_action
            
            update_next_policy(state, next_state_policy)

In [None]:
env = Env()
policy_iteration = PolicyIteration(env)
grid_world = GraphicDisplay(policy_iteration)
grid_world.mainloop()