Aim: Implementing State Action Reward  State action (SARSA) algorithm using python and compare it with Q Learning.

In [None]:
import numpy as np
class GridWorld:
    def __init__(self, rows, cols, start, goal, obstacles):
        self.rows = rows
        self.cols = cols
        self.start = start
        self.goal = goal
        self.obstacles = obstacles
        self.state = start
        self.is_terminal = False

    def reset(self):
        self.state = self.start
        self.is_terminal = False

    def step(self, action):
        next_state = tuple(np.array(self.state) + np.array(action))

        if next_state == self.goal:
            reward = 1
            self.is_terminal = True
        elif next_state in self.obstacles or not (0 <= next_state[0] < self.rows) or not (0 <= next_state[1] < self.cols):
            reward = -1
            self.is_terminal = False
        else:
            reward = 0
            self.state = next_state
            self.is_terminal = False

        return next_state, reward, self.is_terminal


In [None]:
class SARSAAgent:
    def __init__(self, actions, alpha=0.1, gamma=0.9, epsilon=0.1):
        self.actions = actions
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.q_values = {}
    def choose_action(self, state):
        if np.random.uniform(0, 1) < self.epsilon:
            action_index = np.random.choice(len(self.actions))
            return self.actions[action_index]
        else:
            q_vals = [self.get_q_value((state, a)) for a in self.actions]
            return self.actions[np.argmax(q_vals)]

    def get_q_value(self, sa_pair):
        return self.q_values.get(sa_pair, 0)

    def update_q_value(self, state_action, new_q_value):
        self.q_values[state_action] = new_q_value

    def print_q_values(self):
        print("Q-values Table:")
        for state_action, value in self.q_values.items():
            state, action = state_action
            print(f"State: {state}, Action: {action}, Q-value: {value}")


In [None]:
def train_sarsa(agent, environment, episodes):
    for episode in range(episodes):
        environment.reset()
        state = environment.state
        action = agent.choose_action(state)
        while not environment.is_terminal:
            next_state, reward, is_terminal = environment.step(action)
            next_action = agent.choose_action(next_state)
            next_state_action = (next_state, next_action)
            agent.update_q_value((state, action), (1 - agent.alpha) * agent.get_q_value((state, action)) + agent.alpha * (reward + agent.gamma * agent.get_q_value(next_state_action)))
            state, action = next_state, next_action


In [None]:
if __name__ == "__main__":
    rows, cols = 4, 4
    start = (0, 0)
    goal = (3, 3)
    obstacles = [(1, 1), (2, 1), (2, 2)]

    environment = GridWorld(rows, cols, start, goal, obstacles)
    actions = [(0, 1), (0, -1), (1, 0), (-1, 0)]  # right, left, down, up

    sarsa_agent = SARSAAgent(actions)

    episodes = 1000
    train_sarsa(sarsa_agent, environment, episodes)

    # Print the learned Q-values for SARSA
    sarsa_agent.print_q_values()

    # Test the trained SARSA agent
    environment.reset()
    state = environment.state
    steps = 0

    while not environment.is_terminal and steps < 20:
        action = sarsa_agent.choose_action(state)
        next_state, _, _ = environment.step(action)
        state = next_state
        steps += 1

    print(f"SARSA Agent reached the goal in {steps} steps.")


Q-values Table:
State: (0, 0), Action: (0, 1), Q-value: 0.03747798427751295
State: (0, 1), Action: (0, 1), Q-value: -0.037713090507376895
State: (0, 2), Action: (0, 1), Q-value: -0.0014980724010000012
State: (0, 3), Action: (0, 1), Q-value: -0.271
State: (0, 4), Action: (0, 1), Q-value: -0.1
State: (0, 4), Action: (0, -1), Q-value: -0.00036082463647567104
State: (0, 3), Action: (0, -1), Q-value: -0.0007290000000000002
State: (0, 2), Action: (1, 0), Q-value: -0.009315620662732966
State: (1, 2), Action: (0, 1), Q-value: -0.008100000000000001
State: (1, 3), Action: (0, 1), Q-value: -0.19
State: (1, 4), Action: (0, 1), Q-value: -0.19
State: (1, 4), Action: (0, -1), Q-value: 0.0
State: (1, 3), Action: (0, -1), Q-value: -3.0192994488239993e-05
State: (1, 3), Action: (-1, 0), Q-value: 0.0
State: (0, 3), Action: (1, 0), Q-value: 0.016683084000000004
State: (1, 2), Action: (1, 0), Q-value: -0.19
State: (2, 2), Action: (0, 1), Q-value: 0.9289280992185065
State: (1, 4), Action: (1, 0), Q-value: 0