In [3]:
import numpy as np
from collections import namedtuple

In [9]:
State = namedtuple('State', 
                   ['my_i', 'my_j', 'x', 'other_i', 'other_j', 'a', 'b', 'c', 'd', 'e', 'f'])

#### PDWorld

In [12]:
class PDWorld:
    def __init__(self):
        self.size = 5
        self.pickup_locations = [(3,5), (4,2)]
        self.dropoff_locations = [(1,1), (1,5), (3,3), (5,5)]
    
    def setup(self, size : int, pickup_locations : list = None, dropoff_locations : list = None):
        if size:
            self.size = size
            self.board = np.zeros((size, size))
        if pickup_locations:
            self.pickup_locations = pickup_locations
        if dropoff_locations:
            self.dropoff_locations = dropoff_locations


#### Controller

In [13]:
world = PDWorld()
world.setup(size=5)

# Experiment 1b pseudocode
# world.setup(size = 5,
#             pickup_locations = [(3,5), (4,2)],
#             dropoff_locations = [(1,1), (1,5), (3,3), (5,5)])
# world.run(steps=500, policy='PRANDOM', alpha=0.3, gamma=0.5)
# world.run(steps=7500, policy='PGREEDY', alpha=0.3, gamma=0.5)
# world.summary()
# world.display_q_table(agent='male', x=0)

# Experiment 4 pseudocode
# world.setup(size = 5,
#             pickup_locations = [(3,5), (4,2)],
#             dropoff_locations = [(1,1), (1,5), (3,3), (5,5)])
# world.run(steps=500, policy='PRANDOM', alpha=0.3, gamma=0.5, method='SARSA', animate=True)
# world.run(total_runs=3, policy='PEXPLOIT', alpha=0.3, gamma=0.5, method='SARSA', animate=True)
# world.setup(pickup_locations = [(1,2), (4,5)])
# world.run(total_runs=3, policy='PEXPLOIT', alpha=0.3, gamma=0.5, method='SARSA', animate=True)
# world.summary()

#### Q-Table

In [None]:
class QTable:
    def __init__(self):
        self.q_table = np.zeros((world.size**4 * 2**7 ,6))
    
    def _encode_state(state : State):
        """Encodes the given state into its row index in the Q-table

        Parameters
        ----------
        state : State
            Named tuple containing state information
        
        Returns
        -------
        int
            integer index of state in Q-table
        """
        return (
            state.my_i * world.size**3 * 2**7 +
            state.my_j * world.size**2 * 2**7 +
            state.x * world.size**2 * 2**6 +
            state.other_i * world.size * 2**6 +
            state.other_j * 2**6 +
            state.a * 2**5 +
            state.b * 2**4 +
            state.c * 2**3 +
            state.d * 2**2 +
            state.e * 2 +
            state.f)
    