In [17]:
# todo remove unnecessary imports after code is completed
import sys
import numpy as np
import random
import matplotlib.pyplot as plt
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from mpl_toolkits.mplot3d import Axes3D
import ipywidgets as widgets
import time
import mpl_toolkits

In [18]:
# Define colors
BLACK = (0, 0, 0)
WHITE = (255, 255, 255)
GREEN = (0, 255, 0)
RED = (255, 0, 0)
BLUE = (0, 0, 255)

# Define constants
SCREEN_WIDTH = 800
SCREEN_HEIGHT = 800

In [19]:
class Agent:
    def __init__(self, name, position, q_table, alpha, gamma, epsilon, world):
        # Initialize variables with provided values
        self.name = name
        self.position = position
        self.q_table = q_table
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.world_bounds = (3,3,3)
        self.world = world

        # Initialize variables that track agent actions
        self.prev_state = None
        self.prev_action = None
        self.prev_reward = None
        self.carrying = None
        self.valid_actions = []

    def reset(self):
        # Reset all states and rewards
        self.prev_reward = None
        self.prev_action = None
        self.prev_state = None
        self.carrying = None

    # Following functions move agent accordingly; checks if action is valid before moving
    def North(self):
        new_pos = (self.position[0], self.position[1] + 1, self.position[2])
        if self.is_valid_position(new_pos):
            self.position = new_pos

    def East(self):
        new_pos = (self.position[0] + 1, self.position[1], self.position[2])
        if self.is_valid_position(new_pos):
            self.position = new_pos

    def South(self):
        new_pos = (self.position[0], self.position[1] - 1, self.position[2])
        if self.is_valid_position(new_pos):
            self.position = new_pos

    def West(self):
        new_pos = (self.position[0] - 1, self.position[1], self.position[2])
        if self.is_valid_position(new_pos):
            self.position = new_pos

    def Up(self):
        new_pos = (self.position[0], self.position[1], self.position[2] + 1)
        if self.is_valid_position(new_pos):
            self.position = new_pos

    def Down(self):
        new_pos = (self.position[0], self.position[1], self.position[2] - 1)
        if self.is_valid_position(new_pos):
            self.position = new_pos

    # Checks if position is valid
    def is_valid_position(self, position):
        return all(0 <= coord < bound for coord, bound in zip(position, self.world_bounds))

    def pickup(self):
        if self.carrying is not None:
            print(f"{self.name} cannot pick up {self.carrying} as it is already carrying {self.carrying}")
            return

        item = self.world.get_item_at_position(self.position)
        if item is None:
            print(f"{self.name} cannot pick up an item as there is no item at its current position {self.position}")
            return

        if not self.world.is_item_pickupable(item):
            print(f"{self.name} cannot pick up {item} as it is not currently pickupable")
            return

        self.carrying = item
        self.world.remove_item_from_position(self.position)

        print(f"{self.name} picked up {item} at {self.position}")

    def dropoff(self):
        if self.carrying is None:
            print(f"{self.name} cannot drop off an item as it is not currently carrying anything")
            return

        if not self.world.is_item_dropoffable(self.carrying):
            print(f"{self.name} cannot drop off {self.carrying} as it is not currently droppable")
            return

        self.world.add_item_to_position(self.position, self.carrying)
        self.carrying = None

        print(f"{self.name} dropped off {self.carrying} at {self.position}")

    def get_next_action(self, world, policy):
        # Check if there is a valid pickup or dropoff action
        is_pickup_valid, is_dropoff_valid = world.get_valid_actions(self.position, self.carrying)

        # Get q-values for applicable actions
        q_values = {}
        for action in self.valid_actions:
            q_values[action] = self.q_table.get((self.position, self.carrying, action), 0)

        if policy == 'PRANDOM':
            # PRANDOM policy
            if is_pickup_valid or is_dropoff_valid:
                # Choose a random action
                return random.choice(['pickup', 'dropoff'] if is_pickup_valid and is_dropoff_valid else
                                     ['pickup'] if is_pickup_valid else ['dropoff'])
        elif policy == 'PEXPLOIT':
            # PEXPLOIT policy
            if q_values and random.random() < 0.85:
                # Choose the action with the highest q-value
                max_q_value = max(q_values.values())
                best_actions = [action for action, q_value in q_values.items() if q_value == max_q_value]
                return random.choice(best_actions)

        elif policy == 'PGREEDY':
            # PGREEDY policy
            if q_values:
                # Choose the action with the highest q-value
                max_q_value = max(q_values.values())
                best_actions = [action for action, q_value in q_values.items() if q_value == max_q_value]
                return random.choice(best_actions)

        else:
            # If no actions are applicable, choose a random action
            return random.choice(self.valid_actions)


In [20]:
class PDWorld:
    # Set up world for agents
    def __init__(self, size=(3, 3, 3), blocks=None, dropoffs=None):
        self.size = size
        self.blocks = {} if blocks is None else blocks
        self.dropoffs = {} if dropoffs is None else dropoffs

    def is_valid_position(self, position):
        return all(0 <= coord < bound for coord, bound in zip(position, self.size))

    # Returns value of block at given position
    def get_block(self, position):
        return self.blocks.get(position)

    # Sets value of block
    def set_block(self, position, value):
        self.blocks[position] = value

    # Removes block at given position
    def remove_block(self, position):
        self.blocks.pop(position, None)

    # Determine what actions agent can take
    def get_valid_actions(self, agent_pos, agent_holding):
        valid_actions = []

        # Agent can always stay in place
        valid_actions.append('stay')

        # Check if agent can move north
        if self.is_valid_position((agent_pos[0], agent_pos[1]+1, agent_pos[2])):
            valid_actions.append('north')

        # Check if agent can move east
        if self.is_valid_position((agent_pos[0]+1, agent_pos[1], agent_pos[2])):
            valid_actions.append('east')

        # Check if agent can move south
        if self.is_valid_position((agent_pos[0], agent_pos[1]-1, agent_pos[2])):
            valid_actions.append('south')

        # Check if agent can move west
        if self.is_valid_position((agent_pos[0]-1, agent_pos[1], agent_pos[2])):
            valid_actions.append('west')

        # Check if agent can move up
        if self.is_valid_position((agent_pos[0], agent_pos[1], agent_pos[2]+1)):
            valid_actions.append('up')

        # Check if agent can move down
        if self.is_valid_position((agent_pos[0], agent_pos[1], agent_pos[2]-1)):
            valid_actions.append('down')

        # Check if agent can pick up a block
        if not agent_holding and agent_pos in self.blocks:
            valid_actions.append('pickup')

        # Check if agent can drop off a block
        if agent_holding and agent_pos in self.dropoffs and self.dropoffs[agent_pos] == agent_holding:
            valid_actions.append('dropoff')

        return valid_actions

    # There's a bunch of issues with attempting to make the figure interactive but I can't be bothered to fix them cus it's a headache
    # Basically dont try rotating it too much or else it'll end up flickering or glitching out
    def plot_world(self, agents=None):
        %matplotlib notebook
        fig = plt.figure(figsize=(10, 10), dpi=100)
        ax = fig.add_subplot(111, projection='3d', proj_type='ortho')

        # Set limits for the plot
        ax.set_xlim3d(0, self.size[0])
        ax.set_ylim3d(0, self.size[1])
        ax.set_zlim3d(0, self.size[2])

        # Set axis ticks to integers
        ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))
        ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
        ax.zaxis.set_major_locator(plt.MaxNLocator(integer=True))

        # Plot the blocks
        if self.blocks:
            block_positions = list(self.blocks.keys())
            block_xs, block_ys, block_zs = zip(*block_positions)
            block_colors = ['r' if value is None else 'b' for value in self.blocks.values()]
            ax.scatter(block_xs, block_ys, block_zs, c=block_colors, marker='s', antialiased=True)

        # Plot the dropoffs
        if self.dropoffs:
            dropoff_positions = list(self.dropoffs.keys())
            dropoff_xs, dropoff_ys, dropoff_zs = zip(*dropoff_positions)
            ax.scatter(dropoff_xs, dropoff_ys, dropoff_zs, c='g', marker='o', antialiased=True)

        # Plot the agents
        if agents is not None:
            agent_positions = [agent.position for agent in agents]
            if agent_positions:
                agent_xs, agent_ys, agent_zs = zip(*agent_positions)
                ax.scatter(agent_xs, agent_ys, agent_zs, c='k', marker='^', antialiased=True)

        # Add interactivity to the plot
        is_clicked = False
        def on_mouse_press(event):
            nonlocal is_clicked
            is_clicked = True

        def on_mouse_release(event):
            nonlocal is_clicked
            is_clicked = False

        def on_mouse_move(event):
            nonlocal is_clicked

            if is_clicked:
                # Calculate the distance the mouse has moved
                dx = abs(event.x - on_mouse_move.last_x)
                dy = abs(event.y - on_mouse_move.last_y)

                # Only update the plot if the mouse has moved more than a certain threshold
                if dx > 3 or dy > 3:
                    # Update the azimuthal angle (azim) based on the mouse movement
                    azim = ax.azim - (event.xdata - on_mouse_move.last_x) / 2
                    if azim > 180:
                        azim -= 360
                    elif azim < -180:
                        azim += 360

                    # Smooth the azimuthal angle using a smoothing factor
                    smoothing_factor = 0.1
                    target_azim = azim
                    if on_mouse_move.target_azim is None:
                        on_mouse_move.target_azim = target_azim
                    else:
                        on_mouse_move.target_azim += smoothing_factor * (target_azim - on_mouse_move.target_azim)
                    ax.view_init(elev=ax.elev, azim=on_mouse_move.target_azim)

                    # Update the plot
                    fig.canvas.draw()
                    on_mouse_move.last_x = event.xdata
                    on_mouse_move.last_y = event.ydata

            else:
                on_mouse_move.last_x = event.xdata
                on_mouse_move.last_y = event.ydata

        on_mouse_move.last_x = None
        on_mouse_move.last_y = None
        on_mouse_move.target_azim = None

        fig.canvas.mpl_connect('button_press_event', on_mouse_press)
        fig.canvas.mpl_connect('button_release_event', on_mouse_release)
        fig.canvas.mpl_connect('motion_notify_event', on_mouse_move)

        plt.show()

        # Get the current view angle
        elev, azim = ax.elev, ax.azim

        # Adjust the limits of the plot based on the current view angle
        xmin, xmax = ax.get_xlim3d()
        ymin, ymax = ax.get_ylim3d()
        zmin, zmax = ax.get_zlim3d()

        r = 1.2 * max([xmax-xmin, ymax-ymin, zmax-zmin]) / 2

        xc = (xmax + xmin) / 2
        yc = (ymax + ymin) / 2
        zc = (zmax + zmin) / 2

        ax.set_xlim3d(xc - r, xc + r)
        ax.set_ylim3d(yc - r, yc + r)
        ax.set_zlim3d(zc - r, zc + r)

        # Update the plot
        fig.canvas.draw()



In [21]:
class QTable:
    def __init__(self):
        self.table = {}

    # Returns q value for given state action pair from q table; default value is 0 if pair isnt present
    def get(self, state, action):
        return self.table.get((state, action), 0.0)

    # Updates q value for given state action pair with given value
    def update(self, state, action, value):
        self.table[(state, action)] = value

    # Returns action with highest q value for given state among list of actions; Randomizes action if there are multiple highest values
    def get_best_action(self, state, actions):
        q_values = {a: self.get(state, a) for a in actions}
        max_q = max(q_values.values())
        actions_with_max_q = [a for a, q in q_values.items() if q == max_q]
        return random.choice(actions_with_max_q)

In [22]:
# Additional functions
def prevent_collision(agent1, agent2):
    # Agents are at the same position, so move one of them randomly
    if agent1.position == agent2.position:
        possible_moves = ['North', 'East', 'South', 'West', 'Up', 'Down']
        random_move = random.choice(possible_moves)
        if random_move == 'North':
            agent1.North()
        elif random_move == 'East':
            agent1.East()
        elif random_move == 'South':
            agent1.South()
        elif random_move == 'West':
            agent1.West()
        elif random_move == 'Up':
            agent1.Up()
        elif random_move == 'Down':
            agent1.Down()

def reward(prev_state, prev_action, state, blocks):
    reward = 0
    if prev_action in state['valid_actions']:
        if state['agent_holding'] is not None:
            # Agent has dropped off a block at a valid dropoff position
            if blocks.get(state['agent_position']) == state['agent_holding']:
                reward += 14
            else:
                reward -= 5
        else:
            # Agent has moved to a new position without transporting a block
            reward -= 1
            if state['agent_position'] in blocks and blocks[state['agent_position']] is not None:
                # Agent has picked up a block from a valid pickup position
                if prev_state is not None and blocks.get(prev_state['agent_position']) != state['agent_holding']:
                    # Penalize agent for picking up block unnecessarily
                    reward -= 5
                else:
                    reward += 14

    else:
        # Agent has attempted an invalid action (e.g. moving out of bounds or into a wall)
        reward -= 1

    return reward



In [23]:
 # initializing world
WORLD_SIZE = 3 # alter as necessary
world = PDWorld(size=(WORLD_SIZE, WORLD_SIZE, WORLD_SIZE))
agentM = Agent(name='M', position=(1, 1, 1), q_table=None, alpha=0.1, gamma=0.9, epsilon=0.1, world=world)
agentF = Agent(name='F', position=(2, 2, 2), q_table=None, alpha=0.1, gamma=0.9, epsilon=0.1, world=world)
agents = [agentM, agentF]

In [24]:
world.plot_world(agents)

<IPython.core.display.Javascript object>

In [24]:
# Experiment 1: