Solving the Frozen Lake Problem with Policy Iteration


In [1]:
import gymnasium as gym
import numpy as np
env = gym.make("FrozenLake-v1", render_mode="human" , is_slippery=False)

In [2]:

class Policy_iteration():
    def __init__(self , env , n_iterations , discount_factor):
        self.n_iterations = n_iterations
        self.discount_factor = discount_factor
        self.num_states = env.observation_space.n
        self.num_actions = env.action_space.n


        self.value_table = [0] * self.num_states
        self.optimal_policy = list(np.random.randint(0 , self.num_actions , self.num_states))

    def get_optimal_policy(self):
        for _ in range(self.n_iterations):
            for state in range(self.num_states):
                action = self.optimal_policy[state]
                info = env.unwrapped.P[state][action]
                prob = np.array([x[0] for x in info])

                R = np.array([x[2] for x in info]) + \
                self.discount_factor * np.array([self.value_table[x[1]] for x in info])

                v = sum(R * prob)
                self.value_table[state] = v

                action_Q = np.zeros(self.num_actions)
                for action in range(self.num_actions):
                    info = env.unwrapped.P[state][action]
                    prob = np.array([x[0] for x in info])
                    R = np.array([x[2] for x in info]) + \
                    self.discount_factor * np.array([self.value_table[x[1]] for x in info])

                    Q = sum(prob * R)
                    action_Q[action] = Q

                self.optimal_policy[state] = np.argmax(action_Q)
        return self.optimal_policy


In [None]:

cur_state = env.reset()
n_iterations = 1000
discount = 0.9

app = Policy_iteration(env , n_iterations , discount ) 

policy = app.get_optimal_policy()

s = 0 
done = False
while not done:
    t = env.step(int(policy[s]))
    s = int(t[0])
    done = t[2] or t[3]
    env.render()

env.close()

[np.int64(1), np.int64(2), np.int64(1), np.int64(0), np.int64(1), np.int64(0), np.int64(1), np.int64(0), np.int64(2), np.int64(1), np.int64(1), np.int64(0), np.int64(0), np.int64(2), np.int64(2), np.int64(0)]


We can print the obtained optimal policy:

In [6]:
print(list(map(int , policy)))

[1, 2, 1, 0, 1, 0, 1, 0, 2, 1, 1, 0, 0, 2, 2, 0]


As we can observe, our optimal policy tells us to perform the correct action in each state.