In [1]:
import random

In [2]:
class SarsaLearning:
    def __init__(self, length):
        self.length = length
        self.start_state = 0
        self.end_state = length - 1
        self.actions = ['LEFT', 'RIGHT']

        self.alpha = 0.1  # Learning rate
        self.gamma = 0.9  # Discount factor
        self.epsilon = 0.1  # Exploration rate
        # self.q_table = {state: {action: 0.0 for action in self.actions} for state in range(self.length)}
        self.q_table = {-2: {'LEFT': 0.0, 'RIGHT': 0.0}, -1: {'LEFT': 0.0, 'RIGHT': 0.0},
                        0: {'LEFT': 0.0, 'RIGHT': 0.0}, 1: {
                'LEFT': 0.0, 'RIGHT': 0.0}, 2: {'LEFT': 0.0, 'RIGHT': 0.0}}

        print(self.q_table)

    def get_next_state(self, state, action):
        if action == 'LEFT':
            if state == -2:
                return -1

            return state - 1
        elif action == 'RIGHT':
            if state == 2:
                return 1

            return state + 1

    def get_reward(self, state):
        if state == -2:
            return 1
        else:
            return 0

    def is_terminal(self, state):
        return state == -2

    def choose_action(self, state):
        if random.random() < self.epsilon:
            return random.choice(self.actions)
        else:
            return max(self.q_table[state], key=self.q_table[state].get)

    def update_q_value(self, state, action, reward, next_state, next_action):
        old_q_value = self.q_table[state][action]
        next_q_value = self.q_table[next_state][next_action]
        new_q_value = old_q_value + self.alpha * (reward + self.gamma * next_q_value - old_q_value)
        self.q_table[state][action] = new_q_value

    def train(self, episodes):
        for episode in range(episodes):
            state = self.start_state
            action = self.choose_action(state)
            while not self.is_terminal(state):
                next_state = self.get_next_state(state, action)
                reward = self.get_reward(next_state)
                next_action = self.choose_action(next_state)
                self.update_q_value(state, action, reward, next_state, next_action)
                state = next_state
                action = next_action

    def get_best_action(self, state):
        return max(self.q_table[state], key=self.q_table[state].get)

    def print_q_table(self):
        for state in range(-2, 2):
            print(f"State {state}: {self.q_table[state]}")

    def print_optimal_policy(self):
        for state in range(-2, 2):
            if self.is_terminal(state):
                print(f"State {state}: Goal")
            else:
                best_action = self.get_best_action(state)
                print(f"State {state}: Best action = {best_action}")

In [3]:
q = SarsaLearning(1000)
q.train(100)
q.print_q_table()
q.print_optimal_policy()

{-2: {'LEFT': 0.0, 'RIGHT': 0.0}, -1: {'LEFT': 0.0, 'RIGHT': 0.0}, 0: {'LEFT': 0.0, 'RIGHT': 0.0}, 1: {'LEFT': 0.0, 'RIGHT': 0.0}, 2: {'LEFT': 0.0, 'RIGHT': 0.0}}
State -2: {'LEFT': 0.0, 'RIGHT': 0.0}
State -1: {'LEFT': 0.9999734386011123, 'RIGHT': 0.22712169157471837}
State 0: {'LEFT': 0.8965085071648758, 'RIGHT': 0.04013218030692066}
State 1: {'LEFT': 0.2743767040417963, 'RIGHT': 0.0}
State -2: Goal
State -1: Best action = LEFT
State 0: Best action = LEFT
State 1: Best action = LEFT
