In [6]:
%%capture
!git clone https://github.com/cuongtv312/marl-delivery.git
%cd marl-delivery
!pip install -r requirements.txt

In [7]:
%%capture
!pip install stable-baselines3

In [23]:
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import time
from collections import deque
from tqdm.auto import tqdm
from env import Environment
from greedyagent import GreedyAgents
from torch.nn import TransformerEncoder, TransformerEncoderLayer

# Hyperparameters
GAMMA = 0.99
LAMBDA = 0.95
CLIP_EPS = 0.3
LR_INITIAL = 2e-4
LR_PPO = 1e-4
PPO_EPOCHS = 16
MINIBATCH_SIZE = 256
HORIZON = 512
ENTROPY_COEF_HIGH = 1.0
ENTROPY_COEF_LOW = 0.01
VALUE_COEF = 0.5
MAX_GRAD_NORM = 0.5
MAX_EPISODES = 100
BC_EPOCHS = 10
BC_BATCH_SIZE = 128
UPDATES = 2000

torch.set_num_threads(1)

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)

class RunningMeanStd:
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = np.zeros(shape, 'float64')
        self.var = np.ones(shape, 'float64')
        self.count = epsilon

    def update(self, x):
        x = np.array(x)
        batch_mean = np.mean(x, axis=0)
        batch_var = np.var(x, axis=0)
        batch_count = x.shape[0]
        delta = batch_mean - self.mean
        tot_count = self.count + batch_count
        m_a = self.var * self.count
        m_b = batch_var * batch_count
        M2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
        self.mean = self.mean + delta * batch_count / tot_count
        self.var = M2 / tot_count
        self.count = tot_count

    def normalize(self, x):
        return (x - self.mean) / np.sqrt(self.var + 1e-8)

class StateEncoder(nn.Module):
    def __init__(self, n_agents, embed_dim=256, d_model=64, nhead=4, num_layers=2):
        super().__init__()
        self.n_agents = n_agents
        self.cnn = nn.Sequential(
            nn.Conv2d(6, 64, kernel_size=3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1,1)), nn.Flatten()
        )
        self.agent_embed = nn.Linear(6, d_model)
        self.pos_embed = nn.Embedding(n_agents, d_model)
        encoder_layer = TransformerEncoderLayer(d_model, nhead, batch_first=True)
        self.transformer = TransformerEncoder(encoder_layer, num_layers)
        total_in = 256 + d_model + 1
        self.mlp = nn.Sequential(
            nn.Linear(total_in, embed_dim), nn.ReLU(),
            nn.Linear(embed_dim, embed_dim), nn.ReLU()
        )

    def forward(self, spatial, non_spatial):
        B = spatial.size(0)
        cnn_out = self.cnn(spatial)
        pa = non_spatial[:, :self.n_agents*6].view(B, self.n_agents, 6)
        em = self.agent_embed(pa)
        idx = torch.arange(self.n_agents, device=spatial.device).unsqueeze(0).repeat(B,1)
        pe = self.pos_embed(idx)
        tr_in = em + pe
        tr_out = self.transformer(tr_in).mean(dim=1)
        t = non_spatial[:, -1].unsqueeze(1)
        x = torch.cat([cnn_out, tr_out, t], dim=1)
        return self.mlp(x)

class ActorCritic(nn.Module):
    def __init__(self, embed_dim, n_agents, action_per_agent):
        super().__init__()
        self.n_agents = n_agents
        self.ap = action_per_agent
        self.actor = nn.Linear(embed_dim, n_agents * action_per_agent)
        self.critic = nn.Linear(embed_dim, n_agents)

    def forward(self, h):
        B = h.size(0)
        logits = self.actor(h).view(B, self.n_agents, self.ap)
        dists = [torch.distributions.Categorical(logits=logits[:,i,:]) for i in range(self.n_agents)]
        val = self.critic(h)
        return dists, val

class MAPPO:
    def __init__(self, env, device, max_steps, args):
        self.env = env
        self.device = device
        self.max_steps = max_steps
        self.args = args
        n = env.n_robots
        ap = 5 * 3
        if min(env.n_rows, env.n_cols) < 2:
            self.encoder = StateEncoder(n_agents=n, embed_dim=256, d_model=64, nhead=4, num_layers=2).to(device)
            self.encoder.cnn = nn.Sequential(
                nn.Conv2d(6, 64, kernel_size=3, padding=1), nn.ReLU(),
                nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(),
                nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.ReLU(),
                nn.AdaptiveAvgPool2d((1,1)), nn.Flatten()
            )
        else:
            self.encoder = StateEncoder(n_agents=n).to(device)
        self.ac = ActorCritic(embed_dim=256, n_agents=n, action_per_agent=ap).to(device)
        self.optimizer = optim.Adam(
            list(self.encoder.parameters()) + list(self.ac.parameters()), lr=LR_INITIAL)
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=100, gamma=0.9)
        self.writer = SummaryWriter()
        self.ep_rewards = deque(maxlen=MAX_EPISODES)
        self.best = -np.inf
        self.total_ts = 0
        self.norm = RunningMeanStd(shape=(n,))
        self.ent_coef = ENTROPY_COEF_HIGH

    def create_spatial_grid(self, state, pos, nearest_targets):
        n_rows, n_cols = self.env.n_rows, self.env.n_cols
        grid = torch.zeros((6, n_rows, n_cols), dtype=torch.float32, device=self.device)
        grid[0] = torch.tensor(state['map'], dtype=torch.float32, device=self.device)
        for robot in state['robots']:
            r, c = int(robot[0]), int(robot[1])
            if 0 <= r < n_rows and 0 <= c < n_cols:
                grid[1, r, c] = 1.0
        for robot in state['robots']:
            r, c = int(robot[0]), int(robot[1])
            if 0 <= r < n_rows and 0 <= c < n_cols and robot[2] > -1:
                grid[2, r, c] = 1.0
        waiting = {(int(p[1]), int(p[2])) for p in state['packages'] if p[5] == -1}
        for r, c in waiting:
            if 0 <= r < n_rows and 0 <= c < n_cols:
                grid[3, r, c] = 1.0
        ends = {(int(p[3]), int(p[4])) for p in state['packages']}
        for r, c in ends:
            if 0 <= r < n_rows and 0 <= c < n_cols:
                grid[4, r, c] = 1.0
        for target in nearest_targets:
            if target is not None:
                r, c = target
                if 0 <= r < n_rows and 0 <= c < n_cols:
                    grid[5, r, c] = 1.0
        return grid

    def get_distance(self, r1, c1, r2, c2):
        return abs(r1 - r2) + abs(c1 - c2)

    def compute_potential(self, state):
        robots = np.array(state['robots'])
        pkg_list = state['packages']
        potential = 0.0
        for i in range(self.env.n_robots):
            r_robot, c_robot = int(robots[i][0]), int(robots[i][1])
            carrying = robots[i][2] > -1
            if carrying:
                pkg_id = int(robots[i][2])
                for p in pkg_list:
                    if p[0] == pkg_id and p[5] != -1:
                        r_target, c_target = int(p[3]), int(p[4])
                        potential -= self.get_distance(r_robot, c_robot, r_target, c_target)
                        break
            else:
                waiting_packages = [p for p in pkg_list if p[5] == -1]
                if waiting_packages:
                    dists = [self.get_distance(r_robot, c_robot, int(p[1]), int(p[2])) for p in waiting_packages]
                    potential -= min(dists)
        return potential

    def convert_state(self, state, t):
        robots = np.array(state['robots'], dtype=np.float32)
        carry_status = (robots[:, 2] > -1).astype(np.float32)
        pkg_list = state['packages']
        pos = robots[:, :2]
        n_agents = self.env.n_robots
        max_rows_cols = max(self.env.n_rows, self.env.n_cols)
        normalized_pos = pos / max_rows_cols
        normalized_nearest_start = np.zeros((n_agents, 2), dtype=np.float32)
        normalized_nearest_end = np.zeros((n_agents, 2), dtype=np.float32)
        normalized_nearest_dead = np.zeros((n_agents, 1), dtype=np.float32)
        nearest_targets = []

        for i in range(n_agents):
            r_robot = int(robots[i][0])
            c_robot = int(robots[i][1])
            if carry_status[i] == 0:
                waiting_packages = [p for p in pkg_list if p[5] == -1]
                if waiting_packages:
                    dists = [self.get_distance(r_robot, c_robot, int(p[1]), int(p[2])) for p in waiting_packages]
                    idx_min = np.argmin(dists)
                    nearest_pkg = waiting_packages[idx_min]
                    normalized_nearest_start[i] = [nearest_pkg[1] / max_rows_cols, nearest_pkg[2] / max_rows_cols]
                    normalized_nearest_end[i] = [nearest_pkg[3] / max_rows_cols, nearest_pkg[4] / max_rows_cols]
                    normalized_nearest_dead[i] = [nearest_pkg[6] / self.max_steps]
                    nearest_targets.append((int(nearest_pkg[1]), int(nearest_pkg[2])))
                else:
                    normalized_nearest_start[i] = [0, 0]
                    normalized_nearest_end[i] = [0, 0]
                    normalized_nearest_dead[i] = [0]
                    nearest_targets.append(None)
            else:
                pkg_id = int(robots[i][2])
                found = False
                for p in pkg_list:
                    if p[0] == pkg_id and p[5] != -1:
                        normalized_nearest_start[i] = [p[1] / max_rows_cols, p[2] / max_rows_cols]
                        normalized_nearest_end[i] = [p[3] / max_rows_cols, p[4] / max_rows_cols]
                        normalized_nearest_dead[i] = [p[6] / self.max_steps]
                        nearest_targets.append((int(p[3]), int(p[4])))
                        found = True
                        break
                if not found:
                    normalized_nearest_start[i] = [0, 0]
                    normalized_nearest_end[i] = [0, 0]
                    normalized_nearest_dead[i] = [0]
                    nearest_targets.append(None)

        t_norm = float(t) / self.max_steps
        non_spatial = np.concatenate([
            carry_status.flatten(),
            normalized_nearest_start.flatten(),
            normalized_nearest_end.flatten(),
            normalized_nearest_dead.flatten(),
            [t_norm]
        ], axis=0).astype(np.float32)
        spatial_grid = self.create_spatial_grid(state, pos, nearest_targets)
        return spatial_grid.clone().detach().to(self.device).unsqueeze(0), torch.tensor(non_spatial, device=self.device).unsqueeze(0)
    def rollout(self):
        obs, acts, lps, vals, rews, shaped_rews, masks = [], [], [], [], [], [], []
        state = self.env.reset()
        for t in range(HORIZON):
            spatial, non_spatial = self.convert_state(state, t)
            with torch.no_grad():
                h = self.encoder(spatial, non_spatial)
                dists, val = self.ac(h)
            actions = [dist.sample() for dist in dists]
            log_probs = [dists[j].log_prob(actions[j]).detach() for j in range(self.ac.n_agents)]
            action_array = torch.stack(actions, dim=1).cpu().numpy()
            Phi_current = self.compute_potential(state)
            # Sử dụng:
            integer_actions = [action.item() for action in actions]  # Chuyển tensor thành số nguyên
            env_actions = [( ['S','L','R','U','D'][a//3], str(a%3) ) for a in integer_actions]  # Chuyển sang tuple
            state, reward_list, done, _ = self.env.step(env_actions)  # Truyền danh sách tuple
            # Nếu env.step trả về một số, chuyển thành list lặp lại cho mỗi agent
            if isinstance(reward_list, (float, int)):
                reward_list = [reward_list] * self.env.n_robots
            # Nếu trả về numpy array, chuyển thành list Python
            elif isinstance(reward_list, np.ndarray):
                reward_list = reward_list.tolist()
            Phi_next = self.compute_potential(state)
            shaping_term = (GAMMA * Phi_next - Phi_current) / self.env.n_robots
            shaped_reward_list = [r + shaping_term for r in reward_list]


            mask = 0.0 if done else 1.0
            obs.append((spatial, non_spatial))
            acts.append(torch.stack(actions))
            lps.append(torch.stack(log_probs))
            vals.append(val.squeeze(0))
            rews.append(torch.tensor(reward_list, device=self.device, dtype=torch.float32))
            shaped_rews.append(torch.tensor(shaped_reward_list, device=self.device, dtype=torch.float32))
            masks.append(mask)
            if done:
                state = self.env.reset()
        spatial, non_spatial = self.convert_state(state, HORIZON)
        with torch.no_grad():
            h = self.encoder(spatial, non_spatial)
            _, val = self.ac(h)
        vals.append(val.squeeze(0))
        return obs, acts, lps, vals, shaped_rews, masks, rews

    def compute_gae(self, rewards, values, masks):
        returns = []
        gae = torch.zeros(self.ac.n_agents, device=self.device)
        for t in reversed(range(HORIZON)):
            delta = rewards[t] + GAMMA * values[t + 1] * masks[t] - values[t]
            gae = delta + GAMMA * LAMBDA * masks[t] * gae
            returns.insert(0, gae + values[t])
        returns = torch.stack(returns)
        advs = returns - torch.stack(values[:-1])
        advs = (advs - advs.mean()) / (advs.std() + 1e-8)
        return returns, advs

    def update(self, obs, acts, lps, vals, shaped_rews, masks):
        # 1. Tính returns và advantages qua GAE
        returns, advs = self.compute_gae(shaped_rews, vals, masks)
        self.norm.update(advs)

        total_policy_loss = 0.0
        total_value_loss  = 0.0
        total_entropy     = 0.0

        # Tạo permutation để sample minibatches
        indices = torch.randperm(HORIZON, device=self.device)

        for _ in range(PPO_EPOCHS):
            for start in range(0, HORIZON, MINIBATCH_SIZE):
                end       = start + MINIBATCH_SIZE
                batch_idx = indices[start:end]

                # — Chuẩn bị batch tensors —
                spatial_batch     = torch.cat([obs[t][0] for t in batch_idx], dim=0)  # [B', C, H, W]
                non_spatial_batch = torch.cat([obs[t][1] for t in batch_idx], dim=0)  # [B', D]
                act_batch         = torch.stack([acts[t] for t in batch_idx], dim=0).squeeze(-1)  # [B', n_agents]
                old_lp_batch      = torch.stack([lps[t] for t in batch_idx], dim=0).squeeze(-1)  # [B', n_agents]
                return_batch      = returns[batch_idx]    # [B', n_agents]
                adv_batch         = advs[batch_idx]       # [B', n_agents]

                # — Forward qua network —
                h     = self.encoder(spatial_batch, non_spatial_batch)  # [B', embed]
                dists, vals_pred = self.ac(h)                           # vals_pred: [B', n_agents]

                # Lấy logits thô và compute log_probs thủ công
                logits = self.ac.actor(h).view(-1, self.ac.n_agents, self.ac.ap)  # [B', n_agents, ap]
                logp   = F.log_softmax(logits, dim=-1)                           # [B', n_agents, ap]

                # Sử dụng advanced indexing
                B_, N = logp.size(0), logp.size(1)
                b_idx = torch.arange(B_, device=self.device).view(-1, 1).expand(-1, N)
                a_idx = torch.arange(N, device=self.device).view(1, -1).expand(B_, -1)
                new_lp = logp[b_idx, a_idx, act_batch]

                # entropy: trung bình entropy của từng agent
                entropy = (-(logp * logp.exp()).sum(dim=-1)).mean()

                # — PPO surrogate loss —
                ratios = torch.exp(new_lp - old_lp_batch)  # [B', n_agents]
                surr1  = ratios * adv_batch
                surr2  = torch.clamp(ratios, 1 - CLIP_EPS, 1 + CLIP_EPS) * adv_batch
                policy_loss = -torch.min(surr1, surr2).mean()

                # — Value loss —
                value_loss = F.mse_loss(vals_pred, return_batch)

                # — Tổng loss —
                loss = policy_loss + VALUE_COEF * value_loss - self.ent_coef * entropy

                # — Backprop & optimize —
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    list(self.encoder.parameters()) + list(self.ac.parameters()),
                    MAX_GRAD_NORM
                )
                self.optimizer.step()

                # — Cộng dồn để lấy trung bình cuối cùng —
                total_policy_loss += policy_loss.item()
                total_value_loss  += value_loss.item()
                total_entropy     += entropy.item()

        num_updates = PPO_EPOCHS * (HORIZON // MINIBATCH_SIZE)
        return (
            total_policy_loss / num_updates,
            total_value_loss  / num_updates,
            total_entropy     / num_updates
        )


    def train(self):
        if self.args.bc:
            expert = GreedyAgents(self.env, self.device)
            for epoch in range(BC_EPOCHS):
                loss_total = 0.0
                for _ in range(BC_BATCH_SIZE):
                    state = self.env.reset()
                    spatial, non_spatial = self.convert_state(state, 0)
                    expert_acts = expert.step(state)
                    h = self.encoder(spatial, non_spatial)
                    dists, _ = self.ac(h)
                    loss = 0.0
                    for j in range(self.ac.n_agents):
                        loss += -dists[j].log_prob(torch.tensor(expert_acts[j], device=self.device))
                    loss = loss.mean()
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                    loss_total += loss.item()
                self.writer.add_scalar('BC_Loss', loss_total / BC_BATCH_SIZE, epoch)
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = LR_PPO

        for it in tqdm(range(UPDATES)):
            self.ent_coef = ENTROPY_COEF_HIGH - (ENTROPY_COEF_HIGH - ENTROPY_COEF_LOW) * (it / UPDATES)
            obs, acts, lps, vals, shaped_rews, masks, rews = self.rollout()
            total_reward = sum([r.sum().item() for r in rews]) / self.env.n_robots
            self.ep_rewards.append(total_reward)
            # Ghi reward tổng cho lần lặp này
            self.writer.add_scalar('Reward/Total', total_reward, it)

            # Ghi reward trung bình qua các episode (tùy chọn)
            if len(self.ep_rewards) > 0:
                avg_reward = np.mean(self.ep_rewards)
                self.writer.add_scalar('Reward/Average', avg_reward, it)

            policy_loss, value_loss, entropy = self.update(obs, acts, lps, vals, shaped_rews, masks)
            self.writer.add_scalar('Reward', total_reward, it)
            self.writer.add_scalar('Policy_Loss', policy_loss, it)
            self.writer.add_scalar('Value_Loss', value_loss, it)
            self.writer.add_scalar('Entropy', entropy, it)
            self.writer.add_scalar('Learning_Rate', self.scheduler.get_last_lr()[0], it)
            print(f"Iteration {it}: Total Reward = {total_reward}, Average Reward = {avg_reward}")
            if total_reward > self.best:
                self.best = total_reward
                torch.save({
                    'encoder': self.encoder.state_dict(),
                    'ac': self.ac.state_dict(),
                    'optimizer': self.optimizer.state_dict()
                }, 'best_model.pth')
            self.total_ts += HORIZON
            self.scheduler.step()

    def evaluate_policy(self, env, num_episodes=10):
        total_rewards = []
        total_steps = []
        for _ in range(num_episodes):
            state = env.reset()
            done = False
            episode_reward = 0
            episode_steps = 0
            while not done:
                spa, non = self.convert_state(state, episode_steps)
                h = self.encoder(spa.unsqueeze(0), non.unsqueeze(0))
                dists, _ = self.ac(h)
                actions = [d.sample().item() for d in dists]
                env_actions = [( ['S','L','R','U','D'][a//3], str(a%3) ) for a in actions]
                state, reward, done, _ = env.step(env_actions)
                episode_reward += sum(reward) if isinstance(reward, (list, np.ndarray)) else reward
                episode_steps += 1
            total_rewards.append(episode_reward)
            total_steps.append(episode_steps)
        return np.mean(total_rewards), np.mean(total_steps)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--map', type=str, default='map2.txt')
    parser.add_argument('--num_agents', type=int, default=5)
    parser.add_argument('--n_packages', type=int, default=100)
    parser.add_argument('--max_steps', type=int, default=1000)
    parser.add_argument('--seed', type=int, default=10)
    parser.add_argument('--updates', type=int, default=1000)
    parser.add_argument('--bc', action='store_true', help='use behavioral cloning pretraining')
    parser.add_argument('--save_path', type=str, default='mappo_best.pth')
    args, _ = parser.parse_known_args()

    set_seed(args.seed)
    env = Environment(
        map_file=args.map,
        n_robots=args.num_agents,
        n_packages=args.n_packages,
        max_time_steps=args.max_steps,
        move_cost=-0.01,
        delivery_reward=10.0,
        delay_reward=1.0
    )
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = MAPPO(env, device, args.max_steps, args)
    model.train()

if __name__ == '__main__':
    main()

  0%|          | 0/2000 [00:00<?, ?it/s]

Iteration 0: Total Reward = -16.879999202489852, Cumulative Reward = 0.0, Average Reward = -16.879999202489852
Iteration 1: Total Reward = -15.569999206066132, Cumulative Reward = 0.0, Average Reward = -16.224999204277992
Iteration 2: Total Reward = -14.049999192357063, Cumulative Reward = 0.0, Average Reward = -15.49999920030435
Iteration 3: Total Reward = -14.439999336004258, Cumulative Reward = 0.0, Average Reward = -15.234999234229326
Iteration 4: Total Reward = -15.979999211430549, Cumulative Reward = 0.0, Average Reward = -15.38399922966957
Iteration 5: Total Reward = -14.859999188780785, Cumulative Reward = 0.0, Average Reward = -15.29666588952144
Iteration 6: Total Reward = -16.589999094605446, Cumulative Reward = 0.0, Average Reward = -15.481427775962013
Iteration 7: Total Reward = -6.309999829530716, Cumulative Reward = 0.0, Average Reward = -14.3349992826581
Iteration 8: Total Reward = -16.349999204277992, Cumulative Reward = 0.0, Average Reward = -14.558888162838088
Iterati

KeyboardInterrupt: 