In [1]:
!pip install -q git+https://github.com/Farama-Foundation/MAgent2

# Import Libraries

In [2]:
# Libraries

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast

import numpy as np

from magent2.environments import battle_v4

from dataclasses import dataclass
import collections
import random
import time

# Init

In [3]:
# Init
seed = 25
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def compute_output_dim(input_dim, kernel_size, stride, padding):
    return (input_dim - kernel_size + 2 * padding) // stride + 1

def save_model(model, name):
    torch.save(model.state_dict(), f'{name}.pth')

def save_data(data, name='data'):
    np.save(f'{name}.npy', data)

def reseed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

reseed(seed)

@dataclass
class VdnHyperparameters:
    lr: float = 0.001
    gamma: float = 0.99
    batch_size: int = 2048
    update_iter: int = 20
    buffer_limit: int = 9000
    update_target_interval: int = 20
    max_episodes: int = 500
    max_epsilon: float = 0.9
    min_epsilon: float = 0.1
    episode_min_epsilon: int = 200
    test_episodes: int = 10
    warm_up_steps: int = 3000
    chunk_size: int = 1
    recurrent: bool = False

# Replay Buffer

In [4]:
# ReplayBuffer

class ReplayBuffer:
    def __init__(self, buffer_limit):
        self.buffer = collections.deque(maxlen=buffer_limit)

    def put(self, transition):
        """Update buffer with a new transition
        :param transition: tuple of (state, action, reward, next_state, done)
        """
        self.buffer.append(transition)

    def sample_chunk(self, batch_size, chunk_size):
        """Sample a batch of chunk_size transitions from the buffer
        :param batch_size: number of transitions to sample
        :param chunk_size: length of horizon of each batch
        :return: tuple of (states, actions, rewards, next_states, dones),
        their shapes are respectively:
        [batch_size, chunk_size, n_agents, ...obs_shape],
        [batch_size, chunk_size, n_agents],
        [batch_size, chunk_size, n_agents],
        [batch_size, chunk_size, n_agents, ...obs_shape],
        [batch_size, chunk_size, n_agents]
        """
        start_idx = np.random.randint(0, len(self.buffer) - chunk_size, batch_size)
        s_lst, a_lst, r_lst, s_prime_lst, done_lst = [], [], [], [], []

        for idx in start_idx:
            for chunk_step in range(idx, idx + chunk_size):
                # (state, action, reward, next_state, done) * num_agent
                s, a, r, s_prime, done = self.buffer[chunk_step]
                s_lst.append(s)
                a_lst.append(a)
                r_lst.append(r)
                s_prime_lst.append(s_prime)
                done_lst.append(done)
        num_agents = len(s_lst[0])
        obs_shape = s_lst[0][0].shape

        s_lst = np.array(s_lst).reshape(batch_size, chunk_size, num_agents, *obs_shape)
        a_lst = np.array(a_lst).reshape(batch_size, chunk_size, num_agents)
        r_lst = np.array(r_lst).reshape(batch_size, chunk_size, num_agents)
        s_prime_lst = np.array(s_prime_lst).reshape(batch_size, chunk_size, num_agents, *obs_shape)
        done_lst = np.array(done_lst, dtype=bool).reshape(batch_size, chunk_size, num_agents)
        return (
            torch.tensor(s_lst, dtype=torch.float32).to(device),
            torch.tensor(a_lst, dtype=torch.float32).to(device),
            torch.tensor(r_lst, dtype=torch.float32).to(device),
            torch.tensor(s_prime_lst, dtype=torch.float32).to(device),
            torch.tensor(done_lst, dtype=torch.float32).to(device)
        )

    def size(self):
        return len(self.buffer)

# TeamManager

In [5]:
# TeamManager

class TeamManager:

    def __init__(self, agents, my_team = None):
        self.agents = agents
        self.teams = self.group_agents()
        self.terminated_agents = set()
        self.my_team = my_team
        self.random_agents = None
        self.get_random_agents(1)

    def get_teams(self):
        """
        Get the team names.
        :return: a list of team names
        """
        return list(self.teams.keys())

    def get_my_team(self):
        if self.my_team is not None:
            return self.my_team
        else:
            my_team = self.get_teams()[1]
        self.my_team = my_team
        return my_team

    def get_other_team(self):
        return self.get_teams()[0]

    def get_team_agents(self, team):
        """
        Get the agents in a team.
        :param team: the team name
        :return: a list of agent names in the team
        """
        assert team in self.teams, f"Team [{team}] not found."
        return self.teams[team]

    def get_my_agents(self):
        return self.get_team_agents(self.get_my_team())

    def get_other_agents(self):
        return self.get_team_agents(self.get_other_team())

    def group_agents(self):
        """
        Group agents by their team.
        :param agents: a list of agent names in the format of teamname_agentid
        :return: a dictionary with team names as keys and a list of agent names as values
        """
        teams = collections.defaultdict(list)
        for agent in self.agents:
            team, _ = agent.split('_')
            teams[team].append(agent)
        return teams

    def get_info_of_team(self, team, data, default=None):
        """
        Get the information of a team.
        :param team: the team name
        :param data: the data to get information from
        :return: a dictionary with the team name as key and the information as value
        """
        assert team in self.teams, f"Team [{team}] not found."
        result = {}
        for agent in self.get_team_agents(team):
            if agent not in data:
                result[agent] = default
            else:
                result[agent] = data[agent]
        return result
    
    def reset(self):
        self.terminated_agents = set()

    def is_team_terminated(self, team):
        """
        Check if all agents in a team are terminated.
        :param team: the team name
        :return: True if all agents in the team are terminated, False otherwise
        """
        assert team in self.teams, f"Team [{team}] not found."
        return all(agent in self.terminated_agents for agent in self.teams[team])

    def terminate_agent(self, agent):
        """
        Mark an agent as terminated.
        :param agent:
        :return:
        """
        self.terminated_agents.add(agent)

    def has_terminated_teams(self):
        """
        Check if any team is terminated.
        """
        for team in self.teams:
            if self.is_team_terminated(team):
                return True
        return False

    def get_my_terminated_agents(self):
        return list(self.terminated_agents.intersection(self.get_my_agents()))

    def get_other_team_remains(self):
        """
        Get the remaining agents in the other team.
        :return:
        """
        my_team = self.get_my_team()
        other_team = [team for team in self.teams if team != my_team][0]
        return [agent for agent in self.get_team_agents(other_team) if agent not in self.terminated_agents]


    def get_random_agents(self, rate):
        """
        Create a random agent list, and return the first n agents.
        :param rate: the rate of random agents to return
        :return: a list of random agents with the length of rate * num_agents
        """
        num_agents = len(self.get_my_agents())
        if self.random_agents is not None:
            num_random_agents = int(num_agents * rate)
            return self.random_agents[:num_random_agents]
        else:
            self.random_agents = random.sample(self.get_my_agents(), num_agents)
            return self.get_random_agents(rate)

    @staticmethod
    def merge_terminates_truncates(terminates, truncates):
        """
        Merge terminates and truncates into one dictionary.
        :param terminates: a dictionary with agent names as keys and boolean values as values
        :param truncates: a dictionary with agent names as keys and boolean values as values
        :return: a dictionary with agent names as keys and boolean values as values
        """
        result = {}
        for agent in terminates:
            result[agent] = terminates[agent] or truncates[agent]
        return result

# Code train

In [6]:
# train

def train(q, q_target, memory, optimizer, gamma, batch_size, update_iter=10, chunk_size=10, grad_clip_norm=5):
    q.train()
    q_target.eval()
    chunk_size = chunk_size if q.recurrent else 1
    losses = []

    scaler = GradScaler()
    
    for i in range(update_iter):
        # Get data from buffer
        states, actions, rewards, next_states, dones = memory.sample_chunk(batch_size, chunk_size)

        hidden = q.init_hidden(batch_size).to(device)
        target_hidden = q_target.init_hidden(batch_size).to(device)
        
        loss = 0
        for step_i in range(chunk_size):
            with autocast():
                q_out, hidden = q(states[:, step_i].to(device), hidden)  # [batch_size, num_agents, n_actions]
                q_out = q_out.to(device)
                hidden = hidden.to(device)
                q_a = q_out.gather(2, actions[:, step_i, :].unsqueeze(-1).long().to(device)).squeeze(-1)  # [batch_size, num_agents]: q values of actions taken
                sum_q = (q_a * (1 - dones[:, step_i].to(device))).sum(dim=1, keepdims=True)  # [batch_size, 1]
    
                with torch.no_grad():
                    max_q_prime, target_hidden = q_target(next_states[:, step_i].to(device), target_hidden.detach())
                    target_hidden = target_hidden.to(device)
                    max_q_prime = max_q_prime.max(dim=2)[0].squeeze(-1)  # [batch_size, num_agents]
                    target_q = rewards[:, step_i, :].to(device).sum(dim=1, keepdims=True)  # [batch_size, 1]
                    target_q += gamma * ((1 - dones[:, step_i].to(device)) * max_q_prime.to(device)).sum(dim=1, keepdims=True)
            
                loss += F.smooth_l1_loss(sum_q, target_q.detach())
            
                # Create a mask for each agent separately
                done_mask = dones[:, step_i].to(device).bool()  # Shape: (batch_size, num_agents)
                
                # Lấy chỉ số batch và agent nơi done_mask == 1
                batch_indices, agent_indices = torch.where(done_mask)
                
                # Số lượng agents đã kết thúc
                num_terminated = len(batch_indices)
                
                if num_terminated > 0:  # Chỉ xử lý nếu có agent nào bị kết thúc
                    # Khởi tạo hidden states mới cho tất cả các agents bị kết thúc
                    new_hidden = q.init_hidden(batch_size=num_terminated).to(device)  # Shape: (num_terminated, num_agents, hx_size)
                    new_target_hidden = q_target.init_hidden(batch_size=num_terminated).to(device)  # Same shape
                
                    # Lấy hidden states tương ứng với từng agent
                    new_hidden_agents = new_hidden[range(num_terminated), agent_indices, :]  # Shape: (num_terminated, hx_size)
                    new_target_hidden_agents = new_target_hidden[range(num_terminated), agent_indices, :]  # Same shape
                
                    # Gán các hidden states mới vào các vị trí tương ứng trong tensor `hidden` và `target_hidden`
                    hidden[batch_indices, agent_indices, :] = new_hidden_agents
                    target_hidden[batch_indices, agent_indices, :] = new_target_hidden_agents

        losses.append(loss.item())
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(q.parameters(), grad_clip_norm, norm_type=2)
        scaler.step(optimizer)
        scaler.update()

    print('Loss: ' + " ".join([str(round(loss, 2)) for loss in losses]))
    return losses

In [7]:
# run_episode

def run_episode(env, q, opponent_q, memory=None, random_rate=0, epsilon=0.1):
    """Run an episode in self-play mode
    :return: total score of the episode
    """
    observations, infos = env.reset()
    team_manager = TeamManager(env.agents)
    teams = team_manager.get_teams()
    my_team = team_manager.get_my_team()
    opponent_team = team_manager.get_other_team()
    
    hidden = q.init_hidden()
    opponent_hidden = q.init_hidden()
    score = 0.0

    while not team_manager.has_terminated_teams():
        # Fill rows with zeros for terminated agents
        for agent in team_manager.agents:
            if agent not in observations or observations[agent] is None:
                observations[agent] = np.zeros(q.n_obs, dtype=np.float32)
                team_manager.terminate_agent(agent)

        # Get observations for the current team and opponent
        my_team_observations = team_manager.get_info_of_team(my_team, observations)
        opponent_observations = team_manager.get_info_of_team(opponent_team, observations)

        # Get actions for my team
        obs_tensor = torch.tensor(np.array(list(my_team_observations.values()))).unsqueeze(0)
        actions, hidden = q.sample_action(obs_tensor, hidden, epsilon)
        my_team_actions = {
            agent: action
            for agent, action in zip(
                my_team_observations.keys(), actions.squeeze(0).cpu().data.numpy().tolist()
            )
        }

        # Get actions for the opponent team (self-play logic)
        opponent_obs_tensor = torch.tensor(np.array(list(opponent_observations.values()))).unsqueeze(0)
        opponent_actions, opponent_hidden = opponent_q.sample_action(opponent_obs_tensor, opponent_hidden, epsilon)
        opponent_team_actions = {
            agent: action
            for agent, action in zip(
                opponent_observations.keys(), opponent_actions.squeeze(0).cpu().data.numpy().tolist()
            )
        }

        # Combine actions
        agent_actions = {**my_team_actions, **opponent_team_actions}

        # Terminated agents use None action
        for agent in team_manager.terminated_agents:
            agent_actions[agent] = None

        # Step the environment
        observations, agent_rewards, agent_terminations, agent_truncations, agent_infos = env.step(agent_actions)
        score += sum(team_manager.get_info_of_team(my_team, agent_rewards, 0).values())

        if memory is not None:
            # Fill rows with zeros for terminated agents
            next_observations = [
                observations[agent]
                if agent in observations and observations[agent] is not None
                else np.zeros(q.n_obs, dtype=np.float32)
                for agent in team_manager.get_my_agents()
            ]
            my_team_actions = [
                agent_actions[agent]
                if agent in agent_actions and agent_actions[agent] is not None
                else 0
                for agent in team_manager.get_my_agents()
            ]

            memory.put((
                list(my_team_observations.values()),
                my_team_actions,
                list(team_manager.get_info_of_team(my_team, agent_rewards, 0).values()),
                next_observations,
                list(team_manager.get_info_of_team(
                    my_team,
                    TeamManager.merge_terminates_truncates(agent_terminations, agent_truncations)).values())
            ))

        # Check for termination
        for agent, done in agent_terminations.items():
            if done:
                team_manager.terminate_agent(agent)
        for agent, done in agent_truncations.items():
            if done:
                team_manager.terminate_agent(agent)

        # Break if the other team has less than 3 agents
        if len(team_manager.get_other_team_remains()) <= 3:
            break

    print('Score:', score)
    return score

In [8]:
# eval

def evaluate_model(env, num_episodes, model1, model2, run_episode_fn):
    """

    :param env: Environment
    :param num_episodes: How many episodes to test
    :param model: Trained model
    :param run_episode_fn: function to run an episode
    :return: average score over num_episodes
    """
    model1.eval()
    model2.eval()
    score = 0
    for episode_i in range(num_episodes):
        score += run_episode_fn(env, model1, model2, epsilon=0)
    return score / num_episodes

In [9]:
# run

def run_model_train_test(
        env,
        test_env,
        model_team1,
        model_team2,
        target_model_team1,
        target_model_team2,
        save_name_team1,
        save_name_team2,
        team_manager,
        hp,
        train_fn,
        run_episode_fn,
        num_test_runs=1,
):
    """
    Run training and testing loop of a model

    :param env: Training environment
    :param test_env: Testing environment
    :param model: training model
    :param target_model: target model
    :param save_name: name to save the model
    :param team_manager: TeamManager
    :param hp: Hyperparameters
    :param train_fn: training function
    :param run_episode_fn: function to run an episode
    :return: train_scores, test_scores
    """
    reseed(seed)
    # create env.
    memory_team1 = ReplayBuffer(hp.buffer_limit)
    memory_team2 = ReplayBuffer(hp.buffer_limit)

    test_env.reset(seed=seed)
    env.reset(seed=seed)

    target_model_team1.load_state_dict(model_team1.state_dict())
    target_model_team2.load_state_dict(model_team2.state_dict())

    # Setup env

    train_scores_team1, train_scores_team2 = [], []
    test_scores_team1, test_scores_team2 = [], []
    losses_team1, losses_team2 = [], []

    optimizer_team1 = optim.Adam(model_team1.parameters(), lr=hp.lr)
    optimizer_team2 = optim.Adam(model_team2.parameters(), lr=hp.lr)

    # Train and test
    start_train = time.time()
    for episode_i in range(hp.max_episodes):
        start = time.time()
        print(f'Episodes {episode_i + 1} / {hp.max_episodes}')
        # Collect data
        epsilon = max(hp.min_epsilon,
                      hp.max_epsilon - (hp.max_epsilon - hp.min_epsilon) * (episode_i / (hp.episode_min_epsilon)))
        
        model_team1.eval()
        model_team2.eval()
        
        train_score_team1 = run_episode_fn(env, model_team1, model_team2, memory_team1, epsilon=epsilon)
        train_score_team2 = run_episode_fn(env, model_team2, model_team1, memory_team2, epsilon=epsilon)

        train_scores_team1.append(train_score_team1)
        train_scores_team2.append(train_score_team2)

        if train_score_team1 > 200 or train_score_team2 > 200:
            hp.min_epsilon = 0.05

        # Train models
        if memory_team1.size() > hp.warm_up_steps:
            print("Training Team 1:")
            model_team1.train()
            episode_losses_team1 = train_fn(
                model_team1, target_model_team1, memory_team1, optimizer_team1,
                hp.gamma, hp.batch_size, hp.update_iter, hp.chunk_size
            )
            losses_team1.append(episode_losses_team1)

        if memory_team2.size() > hp.warm_up_steps:
            print("Training Team 2:")
            model_team2.train()
            episode_losses_team2 = train_fn(
                model_team2, target_model_team2, memory_team2, optimizer_team2,
                hp.gamma, hp.batch_size, hp.update_iter, hp.chunk_size
            )
            losses_team2.append(episode_losses_team2)

        if episode_i % hp.update_target_interval == 0 and episode_i > 0:
            target_model_team1.load_state_dict(model_team1.state_dict())
            target_model_team2.load_state_dict(model_team2.state_dict())

        # Test phase
        if episode_i >= hp.max_episodes - 20:
            print("Test phase for both teams:")
            model_team1.eval()
            model_team2.eval()

            avg_test_score_team1 = 0
            avg_test_score_team2 = 0

            for _ in range(num_test_runs):
                avg_test_score_team1 += evaluate_model(test_env, hp.test_episodes, model_team1, model_team2, run_episode_fn)
                avg_test_score_team2 += evaluate_model(test_env, hp.test_episodes, model_team2, model_team1,  run_episode_fn)

            avg_test_score_team1 /= num_test_runs
            avg_test_score_team2 /= num_test_runs

            test_scores_team1.append(avg_test_score_team1)
            test_scores_team2.append(avg_test_score_team2)

            save_model(model_team1, f'vdn-{save_name_team1}-{episode_i}')
            save_model(model_team1, f'vdn-{save_name_team2}-{episode_i}')

            print(f"Team 1 Avg Test Score: {avg_test_score_team1:.2f}")
            print(f"Team 2 Avg Test Score: {avg_test_score_team2:.2f}")
            print('#' * 90)

        print(f'Time: {time.time() - start}')
        print(f'Total Time: {time.time() - start_train}')
        print('-' * 90)

    env.close()
    test_env.close()

    return train_scores_team1, train_scores_team2, test_scores_team1, test_scores_team2, losses_team1, losses_team2

# Model

In [10]:
# VDN

class VdnQNet(nn.Module):

    def __init__(self, agents, observation_spaces, action_spaces, recurrent=False):
        super(VdnQNet, self).__init__()
        self.agents = agents
        self.num_agents = len(agents)
        self.recurrent = recurrent
        self.hx_size = 32   # latent repr size
        self.n_obs = observation_spaces[agents[0]].shape    # observation space size of agents
        self.n_act = action_spaces[agents[0]].n  # action space size of agents

        stride1, stride2 = 1, 1
        padding1, padding2 = 1, 1
        kernel_size1, kernel_size2 = 3, 3
        pool_kernel_size, pool_stride = 2, 2

        height = self.n_obs[0]  # n_obs is a tuple (height, width, channels)
        out_dim1 = compute_output_dim(height, kernel_size1, stride1, padding1) // pool_stride
        out_dim2 = compute_output_dim(out_dim1, kernel_size2, stride2, padding2) // pool_stride

        # Compute the final flattened size
        flattened_size = out_dim2 * out_dim2 * 64
        self.feature_cnn = nn.Sequential(
            nn.Conv2d(self.n_obs[2], 32, kernel_size=kernel_size1, stride=stride1, padding=padding1),
            nn.MaxPool2d(kernel_size=pool_kernel_size, stride=pool_stride),
            nn.Conv2d(32, 64, kernel_size=kernel_size2, stride=stride2, padding=padding2),
            nn.MaxPool2d(kernel_size=pool_kernel_size, stride=pool_stride),
            nn.Flatten(),
            nn.Linear(flattened_size, self.hx_size),
            nn.ReLU()
        )
        if recurrent:
            self.gru =  nn.GRUCell(self.hx_size, self.hx_size)  # shape: hx_size, hx_size
        self.q_val = nn.Linear(self.hx_size, self.n_act)    # shape: hx_size, n_actions

    def forward(self, obs, hidden):
        """Predict q values for each agent's actions in the batch
        :param obs: [batch_size, num_agents, ...n_obs]
        :param hidden: [batch_size, num_agents, hx_size]
        :return: q_values: [batch_size, num_agents, n_actions], hidden: [batch_size, num_agents, hx_size]
        """
        obs = obs.to(device)
        hidden = hidden.to(device)
        
        batch_size, num_agents, height, width, channels = obs.shape
        obs = obs.permute(0, 1, 4, 2, 3)  # (batch_size, num_agents, channels, height, width)
        obs = obs.reshape(batch_size * num_agents, channels, height, width)  # (batch_size * num_agents, channels, height, width)
        
        x = self.feature_cnn(obs)  # (batch_size * num_agents, hx_size)
        
        if self.recurrent:
            hidden = hidden.reshape(batch_size * num_agents, -1)  # (batch_size * num_agents, hx_size)
            x = self.gru(x, hidden)  # (batch_size * num_agents, hx_size)
        
        q_values = self.q_val(x)  # (batch_size * num_agents, n_actions)
        
        q_values = q_values.view(batch_size, num_agents, -1)  # (batch_size, num_agents, n_actions)
        
        if self.recurrent:
            next_hidden = x.view(batch_size, num_agents, -1)  # (batch_size, num_agents, hx_size)
        else:
            next_hidden = hidden.view(batch_size, num_agents, -1)
        
        return q_values, next_hidden

    def sample_action(self, obs, hidden, epsilon=1e3):
        """Choose action with epsilon-greedy policy, for each agent in the batch
        :param obs: a batch of observations, [batch_size, num_agents, n_obs]
        :param hidden: a batch of hidden states, [batch_size, num_agents, hx_size]
        :param epsilon: exploration rate
        :return: actions: [batch_size, num_agents], hidden: [batch_size, num_agents, hx_size]
        """
        obs = obs.to(device)
        hidden = hidden.to(device)
        
        q_values, hidden = self.forward(obs, hidden)    # [batch_size, num_agents, n_actions], [batch_size, num_agents, hx_size]
        # epsilon-greedy action selection: choose random action with epsilon probability
        mask = (torch.rand((q_values.shape[0],), device=device) <= epsilon)  # [batch_size]
        actions = torch.empty((q_values.shape[0], q_values.shape[1]), device=device)  # [batch_size, num_agents]
        actions[mask] = torch.randint(0, q_values.shape[2], actions[mask].shape, device=device).float()
        actions[~mask] = q_values[~mask].argmax(dim=2).float()  # choose action with max q value
        return actions, hidden   # [batch_size, num_agents], [batch_size, num_agents, hx_size]

    def init_hidden(self, batch_size=1):
        return torch.zeros((batch_size, self.num_agents, self.hx_size), device=device)

# Training

In [11]:
# Training

save_name_team1 = 'vdn_blue'
save_name_team2 = 'vdn_red'

# Hyperparameters
hp = VdnHyperparameters(
    lr=0.002,
    gamma=0.99,
    batch_size=512,
    buffer_limit=9000,
    max_episodes=200,
    max_epsilon=0.9,
    min_epsilon=0.1,
    episode_min_epsilon=100,
    test_episodes=1,
    warm_up_steps=3000,
    update_iter=20,
    chunk_size=1,
    update_target_interval=20,
    recurrent=True
)
print(hp)

# Create environment
env = battle_v4.parallel_env(map_size=45)
test_env = battle_v4.parallel_env(map_size=45)

env.reset(seed=seed)
test_env.reset(seed=seed)
team_manager = TeamManager(env.agents)

# Create models for two teams
q_team1 = VdnQNet(team_manager.get_my_agents(), env.observation_spaces, env.action_spaces).to(device)
q_target_team1 = VdnQNet(team_manager.get_my_agents(), env.observation_spaces, env.action_spaces).to(device)

q_team2 = VdnQNet(team_manager.get_other_agents(), env.observation_spaces, env.action_spaces).to(device)
q_target_team2 = VdnQNet(team_manager.get_other_agents(), env.observation_spaces, env.action_spaces).to(device)

# Run training for both teams
train_scores_team1, train_scores_team2, test_scores_team1, test_scores_team2, losses_team1, losses_team2 = run_model_train_test(
    env, 
    test_env, 
    q_team1, q_team2, 
    q_target_team1, q_target_team2, 
    save_name_team1, save_name_team2, 
    team_manager, 
    hp, 
    train, 
    run_episode
)

# Save data for Team 1
save_data(np.array(train_scores_team1), f'{save_name_team1}-train_scores')
save_data(np.array(test_scores_team1), f'{save_name_team1}-test_scores')
save_data(np.array(losses_team1), f'{save_name_team1}-losses')

# Save data for Team 2
save_data(np.array(train_scores_team2), f'{save_name_team2}-train_scores')
save_data(np.array(test_scores_team2), f'{save_name_team2}-test_scores')
save_data(np.array(losses_team2), f'{save_name_team2}-losses')

VdnHyperparameters(lr=0.002, gamma=0.99, batch_size=512, update_iter=20, buffer_limit=9000, update_target_interval=20, max_episodes=200, max_epsilon=0.9, min_epsilon=0.1, episode_min_epsilon=100, test_episodes=1, warm_up_steps=3000, chunk_size=1, recurrent=True)
Episodes 1 / 200
Score: -3624.1451344182715
Score: -3543.2751333350316
Time: 10.778191804885864
Total Time: 10.778225660324097
------------------------------------------------------------------------------------------
Episodes 2 / 200
Score: -3597.900132276118
Score: -3566.365135463886
Time: 9.90833830833435
Total Time: 20.68661117553711
------------------------------------------------------------------------------------------
Episodes 3 / 200
Score: -3668.98513507843
Score: -3455.720130195841
Time: 9.734773874282837
Total Time: 30.42143416404724
------------------------------------------------------------------------------------------
Episodes 4 / 200
Score: -3627.135135267861
Score: -3421.2801308939233
Training Team 1:


  scaler = GradScaler()
  with autocast():


Loss: 9.86 9.91 9.99 9.79 7.1 4.23 4.06 2.02 1.32 1.7 1.57 1.22 1.14 0.82 0.89 0.83 0.62 0.66 0.6 0.52
Training Team 2:
Loss: 12.31 12.38 12.04 11.99 10.44 10.65 8.4 5.96 2.46 3.33 5.31 5.6 4.04 2.55 2.18 2.69 2.48 2.14 1.6 1.96
Time: 32.1305775642395
Total Time: 62.552058935165405
------------------------------------------------------------------------------------------
Episodes 5 / 200
Score: -3838.765141826123
Score: -2874.3001016406342
Training Team 1:
Loss: 1.0 0.79 1.03 1.15 0.94 0.63 0.9 1.03 0.72 0.68 0.82 0.55 0.68 0.83 0.6 0.45 0.52 0.47 0.47 0.48
Training Team 2:
Loss: 4.15 3.06 3.76 3.22 3.1 3.62 3.01 1.94 2.19 2.76 2.27 1.58 1.08 1.48 1.67 1.25 0.97 1.15 1.43 1.24
Time: 31.476871967315674
Total Time: 94.02897667884827
------------------------------------------------------------------------------------------
Episodes 6 / 200
Score: -2322.3150825891644
Score: -3348.7601206768304
Training Team 1:
Loss: 0.77 0.71 0.72 0.56 0.58 0.65 0.61 0.43 0.71 0.63 0.34 0.55 0.68 0.47 0.26