In [1]:
import os
import datetime
from typing import Optional, Tuple

import gym
import numpy as np
import torch
from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy
from tianshou.trainer import offpolicy_trainer
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import TensorboardLogger

from TaskAllocation.RL_Policies.Custom_Classes import CustomNet
from TaskAllocation.RL_Policies.Custom_Classes import CustomCollector
from TaskAllocation.RL_Policies.Custom_Classes import CustomParallelToAECWrapper

#from CustomClass_multi_head import CustomNet
from TaskAllocation.RL_Policies.Custom_Classes_simplified import CustomNetSimple
#from Custom_Classes_simplified import CustomCollectorSimple
#from Custom_Classes_simplified import CustomParallelToAECWrapperSimple

from TaskAllocation.RL_Policies.CustomClasses_Transformer_Reduced import CustomNetReduced
from TaskAllocation.RL_Policies.CustomClass_MultiHead_Transformer import CustomNetMultiHead
import importlib

from mUAV_TA.DroneEnv import MultiUAVEnv
#from tianshou_DQN import train
model = "CustomNetMultiHead" # "CustomNet" or "CustomNetSimple" or "CustomNetReduced" or "CustomNetMultiHead"
test_num = "_Eval_TBTA_Relative_Representation_01_"

train_env_num = 10
test_env_num = 10

name = model + test_num

load_policy_name = f'policy_CustomNetMultiHead_Eval_TBTA_02_simplified_UCF1_new_rew_updR.pth'
save_policy_name = f'policy_{name}.pth'
policy_path = "dqn_Custom"
load_model = False

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = name + str(now)

log_path = os.path.join('./', "Logs", "dqn", log_name)

dqn_params = {"discount_factor": 0.98, 
              "estimation_step": 50, 
              "target_update_freq": 100,
              "optminizer": "Adam",
              "lr": 5e-5  }

trainer_params = {"max_epoch": 1000,
                  "step_per_epoch": 300 * train_env_num,
                  "step_per_collect": 50,
                  "episode_per_test": 1,
                  "batch_size" : 128,
                  "update_per_step": 1/300,
                  "tn_eps_max": 0.85,
                  "ts_eps_max": 0.0,
                  }

Run_Data = f'{name}\n\
        Loaded_Model: {load_policy_name if load_model == True else "no"} \n\
        log_path: {log_path} \n\
        train/test_env_num: {train_env_num} / {test_env_num} \n\
        model: {model} \n\
        dqn_params: {dqn_params} \n\
        trainer_params: {trainer_params} \n\
        obs: Task Info -> 12 info \
            Agents_info -> no \
            Scene:  agents= F1:2, F2:2, R1:6, R2:0,     \
                 tasks= Att: 4 , Rec:20, Threats:4'

model_load_path = os.path.join(policy_path, load_policy_name)  
model_save_path = os.path.join(policy_path, save_policy_name)        
os.makedirs(os.path.join(policy_path), exist_ok=True)  
os.makedirs(os.path.join(log_path), exist_ok=True)


def _get_agents(
    agent_learn: Optional[BasePolicy] = None,
    agent_opponent: Optional[BasePolicy] = None,
    optim: Optional[torch.optim.Optimizer] = None,
    policy_load_path = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
    
    env = _get_env()
    agent_name = env.agents[0]  # Get the name of the first agent

    #print(env.observation_space )
    agent_observation_space = env.observation_space # assuming 'agent0' is a valid agent name
    state_shape_agent_position = agent_observation_space["agent_position"].shape[0]
    state_shape_agent_state = agent_observation_space["agent_state"].shape[0]
    state_shape_agent_type = agent_observation_space["agent_type"].shape[0]
    state_shape_next_free_time = agent_observation_space["next_free_time"].shape[0]
    state_shape_position_after_last_task = agent_observation_space["position_after_last_task"].shape[0]       
    #state_shape_agent_relay_area = agent_observation_space["agent_relay_area"].shape[0]
        
    state_shape_agent = (state_shape_agent_position + state_shape_agent_state +
                     state_shape_agent_type+ state_shape_next_free_time + state_shape_position_after_last_task #+                     
                     #state_shape_agent_relay_area
                     )                 
    

    state_shape_task = 31 * 11 #env.observation_space["tasks_info"].shape[0]
                  
    action_shape = env.action_space[agent_name].shape[0]
    #action_shape = env.action_space[agent_name].n
               
    if agent_learn is None:
        # model
        if model == "CustomNet":        
            net = CustomNet(
                state_shape_agent=state_shape_agent,
                state_shape_task=state_shape_task,
                action_shape=action_shape,
                hidden_sizes=[128,128],
                device="cuda" if torch.cuda.is_available() else "cpu",
            ).to("cuda" if torch.cuda.is_available() else "cpu")
        
        if model == "CustomNetSimple":
            net = CustomNetSimple(            
                state_shape_agent=state_shape_agent,
                state_shape_task=state_shape_task,
                action_shape=action_shape,
                hidden_sizes=[128,128],
                device="cuda" if torch.cuda.is_available() else "cpu",
            ).to("cuda" if torch.cuda.is_available() else "cpu")
        
        if model == "CustomNetReduced":
            net = CustomNetReduced(            
                state_shape_agent=state_shape_agent,
                state_shape_task=state_shape_task,
                action_shape=action_shape,
                hidden_sizes=[128,128],
                device="cuda" if torch.cuda.is_available() else "cpu",
            ).to("cuda" if torch.cuda.is_available() else "cpu")
        
        if model == "CustomNetMultiHead":
            net = CustomNetMultiHead(
                state_shape_agent=state_shape_agent,
                state_shape_task=state_shape_task,
                action_shape=action_shape,
                hidden_sizes=[128,128],
                device="cuda" if torch.cuda.is_available() else "cpu",
            ).to("cuda" if torch.cuda.is_available() else "cpu")

    
        if optim is None:
            optim = torch.optim.Adam(net.parameters(), lr=dqn_params["lr"])
    
        agent_learn = DQNPolicy(
            model=net,
            optim=optim,
            discount_factor= dqn_params["discount_factor"],
            estimation_step=dqn_params["estimation_step"],
            target_update_freq=dqn_params["target_update_freq"],
        )  
        
        if load_model == True:
            # Load the saved checkpoint             
            agent_learn.load_state_dict(torch.load(model_load_path))
            print(f'Loaded-> {model_load_path}')
            
        
        agents = [agent_learn for _ in range(len(env.agents))]
        
    policy = MultiAgentPolicyManager(agents, env)    
        
    return policy, optim, env.agents


def _get_env():
    """This function is needed to provide callables for DummyVectorEnv."""   
    env_paralell = MultiUAVEnv()
    #env = parallel_to_aec_wrapper(env_paralell)    
    env = CustomParallelToAECWrapper(env_paralell)
    
    return PettingZooEnv(env)

print(Run_Data)

CustomNetMultiHead_Eval_TBTA_Relative_Representation_01_
        Loaded_Model: no 
        log_path: ./Logs\dqn\CustomNetMultiHead_Eval_TBTA_Relative_Representation_01_231007-115545 
        train/test_env_num: 10 / 10 
        model: CustomNetMultiHead 
        dqn_params: {'discount_factor': 0.98, 'estimation_step': 50, 'target_update_freq': 100, 'optminizer': 'Adam', 'lr': 5e-05} 
        trainer_params: {'max_epoch': 1000, 'step_per_epoch': 3000, 'step_per_collect': 50, 'episode_per_test': 1, 'batch_size': 128, 'update_per_step': 0.0033333333333333335, 'tn_eps_max': 0.85, 'ts_eps_max': 0.0} 
        obs: Task Info -> 12 info             Agents_info -> no             Scene:  agents= F1:2, F2:2, R1:6, R2:0,                      tasks= Att: 4 , Rec:20, Threats:4


In [2]:
if __name__ == "__main__":
                        
    torch.set_grad_enabled(True) 
    # ======== Step 1: Environment setup =========
    train_envs = DummyVectorEnv([_get_env for _ in range(train_env_num)])
    test_envs = DummyVectorEnv([_get_env for _ in range(test_env_num)]) 

    # seed
    seed = 1
    np.random.seed(seed)
    
    torch.manual_seed(seed)
    train_envs.seed(seed)
    test_envs.seed(seed)

    # ======== Step 2: Agent setup =========
    policy, optim, agents = _get_agents()
    

    # ======== Step 3: Collector setup =========
    train_collector = CustomCollector(
        policy,
        train_envs,
        #VectorReplayBuffer(100_000, len(train_envs)),
        PrioritizedVectorReplayBuffer( 100_000, len(train_envs), alpha=0.6, beta=0.4) ,       
        exploration_noise=True             
    )
    test_collector = CustomCollector(policy, test_envs, exploration_noise=True)
     
    #train_collector.collect(n_step=trainer_params['batch_size'] * train_env_num)
    #train_collector.collect(n_episode=trainer_params['batch_size'])
    test_collector.collect(n_episode=2 )
    #test_collector.collect(n_step=trainer_params['batch size'] * train_env_num)
    
    # ======== tensorboard logging setup =========
    #         
    writer = SummaryWriter(log_path)
    writer.add_text(name, str(Run_Data))
    logger = TensorboardLogger(writer)

    global_step_holder = [0]  
    # ======== Step 4: Callback functions setup =========
    def save_best_fn(policy):                
        print("Best Saved")
        torch.save(policy.policies[agents[0]].state_dict(), model_save_path)

    def stop_fn(mean_rewards):
        return mean_rewards >= 9939.0

    def train_fn(epoch, env_step):
        epsilon = trainer_params['tn_eps_max'] - (trainer_params['tn_eps_max'] - trainer_params['tn_eps_max']/100)*(epoch/trainer_params['max_epoch'])  
        policy.policies[agents[0]].set_eps(epsilon)

    def test_fn(epoch, env_step):
        epsilon = trainer_params['ts_eps_max']#0.01#max(0.001, 0.1 - epoch * 0.001)
        policy.policies[agents[0]].set_eps(epsilon)
        
    def reward_metric(rews):       
        #print(rews)  
        global_step_holder[0] += 1      
        return rews.mean()#[:,0]
        
    
 
    #from functools import partial

    # Define the hook function
    # def register_activation_hook(module, input, output, layer_name, writer, global_step_holder):
    #     #print(f"Hook executed for {layer_name} at step {global_step_holder[0]}")
    #     if isinstance(output, tuple):
    #         output = output[0]  # If the output is a tuple, use the first element
    #     writer.add_histogram(f"activations/{layer_name}", output, global_step_holder[0])

    # Register the hook
    #hook_function0 = partial(register_activation_hook, layer_name="task_embeddings", writer=writer, global_step_holder=global_step_holder)    
    #hook0 = policy.policies['agent0'].model.task_encoder.register_forward_hook(hook_function0)

    # Register the hook
    #hook_function1 = partial(register_activation_hook, layer_name="attention_output1", writer=writer, global_step_holder=global_step_holder)    
    #hook1 = policy.policies['agent0'].model.task_encoder.register_forward_hook(hook_function1)

    # Register the hook
    # hook_function2 = partial(register_activation_hook, layer_name="attention_output2", writer=writer, global_step_holder=global_step_holder)    
    # hook2 = policy.policies['agent0'].model.task_encoder.register_forward_hook(hook_function2)

    # Register the hook
    # hook_function3 = partial(register_activation_hook, layer_name="tasks_info", writer=writer, global_step_holder=global_step_holder)    
    # hook3 = policy.policies['agent0'].model.task_encoder.register_forward_hook(hook_function3)

    
    #Add Logger Details
    # def log_gradients(policy, writer, global_step_holder, **kwargs):
    #     for name, param in policy.model.named_parameters():
    #         writer.add_histogram(f"{name}.grad", param.grad, global_step_holder[0])    

    # Modify the hook definition to pass the writer and global_step_holder
    # policy.policies['agent0'].post_optim_hook = partial(log_gradients, writer=writer, global_step_holder=global_step_holder)

    
              
    # ======== Step 5: Run the trainer =========
    result = offpolicy_trainer(
        policy=policy,
        train_collector=train_collector,
        test_collector=test_collector,        
        max_epoch=trainer_params['max_epoch'],
        step_per_epoch=trainer_params['step_per_epoch'],
        step_per_collect=trainer_params['step_per_collect'],
        episode_per_test= trainer_params['episode_per_test'],
        batch_size=trainer_params['batch_size'],
        train_fn=train_fn,
        test_fn=test_fn,
        stop_fn=stop_fn,
        save_best_fn=save_best_fn,
        update_per_step=trainer_params['update_per_step'],
        logger=logger,
        test_in_train=False,
        reward_metric=reward_metric,
        show_progress = True          
        )

    writer.close()
    # return result, policy.policies[agents[1]]
    print(f"\n==========Result==========\n{result}")
    print("\n(the trained policy can be accessed via policy.policies[agents[0]])")

  tasks_info = torch.tensor(task_values, dtype=torch.float32).to(self.device)  # Convert tasks_info to tensor


Best Saved


Epoch #1: 3001it [00:15, 188.15it/s, env_step=3000, len=300, n/ep=10, n/st=50, rew=266.88]                          


Epoch #1: test_reward: 196.650000 ± 0.000000, best_reward: 280.800000 ± 0.000000 in #0


Epoch #2: 3001it [00:14, 203.91it/s, env_step=6000, len=300, n/ep=10, n/st=50, rew=255.84]                          


Epoch #2: test_reward: 264.050000 ± 0.000000, best_reward: 280.800000 ± 0.000000 in #0


Epoch #3: 3001it [00:14, 202.28it/s, env_step=9000, len=300, n/ep=10, n/st=50, rew=228.44]                          


Epoch #3: test_reward: 187.950000 ± 0.000000, best_reward: 280.800000 ± 0.000000 in #0


Epoch #4: 3001it [00:14, 206.13it/s, env_step=12000, len=300, n/ep=10, n/st=50, rew=242.75]                          


Epoch #4: test_reward: 206.300000 ± 0.000000, best_reward: 280.800000 ± 0.000000 in #0


Epoch #5: 3001it [00:13, 222.14it/s, env_step=15000, len=300, n/ep=10, n/st=50, rew=227.66]                          


Best Saved
Epoch #5: test_reward: 289.100000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #6: 3001it [00:14, 213.84it/s, env_step=18000, len=300, n/ep=10, n/st=50, rew=218.98]                          


Epoch #6: test_reward: 240.750000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #7: 3001it [00:13, 219.67it/s, env_step=21000, len=300, n/ep=10, n/st=50, rew=229.80]                          


Epoch #7: test_reward: 192.350000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #8: 3001it [00:14, 211.43it/s, env_step=24000, len=300, n/ep=10, n/st=50, rew=230.83]                          


Epoch #8: test_reward: 220.650000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #9: 3001it [00:13, 224.85it/s, env_step=27000, len=300, n/ep=10, n/st=50, rew=233.51]                          


Epoch #9: test_reward: 258.700000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #10: 3001it [00:14, 210.45it/s, env_step=30000, len=300, n/ep=10, n/st=50, rew=227.38]                          


Epoch #10: test_reward: 263.800000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #11: 3001it [00:13, 217.93it/s, env_step=33000, len=300, n/ep=10, n/st=50, rew=225.02]                          


Epoch #11: test_reward: 264.300000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #12: 3001it [00:14, 205.44it/s, env_step=36000, len=300, n/ep=10, n/st=50, rew=233.53]                          


Epoch #12: test_reward: 163.500000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #13: 3001it [00:14, 213.86it/s, env_step=39000, len=300, n/ep=10, n/st=50, rew=218.05]                          


Epoch #13: test_reward: 171.300000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #14: 3001it [00:14, 203.15it/s, env_step=42000, len=300, n/ep=10, n/st=50, rew=216.34]                          


Epoch #14: test_reward: 176.950000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #15: 3001it [00:13, 219.26it/s, env_step=45000, len=300, n/ep=10, n/st=50, rew=210.64]                          


Epoch #15: test_reward: 285.000000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #16: 3001it [00:12, 231.73it/s, env_step=48000, len=300, n/ep=10, n/st=50, rew=232.84]                          


Epoch #16: test_reward: 276.400000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #17: 3001it [00:12, 236.41it/s, env_step=51000, len=300, n/ep=10, n/st=50, rew=223.85]                          


Epoch #17: test_reward: 286.700000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #18: 3001it [00:12, 233.70it/s, env_step=54000, len=300, n/ep=10, n/st=50, rew=224.74]                          


Epoch #18: test_reward: 211.350000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #19: 3001it [00:12, 236.58it/s, env_step=57000, len=300, n/ep=10, n/st=50, rew=207.27]                          


Epoch #19: test_reward: 260.400000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #20: 3001it [00:12, 235.70it/s, env_step=60000, len=300, n/ep=10, n/st=50, rew=232.19]                          


Epoch #20: test_reward: 255.450000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #21: 3001it [00:12, 237.09it/s, env_step=63000, len=300, n/ep=10, n/st=50, rew=237.53]                          


Epoch #21: test_reward: 189.950000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #22: 3001it [00:12, 234.12it/s, env_step=66000, len=300, n/ep=10, n/st=50, rew=219.05]                          


Epoch #22: test_reward: 242.650000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #23: 3001it [00:13, 216.57it/s, env_step=69000, len=300, n/ep=10, n/st=50, rew=210.03]                          


Epoch #23: test_reward: 244.350000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #24: 3001it [00:13, 225.50it/s, env_step=72000, len=300, n/ep=10, n/st=50, rew=220.18]                          


Epoch #24: test_reward: 237.850000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #25: 3001it [00:13, 219.24it/s, env_step=75000, len=300, n/ep=10, n/st=50, rew=202.28]                          


Epoch #25: test_reward: 164.900000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #26: 3001it [00:13, 227.06it/s, env_step=78000, len=300, n/ep=10, n/st=50, rew=219.67]                          


Epoch #26: test_reward: 200.050000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #27: 3001it [00:14, 214.35it/s, env_step=81000, len=300, n/ep=10, n/st=50, rew=211.48]                          


Epoch #27: test_reward: 241.700000 ± 0.000000, best_reward: 289.100000 ± 0.000000 in #5


Epoch #28: 3001it [00:13, 219.90it/s, env_step=84000, len=300, n/ep=10, n/st=50, rew=226.74]                          


Best Saved
Epoch #28: test_reward: 294.550000 ± 0.000000, best_reward: 294.550000 ± 0.000000 in #28


Epoch #29: 3001it [00:13, 222.41it/s, env_step=87000, len=300, n/ep=10, n/st=50, rew=210.50]                          


Epoch #29: test_reward: 250.250000 ± 0.000000, best_reward: 294.550000 ± 0.000000 in #28


Epoch #30:  73%|#######3  | 2200/3000 [00:09<00:03, 235.31it/s, env_step=89200, len=300, n/ep=0, n/st=50, rew=210.50]

In [None]:
from typing import Optional, Tuple
import os
import numpy as np
import torch
from tianshou.env import DummyVectorEnv
from tianshou.trainer import offpolicy_trainer
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import TensorboardLogger
import torch

import mUAV_TA.MultiDroneEnvUtils as utils
#from Custom_Classes import CustomCollector

def _get_env_eval():
    """This function is needed to provide callables for DummyVectorEnv."""
    case =  {'case' : 0, 'F1':2, 'F2': 2, "R1" : 12, 'R2' : 3, "Att" : 4, "Rec" : 22}

    config = utils.DroneEnvOptions(     
            render_mode = 'human', 
            render_speed = -1,
            simulation_frame_rate = 1 / 50,
            simulator_module = "Internal", 
            max_time_steps = 300,
            action_mode= "TaskAssign",
            agents= {"F1" : case['F1'], "F2" : case['F2'], "R1" : case['R1']},                 
            tasks= { "Att" : case['Att'], "Rec" : case['Rec']},
            random_init_pos = True,
            num_obstacles = 0,
            hidden_obstacles = False,
            fail_rate = 0,
            info = "No Info" )
   
    
    env_paralell = MultiDroneEnv()
    #env = parallel_to_aec_wrapper(env_paralell)    
    env = CustomParallelToAECWrapper(env_paralell)
    
    return PettingZooEnv(env)


# Create a new instance of the policy with the same architecture as the saved policy
name = 'policy_CustomNetMultiHead_Eval_TBTA_02_simplified_UCF_mask01_seed0All.pth' 
load_policy_name = f'policy_{name}'


log_path = os.path.join('./', "Logs", "dqn", name)

policy, optim, _ = _get_agents()
model_save_path = os.path.join("dqn_Custom", save_policy_name)        

# Load the saved checkpoint
policy_test = policy.policies['agent0']
policy_test.load_state_dict(torch.load(model_save_path ))

envs = DummyVectorEnv([_get_env_eval for _ in range(1)])
policy_test.eval()
policy_test.set_eps(0.00)

#collector = CustomCollector(policy.policies['agent0'], envs, exploration_noise=True)
#collector = CustomCollector(policy_test, envs, exploration_noise=False)
collector = CustomCollector(policy, envs, exploration_noise=True)

#results = collector.collect(n_episode=1)
results = collector.collect(n_episode=5)#, gym_reset_kwargs={'seed' :2})
results

In [None]:
results['rews']
print(np.mean(results['rews'][results['rews'] > -10]))


#create a function  to print a histogram of the results['rews']
import matplotlib.pyplot as plt
plt.hist(results['rews'][results['rews'] > -10], bins=100)
plt.show()


In [None]:
from turtle import st
import torch
from tianshou.data import Batch

# load policy as in your original code
policy, optim, _ = _get_agents()
model_save_path = os.path.join("dqn_Custom", save_policy_name)        
policy_test = policy.policies['agent0']
state_saved = torch.load(model_save_path )
#print(policy_test)
policy_test.load_state_dict(state_saved)
policy_test.eval()
policy_test.set_eps(0.00)

# initialize your environment
#env = DummyVectorEnv([_get_env for _ in range(1)])
env = MultiDroneEnv(None)
env.max_time_steps = 50

# simulate the interaction with the environment manually
for i in range(10):
    for episode in range(1):  # simulate 10 episodes
        
        #env.render_speed = 1
        obs, _  = env.reset(seed=episode)         
        info         = env.get_initial_state()
        
        drones = info["drones"]
        tasks = info["tasks"]
            
        done = {0 : False}
        truncations = {0 : False}
        
        episodo_reward = 0
        #obs, reward, done, truncations, info = env.step(action)

        while not all(done.values()) and not all(truncations.values()):
            
            agent_id = "agent" + str(env.agent_selector._current_agent)
            # Create a Batch of observations
            obs_batch = Batch(obs=[obs[agent_id]], info=[{}])  # add empty info for each observation
            
            #print(obs_batch)
            # Forward the batch of observations through the policy to get the actions
            action = policy_test(obs_batch).act
            action = {agent_id : action[0]}
        
            obs, reward, done, truncations, info = env.step(action)
            
            episodo_reward += sum(reward.values())/env.n_agents

        

    print(episodo_reward)
