In [None]:
import random
from copy import deepcopy as dc

"""
Env
--------
Attribute
width, height, reward, possible_actions, all_state
--------
--------
Method
get_reward, state_after_action, get_transition_prob, get_all_states
--------
"""

class PolicyIteration:
  def __init__(self, env):
    self.env = env
    self.value_table = [[0.0] * env.width for _ in range(env.height)]
    self.policy_table = [[[0.25, 0.25, 0.25, 0.25]] * env.width
                                for _ in range(env.height)]
    self.policy_table[2][2] = []
    self.discount_factor = 0.9
    
  def policy_evaluation(self):
    '''
    update value table with Bellman Expectation Equation
    '''
    temp_value_table = [[0.0] * self.env.width for _ in range(self.env.height)]
    for state in self.env.get_all_states():
      
      if state == [2, 2]:
        continue

      new_value = 0.0
      # Bellman Expectation Eq
      for action in self.env.possible_actions:
        next_state = self.env.state_after_action(state, action)
        reward = self.env.get_reward(state, action)
        next_value = self.get_value(next_state)
        probs = self.get_policy(state)
        new_value += probs[action] * (reward + self.discount_factor * next_value)
      
      temp_value_table[state[0]][state[1]] = round(new_value, 2)
    
    self.value_table = temp_value_table

    return

  def policy_improvement(self):
    temp_policy_table = dc(self.policy_table)

    for state in self.env.get_all_states():
      if state == [2, 2]:
        continue
      
      value = float("-inf")
      high_value_actions = []
      for idx, action in enumerate(self.env.possible_actions):
        next_state = self.env.state_after_action(state, action)
        reward = self.env.get_reward(state, action)
        next_value = self.get_value(next_state)
        probs = self.get_policy(state)
        new_value = probs[action] * (reward + self.discount_factor * next_value)

        if new_value > value:
          high_value_actions.clear()
          high_value_actions.append(idx)
          value = new_value
        elif new_value == value:
          high_value_actions.append(idx)
      
      temp_prob = [0.0, 0.0, 0.0, 0.0]
      prob = 1 / len(high_value_actions)
      for idx in high_value_actions:
        temp_prob[idx] = prob
      
      temp_policy_table[state[0]][state[1]] = temp_prob
    
    self.policy_table = temp_policy_table
    return
  
  def get_value(self, state):
    return round(self.value_table[state[0]][state[1]], 2)
  
  def get_policy(self, state):
    return self.policy_table[state[0]][state[1]]
  
  def get_action(self, state):
    random_num = random.randrange(100) / 100
    actions = self.get_policy(state)

    temp_sum = 0.0
    for idx, prob in enumerate(actions):
      temp_sum += prob
      if random_num < temp_sum:
        return idx