In [1]:
import numpy as np
import random

class SARSA:
    def __init__(self, states, actions, alpha=0.1, gamma=0.9, epsilon=0.1):
        """
        Initialize SARSA agent.
        :param states: List of all possible states.
        :param actions: List of all possible actions.
        :param alpha: Learning rate.
        :param gamma: Discount factor.
        :param epsilon: Exploration rate for ε-greedy policy.
        """
        self.states = states
        self.actions = actions
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.Q = {state: {action: 0 for action in actions} for state in states}  # Initialize Q-table

    def choose_action(self, state):
        """
        Choose an action using ε-greedy policy.
        :param state: Current state.
        :return: Chosen action.
        """
        if random.uniform(0, 1) < self.epsilon:
            return random.choice(self.actions)  # Explore
        else:
            return max(self.Q[state], key=self.Q[state].get)  # Exploit

    def update(self, state, action, reward, next_state, next_action):
        """
        Update the Q-value using SARSA update rule.
        :param state: Current state.
        :param action: Action taken in the current state.
        :param reward: Reward received after taking the action.
        :param next_state: Next state reached after taking the action.
        :param next_action: Action taken in the next state.
        """
        td_target = reward + self.gamma * self.Q[next_state][next_action]
        td_error = td_target - self.Q[state][action]
        self.Q[state][action] += self.alpha * td_error

    def get_q_value(self, state, action):
        """
        Get the Q-value for a state-action pair.
        :param state: State.
        :param action: Action.
        :return: Q-value.
        """
        return self.Q[state][action]

# Example Usage
if __name__ == "__main__":
    # Define states and actions
    states = ['S1', 'S2', 'S3', 'Terminal']
    actions = ['A1', 'A2']

    # Initialize SARSA agent
    sarsa_agent = SARSA(states, actions)

    # Simulate episodes
    episodes = [
        [('S1', 'A1', 1, 'S2', 'A2'), ('S2', 'A2', 1, 'S3', 'A1'), ('S3', 'A1', 1, 'Terminal', None)],
        [('S1', 'A2', 1, 'S3', 'A1'), ('S3', 'A1', 1, 'Terminal', None)],
    ]

    for episode in episodes:
        for step in episode:
            state, action, reward, next_state, next_action = step
            if next_action is None:  # Terminal state
                sarsa_agent.update(state, action, reward, next_state, next_action='A1')  # Dummy action
            else:
                sarsa_agent.update(state, action, reward, next_state, next_action)

    # Print the learned Q-table
    for state in states:
        for action in actions:
            print(f"Q({state}, {action}) = {sarsa_agent.get_q_value(state, action):.2f}")

Q(S1, A1) = 0.10
Q(S1, A2) = 0.11
Q(S2, A1) = 0.00
Q(S2, A2) = 0.10
Q(S3, A1) = 0.19
Q(S3, A2) = 0.00
Q(Terminal, A1) = 0.00
Q(Terminal, A2) = 0.00
