In [1]:
%pip install -q git+https://github.com/Farama-Foundation/MAgent2

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for magent2 (pyproject.toml) ... [?25l[?25hdone
Note: you may need to restart the kernel to use updated packages.


In [2]:
from magent2.environments import battle_v4
import torch
import torch.nn as nn
import numpy as np
import collections
import random

try:
    from tqdm import tqdm
except ImportError:
    tqdm = lambda x, *args, **kwargs: x  # Fallback: tqdm becomes a no-op

def compute_output_dim(input_dim, kernel_size, stride, padding):
    return (input_dim - kernel_size + 2 * padding) // stride + 1

device = "cuda" if torch.cuda.is_available() else "cpu"

red_path = '/kaggle/input/red-model/red.pt' # 'red.pt'
red_final_path = '/kaggle/input/red-model/red_final.pt' # 'red_final.pt'
vdn_path = '/kaggle/input/vdn-model/vdn.pth' # 'vdn.pth'

In [3]:
class QNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, action_shape),
        )

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        return self.network(x)

  and should_run_async(code)


In [4]:
class FinalQNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            # nn.LayerNorm(120),
            nn.ReLU(),
            nn.Linear(120, 84),
            # nn.LayerNorm(84),
            nn.Tanh(),
        )
        self.last_layer = nn.Linear(84, action_shape)

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        x = self.network(x)
        self.last_latent = x
        return self.last_layer(x)

In [5]:
class TeamManager:

    def __init__(self, agents, my_team = None):
        self.agents = agents
        self.teams = self.group_agents()
        self.terminated_agents = set()
        self.my_team = my_team
        self.random_agents = None
        self.get_random_agents(1)

    def get_teams(self):
        """
        Get the team names.
        :return: a list of team names
        """
        return list(self.teams.keys())

    def get_my_team(self):
        if self.my_team is not None:
            return self.my_team
        else:
            my_team = self.get_teams()[1]
        self.my_team = my_team
        return my_team

    def get_team_agents(self, team):
        """
        Get the agents in a team.
        :param team: the team name
        :return: a list of agent names in the team
        """
        assert team in self.teams, f"Team [{team}] not found."
        return self.teams[team]

    def get_my_agents(self):
        return self.get_team_agents(self.get_my_team())

    def group_agents(self):
        """
        Group agents by their team.
        :param agents: a list of agent names in the format of teamname_agentid
        :return: a dictionary with team names as keys and a list of agent names as values
        """
        teams = collections.defaultdict(list)
        for agent in self.agents:
            team, _ = agent.split('_')
            teams[team].append(agent)
        return teams

    def get_info_of_team(self, team, data, default=None):
        """
        Get the information of a team.
        :param team: the team name
        :param data: the data to get information from
        :return: a dictionary with the team name as key and the information as value
        """
        assert team in self.teams, f"Team [{team}] not found."
        result = {}
        for agent in self.get_team_agents(team):
            if agent not in data:
                result[agent] = default
            else:
                result[agent] = data[agent]
        return result
    
    def reset(self):
        self.terminated_agents = set()

    def is_team_terminated(self, team):
        """
        Check if all agents in a team are terminated.
        :param team: the team name
        :return: True if all agents in the team are terminated, False otherwise
        """
        assert team in self.teams, f"Team [{team}] not found."
        return all(agent in self.terminated_agents for agent in self.teams[team])

    def terminate_agent(self, agent):
        """
        Mark an agent as terminated.
        :param agent:
        :return:
        """
        self.terminated_agents.add(agent)

    def has_terminated_teams(self):
        """
        Check if any team is terminated.
        """
        for team in self.teams:
            if self.is_team_terminated(team):
                return True
        return False

    def get_my_terminated_agents(self):
        return list(self.terminated_agents.intersection(self.get_my_agents()))

    def get_other_team_remains(self):
        """
        Get the remaining agents in the other team.
        :return:
        """
        my_team = self.get_my_team()
        other_team = [team for team in self.teams if team != my_team][0]
        return [agent for agent in self.get_team_agents(other_team) if agent not in self.terminated_agents]


    def get_random_agents(self, rate):
        """
        Create a random agent list, and return the first n agents.
        :param rate: the rate of random agents to return
        :return: a list of random agents with the length of rate * num_agents
        """
        num_agents = len(self.get_my_agents())
        if self.random_agents is not None:
            num_random_agents = int(num_agents * rate)
            return self.random_agents[:num_random_agents]
        else:
            self.random_agents = random.sample(self.get_my_agents(), num_agents)
            return self.get_random_agents(rate)

    @staticmethod
    def merge_terminates_truncates(terminates, truncates):
        """
        Merge terminates and truncates into one dictionary.
        :param terminates: a dictionary with agent names as keys and boolean values as values
        :param truncates: a dictionary with agent names as keys and boolean values as values
        :return: a dictionary with agent names as keys and boolean values as values
        """
        result = {}
        for agent in terminates:
            result[agent] = terminates[agent] or truncates[agent]
        return result

In [6]:
class VdnQNet(nn.Module):
    def __init__(self, agents, observation_spaces, action_spaces, recurrent=False):
        super(VdnQNet, self).__init__()
        self.agents = agents
        self.num_agents = len(agents)
        self.recurrent = recurrent
        self.hx_size = 32   # latent repr size
        self.n_obs = observation_spaces[agents[0]].shape    # observation space size of agents
        self.n_act = action_spaces[agents[0]].n  # action space size of agents

        stride1, stride2 = 1, 1
        padding1, padding2 = 1, 1
        kernel_size1, kernel_size2 = 3, 3
        pool_kernel_size, pool_stride = 2, 2

        height = self.n_obs[0]  # n_obs is a tuple (height, width, channels)
        out_dim1 = compute_output_dim(height, kernel_size1, stride1, padding1) // pool_stride
        out_dim2 = compute_output_dim(out_dim1, kernel_size2, stride2, padding2) // pool_stride

        # Compute the final flattened size
        flattened_size = out_dim2 * out_dim2 * 64
        self.feature_cnn = nn.Sequential(
            nn.Conv2d(self.n_obs[2], 32, kernel_size=kernel_size1, stride=stride1, padding=padding1),
            nn.MaxPool2d(kernel_size=pool_kernel_size, stride=pool_stride),
            nn.Conv2d(32, 64, kernel_size=kernel_size2, stride=stride2, padding=padding2),
            nn.MaxPool2d(kernel_size=pool_kernel_size, stride=pool_stride),
            nn.Flatten(),
            nn.Linear(flattened_size, self.hx_size),
            nn.ReLU()
        )
        if recurrent:
            self.gru =  nn.GRUCell(self.hx_size, self.hx_size)  # shape: hx_size, hx_size
        self.q_val = nn.Linear(self.hx_size, self.n_act)    # shape: hx_size, n_actions

    def forward(self, obs, hidden):
        """Predict q values for each agent's actions in the batch
        :param obs: [batch_size, num_agents, ...n_obs]
        :param hidden: [batch_size, num_agents, hx_size]
        :return: q_values: [batch_size, num_agents, n_actions], hidden: [batch_size, num_agents, hx_size]
        """
        # TODO: need to have a done_mask param
        obs = obs.to(device)
        hidden = hidden.to(device)
        
        batch_size, num_agents, height, width, channels = obs.shape
        obs = obs.permute(0, 1, 4, 2, 3)  # (batch_size, num_agents, channels, height, width)
        obs = obs.reshape(batch_size * num_agents, channels, height, width)  # (batch_size * num_agents, channels, height, width)
        
        # Forward qua feature_cnn cho tất cả agents cùng lúc
        x = self.feature_cnn(obs)  # (batch_size * num_agents, hx_size)
        
        # Nếu recurrent, xử lý qua GRU
        if self.recurrent:
            hidden = hidden.reshape(batch_size * num_agents, -1)  # (batch_size * num_agents, hx_size)
            x = self.gru(x, hidden)  # (batch_size * num_agents, hx_size)
        
        # Dự đoán Q-values cho tất cả agents
        q_values = self.q_val(x)  # (batch_size * num_agents, n_actions)
        
        # Reshape lại để trả về đúng kích thước
        q_values = q_values.view(batch_size, num_agents, -1)  # (batch_size, num_agents, n_actions)
        
        if self.recurrent:
            next_hidden = x.view(batch_size, num_agents, -1)  # (batch_size, num_agents, hx_size)
        else:
            next_hidden = hidden.view(batch_size, num_agents, -1)
        
        return q_values, next_hidden

    def sample_action(self, obs, hidden, epsilon=1e3):
        """Choose action with epsilon-greedy policy, for each agent in the batch
        :param obs: a batch of observations, [batch_size, num_agents, n_obs]
        :param hidden: a batch of hidden states, [batch_size, num_agents, hx_size]
        :param epsilon: exploration rate
        :return: actions: [batch_size, num_agents], hidden: [batch_size, num_agents, hx_size]
        """
        # TODO: need to have a done_mask param
        obs = obs.to(device)
        hidden = hidden.to(device)
        
        q_values, hidden = self.forward(obs, hidden)    # [batch_size, num_agents, n_actions], [batch_size, num_agents, hx_size]
        # epsilon-greedy action selection: choose random action with epsilon probability
        mask = (torch.rand((q_values.shape[0],), device=device) <= epsilon)  # [batch_size]
        actions = torch.empty((q_values.shape[0], q_values.shape[1]), device=device)  # [batch_size, num_agents]
        actions[mask] = torch.randint(0, q_values.shape[2], actions[mask].shape, device=device).float()
        actions[~mask] = q_values[~mask].argmax(dim=2).float()  # choose action with max q value
        return actions, hidden   # [batch_size, num_agents], [batch_size, num_agents, hx_size]

    def init_hidden(self, batch_size=1):
        return torch.zeros((batch_size, self.num_agents, self.hx_size), device=device)

In [7]:
max_cycles = 300
env = battle_v4.env(map_size=45, max_cycles=max_cycles)
device = "cuda" if torch.cuda.is_available() else "cpu"

def random_policy(env, agent, obs):
    return env.action_space(agent).sample()

q_network = QNetwork(
    env.observation_space("red_0").shape, env.action_space("red_0").n
)
q_network.load_state_dict(
    torch.load(red_path, weights_only=True, map_location="cpu")
)
q_network.to(device)


final_q_network = FinalQNetwork(
    env.observation_space("red_0").shape, env.action_space("red_0").n
)
final_q_network.load_state_dict(
    torch.load(red_final_path, weights_only=True, map_location="cpu")
)
final_q_network.to(device)


env.reset()
team_manager = TeamManager(env.agents)

vdn_network = VdnQNet(
    team_manager.get_my_agents(), env.observation_spaces, env.action_spaces
)
vdn_network.load_state_dict(
    torch.load(vdn_path, weights_only=True, map_location="cpu")
)
vdn_network.to(device)


def pretrain_policy(env, agent, obs):
    observation = (
        torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
    )
    with torch.no_grad():
        q_values = q_network(observation)
    return torch.argmax(q_values, dim=1).cpu().numpy()[0]

def final_pretrain_policy(env, agent, obs):
    observation = (
        torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
    )
    with torch.no_grad():
        q_values = final_q_network(observation)
    return torch.argmax(q_values, dim=1).cpu().numpy()[0]

def blue_policy(env, agent, obs):
    obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
    hidden = vdn_network.init_hidden(batch_size=1)

    with torch.no_grad():
        actions, team_hiddens = vdn_network.sample_action(obs_tensor, hidden, epsilon=0)
        selected_action = int(actions.item())
    return selected_action

def run_eval(env, red_policy, blue_policy, n_episode: int = 100):
    red_win, blue_win = [], []
    red_tot_rw, blue_tot_rw = [], []
    n_agent_each_team = len(env.env.action_spaces) // 2

    for _ in tqdm(range(n_episode)):
        env.reset()
        n_kill = {"red": 0, "blue": 0}
        red_reward, blue_reward = 0, 0

        for agent in env.agent_iter():
            observation, reward, termination, truncation, info = env.last()
            agent_team = agent.split("_")[0]

            n_kill[agent_team] += (
                reward > 4.5
            )  # This assumes default reward settups
            if agent_team == "red":
                red_reward += reward
            else:
                blue_reward += reward

            if termination or truncation:
                action = None  # this agent has died
            else:
                if agent_team == "red":
                    action = red_policy(env, agent, observation)
                else:
                    action = blue_policy(env, agent, observation)

            env.step(action)

        who_wins = "red" if n_kill["red"] >= n_kill["blue"] + 5 else "draw"
        who_wins = "blue" if n_kill["red"] + 5 <= n_kill["blue"] else who_wins
        red_win.append(who_wins == "red")
        blue_win.append(who_wins == "blue")

        red_tot_rw.append(red_reward / n_agent_each_team)
        blue_tot_rw.append(blue_reward / n_agent_each_team)

    return {
        "winrate_red": np.mean(red_win),
        "winrate_blue": np.mean(blue_win),
        "average_rewards_red": np.mean(red_tot_rw),
        "average_rewards_blue": np.mean(blue_tot_rw),
    }

print("=" * 20)
print("Eval with random policy")
print(
    run_eval(
        env=env, red_policy=random_policy, blue_policy=blue_policy, n_episode=30
    )
)
print("=" * 20)

print("Eval with trained policy")
print(
    run_eval(
        env=env, red_policy=pretrain_policy, blue_policy=blue_policy, n_episode=30
    )
)
print("=" * 20)

print("Eval with final trained policy")
print(
    run_eval(
        env=env,
        red_policy=final_pretrain_policy,
        blue_policy=blue_policy,
        n_episode=30,
    )
)
print("=" * 20)



Eval with random policy


100%|██████████| 30/30 [08:09<00:00, 16.33s/it]


{'winrate_red': 0.0, 'winrate_blue': 1.0, 'average_rewards_red': -3.961059815738994, 'average_rewards_blue': 2.53202465689423}
Eval with trained policy


100%|██████████| 30/30 [04:52<00:00,  9.76s/it]


{'winrate_red': 0.0, 'winrate_blue': 0.9666666666666667, 'average_rewards_red': 2.5313415457264603, 'average_rewards_blue': 2.469069887490352}
Eval with final trained policy


100%|██████████| 30/30 [07:13<00:00, 14.44s/it]

{'winrate_red': 0.03333333333333333, 'winrate_blue': 0.7666666666666667, 'average_rewards_red': 0.8238024302942241, 'average_rewards_blue': 1.8392366024504563}



