### 时序差分预测 TD(0)预测

In [1]:
import gym
import matplotlib
import numpy as np
from collections import defaultdict
from gym.envs.toy_text.cliffwalking import CliffWalkingEnv

In [2]:
env = CliffWalkingEnv()

In [3]:
def policy(state, nA):
    return np.ones(nA) * 1 / 4

In [4]:
def td_prediction(env, n, discount=1.0, epsilon=0.1, alpha=0.5):
    V = defaultdict(float) # 初始化值：0.0
    for i in range(n):
        state = env.reset()
        while(True):
            prob = policy(state, env.action_space.n)
            action = np.random.choice(np.arange(len(prob)), p=prob)
            next_state, reward, done, _ = env.step(action)
            td_error = reward + discount * V[next_state] - V[state]
            V[state] = V[state] + alpha * td_error
            state = next_state
            if done:
                break            
    return V

In [5]:
if __name__ == "__main__":
    n = 1000
    V = td_prediction(env, n)
    print(V)

defaultdict(<class 'float'>, {36: -49087.967088934914, 24: -48999.17544442721, 25: -48942.77021206708, 13: -48832.96401186856, 12: -48959.78489109466, 0: -48593.427308997845, 1: -48700.85963793693, 14: -48770.797612896116, 2: -48480.5508483626, 15: -48584.14255438803, 27: -48968.73817298539, 26: -49034.10668798987, 16: -48706.33243983546, 3: -47900.74327907154, 4: -47455.674720725394, 28: -48952.814831745665, 29: -49103.39490122186, 30: -48493.10458006617, 17: -47863.42301899384, 31: -48362.13765188126, 18: -47871.70195584639, 19: -47054.701644371846, 7: -45322.78394858206, 8: -44102.98053546661, 9: -42100.4405613477, 21: -47127.35287969097, 22: -31234.49259185284, 34: -41591.12474690312, 35: -6142.858917104286, 33: -45450.438877383465, 32: -47697.16438223451, 20: -44508.708632301175, 5: -46917.662374235195, 6: -46616.09090426532, 10: -35083.6697115397, 11: -28830.119930876466, 23: -17253.535563505764, 47: 0.0})
