In [2]:
import gym # openAi gym
from gym import envs
import numpy as np 
'''
in this part we define a function determining the best value fucntion for a single action of a given policy that we use later in our code
'''
def policy_eval(policy, env, discount_factor=1.0, theta=0.00001):
    V=np.zeros(env.env.nS)
    while True:
        delta=0
        for state in range(env.env.nS):
            val=0
            for action, action_prob in enumerate(policy[state]):
                for prob, next_state, reward, done in env.env.P[state][action]:
                    val+=action_prob*prob*(reward+discount_factor*V[next_state])
            delta=max(delta,np.abs(val-V[state]))
            V[state]=val
        if delta< theta:
            break
    return np.array(V)
'''in this part we define a function that use the output of the previous function which would be a value function for a given policy to find the best policy with that value function and again use the function before to find a new best value function for new policy
and find the best policy for that again.
'''
def policy_iteration (env, policy_eval_fn = policy_eval, discount_factor=1.0):
    def one_step_lookahead(state, V):
        A= np.zeros(env.env.nA)
        for a in range(env.env.nA):
            for prob, next_state, reward, done in env.env.P[state][a]:
                A[a]+=prob*(reward+discount_factor*V[next_state])
        return A
    policy= np.ones([env.env.nS, env.env.nA])/env.env.nA
    while True:
        curr_pol_val=policy_eval_fn(policy,env,discount_factor)
        policy_stable= True
        for state in range(env.env.nS):
            chosen_act=np.argmax(policy[state])
            act_values=one_step_lookahead(state,curr_pol_val)
            best_act= np.argmax(act_values)
            if chosen_act != best_act:
                policy_stable= False
            policy[state]=np.eye(env.env.nA)[best_act]
        if policy_stable:
            return policy, curr_pol_val
        V_f= curr_pol_val
    return policy

env = gym.make('Taxi-v3')
aaa=policy_iteration(env,policy_eval,discount_factor=0.99)
aaa[0]
print("here is your result")
print(aaa)
print("here is your value function")
final_policy = aaa[0]
policy_eval(final_policy,env,discount_factor=0.99)


here is your result
(array([[0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1., 0.],
       ...,
       [0., 1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0.]]), array([944.72316569, 864.01270478, 903.55686813, 873.75019718,
       789.53759087, 864.01270478, 789.53757231, 816.76645272,
       864.01272333, 826.02673924, 903.55686813, 835.38053503,
       807.59881631, 826.02673924, 807.59879776, 873.75019718,
       955.27593403, 873.75021631, 913.69381566, 883.58606752,
       934.27593403, 854.37257773, 893.52129945, 864.01269521,
       798.52282815, 873.75021631, 798.52280978, 826.02672968,
       854.3725961 , 816.76647185, 893.52129945, 826.02672968,
       816.76649022, 835.38055415, 816.76647185, 883.58606752,
       944.72317469, 883.58608645, 903.5568775 , 893.52128998,
       883.58610464, 807.59880713, 844.82885195, 816.76646238,
       844.82887014, 923.9331565 , 844.82885195, 873.75020684,
       844.

array([944.72316569, 864.01270478, 903.55686813, 873.75019718,
       789.53759087, 864.01270478, 789.53757231, 816.76645272,
       864.01272333, 826.02673924, 903.55686813, 835.38053503,
       807.59881631, 826.02673924, 807.59879776, 873.75019718,
       955.27593403, 873.75021631, 913.69381566, 883.58606752,
       934.27593403, 854.37257773, 893.52129945, 864.01269521,
       798.52282815, 873.75021631, 798.52280978, 826.02672968,
       854.3725961 , 816.76647185, 893.52129945, 826.02672968,
       816.76649022, 835.38055415, 816.76647185, 883.58606752,
       944.72317469, 883.58608645, 903.5568775 , 893.52128998,
       883.58610464, 807.59880713, 844.82885195, 816.76646238,
       844.82887014, 923.9331565 , 844.82885195, 873.75020684,
       844.82887014, 807.59880713, 883.58608645, 816.76646238,
       826.0267668 , 844.82885195, 826.02674861, 893.52128998,
       893.52132673, 934.27592494, 893.52130873, 903.55686813,
       873.7502436 , 798.52281906, 835.38056343, 807.59