Randow Walk
---
## n-step TD Method

<img style="float" src="rw-game.png" alt="drawing" width="700"/>

In this MRP, all episodes start in the center state, C, then proceed either left or right by one state on each step, with equal probability. Episodes terminate either on the extreme left or the extreme right. When an episode terminates on the right, a reward of +1 occurs; all other rewards are zero.

<img style="float" src="n-step.png" alt="drawing" width="700"/>

In [3]:
import numpy as np

In [2]:
# 19 states 
NUM_STATES = 19
START = 9
END_0 = 0
END_1 = 18

In [50]:
class RandomWalk:
    
    def __init__(self, n, start=START, end=False, exp_rate=0.3, lr=0.1, gamma=1):
        self.actions = ["right", "left"]
        self.state = start  # current state
        self.end = end
        self.n = n
        self.exp_rate = exp_rate
        self.lr = lr
        self.gamma = gamma
        self.state_actions = []
        # init q estimates
        self.Q_values = {}
        for i in range(NUM_STATES):
            self.Q_values[i] = {}
            for a in self.actions:
                self.Q_values[i][a] = 0
                
    def chooseAction(self):
        # epsilon-greedy
        mx_nxt_reward = -999
        action = ""
        
        if np.random.uniform(0, 1) <= self.exp_rate:
            action = np.random.choice(self.actions)
        else:
            # greedy action
            for a in self.actions:
                current_position = self.state
                nxt_reward = self.Q_values[current_position][a]
                if nxt_reward > mx_nxt_reward:
                    action = a
                    mx_nxt_reward = nxt_reward
                if nxt_reward == mx_nxt_reward:
                    # make not affected by the initialised actions order
                    action = np.random.choice(self.actions)
        return action 
    
    def takeAction(self, action):
        new_state = self.state
        if not self.end:
            if action == "left":
                new_state = self.state-1
            else:
                new_state = self.state+1
            
            if new_state in [END_0, END_1]:
                self.end = True
        self.state = new_state
        return self.state
    
    def giveReward(self):
        if self.state == END_0:
            return 0
        if self.state == END_1:
            return 1
        # other states
        return 0
    
    def reset(self):
        self.state = START
        self.end = False
           
    def play(self, rounds=100):
        for _ in range(rounds):
            self.reset()
            t = 0
            T = np.inf
            action = self.chooseAction()
            
            actions = [action]
            states = [self.state]
            rewards = [0]
            while True:
                if t < T:
                    state = self.takeAction(action)  # next state
                    reward = self.giveReward()  # next state-reward
                    
                    states.append(state)
                    rewards.append(reward)
                    
                    if self.end:
                        T = t+1
                    else:
                        action = self.chooseAction()
                        actions.append(action)  # next action
                # state tau being updated
                tau = t - self.n + 1
                if tau >= 0:
                    G = 0
                    for i in range(tau+1, min(tau+self.n+1, T+1)):
                        G += np.power(self.gamma, i-tau-1)*rewards[i]
                    if tau+self.n < T:
                        state_action = (states[tau+self.n], actions[tau+self.n])
                        G += np.power(self.gamma, self.n)*self.Q_values[state_action[0]][state_action[1]]
                    # update Q values
                    state_action = (states[tau], actions[tau])
                    self.Q_values[state_action[0]][state_action[1]] += self.lr*(G-self.Q_values[state_action[0]][state_action[1]])
                
                if tau == T-1:
                    break
                
                t += 1

In [51]:
rw = RandomWalk(n=3)
rw.play(10)

8
left
7
right
8
left
7
right
8
right
9
left
8
right
9
left
8
left
7
right
8
left
7
right
8
right
9
left
8
right
9
left
8
left
7
left
6
right
7
left
6
right
7
right
8
right
9
left
8
right
9
right
10
left
9
left
8
right
9
left
8
right
9
left
8
left
7
left
6
left
5
right
6
right
7
right
8
left
7
right
8
left
7
right
8
right
9
right
10
right
11
left
10
left
9
left
8
left
7
right
8
right
9
right
10
left
9
left
8
right
9
left
8
left
7
right
8
left
7
right
8
left
7
left
6
right
7
left
6
left
5
left
4
right
5
right
6
left
5
right
6
right
7
left
6
left
5
left
4
left
3
right
4
left
3
right
4
left
3
right
4
left
3
left
2
left
1
left
0
10
right
11
left
10
left
9
left
8
left
7
left
6
left
5
right
6
left
5
left
4
left
3
right
4
left
3
right
4
right
5
left
4
left
3
left
2
right
3
left
2
right
3
left
2
left
1
right
2
right
3
left
2
left
1
left
0
8
left
7
right
8
left
7
left
6
left
5
left
4
left
3
left
2
right
3
left
2
left
1
left
0
8
left
7
right
8
right
9
left
8
right
9
left
8
left
7
left
6
left
5
r

In [52]:
rw.Q_values

{0: {'right': 0, 'left': 0},
 1: {'right': 0.0, 'left': 0.0},
 2: {'right': 0.0, 'left': 0.0},
 3: {'right': 0.0, 'left': 0.0},
 4: {'right': 0.0, 'left': 0.0},
 5: {'right': 0.0, 'left': 0.0},
 6: {'right': 0.0, 'left': 0.0},
 7: {'right': 0.0021628906358292154, 'left': 0.0},
 8: {'right': 0.006299897923822763, 'left': 0.0018576683301777784},
 9: {'right': 0.010970905357511235, 'left': 0.003192969863174953},
 10: {'right': 0.013157134073890749, 'left': 0.007630349691143989},
 11: {'right': 0.025244426971728278, 'left': 0.007257494292953509},
 12: {'right': 0.027499127309872758, 'left': 0.012660401872068757},
 13: {'right': 0.07764635339809756, 'left': 0.02132244746355568},
 14: {'right': 0.06240818197880054, 'left': 0.0744978247734191},
 15: {'right': 0.33653808048235845, 'left': 0.049948013984016715},
 16: {'right': 0.45535690804823586, 'left': 0.08608119060794119},
 17: {'right': 0.46855900000000006, 'left': 0.03948410089424843},
 18: {'right': 0, 'left': 0}}