In [1]:
from abc import ABC, abstractmethod
import numpy as np
import random
import sys

In [2]:
class BaseEnvironment(ABC):
    
    @abstractmethod
    # Override this method to return the set of states
    def get_states(self):
        pass
    
    @abstractmethod
    # Override this method to return the set of actions available in the state
    def get_actions(self, state):
        pass
    
    @abstractmethod
    def get_all_actions(self):
        pass
    
    @abstractmethod
    def transition(self, action):
        pass
    
    @abstractmethod
    # Ovveride this method to implement action execution
    def do_action_and_get_reward(self, action):
        pass

In [3]:
class GridWorld(BaseEnvironment):
    
    UP = (-1,0)
    DOWN = (1,0)
    LEFT = (0,-1)
    RIGHT = (0,1)
    
    label = {(-1,0):"UP", (1,0):"DOWN", (0,-1):"LEFT", (0,1):"RIGHT"}
    
    def __init__(self):
        self.states = list()
        for row in range(5):
            for col in range(5):
                self.states.append((row,col))
        self.actions = dict()
        self.actions[(0,0)] = [GridWorld.RIGHT, GridWorld.DOWN]
        self.actions[(0,1)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN]
        self.actions[(0,2)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN]
        self.actions[(0,3)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN]
        self.actions[(0,4)] = [GridWorld.LEFT, GridWorld.DOWN]
        self.actions[(1,0)] = [GridWorld.RIGHT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(1,1)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(1,2)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(1,3)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(1,4)] = [GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(2,0)] = [GridWorld.RIGHT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(2,1)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(2,2)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(2,3)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(2,4)] = [GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(3,0)] = [GridWorld.RIGHT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(3,1)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(3,2)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(3,3)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(3,4)] = [GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(4,0)] = [GridWorld.RIGHT, GridWorld.UP]
        self.actions[(4,1)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.UP]
        self.actions[(4,2)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.UP]
        self.actions[(4,3)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.UP]
        self.actions[(4,4)] = [GridWorld.LEFT, GridWorld.UP]
        self.all_actions = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.current_state = (0,0)
    
    def get_states(self):
        return self.states
        
    def get_actions(self, state):
        return self.actions[state]
    
    def get_all_actions(self):
        return self.all_actions
    
    def transition(self, current_state, action):
        new_state = (current_state[0] + action[0], current_state[1] + action[1])
        # Check if the new state is within the grid
        if new_state in self.states:
            return new_state
        # If the new state is off the grid, return the current state
        else:
            return current_state
    
    def do_action_and_get_reward(self, action):
        # If the current state is A or B, move to A' or B' respectively and return the corresponding reward
        if self.current_state == (0,1):
            self.current_state = (4, 1)
            return self.current_state, 10
        elif self.current_state == (0,3):
            self.current_state = (2, 3)
            return self.current_state, 5

        # Compute the next state based on the action
        new_state = self.transition(self.current_state, action)
        
        # If the action would take the agent off the grid, return reward -1
        if new_state == self.current_state:
            return self.current_state, -1
        
        # For all other states, update the current state and return reward 0
        else:
            self.current_state = new_state
            return self.current_state, 0
        
    def reset(self):
        self.current_state = (0,0)

In [4]:
class BasePolicy(ABC):
    
    @abstractmethod
    # Ovveride this method to implement policy application
    # Returns the action given the state
    def apply(self, state):
        pass

In [5]:
class EpsilonGreedyPolicy(BasePolicy):
    
    def __init__(self, environment, Q_table, epsilon):
        self.environment = environment
        self.Q_table = Q_table
        self.epsilon = epsilon
        
    def apply(self, state):
        actions = self.environment.get_actions(state)
        if random.random() < self.epsilon:
            # Choose an action at random with probability epsilon
            return random.choice(actions)
        else:
            # Choose the best action accordin to Q_table with probability 1-epsilon
            # If all actions have the same Q-value then break ties randomly
            max_action_value = -1 * sys.float_info.max
            best_action = random.choice(actions)
            for action in actions:
                if self.Q_table[state][action] > max_action_value:
                    max_action_value = self.Q_table[state][action]
                    best_action = action
            return best_action

In [6]:
class SARSA:
    
    def __init__(self, environment, gamma, alpha, epsilon, episodes):
        self.environment = environment
        self.gamma = gamma
        self.alpha = alpha
        self.epsilon = epsilon
        self.episodes = episodes
        self.Q_table = dict()
        # Initialize the value of each state-action pair to 0
        for state in environment.get_states():
            self.Q_table[state] = dict()
            for action in environment.get_actions(state):
                self.Q_table[state][action] = 0
        # Use epsilon-greedy policy for learning
        self.policy = EpsilonGreedyPolicy(environment, self.Q_table, epsilon)
            
    def apply(self):
        for e in range(self.episodes):
            state = self.environment.current_state 
            action = self.policy.apply(self.environment.current_state) 
            for _ in range(20):
                next_state, reward = self.environment.do_action_and_get_reward(action)
                # next_state = self.environment.current_state
                next_action = self.policy.apply(next_state)
                temporal_difference = self.gamma * self.Q_table[next_state][next_action] - self.Q_table[state][action]
                self.Q_table[state][action] += self.alpha * (reward + temporal_difference)
                state = next_state
                action = next_action
            # Must reset the environment before trying another episode
            self.environment.reset()

In [7]:
grid_world = GridWorld()
sarsa = SARSA(grid_world, 0.9, 0.3, 0.05, 1000000)

In [8]:
sarsa.apply()

In [9]:
class GreedyPolicy(BasePolicy):
    
    def __init__(self, environment, Q_table):
        self.environment = environment
        self.Q_table = Q_table
        
    def apply(self, state):
        actions = self.environment.get_actions(state)
        max_action_value = -1 * sys.float_info.max
        best_action = None
        for action in actions:
            if self.Q_table[state][action] > max_action_value:
                max_action_value = self.Q_table[state][action]
                best_action = action
        return best_action

In [10]:
greedy_policy = GreedyPolicy(grid_world, sarsa.Q_table)

In [11]:
# Print the policy
generated_policy = []
for ri in range(5):
    row = []
    for co in range(5):
        action = GridWorld.label[greedy_policy.apply((ri,co))]
        print(action, end=' ')
        row.append(action)
    print()
    generated_policy.append(row)

# Correct policy derived from the Sutton and Barto book
correct_policy = [
    [['RIGHT'], ['UP', 'RIGHT', 'LEFT', 'DOWN'], ['LEFT'], ['UP', 'RIGHT', 'LEFT', 'DOWN'], ['LEFT']],
    [['UP', 'RIGHT'], ['UP'], ['UP', 'LEFT'], ['LEFT'], ['LEFT']],
    [['UP', 'RIGHT'], ['UP'], ['UP', 'LEFT'], ['UP', 'LEFT'], ['UP', 'LEFT']],
    [['UP', 'RIGHT'], ['UP'], ['UP', 'LEFT'], ['UP', 'LEFT'], ['UP', 'LEFT']],
    [['UP', 'RIGHT'], ['UP'], ['UP', 'LEFT'], ['UP', 'LEFT'], ['UP', 'LEFT']]
]
# Compare the two policies
num_errors = 0
for i in range(5):
    for j in range(5):
        if generated_policy[i][j] not in correct_policy[i][j]:
            num_errors += 1
            print(f"The policies do not match at ({i}, {j}).")
            break


RIGHT LEFT LEFT DOWN LEFT 
RIGHT UP UP UP LEFT 
UP UP LEFT LEFT UP 
UP UP LEFT UP LEFT 
UP UP LEFT UP UP 
The policies do not match at (1, 3).
The number of errors is 1.


In [12]:
state_value_function = [[0]*5 for _ in range(5)]  # Initialize a 5x5 grid with zeros
# Optimal value function derived from the Sutton and Barto book
optimal_value_function = [
    [22.0, 24.4, 22.0, 19.4, 17.5],
    [19.8, 22.0, 19.8, 17.8, 16.0],
    [17.8, 19.8, 17.8, 16.0, 14.4],
    [16.0, 17.8, 16.0, 14.4, 13.0],
    [14.4, 16.0, 14.4, 13.0, 11.7]
]
print("Optimal value function:")
for row in optimal_value_function:
    print(row)

for ri in range(5):
    for co in range(5):
        state = (ri, co)
        actions = sarsa.environment.get_actions(state)
        # Find the maximum action-value for the current state
        state_value_function[ri][co] = max(sarsa.Q_table[state][action] for action in actions)
print("\n")        
print("Obtained State-value function:")
for row in state_value_function:
    print(row)

Optimal value function:
[22.0, 24.4, 22.0, 19.4, 17.5]
[19.8, 22.0, 19.8, 17.8, 16.0]
[17.8, 19.8, 17.8, 16.0, 14.4]
[16.0, 17.8, 16.0, 14.4, 13.0]
[14.4, 16.0, 14.4, 13.0, 11.7]


Obtained State-value function:
[21.1980094486983, 23.81067852666498, 21.322992434445673, 18.45910190342609, 16.356612384721192]
[19.038142694510267, 21.360445741463632, 19.186434204057612, 16.631831948582104, 14.752233321865821]
[17.04250594499519, 17.793609062540483, 17.045404513591436, 14.951552870830037, 13.137846308963441]
[15.181008074564147, 17.17367444288248, 14.71157826834008, 13.390902475181583, 12.01122145204464]
[13.535517584690703, 15.396906108468707, 13.591948029876452, 11.999342010609228, 10.779859013454503]
