Solving the Frozen Lake Problem with Value Iteration


In [3]:
import gymnasium as gym
import numpy as np

In [4]:

class Value_Iteration():
    def __init__(self , env , n_iterations , discount_factor):
        self.n_iterations = n_iterations
        self.discount_factor = discount_factor
        self.num_states = env.observation_space.n
        self.num_actions = env.action_space.n

        self.value_table = [0] * self.num_states
        self.optimal_policy = [0] * self.num_states



    # (probability, next_state, reward, done)

    def get_optimal_value(self):
        for _ in range(self.n_iterations):
            for state in range(self.num_states):
                action_Q = [0] * self.num_actions
                for action in range(self.num_actions):
                    info = env.unwrapped.P[state][action]
                    proba = np.array([x[0] for x in info])
                    R = [x[2] for x in info] + \
                    self.discount_factor * np.array([self.value_table[x[1]] for x in info])
                    action_Q[action] = sum(proba * R)
            
                self.value_table[state] = max(action_Q)


    def get_optimal_policy(self):
        for _ in range(self.n_iterations):
            for state in range(self.num_states):
                action_Q = [0] * self.num_actions
                for action in range(self.num_actions):
                    info = env.unwrapped.P[state][action]
                    proba = np.array([x[0] for x in info])
                    R = [x[2] for x in info] + \
                    self.discount_factor * np.array([self.value_table[x[1]] for x in info])
                    action_Q[action] = sum(proba * R)
            
                self.optimal_policy[state] = np.argmax(action_Q)

        return self.optimal_policy


In [5]:
env = gym.make("FrozenLake-v1", render_mode="human" , is_slippery=False)
cur_state = env.reset()
n_iterations = 1000
discount = 0.9

app = Value_Iteration(env , n_iterations , discount)
app.get_optimal_value()

policy = app.get_optimal_policy()

optimal_policy = list(map(int , policy))


s = 0 
done = False
while not done:
    t = env.step(int(policy[s]))
    # print(t)
    s = int(t[0])
    done = t[2] or t[3]
    env.render()

env.close()

We can print the obtained optimal policy:

In [6]:
print(optimal_policy)

[1, 2, 1, 0, 1, 0, 1, 0, 2, 1, 1, 0, 0, 2, 2, 0]


As we can observe, our optimal policy tells us to perform the correct action in each state.