In [1]:
import torch
import gym
import numpy as np
env = gym.make("FrozenLake-v1")

In [2]:
gamma = 0.99
threshold = 0.0001

In [3]:
def policy_evaluation (env, policy, gamma, threshold):
    n_state = policy.shape[0]
    V = torch.zeros(n_state)
    
    while True:
        V_temp = torch.zeros(n_state)
        
        for state in range (n_state):
            action = policy[state].item()
            
            for trans_prob, new_state, reward, _ in env.env.P[state][action]:
                V_temp[state] += trans_prob*(reward + gamma*V[new_state])
                
        max_delta = torch.max(torch.abs(V-V_temp))
        V = V_temp.clone()

        if max_delta <= threshold:
            break
    
    return V

In [4]:
def policy_improvement(env,V,gamma):
    n_state = env.observation_space.n
    n_action = env.action_space.n
    policy = torch.zeros(n_state)
       
    for state in range (n_state):
        v_actions = torch.zeros(n_action)
        for action in range (n_action):
            for trans_prob, new_state, reward, _ in env.env.P[state][action]:
                v_actions[action] += trans_prob*(reward + gamma*V[new_state])

        policy[state] = torch.argmax(v_actions)
    return policy
    

In [5]:
def policy_iteration(env, gamma, threshold):
    n_state = env.observation_space.n
    n_action = env.action_space.n
    policy = torch.randint(high = n_action, size=(n_state,)).float()
    
    while True:
        V = policy_evaluation(env, policy, gamma, threshold)
        policy_improved = policy_improvement(env, V, gamma)
        
        if torch.equal (policy_improved, policy):
            return V, policy_improved

        policy = policy_improved

In [6]:
V_optimal, optimal_policy = policy_iteration (env, gamma, threshold)

In [9]:
print ('Optimal values:\n',V_optimal)

Optimal values:
 tensor([0.5404, 0.4966, 0.4681, 0.4541, 0.5569, 0.0000, 0.3572, 0.0000, 0.5905,
        0.6421, 0.6144, 0.0000, 0.0000, 0.7410, 0.8625, 0.0000])


In [10]:
print ('Optimal policy:\n',optimal_policy)

Optimal policy:
 tensor([0., 3., 3., 3., 0., 0., 0., 0., 3., 1., 0., 0., 0., 2., 1., 0.])
