# Environment

In [22]:
import torch
import copy
import numpy as np
import numpy.random as npr

In [23]:
# Define the grid world environment
class GridWorldEnvironment:
    def __init__(self, size=5, agents_num=4):
        self.size = size
        self.agents_num = agents_num
        self.agents_positions = {}  # agent position
        self.agents_reached_A = {}  # if agents get item
        self.A_position = None
        self.B_position = (size - 1, size - 1)  # fixed location of B
        self.directions = ["up", "down", "left", "right"]
        self.total_collisions = 0
        self.total_steps = 0
        self.agents_idx = list(range(agents_num))
        self._reset()

    def _reset(self):
        """
        Reset the environment to its initial state.
        """
        # initialize A position
        self.A_position = (
            npr.randint(0, self.size - 1),
            npr.randint(0, self.size - 1),
        )
        # ensure A and B are not in the same position
        while self.A_position == self.B_position:
            self.A_position = (
                npr.randint(0, self.size - 1),
                npr.randint(0, self.size - 1),
            )

        # initialize agents' positions and reached_A status
        self.agents_positions = {}
        self.agents_reached_A = {}
        for idx in self.agents_idx:
            if npr.rand() < 0.5:
                self.agents_positions[idx] = self.A_position
                self.agents_reached_A[idx] = True
            else:
                self.agents_positions[idx] = self.B_position
                self.agents_reached_A[idx] = False

        self.total_collisions = 0
        self.total_steps = 0

    def _get_destination(self, agent_idx):
        """
        Get the destination position(A or B)
        """
        return "B" if self.agents_reached_A[agent_idx] else "A"

    def _find_nearby_collision_agents(self, agent_id):
        """
        Find nearby agents that might collide.
        """
        y, x = self.agents_positions[agent_id]
        destination_cur = self._get_destination(agent_id)
        nearby_agents = [
            (-1, -1),
            (-1, 0),
            (-1, 1),
            (0, -1),
            (0, 1),
            (1, -1),
            (1, 0),
            (1, 1),
        ]
        collision_status = []
        for dy, dx in nearby_agents:
            new_y, new_x = y + dy, x + dx
            # Check if new position is valid
            if 0 <= new_y < self.size and 0 <= new_x < self.size:
                collision = False
                for other_agent_id in self.agents_idx:
                    if (
                        other_agent_id != agent_id
                        and self.agents_positions[other_agent_id] == (new_y, new_x)
                        and self._get_destination(other_agent_id) == destination_cur
                    ):  # agents are going to the same destination would cause collision
                        collision = 1
                collision_status.append(collision)
            else:
                collision_status.append(0)

    def get_state(self, agent_idx):
        """
        Get the state of the environment for a specific agent.
        """
        position = self.agents_positions[agent_idx]
        reached_A = self.agents_reached_A[agent_idx]
        manhattan_distance_to_A = (self.A_position[0] - position[0]) + (
            self.A_position[1] - position[1]
        )
        manhattan_distance_to_B = (self.B_position[0] - position[0]) + (
            self.B_position[1] - position[1]
        )
        collision_agents = self._find_nearby_collision_agents(agent_idx)

        return np.array(
            [
                position[0],
                position[1],
                self.A_position[0],
                self.A_position[1],
                self.B_position[0],
                self.B_position[1],
                manhattan_distance_to_A[0],
                manhattan_distance_to_A[1],
                manhattan_distance_to_B[0],
                manhattan_distance_to_B[1],
                reached_A,
                *collision_agents,
            ]
        )

    def _check_done(self, agent_idx):
        """
        Check if the agent has reached its destination.
        """
        if (
            self.agents_positions[agent_idx] == self.B_position
            and self.agents_reached_A[agent_idx]
        ):
            return True
        else:
            return False

    def take_action(self, action_dict):
        """
        Take an action in the environment and return the next state, reward and collosions.
        """
        planned_actions = {}  # {action_idx: action}
        wall_collisions = []  # number of hitting wall

        for agent_idx, action in action_dict.items():
            y, x = self.agents_positions[agent_idx]
            if self.agents_positions[agent_idx] == "up":
                new_y, new_x = y - 1, x
            elif self.agents_positions[agent_idx] == "down":
                new_y, new_x = y + 1, x
            elif self.agents_positions[agent_idx] == "left":
                new_y, new_x = y, x - 1
            elif self.agents_positions[agent_idx] == "right":
                new_y, new_x = y, x + 1

            # check valid
            if 0 <= new_y < self.size and 0 <= new_x < self.size:
                planned_actions[agent_idx] = (new_y, new_x)  # move
                wall_collisions.append(False)
            else:
                planned_actions[agent_idx] = (y, x)  # not move
                wall_collisions.append(True)

        # check collision
        next_positions = copy.deepcopy(self.agents_positions)
        for idx in self.agents_idx:
            next_positions[idx] = planned_actions[idx]

        collisions = 0  # number of head-on collisions
        position_to_agents = {}
        for agent_idx, pos in next_positions.items():
            if pos not in position_to_agents:
                position_to_agents[pos] = []
            position_to_agents[pos].append(agent_idx)

        agents_collisions = set()
        for pos, agents_cur in position_to_agents.items():
            if len(agents_cur) > 1:
                dirs = [self._get_destination(agent_idx) for agent_idx in agents_cur]
                if "B" in dirs and "A" in dirs:
                    collisions += 1
                    agents_collisions.update(agents_cur)

        # update agents' positions
        self.agents_positions = next_positions

        # calculate rewards
        rewards = {}
        for agent_idx in self.agents_idx:
            reward = 0
            if wall_collisions[agent_idx]:
                reward -= 5  # hitting wall penalty
            else:
                reward -= 1  # step cost

            location = self.agents_positions[agent_idx]
            if self.agents_reached_A[agent_idx]:
                if location == self.B_position:
                    reward += 100  # delivery success
                    self.agents_reached_A[agent_idx] = False
            else:
                if location == self.A_position:
                    reward += 50  # pickup success
                    self.agents_reached_A[agent_idx] = True

        rewards[agent_idx] = reward  # store reward

        # accumulate total collisions and steps
        self.total_collisions += collisions
        self.total_steps += 1

        # format next state
        next_states = {}
        for agent_idx in self.agents_idx:
            next_states[agent_idx] = self.get_state(agent_idx)

        return next_states, rewards, collisions

# Agent

In [24]:
# deep q-learning agent
class Agent:
    def __init__(
        self,
        statespace_size,
        action_size,
        gamma=0.9,
        epsilon=1,
        epsilon_decay=0.9995,
        min_epsilon=0.1,
        batch_size=200,
        replay_buffer_size=1000,
        lr=0.001,
        copy_frequency=100,
    ):
        self.statespace_size = statespace_size
        self.action_size = action_size
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.min_epsilon = min_epsilon
        self.batch_size = batch_size
        self.replay_buffer_size = replay_buffer_size
        self.lr = lr
        self.copy_frequency = copy_frequency

        self.steps = 0  # count agent's steps
        self.replay_buffer = []  # memory
        self.replay_buffer_size = replay_buffer_size  # memory size

        # initialize the DQN
        self.model, self.model2, self.optimizer, self.loss_fn = self.prepare_torch()

    def prepare_torch(self):
        l1 = self.statespace_size
        l2 = 24
        l3 = 24
        l4 = 4
        model = torch.nn.Sequential(
            torch.nn.Linear(l1, l2),
            torch.nn.ReLU(),
            torch.nn.Linear(l2, l3),
            torch.nn.ReLU(),
            torch.nn.Linear(l3, l4),
        )
        model2 = copy.deepcopy(model)
        model2.load_state_dict(model.state_dict())
        loss_fn = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=self)
        return model, model2, optimizer, loss_fn

    def update_target(self):
        self.model2.load_state_dict(self.model.state_dict())

    def get_qvals(self, state):
        state1 = torch.from_numpy(state).float()
        qvals_torch = self.model(state1)
        qvals = qvals_torch.data.numpy()
        return qvals

    def get_maxQ(self, s):
        return torch.max(self.model2(torch.from_numpy(s).float())).float()

    def get_action(self, state):
        if npr.uniform() < self.epsilon:
            action = npr.choice(self.action_size)
        else:
            qvals = self.get_qvals(state)
            action = np.argmax(qvals)
        return action

    def store_transition(self, state, action, reward, next_state):
        """
        Store the transition in the replay buffer.
        """
        if len(self.replay_buffer) >= self.replay_buffer_size:
            self.replay_buffer.pop(0)
        self.replay_buffer.append((state, action, reward, next_state))

    def train_one_step(self, states, actions, targets):
        # convert the states and actions and targets to tensors
        states = np.array(states)
        actions = np.array(actions)
        targets = np.array(targets)

        state_batch = torch.tensor(states, dtype=torch.float32)
        action_batch = torch.tensor(actions, dtype=torch.long)
        target_batch = torch.tensor(targets, dtype=torch.float32)

        # get Q-values for the current states
        q_values = self.model(state_batch)
        predicted_q_values = q_values.gather(1, action_batch.unsqueeze(1)).squeeze()

        # calculate the loss
        loss = self.loss_fn(predicted_q_values, target_batch)

        # backpropagation
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def train(self):
        """
        Train the agent using the replay buffer.
        """
        if len(self.replay_buffer) < self.batch_size:
            return  # samples not enough

        # sample a batch from the replay buffer
        minibatch = npr.sample(
            self.replay_buffer,
            self.batch_size,
        )
        states, actions, rewards, next_states = zip(*minibatch)

        # TD targets
        targets = []
        for i in range(len(minibatch)):
            next_maxQ = self.get_maxQ(next_states[i])
            action_target = rewards[i] + self.gamma * next_maxQ
            targets.append(action_target)

        # train the model
        loss = self.train_one_step(states, actions, targets)

        # update network periodically
        self.steps += 1
        if self.steps % self.copy_frequency == 0:
            self.update_target()

        return loss

    # decay epsilon
    def decay_epsilon(self):
        self.epsilon = max(self.min_epsilon, self.epsilon * self.epsilon_decay)

# Training

In [25]:
from time import time


def train_agents(
    agent, env, max_steps=1500000, max_collisions=4000, max_walltime=600, verbose=True
):
    """Training each agent in the environment."""
    # start time
    start_time = time.time()

    # initialize the environment
    env._reset()

    # initial states
    states = {agent_idx: env.get_state(agent_idx) for agent_idx in env.agents_idx}

    # global variables
    total_collisions = 0
    total_steps = 0
    episode = 0

    while total_collisions <= max_collisions and total_steps <= max_steps:
        if time.time() - start_time > max_walltime:
            print("===== Time limit exceeded. =====")
            break

        actions_dict = {}
        for agent_idx in env.agents_idx:  # central clock - fix order
            action = agent.get_action(states[agent_idx])
            actions_dict[agent_idx] = action

        # take action in the environment
        next_states, rewards, collisions = env.take_action(actions_dict)

        # store transition in replay buffer
        for agent_idx in env.agents_idx:
            state = states[agent_idx]
            action = actions_dict[agent_idx]
            reward = rewards[agent_idx]
            next_state = next_states[agent_idx]
            agent.store_transition(state, action, reward, next_state)

        # train the agent
        if len(agent.replay_buffer) >= agent.batch_size:
            loss = agent.train()

        # update the states
        states = next_states

        total_collisions += collisions
        total_steps += 1
        episode += 1

        agent.decay_epsilon()
        if verbose and total_steps % 10000 == len(env.agents_idx):
            elapsed_time = time.time() - start_time
            print(
                f"Steps: {total_steps}/{max_steps}, Collisions: {total_collisions}/{max_collisions}, Epsilon: {agent.epsilon:.3f}, Time Elapsed: {elapsed_time:.2f}s"
            )

    print("Training completed.")
    print(f"Total steps: {total_steps}")
    print(f"Total collisions: {total_collisions}")
    print(f"Final epsilon: {agent.epsilon:.3f}")

# Test