In [1]:
import numpy as np
from collections import namedtuple

In [2]:
State = namedtuple('State', ['my_i', 'my_j', 'x', 'other_i', 'other_j', 'a', 'b', 'c', 'd', 'e', 'f'])
# - i: row index of agent
# - j: column index of agent
# - x: 1 if agent is carrying a block
# - i_distance: distance between the agents' i values
# - j_distance: distance between the agents' j values
# - a: 1 if pickup location 1 has blocks left
# - b: 1 if pickup location 2 has blocks left
# - c: 1 if dropoff location 1 has capacity left
# - d: 1 if dropoff location 2 has capacity left
# - e: 1 if dropoff location 3 has capacity left
# - f: 1 if dropoff location 4 has capacity left

### PDWorld

In [None]:
class QTable:
    def __init__(self):
        self.q_table = np.zeros((world.size**4 * 128, 6))
        self.__operator_dict = {'n':0, 's':1, 'e':2, 'w':3, 'p':4, 'd':5}
    
    def next_operator(self, current_state: State, applicable_operators: list[str], policy: str ='PRANDOM') -> str:
        """Returns the next operator to be applied given the current state, method, and policy as a string.

        Parameters
        ----------
        current_state: State
            Named tuple containg state information.
            - i: row index of agent
            - j: column index of agent
            - x: 1 if agent is carrying a block
            - i_distance: distance between the agents' i values
            - j_distance: distance between the agents' j values
            - a: 1 if pickup location 1 has blocks left
            - b: 1 if pickup location 2 has blocks left
            - c: 1 if dropoff location 1 has capacity left
            - d: 1 if dropoff location 2 has capacity left
            - e: 1 if dropoff location 3 has capacity left
            - f: 1 if dropoff location 4 has capacity left

        applicable_operators: list
            List of applicable operators in the current state
        
        policy: ['PRANDOM' | 'PEXPLOIT' | 'PGREEDY']
            Policy to use when selecting next operator
        
        Returns
        -------
            String corresponding to operator to apply:
            ['n' | 's' | 'e' | 'w' | 'p' | 'd']
        """
        assert policy in ['PRANDOM', 'PEXPLOIT', 'PGREEDY'], 'Error: Invalid policy'
        next_operator = self.__next_operator(current_state, applicable_operators, policy)
        return ['n', 's', 'e', 'w', 'p', 'd'][next_operator]

    def update_q_table(self, previous_state: State, operator: str, current_state: State, applicable_operators: list[str], policy: str = 'PRANDOM', method: str = 'QL'):
        """Updates the Q-table values for the previous state and action used using the given policy and method.

        Parameters
        ----------
        previous_state: State
            Named tuple containg previous state information.
            - i: row index of agent
            - j: column index of agent
            - x: 1 if agent is carrying a block
            - i_distance: distance between the agents' i values
            - j_distance: distance between the agents' j values
            - a: 1 if pickup location 1 has blocks left
            - b: 1 if pickup location 2 has blocks left
            - c: 1 if dropoff location 1 has capacity left
            - d: 1 if dropoff location 2 has capacity left
            - e: 1 if dropoff location 3 has capacity left
            - f: 1 if dropoff location 4 has capacity left

        operator: ['n' | 's' | 'e' | 'w' | 'p' | 'd']
            Operator that was used from previous state to current state

        current_state: State
            Named tuple containg current state information.

        applicable_operators: list
            List of applicable operators in the current state

        policy: ['PRANDOM' | 'PEXPLOIT' | 'PGREEDY']
            Policy to use when selecting next operator (used only in SARSA)
        
        method: ['QL' | 'SARSA']
            Method to use when updating Q-table

        Returns
        -------
            None
        """
        assert method in ['QL', 'SARSA']
        
        action = self.__operator_dict[operator]
        if method == 'SARSA':
            self.q_table[self.__encode_state(previous_state), action] = (
                self.q_table[self.__encode_state(previous_state), action] +
                world.alpha * (world.reward((previous_state.i, previous_state.j), operator) +
                world.gamma * self.q_table[self.__encode_state(current_state), self.__next_operator(current_state, policy)] -
                self.q_table[self.__encode_state(previous_state), action])
            )
        elif method == 'QL':
            self.q_table[self.__encode_state(previous_state), action] = (
                (1 - world.alpha) * self.q_table[self.__encode_state(previous_state), action] + 
                world.alpha * (world.reward((current_state.i, current_state.j), operator) + 
                world.gamma * max(self.q_table[self.__encode_state(current_state)][[self.__operator_dict[operator] for operator in applicable_operators]]))
            )

    def __next_operator(self, current_state: State, applicable_operators: list[str], policy: str ='PRANDOM') -> int:
        """Returns the next operator to be applied given the current state and policy as an index

        Parameters
        ----------
        current_state: State
            Named tuple containg state information.

        applicable_operators: list
            List of applicable operators in the current state

        policy: ['PRANDOM' | 'PEXPLOIT' | 'PGREEDY']
            Policy to use when selecting next operator
        
        Returns
        -------
            Column index of q-table corresponding to the operator to take
            - 0: north
            - 1: south
            - 2: east
            - 3: west
            - 4: pick up
            - 5: drop off
        """
        applicable_operators_as_indices = [self.__operator_dict[operator] for operator in applicable_operators]
        if 4 in applicable_operators_as_indices:
            return 4
        if 5 in applicable_operators_as_indices:
            return 5
        else:
            max_val_operators = np.flatnonzero(self.q_table[self.__encode_state(current_state)] == np.max(self.q_table[self.__encode_state(current_state)]))
            if policy == 'PRANDOM':
                # select applicable operator randomly
                return np.random.choice(applicable_operators_as_indices)
            elif policy == 'PEXPLOIT':
                if np.random.rand() < 0.8:
                    # select applicable operator with highest q-value
                    return np.random.choice(max_val_operators)
                else:
                    # select applicable operator randomly from operators without highest q-value
                    return np.random.choice(np.setdiff1d(applicable_operators_as_indices, max_val_operators))
            elif policy == 'PGREEDY':
                # select applicable operator with highest q-value
                return np.random.choice(max_val_operators)

    def __encode_state(state: State) -> int:
        """Encodes the given state into its row index in the Q-table

        Parameters
        ----------
        state : State
            Named tuple containing state information
        
        Returns
        -------
        int
            integer index of state in Q-table
        """
        multipliers = np.array([world.size**3 * 128, world.size**2 * 128, world.size**2 * 64, world.size * 64, 64, 32, 16, 8, 4, 2, 1])
        return np.sum(np.multiply(np.array(state), multipliers))
    
    def __decode_index(index: int) -> State:
        """Decodes the given row index of the Q-table into a State named tuple

        Parameters
        ----------
        index : int
            Row index of the Q-table
        
        Returns
        -------
            Named tuple containing state information
        """
        state_values = []
        divisors = [world.size**3 * 128, world.size**2 * 128, world.size**2 * 64, world.size * 64, 64, 32, 16, 8, 4, 2, 1]
        remainder = index
        for divisor in divisors:
            state_values.append(int(np.floor(remainder / divisor)))
            remainder = remainder % divisor
        return State(*state_values)    


In [None]:
class Agent:
    def __init__(self, q_table_obj, agent="F"):
        self.identifier = agent
        self.carrying_block = False
        self.q_table = QTable()

        # Female agent starts at (1,3) in non-index position, male starts at (5,3)
        self.i = 0 if agent == "F" else 4
        self.j = 2 if agent == "F" else 2
        self.i_other_agent = 4 if agent == "F" else 0
        self.j_other_agent = 2 if agent == "F" else 2
    
    def update_position(self, new_i: int, new_j: int):
        self.i = new_i
        self.j = new_j
    
    def update_other_agent_position(self, new_i: int, new_j: int):
        self.i_other_agent = new_i
        self.j_other_agent = new_j
    
    def get_distance_to_other_agent(self):
        return (abs(self.i - self.i_other_agent), abs(self.j - self.j_other_agent))

In [None]:
class PDWorld:
    def __init__(self, size, alpha, gamma, pickup_capacity, dropoff_capacity, pickup_locations, dropoff_locations):
        """
        Arguments:
            - size = 5,
            - alpha = 0.3,
            - gamma = 0.5,
            - pickup_capacity = 10,
            - dropoff_capacity = 5,
            - agent_start_locations = [(1,3), (5,3)],
            - pickup_locations = [(3,5), (4,2)],
            - dropoff_locations = [(1,1), (1,5), (3,3), (5,5)])
        """

        self.board = np.zeros((size,size))
        self.alpha = alpha
        self.gamma = gamma
        self.dropoff_capacity = dropoff_capacity
        self.female_agent = Agent(None, agent="F")  # ToDo: give qtable
        self.male_agent = Agent(None, agent="F")  # ToDo: give qtable
        self.iteration = 0
        self.turn = 0  # 0 if female agent's turn, 1 otherwise

        self.dropoff_locations = {
            (0,0): 0, 
            (0,4): 0, 
            (2,2): 0, 
            (4,4): 0
        }
        self.pickup_locations = {
            (2,4): pickup_capacity, 
            (3,1): pickup_capacity
        }
    
    def setup(self, size: int, pickup_locations: list = None, dropoff_locations: list = None):
        """
        Arguments:
        - size: size of square board
        - pickup_locations: list of tuples (i,j) of pickup locations
        - dropoff_locations: list of tuples (i,j) of dropoff locations
        """
        self.size = size
        self.board = np.zeros((size, size))

        if pickup_locations:
            self.pickup_locations.clear()
            for tup in pickup_locations:
                self.pickup_locations[tup] = 10
        
        if dropoff_locations:
            self.dropoff_locations.clear()
            for tup in dropoff_locations:
                self.dropoff_locations[tup] = 0
    
    def applicable_operators(self, agent):
        applicable_ops = ["n", "s", "e", "w"]  # "p" and "d" are appended if conditions are met

        if agent.i == self.board.shape[1] or agent.i_other_agent == agent.i + 1:
            applicable_ops.remove("s")
        
        if agent.i == 0 or agent.i_other_agent == agent.i - 1:
            applicable_ops.remove("n")

        if agent.j == self.board.shape[0] or agent.j_other_agent == agent.j + 1:
            applicable_ops.remove("e")

        if agent.j == 0 or agent.j_other_agent == agent.j - 1:
            applicable_ops.remove("w")
        
        for loc, block_count in self.dropoff_locations.items():
            if agent.carrying_block and block_count < self.dropoff_capacity and \
                agent.i == loc[0] and agent.j == loc[1]:
                applicable_ops.append("d")
    
        for loc, block_count in self.pickup_locations.items():
            if not agent.carrying_block and block_count > 0 and \
                agent.i == loc[0] and agent.j == loc[1]:
                applicable_ops.append("p")
    
    def apply_operator(self, agent, other_agent, operator):
        assert operator in ["n", "s", "e", "w", "d", "p"], "Error: Unknown Operator"

        if operator == "n":
            agent.update_position(agent.i - 1, agent.j)
            other_agent.update_other_agent_position(agent.i, agent.j)
        elif operator == "s":
            agent.update_position(agent.i + 1, agent.j)
            other_agent.update_other_agent_position(agent.i, agent.j)
        elif operator == "e":
            agent.update_position(agent.i, agent.j + 1)
            other_agent.update_other_agent_position(agent.i, agent.j)
        elif operator == "w":
            agent.update_position(agent.i, agent.j - 1)
            other_agent.update_other_agent_position(agent.i, agent.j)
        elif operator == "d":
            agent.carrying_block = False
            loc_tup = (agent.i, agent.j)
            self.dropoff_locations[loc_tup] += 1
        elif operator == "p":
            agent.carrying_block = True
            loc_tup = (agent.i, agent.j)
            self.pickup_locations[loc_tup] -= 1
    
    def run(self, steps=500, policy='PRANDOM', method='SARSA'):
        for iter in range(steps):
            ##########################
            # Female agent making move
            ##########################
            # Get applicable operators given current state
            F_applicable_operators = self.applicable_operators(self.female_agent)
            # Compute current state named tuple
            curr_state = self._compute_current_state()
            # Use QTable function to get next operator based on current state, policy, and method
            next_op = self.female_agent.q_table.next_operator(curr_state, F_applicable_operators, policy=policy)
            # Apply operator to current state and female agent
            self.apply_operator(self.female_agent, self.male_agent, next_op)
            updated_state = self._compute_current_state()

            # Update female agent QTable
            self.female_agent.q_table.update_q_table(
                previous_state=curr_state, 
                operator=next_op, 
                current_state=updated_state, 
                applicable_operators=F_applicable_operators, 
                policy=policy, 
                method=method)
            
            ##########################
            # Male agent making move
            ##########################
            # Get applicable operators given current state
            M_applicable_operators = self.applicable_operators(self.male_agent)
            # Compute current state named tuple
            M_curr_state = self._compute_current_state()
            # Use QTable function to get next operator based on current state, policy, and method
            next_op = self.male_agent.q_table.next_operator(M_curr_state, M_applicable_operators, policy=policy)
            # Apply operator to current state and female agent
            self.apply_operator(self.male_agent, self.female_agent, next_op)
            M_updated_state = self._compute_current_state()

            # Update male agent QTable
            self.male_agent.q_table.update_q_table(
                previous_state=M_curr_state, 
                operator=next_op, 
                current_state=M_updated_state, 
                applicable_operators=M_applicable_operators, 
                policy=policy, 
                method=method)
            
            self.iteration += 1
    
    def _compute_current_state(self):
        distance_tuple = self.female_agent.get_distance_to_other_agent()
        pickup_flags = [1 if block_count > 0 else 0 for _, block_count in self.dropoff_locations.items()]
        dropoff_flags = [1 if block_count < self.dropoff_capacity else 0 for _, block_count in self.dropoff_locations.items()]

        curr_state = State(
            self.female_agent.i, 
            self.female_agent.j, 
            int(self.female_agent.carrying_block), 
            distance_tuple[0],
            distance_tuple[1],
            pickup_flags[0],
            pickup_flags[1],
            dropoff_flags[0],
            dropoff_flags[1],
            dropoff_flags[2],
            dropoff_flags[3]
        )
        return curr_state
        
    def change_pickup_location(self, new_pickup_locations: list[tuple]):
        """
        Arguments:
        - new_pickup_locations: list of tuples (i,j) of new pickup locations on board
        """
        self.pickup_locations.clear()
        for tup in new_pickup_locations:
                self.pickup_locations[tup] = 10

    def save_visual_midrun():
        pass


### Driver Code

In [None]:
# global PDWorld object
world = PDWorld()

In [None]:
# functions to run experiments
def experiment_1a():
    world.setup(size = 5,
                alpha = 0.3,
                gamma = 0.5,
                pickup_capacity = 10,
                dropoff_capacity = 5,
                agent_start_locations = [(1,3), (5,3)],
                pickup_locations = [(3,5), (4,2)],
                dropoff_locations = [(1,1), (1,5), (3,3), (5,5)])
    world.run(steps=500, policy='PRANDOM')
    world.run(steps=7500, policy='PRANDOM')
    world.summary()

def experiment_1b():
    world.setup(size = 5,
                alpha = 0.3,
                gamma = 0.5,
                pickup_capacity = 10,
                dropoff_capacity = 5,
                agent_start_locations = [(1,3), (5,3)],
                pickup_locations = [(3,5), (4,2)],
                dropoff_locations = [(1,1), (1,5), (3,3), (5,5)])
    world.run(steps=500, policy='PRANDOM')
    world.run(steps=7500, policy='PGREEDY')
    world.summary()

def experiment_1c():
    world.setup(size = 5,
                alpha = 0.3,
                gamma = 0.5,
                pickup_capacity = 10,
                dropoff_capacity = 5,
                agent_start_locations = [(1,3), (5,3)],
                pickup_locations = [(3,5), (4,2)],
                dropoff_locations = [(1,1), (1,5), (3,3), (5,5)])
    world.run(steps=500, policy='PRANDOM')
    world.run(steps=7500, policy='PEXPLOIT')
    world.summary()
    world.display_q_table(agent='male')

def experiment_2():
    world.setup(size = 5,
                alpha = 0.3,
                gamma = 0.5,
                pickup_capacity = 10,
                dropoff_capacity = 5,
                agent_start_locations = [(1,3), (5,3)],
                pickup_locations = [(3,5), (4,2)],
                dropoff_locations = [(1,1), (1,5), (3,3), (5,5)])
    world.run(steps=500, policy='PRANDOM', method='SARSA')
    world.run(steps=7500, policy='PEXPLOIT', method='SARSA')
    world.summary()
    world.display_q_table(agent='male')

def experiment_3a():
    world.setup(size = 5,
                alpha = 0.15,
                gamma = 0.5,
                pickup_capacity = 10,
                dropoff_capacity = 5,
                agent_start_locations = [(1,3), (5,3)],
                pickup_locations = [(3,5), (4,2)],
                dropoff_locations = [(1,1), (1,5), (3,3), (5,5)])
    world.run(steps=500, policy='PRANDOM')
    world.run(steps=7500, policy='PEXPLOIT')
    world.summary()
    world.display_q_table(agent='male')

def experiment_3b():
    world.setup(size = 5,
                alpha = 0.45,
                gamma = 0.5,
                pickup_capacity = 10,
                dropoff_capacity = 5,
                agent_start_locations = [(1,3), (5,3)],
                pickup_locations = [(3,5), (4,2)],
                dropoff_locations = [(1,1), (1,5), (3,3), (5,5)])
    world.run(steps=500, policy='PRANDOM')
    world.run(steps=7500, policy='PEXPLOIT')
    world.summary()
    world.display_q_table(agent='male')

def experiment_4():
    world.setup(size = 5,
                alpha = 0.3,
                gamma = 0.5,
                pickup_capacity = 10,
                dropoff_capacity = 5,
                agent_start_locations = [(1,3), (5,3)],
                pickup_locations = [(3,5), (4,2)],
                dropoff_locations = [(1,1), (1,5), (3,3), (5,5)])
    world.run(steps=500, policy='PRANDOM')
    world.run(total_runs=3, policy='PEXPLOIT', animate=True)
    world.summary()
    world.display_q_table(agent='male')
    
    world.setup(pickup_locations = [(1,2), (4,5)])
    world.run(total_runs=3, policy='PEXPLOIT', animate=True)
    world.summary()
    world.display_q_table(agent='male')