If you are using google colab, uncomment these lines to upload data files to Google Colab

In [None]:
# from goodle.colab import files
# uploaded = files.upload()
# %ls

# Import libraries
Do not use any other libraries.

In [None]:
from mdp import FiniteStateMDP
from wumpus import _OBJ_KEYS, _OBS_KEYS, WumpusMDP, WumpusState, Actions
from matplotlib import pyplot as plt
import numpy as np

# Load datasets

In [None]:
def get_wumpus_worlds(plot=False):
    world_one = WumpusMDP(3, 4, -0.1, 10)
    world_one.add_obstacle('wumpus', [1, 2], -100)
    world_one.add_obstacle('pit', [0, 2], -50)
    world_one.add_obstacle('goal', [2, 3], 100)

    world_two = WumpusMDP(10, 10, -0.1, 20)
    world_two.add_obstacle('goal', [9, 9], 100)

    world_three = WumpusMDP(10, 10, -0.1, 20)
    world_three.add_obstacle('goal', [5, 5], 100)
    
    if plot:
        world_one.display()
        world_two.display()
        world_three.display()

    return [world_one, world_two, world_three]

# Tabular Q-learning

In [None]:
class QLearningAgent:
    def __init__(self, mdp, discount_factor=0.9, exploration_rate=0.2):
        self.mdp = mdp
        self.discount_factor = discount_factor
        self.exploration_rate = exploration_rate
        self.q_values = {}  # key: (state, action), value: Q-value
        self.q_visits = {}  # key: (state, action), value: visit count

    def get_q_value(self, state, action):
        return self.q_values.get((state.i, action), 0.0)

    def get_q_visit_count(self, state, action):
        return self.q_visits.get((state.i, action), 0)

    def choose_action(self, state):
        # Epsilon-greedy action selection

    def update_q_value(self, state, action, reward, next_state):
        # Get max Q-value for next state
        # Update Q-value using learning rate
        # Increment visit count for state-action pair
        
        return new_q

# Tabular SARSA

In [None]:
class SARSAgent(QLearningAgent):
    def __init__(self, mdp, discount_factor=0.9, exploration_rate=0.2):
        super().__init__(mdp, discount_factor, exploration_rate)
    def update_q_value(self, state, action, reward, next_state, next_action):
        # Update Q-value using learning rate
        # Increment visit count for state-action pair
        return new_q

# Training algorithm

In [None]:
def train_agent(agent, episodes):
    for episode in range(episodes):
        # Initialize state and first action
        while not agent.mdp.is_terminal(state):
            # Take action and observe next state and reward
            # Q-learning and SARSA agents differ here. 
            if isinstance(agent, SARSAgent):
                # SARSA update requires next action
                pass
            else:
                # Q-learning update does not require next action
                pass
            # Move to next state

## Plotting helper

In [None]:
def pyplot_policy(agent):
    policy_grid = np.full((agent.mdp.height, agent.mdp.width), '', dtype=object)
    world_grid = np.full((agent.mdp.height, agent.mdp.width), 'empty', dtype=object)
    action_symbols = {
        Actions.UP: '↑',
        Actions.DOWN: '↓',
        Actions.LEFT: '←',
        Actions.RIGHT: '→',
        Actions.PICK_UP: '⧉'
    }
    # Extract policy
    for state in agent.mdp.states:
        if agent.mdp.is_terminal(state):
            policy_grid[state.y, state.x] = 'T'
            continue
        q_values = {action: agent.get_q_value(state, action) for action in agent.mdp.actions_at(state)}
        max_q = max(q_values.values())
        best_actions = [action for action, q in q_values.items() if q == max_q]
        policy_grid[state.y, state.x] = action_symbols[best_actions[0]]
    # Extract world layout for visualization
    for state in agent.mdp.states:
        pos = (state.y, state.x)
        for obs in _OBS_KEYS:
            if agent.mdp.obs_at(obs, state.pos):
                world_grid[pos] = obs
        for obj in _OBJ_KEYS:
            if agent.mdp.obj_at(obj, state.pos):
                world_grid[pos] = obj
    # Map world elements to colors
    color_map = {
        'empty': 'white',
        'wumpus': 'black',
        'pit': 'brown',
        'goal': 'gold',
        'gold': 'yellow',
        'immune': 'cyan'
    }
    cell_colors = np.vectorize(color_map.get)(world_grid)
    # Plotting
    plt.figure(figsize=(agent.mdp.width, agent.mdp.height))
    # Create a color grid
    for y in range(agent.mdp.height):
        for x in range(agent.mdp.width):
            # reverse y to have (0,0) at bottom-left
            height = agent.mdp.height
            reverse_y = height - y - 1
            plt.gca().add_patch(plt.Rectangle((x, reverse_y), 1, 1, color=cell_colors[y, x], ec='black'))
            plt.text(x + 0.5, reverse_y + 0.5, policy_grid[y, x], ha='center', va='center', fontsize=20)
    plt.xlim(0, agent.mdp.width)
    plt.ylim(0, agent.mdp.height)
    plt.gca().invert_yaxis()
    plt.xticks(np.arange(agent.mdp.width + 1))
    plt.yticks(np.arange(agent.mdp.height + 1))
    plt.grid(False)
    plt.show()

# Run the model and visualize the results