In [5]:
import numpy as np
import pprint
import sys
if "../" not in sys.path:
  sys.path.append("../") 
from lib.envs.gridworld import GridworldEnv

In [6]:
pp = pprint.PrettyPrinter(indent=2)
env = GridworldEnv()

In [7]:
# Taken from Policy Evaluation Exercise!

def policy_eval(policy, env, discount_factor=1.0, theta=0.00001):
    """
    Evaluate a policy given an environment and a full description of the environment's dynamics.
    
    Args:
        policy: [S, A] shaped matrix representing the policy.
        env: OpenAI env. env.P represents the transition probabilities of the environment.
            env.P[s][a] is a list of transition tuples (prob, next_state, reward, done).
            env.nS is a number of states in the environment. 
            env.nA is a number of actions in the environment.
        theta: We stop evaluation once our value function change is less than theta for all states.
        discount_factor: Gamma discount factor.
    
    Returns:
        Vector of length env.nS representing the value function.
    """
    # Start with a random (all 0) value function
    V = np.zeros(env.nS)
    while True:
        delta = 0
        # For each state, perform a "full backup"
        for s in range(env.nS):
            v = 0
            # Look at the possible next actions
            for a, action_prob in enumerate(policy[s]):
                # For each action, look at the possible next states...
                for  prob, next_state, reward, done in env.P[s][a]:
                    # Calculate the expected value
                    v += action_prob * prob * (reward + discount_factor * V[next_state])
            # How much our value function changed (across any states)
            delta = max(delta, np.abs(v - V[s]))
            V[s] = v
        # Stop evaluating once our value function change is below a threshold
        if delta < theta:
            break
    return np.array(V)

In [13]:
def policy_improvement(env, policy_eval_fn=policy_eval, discount_factor=1.0):
    """
    Policy Improvement Algorithm. Iteratively evaluates and improves a policy
    until an optimal policy is found.
    
    Args:
        env: The OpenAI envrionment.
        policy_eval_fn: Policy Evaluation function that takes 3 arguments:
            policy, env, discount_factor.
        discount_factor: gamma discount factor.
        
    Returns:
        A tuple (policy, V). 
        policy is the optimal policy, a matrix of shape [S, A] where each state s
        contains a valid probability distribution over actions.
        V is the value function for the optimal policy.
        
    """
    # Start with a random policy
    policy = np.ones([env.nS, env.nA]) / env.nA
    
    while True:
        # Implement this!
        break
    
    return policy, np.zeros(env.nS)

In [14]:
policy, v = policy_improvement(env)
print("Policy Probability Distribution:")
print(policy)
print("")

print("Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):")
print(np.reshape(np.argmax(policy, axis=1), env.shape))
print("")

print("Value Function:")
print(v)
print("")

print("Reshaped Grid Value Function:")
print(v.reshape(env.shape))
print("")



Policy Probability Distribution:
[[ 0.25  0.25  0.25  0.25]
 [ 0.25  0.25  0.25  0.25]
 [ 0.25  0.25  0.25  0.25]
 [ 0.25  0.25  0.25  0.25]
 [ 0.25  0.25  0.25  0.25]
 [ 0.25  0.25  0.25  0.25]
 [ 0.25  0.25  0.25  0.25]
 [ 0.25  0.25  0.25  0.25]
 [ 0.25  0.25  0.25  0.25]
 [ 0.25  0.25  0.25  0.25]
 [ 0.25  0.25  0.25  0.25]
 [ 0.25  0.25  0.25  0.25]
 [ 0.25  0.25  0.25  0.25]
 [ 0.25  0.25  0.25  0.25]
 [ 0.25  0.25  0.25  0.25]
 [ 0.25  0.25  0.25  0.25]]

Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):
[[0 0 0 0]
 [0 0 0 0]
 [0 0 0 0]
 [0 0 0 0]]

Value Function:
[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]

Reshaped Grid Value Function:
[[ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]]



In [18]:
import numpy as np
import gym

def policy_iteration(env, gamma=0.99, theta=1e-6):
    nS = env.observation_space.n  # Number of states
    nA = env.action_space.n  # Number of actions

    # Initialize the policy arbitrarily
    policy = np.ones([nS, nA]) / nA

    while True:
        # Policy Evaluation (Calculate the state-value function V_pi)
        V = np.zeros(nS)
        while True:
            delta = 0
            for s in range(nS):
                v = V[s]
                weighted_rewards = [sum(p * (r + gamma * V[s_next]) for p, s_next, r, _ in env.P[s][a]) for a in range(nA)]
                V[s] = sum(policy[s][a] * weighted_rewards[a] for a in range(nA))
                delta = max(delta, abs(v - V[s]))
            if delta < theta:
                break

        policy_stable = True

        # Policy Improvement
        for s in range(nS):
            old_action = np.argmax(policy[s])
            action_values = [sum(p * (r + gamma * V[s_next]) for p, s_next, r, _ in env.P[s][a]) for a in range(nA)]
            best_action = np.argmax(action_values)

            # Update the policy
            new_policy = np.zeros(nA)
            new_policy[best_action] = 1
            policy[s] = new_policy

            if old_action != best_action:
                policy_stable = False

        if policy_stable:
            break

    return policy, V

# Create the environment
env = gym.make('Taxi-v3')

# Run Policy Iteration
optimal_policy, optimal_value = policy_iteration(env)

# Print the optimal policy and value function
print("Optimal Policy:")
print(optimal_policy)

print("Optimal Value Function:")
print(optimal_value)


Optimal Policy:
[[0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1. 0.]
 ...
 [0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0.]]
Optimal Value Function:
[944.72357234 864.01312811 903.55729147 873.75063348 789.53799752
 864.01312811 789.53799565 816.76688902 864.01312999 826.02716258
 903.55729147 835.38097132 807.59922297 826.02716258 807.59922109
 873.75063348 955.27633662 873.75063542 913.69423476 883.58649945
 934.27633662 854.37299683 893.52171855 864.01312715 798.52323074
 873.75063542 798.52322888 826.02716161 854.37299869 816.76689095
 893.52171855 826.02716161 816.76689281 835.38097326 816.76689095
 883.58649945 944.72357326 883.58650137 903.55729242 893.5217176
 883.58650321 807.59922204 844.82926686 816.76688999 844.8292687
 923.93357142 844.82926686 873.75063446 844.8292687  807.59922204
 883.58650137 816.76688999 826.02716536 844.82926686 826.02716353
 893.5217176  893.52172131 934.2763357  893.52171949 903.55729147
 873.75063818 798.52322982 835.3809742