In [3]:
import numpy as np
from collections import namedtuple

In [9]:
State = namedtuple('State', 
                   ['my_i', 'my_j', 'x', 'other_i', 'other_j', 'a', 'b', 'c', 'd', 'e', 'f'])

#### PDWorld

In [12]:
class PDWorld:
    def __init__(self):
        self.size = 5
        self.alpha = 0.3,
        self.gamma = 0.5,
        self.pickup_locations = [(3,5), (4,2)]
        self.dropoff_locations = [(1,1), (1,5), (3,3), (5,5)]
        self.board = np.zeros((self.size, self.size))

    def setup(self, size: int, alpha: float, gamma: float, pickup_locations: list, dropoff_locations: list):
        if size:
            self.size = size
            self.board = np.zeros((size, size))
        if alpha:
            self.alpha = alpha
        if gamma:
            self.gamma = gamma
        if pickup_locations:
            self.pickup_locations = pickup_locations
        if dropoff_locations:
            self.dropoff_locations = dropoff_locations


#### Driver

In [None]:
# global PDWorld object
world = PDWorld()

In [1]:
# functions to run experiments
def experiment_1a():
    world.setup(size = 5,
                alpha = 0.3,
                gamma = 0.5,
                pickup_capacity = 10,
                dropoff_capacity = 5,
                agent_start_locations = [(1,3), (5,3)],
                pickup_locations = [(3,5), (4,2)],
                dropoff_locations = [(1,1), (1,5), (3,3), (5,5)])
    world.run(steps=500, policy='PRANDOM')
    world.run(steps=7500, policy='PRANDOM')
    world.summary()

def experiment_1b():
    world.setup(size = 5,
                alpha = 0.3,
                gamma = 0.5,
                pickup_capacity = 10,
                dropoff_capacity = 5,
                agent_start_locations = [(1,3), (5,3)],
                pickup_locations = [(3,5), (4,2)],
                dropoff_locations = [(1,1), (1,5), (3,3), (5,5)])
    world.run(steps=500, policy='PRANDOM')
    world.run(steps=7500, policy='PGREEDY')
    world.summary()

def experiment_1c():
    world.setup(size = 5,
                alpha = 0.3,
                gamma = 0.5,
                pickup_capacity = 10,
                dropoff_capacity = 5,
                agent_start_locations = [(1,3), (5,3)],
                pickup_locations = [(3,5), (4,2)],
                dropoff_locations = [(1,1), (1,5), (3,3), (5,5)])
    world.run(steps=500, policy='PRANDOM')
    world.run(steps=7500, policy='PEXPLOIT')
    world.summary()
    world.display_q_table(agent='male')

def experiment_2():
    world.setup(size = 5,
                alpha = 0.3,
                gamma = 0.5,
                pickup_capacity = 10,
                dropoff_capacity = 5,
                agent_start_locations = [(1,3), (5,3)],
                pickup_locations = [(3,5), (4,2)],
                dropoff_locations = [(1,1), (1,5), (3,3), (5,5)])
    world.run(steps=500, policy='PRANDOM', method='SARSA')
    world.run(steps=7500, policy='PEXPLOIT', method='SARSA')
    world.summary()
    world.display_q_table(agent='male')

def experiment_3a():
    world.setup(size = 5,
                alpha = 0.15,
                gamma = 0.5,
                pickup_capacity = 10,
                dropoff_capacity = 5,
                agent_start_locations = [(1,3), (5,3)],
                pickup_locations = [(3,5), (4,2)],
                dropoff_locations = [(1,1), (1,5), (3,3), (5,5)])
    world.run(steps=500, policy='PRANDOM')
    world.run(steps=7500, policy='PEXPLOIT')
    world.summary()
    world.display_q_table(agent='male')

def experiment_3b():
    world.setup(size = 5,
                alpha = 0.45,
                gamma = 0.5,
                pickup_capacity = 10,
                dropoff_capacity = 5,
                agent_start_locations = [(1,3), (5,3)],
                pickup_locations = [(3,5), (4,2)],
                dropoff_locations = [(1,1), (1,5), (3,3), (5,5)])
    world.run(steps=500, policy='PRANDOM')
    world.run(steps=7500, policy='PEXPLOIT')
    world.summary()
    world.display_q_table(agent='male')

def experiment_4():
    world.setup(size = 5,
                alpha = 0.3,
                gamma = 0.5,
                pickup_capacity = 10,
                dropoff_capacity = 5,
                agent_start_locations = [(1,3), (5,3)],
                pickup_locations = [(3,5), (4,2)],
                dropoff_locations = [(1,1), (1,5), (3,3), (5,5)])
    world.run(steps=500, policy='PRANDOM')
    world.run(total_runs=3, policy='PEXPLOIT', animate=True)
    world.summary()
    world.display_q_table(agent='male')
    
    world.setup(pickup_locations = [(1,2), (4,5)])
    world.run(total_runs=3, policy='PEXPLOIT', animate=True)
    world.summary()
    world.display_q_table(agent='male')


#### Q-Table

In [None]:
class QTable:
    def __init__(self):
        self.q_table = np.zeros((world.size**4 * 2**7 ,6))
    
    def _encode_state(state: State) -> int:
        """Encodes the given state into its row index in the Q-table

        Parameters
        ----------
        state : State
            Named tuple containing state information
        
        Returns
        -------
        int
            integer index of state in Q-table
        """
        return (
            state.my_i * world.size**3 * 2**7 +
            state.my_j * world.size**2 * 2**7 +
            state.x * world.size**2 * 2**6 +
            state.other_i * world.size * 2**6 +
            state.other_j * 2**6 +
            state.a * 2**5 +
            state.b * 2**4 +
            state.c * 2**3 +
            state.d * 2**2 +
            state.e * 2 +
            state.f)
    
    def next_operator(current_state: State, method: str = 'QL', policy: str ='PRANDOM'):
        applicable_operators = world.applicable_operators(current_state)

        pass

    def _update_q_table(self, current_state: State, action: int, next_state: State, method: str = 'QL'):
        if method == 'SARSA':
            self.q_table[self._encode_state(current_state), action] = (
                
            )
        else:
            self.q_table[self._encode_state(current_state), action] = (
                (1 - world.alpha) * self.q_table[self._encode_state(current_state), action] + 
                world.alpha * ((world.penalty if action < 4 else world.reward) + 
                world.gamma * self.q_table[self._encode_state(next_state), max(world.applicable_operators(next_state))])
            )
