# SARSA(lambda)

In [1]:
# Find the value function of policy
import numpy as np

# display output
import random
from random import uniform
import time
from IPython.display import display, clear_output

In [2]:
actions = [[-1, 0], [0, 1], [1, 0], [0, -1]] #up, right, down, left = (clockwise from up) 
action_count = len(actions) # total number of actions
gridSize = 5 # create a square grid of gridSize by gridSize
state_count = gridSize*gridSize # total number of states

In [3]:
class Gridworld():
    def __init__(self, gridSize):
        self.valueMap = np.zeros((gridSize, gridSize))
        self.states = [[i, j] for i in range(gridSize) for j in range(gridSize)]
        self.size = gridSize
        self.new_pos = [0, 0] # initialize new position for p_transition
        self.transition_prob = 1 # deterministic
    
    def initial_state(self):        # return initial state
        return grid.states[gridSize*gridSize-1]
   
    def transition_reward(self, current_pos, action): # return the transition probability

        # get next position: state: [0, 0], action: [0, 1], new_state = [0, 1]
        self.new_pos = np.array(current_pos) + np.array(action)

        # normally, reward = 0
        reward = 0

        # if new pos results in off the grid, return reward -1
        if -1 in self.new_pos or self.size in self.new_pos:
            reward = -1
        # if in state A, receive + 10
        if current_pos == [0, 1]:
            reward = 10
        # if in state B, receive + 5
        if current_pos == [0, 3]:
            reward = 5

        # if taking an action crosses the border; agent's new_pos is the same as the current pos
        if -1 in self.new_pos or self.size in self.new_pos: 
            self.new_pos = current_pos
            
        # if in state A, transition to state A'
        if current_pos == [0, 1]:
            self.new_pos = [4, 1]
            
        # if in state B, transition to state B'
        if current_pos == [0, 3]:
            self.new_pos = [2, 3]

        return self.new_pos, reward

In [4]:
# create a grid object
grid = Gridworld(5)

In [5]:
# get initial state (bottom right)
grid.initial_state()

[4, 4]

### SARSA(Lamda)

In [6]:
# initialize q values for all state action pairs
Q_values = np.random.randint(0,1000,size = (state_count, action_count))
Q_values

array([[812, 956, 618, 304],
       [848, 840, 479, 487],
       [536, 876, 133, 449],
       [595, 773, 757, 691],
       [189, 252, 219, 304],
       [926, 957, 878, 357],
       [  6, 297, 294, 771],
       [920, 341, 423, 157],
       [707, 568, 687,  12],
       [846,  14, 370, 116],
       [ 96, 435, 478, 896],
       [511,  67, 839, 929],
       [473, 591, 971, 972],
       [205, 424, 817, 386],
       [307, 352, 827, 783],
       [235, 436, 924, 342],
       [511, 561,  86, 877],
       [861, 362, 926, 834],
       [734, 719, 576, 234],
       [605,  88, 978, 455],
       [737, 675, 176, 834],
       [879, 666, 723, 386],
       [468, 818, 972, 191],
       [459, 960, 366,  23],
       [ 51, 442, 276, 216]])

In [7]:
# intialize parameters
gamma = 0.99
epsilon = 0.8
lamda = 0.9
alpha = 0.1

In [8]:
# iterate 500 times: each time, generating an episode of 200 steps
max_steps = 200

# define variables for keeping track of time steps
Terminal = max_steps
t_list=[]
for i in range(1,max_steps+1):
    t = Terminal - i
    t_list.append(t)

In [9]:
def choose_action(state, epsilon):
    
    # choose an action type: explore or exploit
    action_type = int(np.random.choice(2, 1, p=[epsilon,1-epsilon]))

    # find best action based on Q values
    best_action_index = np.argmax(Q_values[state])

    # pick a random action
    random_action_index = random.choice(range(4))

    # choose an action based on exploit or explore
    if action_type == 0:
        
        # while random action is the same as the best action, pick a new action
        while random_action_index == best_action_index:
            random_action_index = random.choice(range(4))
        
        # explore
        # print("explore")
        action_index = random_action_index
    else:
        # exploit
        # print("exploit")
        action_index = best_action_index
        
    return action_index

In [11]:
# initialize q values for all state action pairs
Q_values = np.zeros((state_count, action_count))

In [12]:
# iteration 500 times
for iteration in range(500):
    
    # initialize delta
    delta = 0
    
    # initialize S,A (? should i choose an Action using epsilon-greedy here or just select an Action?)
    state_vector = grid.initial_state()
    state_index = grid.states.index(state_vector)
    
    # initialize  eligibility traces for all state action pairs of all states to 0
    z_values = np.zeros((state_count, action_count))
    
    action_index = choose_action(state_index, epsilon)
    action_vector = actions[action_index]
    
#     print("EPISODE--------------------------------------------------")
#     print("state: ", state_index)
#     print("action: ", action_index)
    
    # iteration 200 steps of the episode
    for i in range(max_steps):
#         print("STEP-------------------------------------------------")
        # Take action A, oberserve R, S'
        next_state_vector, reward = grid.transition_reward(state_vector, action_vector)
        next_state_index = grid.states.index(list(next_state_vector))
        
#         print("next_state: ", next_state_index)
#         print("reward: ", reward)
        # Choose A' from S' using policy derived from Q (eg. epsilon-greedy)
        next_action_index = choose_action(next_state_index, epsilon)
        next_action_vector = actions[next_action_index]
        
#         print("next_action: ", next_action_index)
        
        # update the action-value form of the TD error
        delta = reward + gamma*Q_values[next_state_index][next_action_index] - Q_values[state_index][action_index]
        
#         print("delta: ", delta)
        
        # accumulate traces (? big S and big A?)
        z_values[state_index][action_index] +=1
        
#         print("z_values: ", z_values[state_index][action_index])
        
        # update Q value
        Q_values[state_index][action_index] = Q_values[state_index][action_index] + alpha*delta*z_values[state_index][action_index]
        
#         print("Q_values: ", Q_values[state_index][action_index])
        
        # update z value
        z_values[state_index][action_index] = gamma*lamda*z_values[state_index][action_index]
        
#         print("z_values: ", z_values[state_index][action_index])
        
        # update state and action vector
        state_vector = list(next_state_vector)
        state_index = grid.states.index(state_vector)
        action_vector = list(next_action_vector)
        action_index = next_action_index

In [13]:
z_values

array([[0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        ],
       [0.        , 1.684881  , 0.        , 0.        ],
       [0.        , 0.        , 1.684881  , 0.891     ],
       [0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.891     , 0.        , 0.        ],
       [0.        , 2.39222897, 1.684881  , 0.891     ],
       [0.891     , 2.39222897, 1.684881  , 0.891     ],
       [0.891     , 0.891     , 3.02247601, 1.684881  ],
       [0.        , 0.        , 0.891     , 0.891     ],
       [0.        , 0.        , 0.891     , 0.        ],
       [1.684881  , 0.        , 0.891     , 0.        ],
       [3.02247601, 0.891     , 4.08436728, 4.08436728],
       [1.684881  , 3.02247601, 0.891     , 3.58402613],
       [0.891     , 1.684881  ,

In [14]:
np.set_printoptions(precision=2)
Q_values

array([[ -4.69,   0.46,  -4.62,  -3.85],
       [  0.54,   0.59,   0.5 ,   0.78],
       [ -2.94,  -2.07,  -4.49,   0.71],
       [ -2.15,  -2.17,  -2.43,  -2.54],
       [ -6.97,  -5.4 ,  -6.06,  -2.35],
       [ -3.14,  -4.99,  -8.05,  -6.85],
       [  0.49,  -4.35,  -7.53,  -5.6 ],
       [ -2.52,  -5.76,  -7.75,  -3.67],
       [ -2.44,  -6.23,  -7.54,  -4.67],
       [ -5.98,  -7.01,  -7.84,  -4.33],
       [ -5.48,  -8.05,  -8.77,  -8.62],
       [ -4.37,  -7.83,  -8.67,  -8.53],
       [ -5.02,  -8.3 ,  -8.9 ,  -6.24],
       [ -5.63,  -8.43,  -9.68,  -6.98],
       [ -7.36,  -8.95,  -9.4 ,  -7.45],
       [ -7.67,  -8.68,  -9.67,  -9.48],
       [ -7.98,  -8.47,  -9.12,  -9.73],
       [ -7.41,  -8.89,  -9.76,  -8.29],
       [ -8.07,  -9.54,  -9.91,  -9.41],
       [ -7.71, -10.13,  -9.88,  -9.17],
       [ -8.74,  -9.5 , -11.34, -12.75],
       [ -8.53,  -9.68, -10.62,  -9.74],
       [ -8.29, -10.38,  -9.93,  -9.76],
       [ -8.78,  -9.75, -10.97,  -9.83],
       [ -9.36, 

In [15]:
# PRINT POLICY TABLE ################################################################################
# import pandas library
import pandas as pd
# define column and index
columns=range(grid.size)
index = range(grid.size)
# define dataframe to represent policy table
policy_table = pd.DataFrame(index = index, columns=columns)

# iterate through policy to make a table that represents action number
# as action name (eg. left, right, up, down)
for state in range(len(Q_values)):
    
    # find the best action at each state
    best_action = np.argmax(Q_values[state])

    # get action name
    if best_action == 0:
        action_name = 'up'
    elif best_action == 1:
        action_name = 'right'
    elif best_action == 2:
        action_name = 'down'
    else:
        action_name = 'left'

    # calculate the row and column coordinate of the current state number
    row = int(state/grid.size)
    column = round((state/grid.size - int(state/grid.size))*grid.size)
            
    # assign action name
    policy_table.loc[row][column] = action_name

print("Policy Table: ")
print(policy_table)
print()

Policy Table: 
       0     1     2   3     4
0  right  left  left  up  left
1     up    up    up  up  left
2     up    up    up  up    up
3     up    up    up  up    up
4     up    up    up  up  left

