### Importing Packages

In [1]:
import random


### Define the environment


In [2]:
n_states = 10
n_actions = 2
state_space = range(n_states)
action_space = range(n_actions)


### Initialize Q values


In [3]:
Q = {}
for s in state_space:
    for a in action_space:
        Q[(s,a)] = 0.0


### Set hyperparameters


In [4]:
alpha = 0.5  # learning rate
epsilon = 0.1  # exploration rate
gamma = 0.9  # discount factor

### Define the SARSA algorithm


In [5]:
def sarsa(num_episodes):
    for episode in range(num_episodes):
        s = random.choice(state_space)  # start in a random state
        if random.random() < epsilon:
            a = random.choice(action_space)  # choose a random action
        else:
            # choose the action with the highest Q value
            a = max(action_space, key=lambda x: Q[(s,x)])

        while True:
            # take the chosen action and observe the next state and reward
            if a == 0:
                r = s
                s_next = max(s-1, 0)
            else:
                r = n_states - s
                s_next = min(s+1, n_states-1)

            # choose the next action based on epsilon-greedy policy
            if random.random() < epsilon:
                a_next = random.choice(action_space)
            else:
                a_next = max(action_space, key=lambda x: Q[(s_next,x)])

            # update Q value for current state-action pair
            Q[(s,a)] += alpha * (r + gamma*Q[(s_next,a_next)] - Q[(s,a)])

            # update state and action
            s = s_next
            a = a_next

            # end the episode when the goal state is reached
            if s == n_states-1:
                break

    return Q


### Run SARSA algorithm


In [6]:
num_episodes = 1000
Q = sarsa(num_episodes)

### Print learned Q values


In [7]:
for s in state_space:
    print("State", s, "Q values:", [Q[(s,a)] for a in action_space])

State 0 Q values: [58.865880762761606, 65.89476824333084]
State 1 Q values: [60.077313766619596, 61.97680547888258]
State 2 Q values: [57.683914727105005, 59.00006780218227]
State 3 Q values: [55.79032265275241, 56.70079323530368]
State 4 Q values: [55.01438170943484, 55.219267913067796]
State 5 Q values: [54.65042745756463, 54.61518002973817]
State 6 Q values: [55.226379827516, 54.76086898905501]
State 7 Q values: [56.705202308766104, 54.33572621735438]
State 8 Q values: [59.021542181953, 56.49051161836125]
State 9 Q values: [62.108374365686004, 55.85990417100403]
