In [1]:
import gymnasium as gym
import numpy as np

In [2]:
def policy_evaluation(env, policy, gamma=1.0, theta=0.00001):
    V = np.zeros(env.nS)
    ite = 0
    while (ite<100):
        delta = 0
        for s in range(env.nS):
            v = 0
            for a, action_prob in enumerate(policy[s]):
                for prob, next_state, reward, done in env.P[s][a]:
                    v += action_prob * prob * (reward + gamma * V[next_state])
            delta = max(delta, np.abs(v - V[s]))
            V[s] = v
        if delta < theta:
            break
        ite += 1
    return np.array(V)


In [3]:
def policy_improvement(env, policy_eval_fn=policy_evaluation, gamma=1.0):
    def one_step_lookahead(state, V):
        A = np.zeros(env.nA)
        for a in range(env.nA):
            for prob, next_state, reward, done in env.P[state][a]:
                A[a] += prob * (reward + gamma * V[next_state])
        return A

    policy = np.zeros([env.nS, env.nA]) / env.nA
    ite = 0
    while (ite<100):
        V = policy_eval_fn(env, policy, gamma=gamma)

        policy_stable = True
        for s in range(env.nS):
            chosen_a = np.argmax(policy[s])
            action_values = one_step_lookahead(s, V)
            best_a = np.argmax(action_values)
            if chosen_a != best_a:
                policy_stable = False
            policy[s] = np.eye(env.nA)[best_a]
        if policy_stable:
            return policy, V
        ite += 1
    return policy, V

In [None]:
def policy_render(env, policy, n=100):
    wins = 0
    for _ in range(n):
        state,_ = env.reset()
        done = False
        while not done:
            action = np.argmax(policy[state])
            state, reward, done,_, info = env.step(action)
            env.render()
            if done and reward == 1:
                wins += 1
    return wins / n

env = gym.make("FrozenLake-v1",map_name = "4x4",render_mode="human")
env.nS = 16
env.nA = 4
random_policy = np.ones([env.nS, env.nA]) / env.nA
v = policy_evaluation(env, random_policy)
policy, v = policy_improvement(env)
policy_render(env, policy)

  logger.warn(
2024-11-05 15:03:34.671 python[39428:12017097] +[IMKClient subclass]: chose IMKClient_Legacy
2024-11-05 15:03:34.671 python[39428:12017097] +[IMKInputSession subclass]: chose IMKInputSession_Legacy


In [12]:
a = env.observation_space

In [13]:
a

Discrete(16)

In [5]:
observation, info = env.reset()

episode_over = False
while not episode_over:
    action = env.action_space.sample()  # agent policy that uses the observation and info
    observation, reward, terminated, truncated, info = env.step(action)

    episode_over = terminated or truncated

env.close()