In [59]:
import numpy as np
import sys
if "../" not in sys.path:
    sys.path.append("../")
from lib.envs.random_walk import RandomWalk
import itertools


In [60]:
env2= RandomWalk()

In [61]:
def random_policy(V):
    A = np.ones(env2.action_space.n, dtype=float)/env2.action_space.n
    def policy(state):
        return A
    return policy

In [62]:
def n_step_TD(env, num_episodes, n, discount=1.0, alpha=0.4):
    V = np.zeros(env.observation_space.n)
    policy = random_policy(V)
    
    for ep in range(num_episodes):
        G = 0.0
        states = []
        rewards=[]
        T = float('inf')
        state = env.reset()
        states.append(state)
        for t in itertools.count():
            if t < T:
                action = np.random.choice(np.arange(env.action_space.n), p=policy(state))
                next_state, reward, done, _ = env.step(action)
                rewards.append(reward)
                states.append(next_state)
                if done:
                    T = t+1
            tau = t - n + 1
            if tau >= 0:
                for i in range(tau, min(tau+n, T)):
                    G = sum(discount**(i-tau) * rewards[i] for i in range(tau, min(tau+n, T)))
                if tau + n < T:
                    G += discount**n * V[states[tau+n]]
                V[states[tau]] += alpha*(G - V[states[tau]])

            state = next_state
            if tau == T-1:
                break
    return V

In [77]:
V = n_step_TD(env2, num_episodes=20, n=4)

In [78]:
print(V)

[0.         0.         0.00122578 0.00655683 0.01808109 0.04046312
 0.06665874 0.1880114  0.27428006 0.52058462 0.4541326  0.53091539
 0.60474894 0.75215003 0.86844681 0.87187264 0.96231742 0.88869273
 0.94777586 0.99024599 0.        ]


In [None]:
"""
There are two terminal states location at both the extremes of the chain of states.
    1. The extreme right of the walk has a reward of +1
    2. The extreme left of the walk has a reward of 0
    
    As you can observe from the state values above, the states  to the right have 
    higer values and the ones to the left have lower values as expected.
"""