In [1]:
import os
import io
import datetime
from typing import Optional, Tuple
from functools import partial

import gym
import numpy as np
import torch
from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy, RainbowPolicy
from tianshou.trainer import offpolicy_trainer
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import TensorboardLogger

#from torchviz import make_dot

from TaskAllocation.RL_Policies.Custom_Classes import CustomNet
from TaskAllocation.RL_Policies.Custom_Classes import CustomCollector
from TaskAllocation.RL_Policies.Custom_Classes import CustomParallelToAECWrapper

#from CustomClass_multi_head import CustomNet
from TaskAllocation.RL_Policies.Custom_Classes_simplified import CustomNetSimple
#from Custom_Classes_simplified import CustomCollectorSimple
#from Custom_Classes_simplified import CustomParallelToAECWrapperSimple

from TaskAllocation.RL_Policies.CustomClasses_Transformer_Reduced import CustomNetReduced
from TaskAllocation.RL_Policies.CustomClass_MultiHead_Transformer import CustomNetMultiHead

from mUAV_TA.MultiDroneEnvUtils import agentEnvOptions

from mUAV_TA.DroneEnv import MultiUAVEnv
#from tianshou_DQN import train
model = "CustomNetMultiHead" # "CustomNet" or "CustomNetSimple" or "CustomNetReduced" or "CustomNetMultiHead"
test_num = "_TBTA_NOV06_Emb128"
policyModel = "DQN"

train_env_num = 10
test_env_num = 20

name = model + test_num

load_policy_name = f'policy_CustomNetMultiHead_TBTA_NOV06_Emb128_rew2k.pth'
save_policy_name = f'policy_{name}'
policy_path = "dqn_Custom"

same_policy = True

load_model = False

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = name + str(now)

log_path = os.path.join('./', "Logs", "dqn", log_name)

dqn_params = {"discount_factor": 0.98, 
              "estimation_step": 150, 
              "target_update_freq": 150 * 30,
              "optminizer": "Adam",
              "lr": 5e-5 }

trainer_params = {"max_epoch": 500,
                  "step_per_epoch": 150 * 50,
                  "step_per_collect": 150 * 10,
                  "episode_per_test": 20,
                  "batch_size" : 1500,
                  "update_per_step": 1 / 300, #Only run after close a Collect (run many times as necessary to meet the value)
                  "tn_eps_max": 0.80,
                  "ts_eps_max": 0.0,
                  }

config_default = agentEnvOptions(                                        
                 render_speed=-1,
                 simulation_frame_rate = 0.01, 
                 action_mode="TaskAssign",
                 simulator_module = "Internal", 
                 max_time_steps=150, 
                 agents= {"F1" : 0, "F2" : 1, "R1" : 0, "R2" : 0},                 
                 tasks= { "Att" : 10 , "Rec" : 0, "Hold" : 0},
                 multiple_tasks_per_agent = False,
                 multiple_agents_per_task = True,
                 random_init_pos=False,
                 num_obstacles=0,
                 hidden_obstacles = False,
                 fail_rate = 0.0,
                 threats_list = [],#[("T1", 4), ("T2" , 2)],
                 fixed_seed = -1,
                 info = "No Info")    

Run_Data = f'''{name}  
Loaded_Model: {load_policy_name if load_model else "no"}  
log_path: {log_path}  
train/test_env_num: {train_env_num} / {test_env_num}  
model: {model}  
dqn_params: {dqn_params}  
trainer_params: {trainer_params} 
single_policy: {same_policy}

--------- Env ------------  

Rewards Only Final Quality and SQuality
F_Rew / 20 > lre
random_init_pos      : {config_default.random_init_pos}
max_time_steps       : {config_default.max_time_steps}
simulation_frame_rate: {config_default.simulation_frame_rate}
Agents               : {config_default.agents}
tasks                : {config_default.tasks}
random_init_pos      : {config_default.random_init_pos} 
threats              : {config_default.threats_list}
seed                 : {config_default.fixed_seed}
'''

model_load_path = os.path.join(policy_path, load_policy_name)  
model_save_path = os.path.join(policy_path, save_policy_name)        
os.makedirs(os.path.join(policy_path), exist_ok=True)  
os.makedirs(os.path.join(log_path), exist_ok=True)

def generate_dummy_observation(batch_size=1, sequence_length=31, feature_dim=12):
    # Generate a random tensor with the given shape
    dummy_obs = torch.randn(batch_size, sequence_length, feature_dim)

    return dummy_obs

def _get_agents(
    agent_learn: Optional[BasePolicy] = None,
    agent_opponent: Optional[BasePolicy] = None,
    optim: Optional[torch.optim.Optimizer] = None,
    policy_load_path = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
    
    env = _get_env()
    agent_name = env.agents[0]  # Get the name of the first agent
    
    #print(env.observation_space )
    agent_observation_space = env.observation_space # assuming 'agent0' is a valid agent name
    state_shape_agent_position = agent_observation_space["agent_position"].shape[0]
    state_shape_agent_state = agent_observation_space["agent_state"].shape[0]
    state_shape_agent_type = agent_observation_space["agent_type"].shape[0]
    state_shape_next_free_time = agent_observation_space["next_free_time"].shape[0]
    state_shape_position_after_last_task = agent_observation_space["position_after_last_task"].shape[0]       
    #state_shape_agent_relay_area = agent_observation_space["agent_relay_area"].shape[0]
        
    state_shape_agent = (state_shape_agent_position + state_shape_agent_state +
                     state_shape_agent_type+ state_shape_next_free_time + state_shape_position_after_last_task #+                     
                     #state_shape_agent_relay_area
                     )                 

    state_shape_task = 31 * 13 #env.observation_space["tasks_info"].shape[0]
                  
    action_shape = env.action_space[agent_name].shape[0]
    #action_shape = env.action_space[agent_name].n
    device="cuda" if torch.cuda.is_available() else "cpu"          
    
    if agent_learn is None:
        # model
        if model == "CustomNet":        
            net = CustomNet(
                state_shape_agent=state_shape_agent,
                state_shape_task=state_shape_task,
                action_shape=action_shape,
                hidden_sizes=[128,128],
                device=device,
            ).to(device)
        
        if model == "CustomNetSimple":
            net = CustomNetSimple(            
                state_shape_agent=state_shape_agent,
                state_shape_task=state_shape_task,
                action_shape=action_shape,
                hidden_sizes=[128,128],
                device="cuda" if torch.cuda.is_available() else "cpu",
            ).to(device)
        
        if model == "CustomNetReduced":
            net = CustomNetReduced(            
                state_shape_agent=state_shape_agent,
                state_shape_task=state_shape_task,
                action_shape=action_shape,
                hidden_sizes=[128,128],
                device="cuda" if torch.cuda.is_available() else "cpu",
            ).to(device)
        
        if model == "CustomNetMultiHead":
            net = CustomNetMultiHead(
                state_shape_agent=state_shape_agent,
                state_shape_task=state_shape_task,
                action_shape=action_shape,
                hidden_sizes=[128,128],
                device="cuda" if torch.cuda.is_available() else "cpu",
            ).to(device)

            net2 = CustomNetMultiHead(
                state_shape_agent=state_shape_agent,
                state_shape_task=state_shape_task,
                action_shape=action_shape,
                hidden_sizes=[128,128],
                device="cuda" if torch.cuda.is_available() else "cpu",
            ).to(device)

    
        if optim is None:
            optim = torch.optim.Adam(net.parameters(), lr=dqn_params["lr"], weight_decay=0.0, amsgrad= True )
        
        optim2 = torch.optim.Adam(net.parameters(), lr=dqn_params["lr"], weight_decay=0.0, amsgrad= True )
    
        if policyModel == "DQN":
            agent_learn = DQNPolicy(
                model=net,
                optim=optim,
                discount_factor= dqn_params["discount_factor"],
                estimation_step=dqn_params["estimation_step"],
                target_update_freq=dqn_params["target_update_freq"],
                reward_normalization = False 
            ) 
            
            agent_learn2 = DQNPolicy(
                model=net2,
                optim=optim2,
                discount_factor= dqn_params["discount_factor"],
                estimation_step=dqn_params["estimation_step"],
                target_update_freq=dqn_params["target_update_freq"],
            ) 

        if policyModel == "Rainbow":
            agent_learn = RainbowPolicy(
                model=net.to(device),
                optim=optim,
                num_atoms= 31,
                discount_factor= dqn_params["discount_factor"],
                estimation_step=dqn_params["estimation_step"],
                target_update_freq=dqn_params["target_update_freq"],
            ) 
         
 
        if load_model == True:
            # Load the saved checkpoint             
            agent_learn.load_state_dict(torch.load(model_load_path))
            print(f'Loaded-> {model_load_path}')
                   
        #print(env.agents)
        #agents = [agent_learn for _ in range(len(env.agents))]
        
        agents = [None for _ in range(len(env.agents))]        
        
        if not same_policy:

            for i,agent in enumerate(env.agents):             
                if agent[0] == "F":                
                    agents[i] = agent_learn2
                    #print("F")
                else:
                    agents[i] = agent_learn
                    #print("R")
        else:
            agents = [agent_learn for _ in range(len(env.agents))]

        # print(agents)
        # print([o.type for o in agents_obj])


        # agent_learn2
        
    policy = MultiAgentPolicyManager(agents, env)  
        
    return policy, optim, env.agents


def _get_env():
    """This function is needed to provide callables for DummyVectorEnv."""   
    env_paralell = MultiUAVEnv(config=config_default)    
    #env = parallel_to_aec_wrapper(env_paralell)    
    env = CustomParallelToAECWrapper(env_paralell)
    
    return PettingZooEnv(env)

print(Run_Data)


CustomNetMultiHead_TBTA_NOV06_Emb128  
Loaded_Model: no  
log_path: ./Logs\dqn\CustomNetMultiHead_TBTA_NOV06_Emb128231108-171143  
train/test_env_num: 10 / 20  
model: CustomNetMultiHead  
dqn_params: {'discount_factor': 0.98, 'estimation_step': 150, 'target_update_freq': 4500, 'optminizer': 'Adam', 'lr': 5e-05}  
trainer_params: {'max_epoch': 500, 'step_per_epoch': 7500, 'step_per_collect': 1500, 'episode_per_test': 20, 'batch_size': 1500, 'update_per_step': 0.0033333333333333335, 'tn_eps_max': 0.8, 'ts_eps_max': 0.0} 
single_policy: True

--------- Env ------------  

Rewards Only Final Quality and SQuality
F_Rew / 20 > lre
random_init_pos      : False
max_time_steps       : 150
simulation_frame_rate: 0.01
Agents               : {'F1': 0, 'F2': 1, 'R1': 0, 'R2': 0}
tasks                : {'Att': 10, 'Rec': 0, 'Hold': 0}
random_init_pos      : False 
threats              : []
seed                 : -1



In [2]:
if __name__ == "__main__":
                        
    torch.set_grad_enabled(True) 
    # ======== Step 1: Environment setup =========
    train_envs = DummyVectorEnv([_get_env for _ in range(train_env_num)])
    test_envs = DummyVectorEnv([_get_env for _ in range(test_env_num)]) 

    # seed
    seed = 0
    np.random.seed(seed)
    
    torch.manual_seed(seed)

    train_envs.seed(seed)
    test_envs.seed(seed)

    # ======== Step 2: Agent setup =========
    policy, optim, agents = _get_agents()    

    # ======== Step 3: Collector setup =========
    train_collector = CustomCollector(
        policy,
        train_envs,
        #VectorReplayBuffer(100_000, len(train_envs)),
        PrioritizedVectorReplayBuffer( 300_000, len(train_envs), alpha=0.6, beta=0.4) , 
        #ListReplayBuffer(100000)       
        exploration_noise=True             
    )
    test_collector = CustomCollector(policy, test_envs, exploration_noise=False)
     
    print("Buffer Warming Up ")
    
    for i in range(10):#int(trainer_params['batch_size'] / (300 * 10 ) )):
        
        train_collector.collect(n_episode=train_env_num)#,random=True) #trainer_params['batch_size'] * train_env_num))
        #train_collector.collect(n_step=300 * 10)
        print(".", end="") 
    
    print("\nBuffer Lenght: ", len(train_collector.buffer)/ 150 ) 
    #train_collector.collect(n_episode=trainer_params['batch_size'])
    #test_collector.collect(n_episode=2 )
    #test_collector.collect(n_step=trainer_params['batch size'] * train_env_num)
    
    # ======== tensorboard logging setup =========
    #         
    writer = SummaryWriter(log_path)
    writer.add_text("Config", str(Run_Data))
    if same_policy:
        writer.add_text("Model", str(policy.policies[agents[0]].model).replace('\n', '  \n'))    
    else:
         writer.add_text("ModelR", str(policy.policies['R_agent0'].model).replace('\n', '  \n'))
         writer.add_text("ModelF", str(policy.policies['F_agent0'].model).replace('\n', '  \n'))
    
    logger = TensorboardLogger(writer)

    global_step_holder = [0]  
    # ======== Step 4: Callback functions setup =========
    def save_best_fn(policy):                
        if same_policy:             
            torch.save(policy.policies[agents[0]].state_dict(), model_save_path + ".pth")
            print("Best Saved")
        else:
            torch.save(policy.policies['R_agent0'].state_dict(), model_save_path + "R.pth")
            torch.save(policy.policies['F_agent0'].state_dict(), model_save_path + "F.pth")
            print("Bests Saved")

    def stop_fn(mean_rewards):
        return mean_rewards >= 9939.0

    def train_fn(epoch, env_step):
        epsilon = trainer_params['tn_eps_max'] - (trainer_params['tn_eps_max'] - trainer_params['tn_eps_max']/100)*(epoch/trainer_params['max_epoch'])  
        if same_policy:
            policy.policies[agents[0]].set_eps(epsilon)
        else:
            policy.policies['R_agent0'].set_eps(epsilon)
            policy.policies['F_agent0'].set_eps(epsilon)


    def test_fn(epoch, env_step):
        epsilon = trainer_params['ts_eps_max']#0.01#max(0.001, 0.1 - epoch * 0.001)
        if same_policy:
            policy.policies[agents[0]].set_eps(epsilon)
        else:
            policy.policies['R_agent0'].set_eps(epsilon)
            policy.policies['F_agent0'].set_eps(epsilon)

        
    def reward_metric(rews):       
        #print(rews)  
        global_step_holder[0] += 1    

        #if rews[:,0].mean() != 0:
        #    print( rews)
        return rews[:,0]


    #Define the hook function
    def register_activation_hook(module, input, output, layer_name, writer, global_step_holder):
        #print(f"Hook executed for {layer_name} at step {global_step_holder[0]}")
        if isinstance(output, tuple):
            output = output[0]  # If the output is a tuple, use the first element
        writer.add_histogram(f"activations/{layer_name}", output, global_step_holder[0])

    #Register the hook
    #hook_function0 = partial(register_activation_hook, layer_name="task_embeddings", writer=writer, global_step_holder=global_step_holder)    
    #hook0 = policy.policies['agent0'].model.task_encoder.register_forward_hook(hook_function0)

    #Register the hook
    # hook_function1 = partial(register_activation_hook, layer_name="attention_output2", writer=writer, global_step_holder=global_step_holder)    
    # hook1 = policy.policies['agent0'].model.task_encoder.register_forward_hook(hook_function1)

    #Register the hook
    #hook_function2 = partial(register_activation_hook, layer_name="attention_output2", writer=writer, global_step_holder=global_step_holder)    
    #hook2 = policy.policies['agent0'].model.task_encoder.register_forward_hook(hook_function2)

    #Register the hook
    #hook_function3 = partial(register_activation_hook, layer_name="tasks_info", writer=writer, global_step_holder=global_step_holder)    
    #hook3 = policy.policies['agent0'].model.task_encoder.register_forward_hook(hook_function3)
        
    #Add Logger Details
    def log_gradients(policy, writer, global_step_holder, **kwargs):
        for name, param in policy.model.named_parameters():
            writer.add_histogram(f"{name}.grad", param.grad, global_step_holder[0])    

    #Modify the hook definition to pass the writer and global_step_holder
    # policy.policies['agent0'].post_optim_hook = partial(log_gradients, writer=writer, global_step_holder=global_step_holder)
        
   
    def condensed_make_dot(var, params=None):
        dot = make_dot(var, params)
        
        # Here's where you'd condense or modify the graph.
        # For example, to remove all nodes related to ReLU operations:
        # (This is just a conceptual example. You'd modify this to fit your needs.)
        nodes_to_remove = [n for n in dot.body if 'Relu' in n]
        nodes_to_remove += [n for n in dot.body if 'Accumu' in n]       
        for n in nodes_to_remove:
            dot.body.remove(n)

        return dot

    if False:

        dummy_input = generate_dummy_observation()    
        output = policy.policies['agent0'].model(dummy_input)     
        
        dot = make_dot(output[0], params=dict(policy.policies['agent0'].model.named_parameters()))
        # Save as .png
        dot.format = 'svg'
        dot.render(filename='model_architecture', directory=log_path, cleanup=True)

    # dot = condensed_make_dot(output[0], params=dict(policy.policies['agent0'].model.named_parameters()))
    # # Save as .png
    # dot.format = 'svg'
    # dot.render(filename='model_architecture_condensed', directory=log_path, cleanup=True)

    
    # policy.policies[agents[0]].set_eps(0.8)
    
    # for i in range(int(15000)):  # total step
        
    #     collect_result = train_collector.collect(n_step=450)

    #     # or every 1000 steps, we test it on test_collector
    #     if collect_result['rews'].mean() >= 10 or i % 1500 == 0:
    #         policy.policies[agents[0]].set_eps(0.0)
            
    #         result = test_collector.collect(n_episode=1)
            
    #         if result['rews'].mean() >= 10:
    #             print(f'Finished training! Test mean returns: {result["rews"].mean()}')
    #             break
    #         else:
    #             # back to training eps
    #             policy.policies[agents[0]].set_eps(0.8)

    #     # train policy with a sampled batch data from buffer
    #     losses = policy.policies[agents[0]].update(64, train_collector.buffer)
    #     print(losses)


    # # ======== Step 5: Run the trainer =========
    result = offpolicy_trainer(
        policy=policy,
        train_collector=train_collector,
        test_collector=test_collector,        
        max_epoch=trainer_params['max_epoch'],
        step_per_epoch=trainer_params['step_per_epoch'],
        step_per_collect=trainer_params['step_per_collect'],
        episode_per_test= trainer_params['episode_per_test'],
        batch_size=trainer_params['batch_size'],
        train_fn=train_fn,
        test_fn=test_fn,
        stop_fn=stop_fn,
        save_best_fn=save_best_fn,
        update_per_step=trainer_params['update_per_step'],
        logger=logger,
        test_in_train=False,
        reward_metric=reward_metric,
        show_progress = True 
               
        )

    writer.close()
    # return result, policy.policies[agents[1]]
    print(f"\n==========Result==========\n{result}")
    print("\n(the trained policy can be accessed via policy.policies[agents[0]])")



Buffer Warming Up 
..........
Buffer Lenght:  100.0
Best Saved


Epoch #1: 7501it [01:38, 75.96it/s, F2_agent0/loss=37317.528, env_step=7500, len=150, n/ep=10, n/st=1500, rew=475.10]                           


Best Saved
Epoch #1: test_reward: 236.555568 ± 358.367328, best_reward: 236.555568 ± 358.367328 in #1


Epoch #2: 7501it [01:42, 73.19it/s, F2_agent0/loss=43241.855, env_step=15000, len=150, n/ep=10, n/st=1500, rew=721.97]                           


Best Saved
Epoch #2: test_reward: 306.295364 ± 324.438667, best_reward: 306.295364 ± 324.438667 in #2


Epoch #3: 7501it [01:49, 68.38it/s, F2_agent0/loss=49275.385, env_step=22500, len=150, n/ep=10, n/st=1500, rew=629.67]                           


Best Saved
Epoch #3: test_reward: 814.425608 ± 551.271626, best_reward: 814.425608 ± 551.271626 in #3


Epoch #4: 7501it [01:55, 65.19it/s, F2_agent0/loss=54264.221, env_step=30000, len=150, n/ep=10, n/st=1500, rew=722.94]                           


Best Saved
Epoch #4: test_reward: 1197.753282 ± 262.788948, best_reward: 1197.753282 ± 262.788948 in #4


Epoch #5: 7501it [01:50, 68.17it/s, F2_agent0/loss=63276.336, env_step=37500, len=150, n/ep=10, n/st=1500, rew=770.79]                           


Epoch #5: test_reward: 1121.004961 ± 422.622615, best_reward: 1197.753282 ± 262.788948 in #4


Epoch #6:  40%|####      | 3000/7500 [00:24<00:41, 108.10it/s, F2_agent0/loss=64745.086, env_step=39000, len=150, n/ep=10, n/st=1500, rew=908.97]

In [None]:
from typing import Optional, Tuple
import os
import numpy as np
import torch
from tianshou.env import DummyVectorEnv
from tianshou.trainer import offpolicy_trainer
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import TensorboardLogger
import torch

import mUAV_TA.MultiDroneEnvUtils as utils
#from Custom_Classes import CustomCollector

def _get_env_eval():
    """This function is needed to provide callables for DummyVectorEnv."""
    case =  {'case' : 0, 'F1':0, 'F2': 0, "R1" : 6, 'R2' : 0, "Att" : 0, "Rec" : 16, "Hold" : 4}

    config_default = agentEnvOptions(                                        
                 render_speed=-1,
                 simulation_frame_rate = 0.01, 
                 action_mode="TaskAssign",
                 simulator_module = "Internal", 
                 max_time_steps=150, 
                 agents= {"F1" : 4, "F2" : 4, "R1" : 4},                 
                 tasks= { "Att" : 0 , "Rec" : 16, "Hold" : 4},
                 multiple_tasks_per_agent = False,
                 multiple_agents_per_task = True,
                 random_init_pos=False,
                 num_obstacles=0,
                 hidden_obstacles = False,
                 fail_rate = 0.0,
                 threats_list = [("T1", 4), ("T2" , 2)],
                 fixed_seed = -1,
                 info = "No Info")  
   
    
    env_paralell = MultiUAVEnv(config = config_default)
    #env = parallel_to_aec_wrapper(env_paralell)    
    env = CustomParallelToAECWrapper(env_paralell)
    
    return PettingZooEnv(env)


# Create a new instance of the policy with the same architecture as the saved policy
name = 'CustomNetMultiHead_Eval_TBTA_OCT01.pth' 
load_policy_name = f'policy_{name}'

log_path = os.path.join('./', "Logs", "dqn", name)

policy, optim, _ = _get_agents()
model_save_path = os.path.join("dqn_Custom", save_policy_name)        

# Load the saved checkpoint
policy_test = policy.policies["R1_agent0"]
policy_test.load_state_dict(torch.load(model_save_path + ".pth" ))

envs = DummyVectorEnv([_get_env_eval for _ in range(1)])
#policy_test.eval()
policy_test.set_eps(0.00)

#collector = CustomCollector(policy.policies['agent0'], envs, exploration_noise=True)
#collector = CustomCollector(policy_test, envs, exploration_noise=False)
collector = CustomCollector(policy, envs, exploration_noise=True)

#results = collector.collect(n_episode=1)
results = collector.collect(n_episode=10)#, gym_reset_kwargs={'seed' :2})
results

In [None]:
results['rews']
print(np.mean(results['rews'][results['rews'] > -10]))


#create a function  to print a histogram of the results['rews']
import matplotlib.pyplot as plt
plt.hist(results['rews'][results['rews'] > -10], bins=100)
plt.show()


In [None]:
from turtle import st
import torch
from tianshou.data import Batch

# load policy as in your original code
policy, optim, _ = _get_agents()
model_save_path = os.path.join("dqn_Custom", save_policy_name)        
policy_test = policy.policies['agent0']
state_saved = torch.load(model_save_path )
#print(policy_test)
policy_test.load_state_dict(state_saved)
policy_test.eval()
policy_test.set_eps(0.00)

# initialize your environment
#env = DummyVectorEnv([_get_env for _ in range(1)])
env = MultiDroneEnv(None)
env.max_time_steps = 50

# simulate the interaction with the environment manually
for i in range(10):
    for episode in range(1):  # simulate 10 episodes
        
        #env.render_speed = 1
        obs, _  = env.reset(seed=episode)         
        info         = env.get_initial_state()
        
        drones = info["drones"]
        tasks = info["tasks"]
            
        done = {0 : False}
        truncations = {0 : False}
        
        episodo_reward = 0
        #obs, reward, done, truncations, info = env.step(action)

        while not all(done.values()) and not all(truncations.values()):
            
            agent_id = "agent" + str(env.agent_selector._current_agent)
            # Create a Batch of observations
            obs_batch = Batch(obs=[obs[agent_id]], info=[{}])  # add empty info for each observation
            
            #print(obs_batch)
            # Forward the batch of observations through the policy to get the actions
            action = policy_test(obs_batch).act
            action = {agent_id : action[0]}
        
            obs, reward, done, truncations, info = env.step(action)
            
            episodo_reward += sum(reward.values())/env.n_agents

        

    print(episodo_reward)
