Randow Walk
---
## n-step TD Method

<img style="float" src="rw-game.png" alt="drawing" width="700"/>

In this MRP, all episodes start in the center state, C, then proceed either left or right by one state on each step, with equal probability. Episodes terminate either on the extreme left or the extreme right. When an episode terminates on the right, a reward of +1 occurs; all other rewards are zero.

<img style="float" src="n-step.png" alt="drawing" width="700"/>

In [1]:
import numpy as np

In [24]:
# 19 states (not including the ending state)
NUM_STATES = 19
START = 9
END_0 = 0
END_1 = 20

In [32]:
class RandomWalk:
    
    def __init__(self, n, start=START, end=False, lr=0.1, gamma=1):
        self.actions = ["left", "right"]
        self.state = start  # current state
        self.end = end
        self.n = n
        self.lr = lr
        self.gamma = gamma
        self.state_actions = []
        # init q estimates
        self.Q_values = {}
        for i in range(NUM_STATES+2):  
            self.Q_values[i] = {}
            for a in self.actions:
                self.Q_values[i][a] = 0
                
    def chooseAction(self):    
        action = np.random.choice(self.actions)
        return action 
    
    def takeAction(self, action):
        new_state = self.state
        if not self.end:
            if action == "left":
                new_state = self.state-1
            else:
                new_state = self.state+1
            
            if new_state in [END_0, END_1]:
                self.end = True
        self.state = new_state
        return self.state
    
    def giveReward(self):
        if self.state == END_0:
            return -1
        if self.state == END_1:
            return 1
        # other states
        return 0
    
    def reset(self):
        self.state = START
        self.end = False
           
    def play(self, rounds=100):
        for _ in range(rounds):
            self.reset()
            t = 0
            T = np.inf
            action = self.chooseAction()
            
            actions = [action]
            states = [self.state]
            rewards = [0]
            while True:
                if t < T:
                    state = self.takeAction(action)  # next state
                    reward = self.giveReward()  # next state-reward
                    
                    states.append(state)
                    rewards.append(reward)
                    
                    if self.end:
                        print("End at state {} | number of states {}".format(state, len(states)))
                        T = t+1
                    else:
                        action = self.chooseAction()
                        actions.append(action)  # next action
                # state tau being updated
                tau = t - self.n + 1
                if tau >= 0:
                    G = 0
                    for i in range(tau+1, min(tau+self.n+1, T+1)):
                        G += np.power(self.gamma, i-tau-1)*rewards[i]
                    if tau+self.n < T:
                        state_action = (states[tau+self.n], actions[tau+self.n])
                        G += np.power(self.gamma, self.n)*self.Q_values[state_action[0]][state_action[1]]
                    # update Q values
                    state_action = (states[tau], actions[tau])
                    self.Q_values[state_action[0]][state_action[1]] += self.lr*(G-self.Q_values[state_action[0]][state_action[1]])
                
                if tau == T-1:
                    break
                
                t += 1

In [33]:
rw = RandomWalk(n=3)
rw.play(100)

End at state 0 | number of states 106
End at state 0 | number of states 188
End at state 20 | number of states 40
End at state 0 | number of states 26
End at state 20 | number of states 182
End at state 20 | number of states 222
End at state 0 | number of states 80
End at state 0 | number of states 26
End at state 20 | number of states 24
End at state 20 | number of states 46
End at state 20 | number of states 94
End at state 0 | number of states 98
End at state 0 | number of states 52
End at state 20 | number of states 68
End at state 20 | number of states 192
End at state 20 | number of states 136
End at state 20 | number of states 56
End at state 0 | number of states 126
End at state 0 | number of states 84
End at state 0 | number of states 80
End at state 0 | number of states 376
End at state 0 | number of states 130
End at state 20 | number of states 48
End at state 20 | number of states 44
End at state 20 | number of states 132
End at state 20 | number of states 48
End at state 2

In [34]:
rw.Q_values

{0: {'left': 0, 'right': 0},
 1: {'left': -0.997534965295042, 'right': -0.7239519692230982},
 2: {'left': -0.8392118401720433, 'right': -0.6394526508921221},
 3: {'left': -0.7554599294286835, 'right': -0.5419782947847438},
 4: {'left': -0.6718963513067013, 'right': -0.5201297882185493},
 5: {'left': -0.5782429463877724, 'right': -0.4160937563589277},
 6: {'left': -0.4932368848752557, 'right': -0.388801696560297},
 7: {'left': -0.43554033317074614, 'right': -0.22944565819591475},
 8: {'left': -0.3266028546177513, 'right': -0.14702311616942723},
 9: {'left': -0.2906020444017145, 'right': -0.07640697619864936},
 10: {'left': -0.22761135053524775, 'right': -0.033160735652345084},
 11: {'left': -0.16659698517527222, 'right': 0.03302730515835058},
 12: {'left': -0.06489122906030385, 'right': 0.17646238703182698},
 13: {'left': 0.0605155055503486, 'right': 0.28393965128880116},
 14: {'left': 0.13247040692234296, 'right': 0.42325915736689446},
 15: {'left': 0.2604776914779341, 'right': 0.51770

In [35]:
np.arange(-20, 22, 2) / 20.0

array([-1. , -0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1,  0. ,
        0.1,  0.2,  0.3,  0.4,  0.5,  0.6,  0.7,  0.8,  0.9,  1. ])

In [42]:
np.arange(-18, 20, 2) / 18.0

array([-1.        , -0.88888889, -0.77777778, -0.66666667, -0.55555556,
       -0.44444444, -0.33333333, -0.22222222, -0.11111111,  0.        ,
        0.11111111,  0.22222222,  0.33333333,  0.44444444,  0.55555556,
        0.66666667,  0.77777778,  0.88888889,  1.        ])