In [1]:
import numpy as np
import random
import time

random.seed(time.time_ns())

NUM_STATES = 100
NUM_ACTIONS = 4
GOAL_STATE = 29
GOAL_STATE_REWARD = 100

def display_grid(pi):
    mapp = ['^', '>', 'v', '<']
    offset = 1

    for i in range(0, 25, 5): 

        for j in range(i, i + 5):
            # print(f' {j}')
            val = mapp[np.argmax(pi[j])]
            spaces = ' ' * (offset - len(val))
            print(f' {val}{spaces}', end='')

        if i != 10: 
            print('  ', end='')
        else:
            print(' -', end='')

        for j in range(i + 25, i + 30):
            # print(f' {j}')
            val = mapp[np.argmax(pi[j])]
            spaces = ' ' * (offset - len(val))
            print(f' {val}{spaces}', end='')

        print()

    pos1 = ' ' * 4
    pos2 = ' ' * 5
    print(f'{pos1} | {pos1}', end='')
    print(f'{pos2} | {pos2}')


    for i in range(50, 75, 5): 

        for j in range(i, i + 5):
            # print(f' {j}')
            val = mapp[np.argmax(pi[j])]
            spaces = ' ' * (offset - len(val))
            print(f' {val}{spaces}', end='')

        if i != 60: 
            print('  ', end='')
        else:
            print(' -', end='')

        for j in range(i + 25, i + 30):
            # print(f' {j}')
            val = mapp[np.argmax(pi[j])]
            spaces = ' ' * (offset - len(val))
            print(f' {val}{spaces}', end='')

        print()

class TransitionProbability: 
    def __init__(self, p1, p2):
        self.p1 = p1
        self.p2 = p2
        self.alt = (1 - p1 - p2) / 2
        self.p = [[[0 for s in range(NUM_STATES)] for a in range(NUM_ACTIONS)] for sp in range(NUM_STATES)]

    def action_stay(self):
        ''' For whatever action, p2 is the probability of staying '''
        for state_index in range(1, NUM_STATES):
            for action_index in range(0, NUM_ACTIONS):
                self.p[state_index][action_index][state_index] = self.p2

    def action_up(self, state_offset=0):
        action_offset = 5
        action_id = 0

        for state_index in [6,7,8, 11,12,13, 16,17,18 ,21,22,23]:
            self.p[state_index + state_offset][action_id][state_index + state_offset - action_offset] = self.p1
            self.p[state_index + state_offset][action_id][state_index + state_offset - action_offset + 1] = self.alt
            self.p[state_index + state_offset][action_id][state_index + state_offset - action_offset - 1] = self.alt

        for state_index in [0, 1, 2, 3, 4]:
            self.p[state_index + state_offset][action_id][state_index + state_offset] = self.p1 + self.p2

            if state_index != 4 and state_index != 0:
                self.p[state_index + state_offset][action_id][state_index + state_offset - 1] = self.alt
                self.p[state_index + state_offset][action_id][state_index + state_offset + 1] = self.alt
        
        self.p[4 + state_offset][action_id][9 + state_offset] = self.p[4 + state_offset][action_id][3 + state_offset] = self.alt
        self.p[0 + state_offset][action_id][5 + state_offset] = self.p[0 + state_offset][action_id][1 + state_offset] = self.alt

        for state_index in [5, 10, 15, 20]:
            self.p[state_index + state_offset][action_id][state_index + state_offset - action_offset] = self.p1 + self.alt 
            self.p[state_index + state_offset][action_id][state_index + state_offset - action_offset + 1] = self.alt 

        for state_index in [9, 14, 19, 24]:
            self.p[state_index + state_offset][action_id][state_index + state_offset - action_offset] = self.p1 + self.alt 
            self.p[state_index + state_offset][action_id][state_index + state_offset - action_offset - 1] = self.alt 
    
    def action_down(self, state_offset=0):
        action_offset = 5
        action_id = 2

        for state_index in [1,2,3, 6,7,8, 11,12,13, 16,17,18]:
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset] = self.p1
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset - 1] = self.alt
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset + 1] = self.alt
        
        for state_index in [20, 21, 22, 23, 24]:
            self.p[state_index + state_offset][action_id][state_index + state_offset] = self.p1 + self.p2

            if state_index != 20 and state_index != 24:
                self.p[state_index + state_offset][action_id][state_index + state_offset - 1] = self.alt
                self.p[state_index + state_offset][action_id][state_index + state_offset + 1] = self.alt
        
        self.p[20 + state_offset][action_id][15 + state_offset] = self.p[20 + state_offset][action_id][21 + state_offset] = self.alt
        self.p[24 + state_offset][action_id][19 + state_offset] = self.p[24 + state_offset][action_id][23 + state_offset] = self.alt

        for state_index in [0, 5, 10, 15]:
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset] = self.p1 + self.alt
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset - 1] = self.alt

        for state_index in [4, 9, 14, 19]:
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset] = self.p1 + self.alt
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset - 1] = self.alt

    
    def action_right(self, state_offset=0):
        action_offset = 1
        action_id = 1

        for state_index in [5,6,7,8, 10,11,12,13, 15,16,17,18]:
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset] = self.p1
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset - 5] = self.alt
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset + 5] = self.alt

        for state_index in [0, 1, 2, 3]:
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset] = self.p1 + self.alt
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset + 5] = self.alt

        for state_index in [20, 21, 22, 23]:
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset] = self.p1 + self.alt
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset - 5] = self.alt

        for state_index in [4, 9, 14, 19, 24]:
            self.p[state_index + state_offset][action_id][state_index + state_offset] = self.p1 + self.p2

            if state_index != 4 and state_index != 24:
                self.p[state_index + state_offset][action_id][state_index + state_offset - 5] = self.alt
                self.p[state_index + state_offset][action_id][state_index + state_offset + 5] = self.alt

        self.p[4 + state_offset][action_id][9 + state_offset] = self.p[4 + state_offset][action_id][3 + state_offset] = self.alt
        self.p[24 + state_offset][action_id][23 + state_offset] = self.p[24 + state_offset][action_id][19 + state_offset] = self.alt

        
    def action_left(self, state_offset=0):
        action_offset = -1
        action_id = 3

        for state_index in [6,7,8,9, 11,12,13,14, 16,17,18,19]:
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset] = self.p1
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset - 5] = self.alt
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset + 5] = self.alt

        for state_index in [1, 2, 3, 4]:
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset] = self.p1 + self.alt
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset + 5] = self.alt

        for state_index in [21, 22, 23, 24]:
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset] = self.p1 + self.alt
            self.p[state_index + state_offset][action_id][state_index + state_offset + action_offset - 5] = self.alt

        for state_index in [0, 5, 10, 15, 20]:
            self.p[state_index + state_offset][action_id][state_index + state_offset] = self.p1 + self.p2

            if state_index != 20 and state_index != 0:
                self.p[state_index + state_offset][action_id][state_index + state_offset - 5] = self.alt
                self.p[state_index + state_offset][action_id][state_index + state_offset + 5] = self.alt

        self.p[0 + state_offset][action_id][1 + state_offset] = self.p[0 + state_offset][action_id][5 + state_offset] = self.alt
        self.p[20 + state_offset][action_id][15 + state_offset] = self.p[20 + state_offset][action_id][21 + state_offset] = self.alt


    def get_probs(self,  state_offset=0, inplace=False):
        if not inplace: 
            self.action_stay()
        
        self.action_down(state_offset)
        self.action_up(state_offset)
        self.action_left(state_offset)
        self.action_right(state_offset)

        if not inplace: return self.p

    def _connect(self, s1, s2, action_id):
        self.p[s1][action_id][s2] = self.p1
        self.p[s1][action_id][s2 - 5] = self.alt
        self.p[s1][action_id][s2 + 5] = self.alt
        self.p[s1][action_id][s1] = self.p2
        self.p[s1][action_id][s1 - 5] = 0
        self.p[s1][action_id][s1 + 5] = 0


    def connections(self):
        self._connect(14, 35, 1)
        self._connect(35, 14, 3)

        self._connect(64, 85, 1)
        self._connect(85, 64, 3)

        self._connect(22, 52, 2)
        self._connect(52, 22, 0)

        self._connect(47, 77, 2)
        self._connect(77, 47, 0)

    def four_room(self):
        self.action_stay()
        
        for i in range(0, 4):
            self.get_probs(i * 25, True)

        self.connections()

        return self.p

In [2]:
class Environment:
    def __init__(self, p1, p2):
        self.P1 = p1
        self.P2 = p2
        
        self.tprobs = TransitionProbability(p1, p2).four_room()

        self.actions = [0, 1, 2, 3]

    def _init_state(self):
        return random.choice([i for i in range(0, NUM_STATES) if i != GOAL_STATE])

    def _observe(self, state, action):
        next_state = random.choices([i for i in range(0, NUM_STATES)], self.tprobs[state][action])[0]

        reward = -1
        if next_state == GOAL_STATE:
            reward = GOAL_STATE_REWARD

        return next_state, reward

In [3]:
class SARSAAgent:
    def __init__(self, epsilon, alpha, gamma, env):
        self.Eps = epsilon
        self.Alpha = alpha
        self.Gamma = gamma

        self.env = env
        
        # Initializing Q to arbitrary, for goal state it should be 0 for all actions
        self._arbitrary_Q()
        
         # 0 = up, 1 = right, 2 = down, 3 = left

    def _arbitrary_Q(self):
        self.Q = [[random.uniform(0.5, 1) for _ in range(NUM_ACTIONS)] for _ in range(NUM_STATES)]
        for action in range(NUM_ACTIONS):
            self.Q[GOAL_STATE][action] = 0

    def _e_greedy_choice(self, state):
        if self.Eps < random.uniform(0, 1):
            action = np.argmax(self.Q[state])
        else:
            action = random.choice([0,1,2,3])

        return action
    
    def episode_step(self, state, A):
        # Take action A, observe R, S'
        S_prime, R = self.env._observe(state, A)

        A_prime = self._e_greedy_choice(S_prime)

        # Update Q
        expected_reward = R + self.Gamma * self.Q[S_prime][A_prime]
        self.Q[state][A] = self.Q[state][A] + self.Alpha * (expected_reward - self.Q[state][A])

        return S_prime, A_prime

def run_sarsa(episodes=1000, timesteps=70):
    env = Environment(1, 0)
    SARSALeaner = SARSAAgent(0.1, 0.1, 0.9, env)

    for episode_index in range(episodes):
        current_state = env._init_state()
        current_action = SARSALeaner._e_greedy_choice(current_state)

        for step in range(timesteps):
            next_state, next_action = SARSALeaner.episode_step(current_state, current_action)

            if next_state == GOAL_STATE: break

            current_state = next_state
            current_action = next_action 

    return SARSALeaner

In [4]:
sarsa = run_sarsa()

In [5]:
def map_q_to_pi(q):
    pi = [[0] * NUM_ACTIONS for _ in range(NUM_STATES)]

    for state in range(NUM_STATES):
        pi[state][np.argmax(q[state])] = 1
    
    return pi

display_grid(map_q_to_pi(sarsa.Q))

 v > v > v   > > > > ^
 > > v > v   ^ ^ > ^ ^
 > > > > > - ^ > ^ ^ ^
 ^ ^ > ^ ^   ^ > ^ ^ ^
 ^ ^ ^ ^ ^   ^ ^ ^ ^ <
     |           |      
 > > ^ < <   > > ^ < <
 ^ > ^ ^ v   > ^ ^ < <
 ^ > ^ > > - > ^ ^ ^ <
 < > ^ > ^   ^ > ^ < <
 ^ v ^ v ^   ^ ^ ^ ^ ^
