In [17]:
!pip install -q -U git+https://github.com/Farama-Foundation/MAgent2

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [18]:
from magent2.environments import battle_v4
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
import random
from collections import deque
from tqdm import tqdm

In [19]:
env = battle_v4.parallel_env(map_size=45,max_cycles=300)

In [20]:
config = {
    "obs_shape": env.observation_space("red_0").shape,
    "action_dims": int(env.action_space("red_0").n),
    "learning_rate": 0.001,
    "epsilon": 1.0,
    "epsilon_decay": 0.998,
    "epsilon_min": 0.05,
    "gamma": 0.98, # discount
    "batch_size": 512,
    "tau": 0.005, # soft update,
    "red_update_interval": 2,
    "blue_update_interval": 2,
    "num_episode": 150,
    "num_step": 300,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
}

In [21]:
class DuelingQNetwork(nn.Module):
    def __init__(self, obs_shape, actions_dim):
        super(DuelingQNetwork, self).__init__()

        self.conv1 = nn.Conv2d(obs_shape[-1], 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)

        self.fc1 = nn.Linear(64 * obs_shape[0] * obs_shape[1], 128)

        # Separate streams for value and advantage
        self.value_stream = nn.Linear(128, 1)
        self.advantage_stream = nn.Linear(128, actions_dim)

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)  # (batch_size, width, height, channels) -> (batch_size, channels, width, height)

        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))

        x = x.reshape(x.size(0), -1)  # Flatten

        x = F.relu(self.fc1(x))

        value = self.value_stream(x)
        advantage = self.advantage_stream(x)

        # Combine value and advantage streams
        q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))

        return q_values

In [22]:
def test_qnetwork():
    obs_shape = (13, 13, 5)
    actions_dim = 21
    qnetwork = DuelingQNetwork(obs_shape, actions_dim)

    test_input = torch.randn(10, *obs_shape)
    output = qnetwork(test_input)
    print(output.shape)

test_qnetwork()

torch.Size([10, 21])


In [23]:
class ReplayBuffer:
  def __init__(self, buffer_size):
    self.buffer_size = buffer_size
    self.buffer = deque(maxlen=buffer_size)

  def add(self, state, action, reward, next_state, done):
    self.buffer.append((state, action, reward, next_state, done))

  def sample(self, batch_size, device):
    samples = random.sample(self.buffer, batch_size)

    states, actions, rewards, next_states, dones = zip(*samples)

    return (
        torch.cat(states).to(device),
        torch.tensor(actions, dtype=torch.long).unsqueeze(1).to(device),
        torch.tensor(rewards, dtype=torch.float32).unsqueeze(1).to(device),
        torch.cat(next_states).to(device),
        torch.tensor(dones, dtype=torch.float32).unsqueeze(1).to(device)
    )

  def __len__(self):
    return len(self.buffer)

In [24]:
def update_network(network, target_network):
  target_network.load_state_dict(network.state_dict())

In [25]:
red_q_network = DuelingQNetwork(
    obs_shape=config["obs_shape"],
    actions_dim=config["action_dims"]
).to(config["device"])

red_target_q_network = DuelingQNetwork(
    obs_shape=config["obs_shape"],
    actions_dim=config["action_dims"]
).to(config["device"])

red_optimizer = optim.Adam(red_q_network.parameters(), lr=config["learning_rate"])
red_buffer = ReplayBuffer(buffer_size=100000)

update_network(red_q_network, red_target_q_network)
red_target_q_network.eval()

DuelingQNetwork(
  (conv1): Conv2d(5, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=10816, out_features=128, bias=True)
  (value_stream): Linear(in_features=128, out_features=1, bias=True)
  (advantage_stream): Linear(in_features=128, out_features=21, bias=True)
)

In [26]:
blue_q_network = DuelingQNetwork(
    obs_shape=config["obs_shape"],
    actions_dim=config["action_dims"]
).to(config["device"])

blue_target_q_network = DuelingQNetwork(
    obs_shape=config["obs_shape"],
    actions_dim=config["action_dims"]
).to(config["device"])

blue_optimizer = optim.Adam(blue_q_network.parameters(), lr=config["learning_rate"])
blue_buffer = ReplayBuffer(buffer_size=100000)

update_network(blue_q_network, blue_target_q_network)
blue_target_q_network.eval()

DuelingQNetwork(
  (conv1): Conv2d(5, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=10816, out_features=128, bias=True)
  (value_stream): Linear(in_features=128, out_features=1, bias=True)
  (advantage_stream): Linear(in_features=128, out_features=21, bias=True)
)

In [None]:
pbar = tqdm(range(config["num_episode"]))
for episode in pbar:
  obs = env.reset()[0]  # This return a tuple on colab, dict on local
  red_total_reward = 0
  blue_total_reward = 0
  red_losses = []
  blue_losses = []

  done_agents = set()

  for step in range(config["num_step"]):
    actions = {}

    red_team = [agent for agent in env.agents if "red" in agent and agent not in done_agents]
    blue_team = [agent for agent in env.agents if "blue" in agent and agent not in done_agents]

    if len(red_team) > 0:
      red_team_state = torch.stack(
          [torch.tensor(obs[agent], dtype=torch.float32) for agent in red_team]
      ).to(config["device"])

      with torch.no_grad():
        red_q_values = red_q_network(red_team_state)
        network_actions = torch.argmax(red_q_values, dim=1)

      random_actions = torch.randint(0, config["action_dims"], (len(red_team),), device=config["device"])

      red_actions = torch.where(
          torch.rand(len(red_team), device=config["device"]) < config["epsilon"],
          random_actions,
          network_actions
      ).to(config["device"])

      actions.update({agent: action.item() for agent, action in zip(red_team, red_actions)})

    if len(blue_team) > 0:
      blue_team_state = torch.stack(
          [torch.tensor(obs[agent], dtype=torch.float32) for agent in blue_team]
      ).to(config["device"])

      with torch.no_grad():
        blue_q_values = blue_q_network(blue_team_state)
        network_actions = torch.argmax(blue_q_values, dim=1)

      random_actions = torch.randint(0, config["action_dims"], (len(blue_team),), device=config["device"])

      blue_actions = torch.where(
          torch.rand(len(blue_team), device=config["device"]) < config["epsilon"],
          random_actions,
          network_actions
      ).to(config["device"])

      actions.update({agent: action.item() for agent, action in zip(blue_team, blue_actions)})

    next_obs, rewards, terminations, truncations, infos = env.step(actions)
    dones = {agent: terminations.get(agent, False) or truncations.get(agent, False) for agent in env.agents}

    for agent in red_team:
      if agent in done_agents:
        continue
      state = torch.tensor(obs[agent], dtype=torch.float32).unsqueeze(0).to(config["device"])
      action = actions[agent]
      reward = rewards.get(agent, 0.0)
      next_state = torch.tensor(next_obs[agent], dtype=torch.float32).unsqueeze(0).to(config["device"])
      done = dones.get(agent, False)

      red_buffer.add(state, action, reward, next_state, done)
      red_total_reward += reward

      if done:
        done_agents.add(agent)

    for agent in blue_team:
      if agent in done_agents:
        continue
      state = torch.tensor(obs[agent], dtype=torch.float32).unsqueeze(0).to(config["device"])
      action = actions[agent]
      reward = rewards.get(agent, 0.0)
      next_state = torch.tensor(next_obs[agent], dtype=torch.float32).unsqueeze(0).to(config["device"])
      done = dones.get(agent, False)

      blue_buffer.add(state, action, reward, next_state, done)
      blue_total_reward += reward

      if done:
        done_agents.add(agent)

    obs = next_obs

    if len(red_buffer) >= config["batch_size"]:
      states, actions, rewards, next_states, dones = red_buffer.sample(config["batch_size"], config["device"])

      q_values = red_q_network(states).gather(1, actions)
      with torch.no_grad():
        next_actions = red_q_network(next_states).argmax(1, keepdim=True)
        next_q_values = red_target_q_network(next_states).gather(1, next_actions)
        target_q_values = rewards + (1 - dones) * config["gamma"] * next_q_values

      loss = F.mse_loss(q_values, target_q_values)
      red_losses.append(loss.item())
      red_optimizer.zero_grad()
      loss.backward()
      red_optimizer.step()

      for target_param, local_param in zip(red_target_q_network.parameters(), red_q_network.parameters()):
        target_param.data.copy_(config["tau"] * local_param.data + (1.0 - config["tau"]) * target_param.data)

    if len(blue_buffer) >= config["batch_size"]:
      states, actions, rewards, next_states, dones = blue_buffer.sample(config["batch_size"], config["device"])

      q_values = blue_q_network(states).gather(1, actions)
      with torch.no_grad():
        next_actions = blue_q_network(next_states).argmax(1, keepdim=True)
        next_q_values = blue_target_q_network(next_states).gather(1, next_actions)
        target_q_values = rewards + (1 - dones) * config["gamma"] * next_q_values

      loss = F.mse_loss(q_values, target_q_values)
      blue_losses.append(loss.item())
      blue_optimizer.zero_grad()
      loss.backward()
      blue_optimizer.step()

      for target_param, local_param in zip(blue_target_q_network.parameters(), blue_q_network.parameters()):
        target_param.data.copy_(config["tau"] * local_param.data + (1.0 - config["tau"]) * target_param.data)

  pbar.set_postfix({
      'Red Reward': red_total_reward,
      'Blue Reward': blue_total_reward,
      'Epsilon': config['epsilon']
  })
  if config["epsilon"] > config["epsilon_min"]:
    config["epsilon"] *= config["epsilon_decay"]

  if episode % config["red_update_interval"] == 0:
    update_network(red_q_network, red_target_q_network)
  if episode % config["blue_update_interval"] == 0:
    update_network(blue_q_network, blue_target_q_network)

 26%|██▌       | 39/150 [09:12<25:32, 13.81s/it, Red Reward=-725, Blue Reward=-738, Epsilon=0.927]

In [None]:
torch.save(blue_q_network.state_dict(), "blue_dueling_q_default.pth")

In [None]:
import numpy as np

np.save("red_losses.npy", np.array(red_losses))
np.save("blue_losses.npy", np.array(blue_losses))