forked from llSourcell/AI_for_video_games_demo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
policy_iteration_demo.py
77 lines (68 loc) · 2.46 KB
/
policy_iteration_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
Solving FrozenLake8x8 environment using Policy iteration.
Author : Moustafa Alzantot (malzantot@ucla.edu)
"""
import numpy as np
import gym
from gym import wrappers
def run_episode(env, policy, gamma = 1.0, render = False):
""" Runs an episode and return the total reward """
obs = env.reset()
total_reward = 0
step_idx = 0
while True:
if render:
env.render()
obs, reward, done , _ = env.step(int(policy[obs]))
total_reward += (gamma ** step_idx * reward)
step_idx += 1
if done:
break
return total_reward
def evaluate_policy(env, policy, gamma = 1.0, n = 100):
scores = [run_episode(env, policy, gamma, False) for _ in range(n)]
return np.mean(scores)
def extract_policy(v, gamma = 1.0):
""" Extract the policy given a value-function """
policy = np.zeros(env.nS)
for s in range(env.nS):
q_sa = np.zeros(env.nA)
for a in range(env.nA):
q_sa[a] = sum([p * (r + gamma * v[s_]) for p, s_, r, _ in env.P[s][a]])
policy[s] = np.argmax(q_sa)
return policy
def compute_policy_v(env, policy, gamma=1.0):
""" Iteratively evaluate the value-function under policy.
Alternatively, we could formulate a set of linear equations in iterms of v[s]
and solve them to find the value function.
"""
v = np.zeros(env.nS)
eps = 1e-10
while True:
prev_v = np.copy(v)
for s in range(env.nS):
policy_a = policy[s]
v[s] = sum([p * (r + gamma * prev_v[s_]) for p, s_, r, _ in env.P[s][policy_a]])
if (np.sum((np.fabs(prev_v - v))) <= eps):
# value converged
break
return v
def policy_iteration(env, gamma = 1.0):
""" Policy-Iteration algorithm """
policy = np.random.choice(env.nA, size=(env.nS)) # initialize a random policy
max_iterations = 200000
gamma = 1.0
for i in range(max_iterations):
old_policy_v = compute_policy_v(env, policy, gamma)
new_policy = extract_policy(old_policy_v, gamma)
if (np.all(policy == new_policy)):
print ('Policy-Iteration converged at step %d.' %(i+1))
break
policy = new_policy
return policy
if __name__ == '__main__':
env_name = 'FrozenLake8x8-v0'
env = gym.make(env_name).unwrapped
optimal_policy = policy_iteration(env, gamma = 1.0)
scores = evaluate_policy(env, optimal_policy, gamma = 1.0)
print('Average scores = ', np.mean(scores))