In [22]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt

In [23]:
class GridWorld:
  def __init__(self) -> None:
    self.size = 5
    self.start = (0, 0)
    self.gamma = 0.9
    self.actions = [(-1,0), (0,1), (1,0), (0,-1)]#up,right,down,left
    self.V = np.zeros((self.size, self.size))
    self.pi = np.ones((self.size, self.size, len(self.actions))) / len(self.actions) #uniform policy

  def is_valid_state(self, state):
    return (0 <= state[0] < self.size and 
            0 <= state[1] < self.size)
    
  # def get_reward_and_next_state(self, state):
  #   i, j = state
  #   if i == 0 and j == 1:
  #     return 10, (4, 1)
  #   elif i == 0 and j == 3:
  #     return 5, (2, 3)
  #   return 0, state
      
  def policy_eval(self):
    theta = 0.001
    while True:
      delta = 0
      for i in range(self.size):
        for j in range(self.size):
          current_state = (i, j)
          Vold = self.V[i, j]
          
          if i == 0 and j == 1:
            self.V[i, j] = 10 + self.gamma * self.V[4, 1]
          elif i == 0 and j == 3:
            self.V[i, j] = 5 + self.gamma * self.V[2, 3]
          else:
            value = 0
            for idx, action in enumerate(self.actions):
              next_state = (i + action[0], j + action[1])
              if self.is_valid_state(next_state):
                value += self.pi[i, j, idx] * (0 + self.gamma * self.V[next_state])
              else:
                value += self.pi[i, j, idx] * (-1 + self.gamma * self.V[current_state])
            self.V[i, j] = value
          delta = max(delta, abs(Vold - self.V[i, j]))
      if delta < theta:
        break
  
  def policy_improvement(self):
    policy_stable = True
    for i in range(self.size):
      for j in range(self.size):
        current_state = (i, j)
        old_action = self.pi[i, j].copy()
        Q = np.zeros(len(self.actions))
        if(i == 0 and j == 1):
          Q[0] = 10 + self.gamma * self.V[4][1]
        elif i == 0 and j == 3:
          Q[2] = 5 + self.gamma * self.V[2][3]
        else:
          for idx, action in enumerate(self.actions):
            next_state = (i + action[0], j + action[1])
            if next_state[0] >= 0 and next_state[0] < self.size and next_state[1] >= 0 and next_state[1] < self.size:
              Q[idx] = 0 + self.gamma * self.V[next_state]
            else:
              Q[idx] = -1 + self.gamma * self.V[current_state]
        new_best_action = np.argmax(Q)
        new_action = np.zeros(len(self.actions))
        new_action[new_best_action] = 1
        self.pi[i][j] = new_action
        if not np.array_equal(old_action, new_action):
          policy_stable = False

    return policy_stable
  
  def policy_iteration(self):
    policy_stable = False
    while not policy_stable:
      self.policy_eval()
      policy_stable = self.policy_improvement()
    return self.V, self.pi


    
    
    

In [24]:
#running policy iteration
grid = GridWorld()
V, pi = grid.policy_iteration()
print(V)

[[21.97748462 24.41942766 21.97748489 19.41942766 17.47748489]
 [19.77973615 21.97748489 19.7797364  17.80176276 16.02158649]
 [17.80176254 19.7797364  17.80176276 16.02158649 14.41942784]
 [16.02158629 17.80176276 16.02158649 14.41942784 12.97748505]
 [14.41942766 16.02158649 14.41942784 12.97748505 11.67973655]]


In [25]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt

In [26]:
class GridWorld:
  def __init__(self) -> None:
    self.size = 5
    self.start = (0, 0)
    self.gamma = 0.9
    self.actions = [(-1,0), (0,1), (1,0), (0,-1)]#up,right,down,left
    self.V = np.zeros((self.size, self.size))
    self.pi = np.ones((self.size, self.size, len(self.actions))) / len(self.actions) #uniform policy
  def pick_action(self):
    return np.random.choice(self.actions)
  
      
  def policy_eval(self):
    theta = 0.001
    while True:
      delta = 0
      for i in range(self.size):
        for j in range(self.size):
          current_state = (i, j)
          reward, next_state = 0, (0,0)
          Vold = self.V[current_state]
          if i == 0 and j == 1:
            reward = 10
            next_state = (4,1)
            self.V[current_state] = reward + self.gamma * self.V[next_state]
          elif i == 0 and j == 3:
            reward = 5
            next_state = (2,3)
            self.V[current_state] = reward + self.gamma * self.V[next_state]
          else:
            sum = 0
            for idx, action in enumerate(self.actions):
              next_state = (i + action[0], j + action[1])
              if next_state[0] >= 0 and next_state[0] < self.size and next_state[1] >= 0 and next_state[1] < self.size:
                sum += self.pi[i][j][idx] * (0 + self.gamma * self.V[next_state])  
              else:
                sum += self.pi[i][j][idx] * (-1 + self.gamma * self.V[current_state])
            self.V[current_state] = sum
          delta = max(delta, abs(Vold - self.V[current_state]))
      if delta < theta:
        break
    return self.V
  
  def policy_improvement(self):
    policy_stable = True
    for i in range(self.size):
      for j in range(self.size):
        current_state = (i, j)
        old_action = self.pi[i][j]
        Q = np.zeros(len(self.actions))
        if(i == 0 and j == 1):
          Q[0] = 10 + self.gamma * self.V[4][1]
        elif i == 0 and j == 3:
          Q[2] = 5 + self.gamma * self.V[2][3]
        else:
          for idx, action in enumerate(self.actions):
            next_state = (i + action[0], j + action[1])
            if next_state[0] >= 0 and next_state[0] < self.size and next_state[1] >= 0 and next_state[1] < self.size:
              Q[idx] = 0 + self.gamma * self.V[next_state]
            else:
              Q[idx] = -1 + self.gamma * self.V[current_state]
        new_best_action = np.argmax(Q)
        new_action = np.zeros(len(self.actions))
        new_action[new_best_action] = 1
        self.pi[i][j] = new_action
        if not np.array_equal(old_action, new_action):
          policy_stable = False

    return policy_stable
  
  def policy_iteration(self):
    policy_stable = False
    while not policy_stable:
      self.policy_eval()
      policy_stable = self.policy_improvement()
    return self.V, self.pi


    
    
    

In [27]:
#running policy iteration
grid = GridWorld()
V, pi = grid.policy_iteration()
print(V)

[[ 3.31359559  8.79292942  4.43113177  5.32556099  1.4955287 ]
 [ 1.52582318  2.99591435  2.2534199   1.91064941  0.55045095]
 [ 0.05486787  0.74165922  0.67626363  0.36114423 -0.40025498]
 [-0.96965064 -0.43208514 -0.35180898 -0.58272448 -1.18027658]
 [-1.85380443 -1.34185832 -1.22622928 -1.42007309 -1.97241846]]
