In [1]:
import os, sys

ruta_raiz = os.path.abspath("..")
if ruta_raiz not in sys.path:
    sys.path.append(ruta_raiz)

In [11]:
import numpy as np
import random
import torch
import torch.nn as nn
from collections import deque
from mpe2 import simple_tag_v3
import supersuit as ss
from NGUMultiAgent.NGUMulti import NGUMultiAgent
from NGU.DQN import DQN as NGU_DQN
import pandas as pd
import matplotlib.pyplot as plt

In [3]:
class FakeEnv:
    def __init__(self, observation_space, action_space):
        self.observation_space = observation_space
        self.action_space = action_space

In [6]:
seed = 42

base_params = dict(
    learning_rate = 0.001,
    buffer_size = 1000000,
    learning_starts = 5000,
    batch_size = 128,
    tau = 1.0,
    gamma = 0.99,
    train_freq = 16,
    gradient_steps = 4,
    target_update_interval = 2000,
    exploration_fraction = 0.1,
    exploration_initial_eps = 1,
    exploration_final_eps = 0.1,
    max_grad_norm = 10,
    verbose = 0,
    beta = 0.1
)

In [None]:
def run_agent(filename):
    env = simple_tag_v3.parallel_env(render_mode='rgb_array', num_good=1, num_adversaries=3, num_obstacles=0, max_cycles=150, dynamic_rescaling = True)
    env.reset(seed=seed)
    agents = env.agents

    agent_dict = {
        agent: NGU_DQN(FakeEnv(env.observation_space(agent), env.action_space(agent)), **base_params)
        for agent in agents if agent.startswith("adversary")
    }

    trainer = NGUMultiAgent(env, agent_dict, total_timesteps=200_000, log_interval=100)
    trainer.share_replay_buffer("adversary_0")  # Compartir el buffer de replay del primer agente adversario
    trainer.learn()
    trainer.save_rewards_to_csv(filename)
    #trainer.plot_total_rewards()
    #trainer.evaluate(episodes=20)
    #trainer.render_and_save(num_tests=5, save_path="demo_shared_def.mp4", fps=5, max_steps=50)

In [29]:
for i in range(1, 3):
    run_agent(f"runs/shared_buffer_run_{i}.csv")

In [None]:
agent = "adversary_0"

for i in range(1, 3):
    df = pd.read_csv(f"runs/shared_buffer_run_{i}.csv")
    plt.plot(df["episode"], df[agent], label=f"Run {i}")

plt.xlabel("Episodio")
plt.ylabel("Reward total")
plt.title(f"Recompensas de {agent} en múltiples runs")
plt.legend()
plt.grid(True)
plt.show()