In [1]:
import random
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

### Environment

In [2]:
class GridWorld:
    def __init__(self):
        self.x = 0
        self.y = 0

    def step(self, a):
        if a == 0:
            self.move_right()
        elif a == 1:
            self.move_left()
        elif a == 2:
            self.move_up()
        elif a == 3:
            self.move_down()

        reward = -1
        done = self.is_done()
        return (self.x, self.y), reward, done

    def move_right(self):
        self.y += 1
        if self.y > 3:
            self.y = 3

    def move_left(self):
        self.y -= 1
        if self.y < 0:
            self.y = 0

    def move_up(self):
        self.x -= 1
        if self.x < 0:
            self.x = 0

    def move_down(self):
        self.x += 1
        if self.x > 3:
            self.x = 3

    def is_done(self):
        if self.x == 3 and self.y == 3:
            return True
        else:
            return False

    def get_state(self):
        return (self.x, self.y)

    def reset(self):
        self.x = 0
        self.y = 0
        return (self.x, self.y)


### Agent

In [3]:
class Agent:
    def __init__(self):
        pass

    def select_action(self):
        coin = random.random()
        if coin < 0.25:
            action = 0
        elif coin < 0.5:
            action = 1
        elif coin < 0.75:
            action = 2
        else:
            action = 3
        return action

### MC

In [4]:
def MC_evaluation(gamma: float = 1.0, alpha: float = 0.001, num_episodes: int = 50000):
    env = GridWorld()
    agent = Agent()
    V = np.zeros((4, 4), dtype=float)
    
    for episode in tqdm(range(num_episodes)):
        env.reset()
        state = env.get_state()
        terminated = False
        trajectory = []
        while not terminated:
            action = agent.select_action()
            next_state, reward, terminated = env.step(action)
            trajectory.append((state, reward))
            state = next_state

        G = 0
        for transition in trajectory[::-1]:
            state, reward = transition
            G = reward + gamma*G
            V[state] = V[state] + alpha*(G - V[state])

    return V


### One-step TD

In [5]:
def One_Step_TD_evaluation(gamma: float = 1.0, alpha: float = 0.001, num_episodes: int = 50000):
    env = GridWorld()
    agent = Agent()
    V = np.zeros((4, 4), dtype=float)
    
    for episode in tqdm(range(num_episodes)):
        env.reset()
        state = env.get_state()
        terminated = False
        while not terminated:
            action = agent.select_action()
            next_state, reward, terminated = env.step(action)
            V[state] = V[state] + alpha * (reward + gamma * V[next_state] - V[state])
            state = next_state
    return V

### N-step TD

In [6]:
def N_Step_TD_evaluation(n: int, gamma: float = 1.0, alpha: float = 0.001, num_episodes: int = 50000):
    env = GridWorld()
    agent = Agent()
    V = np.zeros((4, 4), dtype=float)
    
    for episode in tqdm(range(num_episodes)):
        env.reset()
        state = env.get_state()
        terminated = False
        trajectory = []
        trajectory.append((state, 0))
        T = np.inf
        t = 0
        
        while True:
            if t < T:
                action = agent.select_action()
                next_state, reward, terminated = env.step(action)
                trajectory.append((state, reward))
                
                if terminated:
                    T = t + 1
                else:
                    state = next_state
                
            tau = t - n + 1
            if tau >= 0:
                G = 0
                for i in range(tau + 1, min(tau + n + 1, T + 1)):
                    G += np.power(gamma, i - tau - 1) * trajectory[i - 1][1]
                if tau + n < T:
                    G += np.power(gamma, n) * V[trajectory[tau + n][0]]
                V[trajectory[tau][0]] += alpha * (G - V[trajectory[tau][0]])
            if tau == T - 1:
                break
            
            t += 1
    return V

In [7]:
print("MC:", MC_evaluation())
print("TD(0):", One_Step_TD_evaluation())
for n in [2, 3, 4, 5, 10, 20]:
    print(f"{n}-step TD:", N_Step_TD_evaluation(n))

100%|██████████| 50000/50000 [00:02<00:00, 23713.02it/s]


MC: [[-58.83863793 -58.76293453 -57.47678121 -55.27356081]
 [-58.49380835 -55.36528549 -51.59867416 -46.9189773 ]
 [-54.14887321 -51.70238293 -43.69344589 -31.94178156]
 [-50.50331016 -45.00110269 -30.35996954   0.        ]]


100%|██████████| 50000/50000 [00:02<00:00, 20909.30it/s]


TD(0): [[-58.24218197 -56.2696774  -53.21787154 -50.6497432 ]
 [-56.33610727 -53.49983492 -48.69766837 -43.98591504]
 [-53.38461657 -48.84019297 -39.62675312 -29.38989848]
 [-50.85609321 -44.35489245 -29.18641246   0.        ]]


100%|██████████| 50000/50000 [00:11<00:00, 4201.46it/s]


2-step TD: [[-58.97615095 -57.16265775 -54.04242489 -51.47193741]
 [-57.07993237 -54.27139759 -49.34511961 -44.68002261]
 [-54.19273035 -49.59521942 -41.09070014 -39.96876671]
 [-51.1754361  -45.15644746 -38.86775757   0.        ]]


100%|██████████| 50000/50000 [00:14<00:00, 3414.29it/s]


3-step TD: [[-58.76535374 -56.67811873 -53.76951313 -51.23389203]
 [-56.87954977 -54.10646931 -48.79871629 -44.5217288 ]
 [-53.65931363 -48.98768838 -40.11770205 -38.48560602]
 [-51.61010445 -45.18975118 -39.29081079   0.        ]]


100%|██████████| 50000/50000 [00:17<00:00, 2875.56it/s]


4-step TD: [[-58.43711717 -55.95447468 -52.56769571 -49.82503799]
 [-56.13261163 -52.87183063 -47.97055328 -42.92139439]
 [-53.44159211 -48.65658471 -39.33854354 -37.82579109]
 [-51.23957351 -43.94466911 -37.97164198   0.        ]]


100%|██████████| 50000/50000 [00:19<00:00, 2506.81it/s]


5-step TD: [[-57.64545814 -55.6408881  -53.07337495 -50.0285047 ]
 [-56.0469075  -52.57359134 -48.17039731 -44.0588493 ]
 [-53.03198438 -47.68805735 -38.91878643 -38.43445102]
 [-50.41189012 -42.2798974  -37.23235966   0.        ]]


100%|██████████| 50000/50000 [00:32<00:00, 1546.66it/s]


10-step TD: [[-60.05326182 -57.4514421  -54.30990279 -51.38334077]
 [-58.12224859 -54.70367833 -50.60562422 -46.37460811]
 [-54.74309817 -49.97377345 -41.59657487 -39.83903496]
 [-51.53638691 -44.81786411 -39.16800497   0.        ]]


100%|██████████| 50000/50000 [00:53<00:00, 935.46it/s] 

20-step TD: [[-59.19236005 -57.30449697 -54.23283098 -50.41387899]
 [-56.67984293 -54.43201054 -50.32083253 -45.18343911]
 [-54.91599469 -50.58953781 -42.45278647 -40.480507  ]
 [-52.81282423 -48.36942643 -42.16864308   0.        ]]



