In [1]:
import os
from typing import Optional, Tuple

import gym
import numpy as np
import torch
from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy
from tianshou.trainer import offpolicy_trainer
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import TensorboardLogger


from Custom_Classes import CustomNet
from Custom_Classes import CustomCollector
from Custom_Classes import CustomParallelToAECWrapper

#from CustomClass_multi_head import CustomNet
from Custom_Classes_simplified import CustomNetSimple
#from Custom_Classes_simplified import CustomCollectorSimple
#from Custom_Classes_simplified import CustomParallelToAECWrapperSimple

from CustomClasses_Transformer_Reduced import CustomNetReduced
import importlib

from DroneEnv import MultiDroneEnv
from tianshou_DQN import train


model = "CustomNetSimple" # "CustomNet" or "CustomNetSimple" or "CustomNetReduced"
test_num = "Eval_TBTA_01_max30agents_timeRew"

train_env_num = 10
test_env_num = 10

name = model + test_num

load_policy_name = f'policy_CustomNetSimple1605_01_1_Priorized_1605_01_1_Priorized.pth'
save_policy_name = f'policy_{name}.pth'
policy_path = "dqn_Custom"
load_model = False

log_path = os.path.join('./', "Logs", "dqn", name)

dqn_params = {"discount_factor": 0.99, 
              "estimation_step": 1, 
              "target_update_freq": 100,
              "optminizer": "Adam",
              "lr": 1e-4  }

trainer_params = {"max_epoch": 200,
                  "step_per_epoch": 600 * train_env_num,
                  "step_per_collect": 50 * train_env_num,
                  "episode_per_test": 10 * test_env_num,
                  "batch_size" : 32,
                  "update_per_step": 0.1,
                  "tn_eps_max": 0.8,
                  "ts_eps_max": 0.001,
                  }

Run_Data = f'{name}\n\
        Loaded_Model: {load_policy_name if load_model == True else "no"} \n\
        log_path: {log_path} \n\
        train/test_env_num: {train_env_num} / {test_env_num} \n\
        model: {model} \n\
        dqn_params: {dqn_params} \n\
        trainer_params: {trainer_params} \n\
        obs: Task Info -> Dist | Quality for each drone \
            agents= F1:2, R1:4 | tasks= Rec:15, Att:5'

model_load_path = os.path.join(policy_path, load_policy_name)  
model_save_path = os.path.join(policy_path, save_policy_name)        
os.makedirs(os.path.join(policy_path), exist_ok=True)  
os.makedirs(os.path.join(log_path), exist_ok=True)


def _get_agents(
    agent_learn: Optional[BasePolicy] = None,
    agent_opponent: Optional[BasePolicy] = None,
    optim: Optional[torch.optim.Optimizer] = None,
    policy_load_path = None
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
    
    env = _get_env()
    agent_name = env.agents[0]  # Get the name of the first agent
           
    agent_observation_space = env.observation_space('agent0') # assuming 'agent0' is a valid agent name
    state_shape_agent_position = agent_observation_space["agent_position"].shape[0]
    state_shape_agent_state = agent_observation_space["agent_state"].shape[0]
    state_shape_agent_type = agent_observation_space["agent_type"].shape[0]
    state_shape_next_free_time = agent_observation_space["next_free_time"].shape[0]
    state_shape_position_after_last_task = agent_observation_space["position_after_last_task"].shape[0]       
    #state_shape_agent_relay_area = agent_observation_space["agent_relay_area"].shape[0]
        
    state_shape_agent = (state_shape_agent_position + state_shape_agent_state +
                     state_shape_agent_type+ state_shape_next_free_time + state_shape_position_after_last_task #+                     
                     #state_shape_agent_relay_area
                     )                 
    

    state_shape_task = env.observation_space('agent0')["tasks_info"].shape[0]
                  
    action_shape = env.action_space[agent_name].shape[0]
    #action_shape = env.action_space[agent_name].n
               
    if agent_learn is None:
        # model
        if model == "CustomNet":        
            net = CustomNet(
                state_shape_agent=state_shape_agent,
                state_shape_task=state_shape_task,
                action_shape=action_shape,
                hidden_sizes=[128,128],
                device="cuda" if torch.cuda.is_available() else "cpu",
            ).to("cuda" if torch.cuda.is_available() else "cpu")
        
        if model == "CustomNetSimple":
            net = CustomNetSimple(            
                state_shape_agent=state_shape_agent,
                state_shape_task=state_shape_task,
                action_shape=action_shape,
                hidden_sizes=[128,128],
                device="cuda" if torch.cuda.is_available() else "cpu",
            ).to("cuda" if torch.cuda.is_available() else "cpu")
        
        if model == "CustomNetReduced":
            net = CustomNetReduced(            
                state_shape_agent=state_shape_agent,
                state_shape_task=state_shape_task,
                action_shape=action_shape,
                hidden_sizes=[128,128],
                device="cuda" if torch.cuda.is_available() else "cpu",
            ).to("cuda" if torch.cuda.is_available() else "cpu")

    
        if optim is None:
            optim = torch.optim.Adam(net.parameters(), lr=dqn_params["lr"])
    
        agent_learn = DQNPolicy(
            model=net,
            optim=optim,
            discount_factor= dqn_params["discount_factor"],
            estimation_step=dqn_params["estimation_step"],
            target_update_freq=dqn_params["target_update_freq"],
        )  
        
        if load_model == True:
            # Load the saved checkpoint             
            agent_learn.load_state_dict(torch.load(model_load_path))
            print(f'Loaded-> {model_load_path}')
            
        
        agents = [agent_learn for _ in range(len(env.agents))]
        
    policy = MultiAgentPolicyManager(agents, env)    
        
    return policy, optim, env.agents


def _get_env():
    """This function is needed to provide callables for DummyVectorEnv."""
    env_paralell = MultiDroneEnv()
    #env = parallel_to_aec_wrapper(env_paralell)    
    env = CustomParallelToAECWrapper(env_paralell)
    
    return PettingZooEnv(env)

print(Run_Data)

  from .autonotebook import tqdm as notebook_tqdm


CustomNetSimpleEval_TBTA_01_max30agents_timeRew
        Loaded_Model: no 
        log_path: ./Logs\dqn\CustomNetSimpleEval_TBTA_01_max30agents_timeRew 
        train/test_env_num: 10 / 10 
        model: CustomNetSimple 
        dqn_params: {'discount_factor': 0.99, 'estimation_step': 1, 'target_update_freq': 100, 'optminizer': 'Adam', 'lr': 0.0001} 
        trainer_params: {'max_epoch': 200, 'step_per_epoch': 6000, 'step_per_collect': 500, 'episode_per_test': 100, 'batch_size': 32, 'update_per_step': 0.1, 'tn_eps_max': 0.8, 'ts_eps_max': 0.001} 
        obs: Task Info -> Dist | Quality for each drone             agents= F1:2, R1:4 | tasks= Rec:15, Att:5


In [2]:
if __name__ == "__main__":
                        
    torch.set_grad_enabled(True) 
    # ======== Step 1: Environment setup =========
    train_envs = DummyVectorEnv([_get_env for _ in range(train_env_num)])
    test_envs = DummyVectorEnv([_get_env for _ in range(test_env_num)]) 

    # seed
    seed = 1
    np.random.seed(seed)
    
    torch.manual_seed(seed)
    train_envs.seed(seed)
    test_envs.seed(seed)

    # ======== Step 2: Agent setup =========
    policy, optim, agents = _get_agents()
    

    # ======== Step 3: Collector setup =========
    train_collector = CustomCollector(
        policy,
        train_envs,
        #VectorReplayBuffer(100_000, len(train_envs)),
        PrioritizedVectorReplayBuffer( 100_000, len(train_envs), alpha=0.6, beta=0.4) ,       
        exploration_noise=True        
    )
    test_collector = CustomCollector(policy, test_envs, exploration_noise=True)
     
    train_collector.collect(n_step=trainer_params['batch_size'] * train_env_num)
    #test_collector.collect(n_step=trainer_params['batch size'] * train_env_num)
    
    # ======== tensorboard logging setup =========
    #         
    writer = SummaryWriter(log_path)
    writer.add_text(name, str(Run_Data))
    logger = TensorboardLogger(writer)
        
    # ======== Step 4: Callback functions setup =========
    def save_best_fn(policy):                
        torch.save(policy.policies[agents[0]].state_dict(), model_save_path)

    def stop_fn(mean_rewards):
        return mean_rewards >= 9939.0

    def train_fn(epoch, env_step):
        epsilon = trainer_params['tn_eps_max'] - (trainer_params['tn_eps_max'] - trainer_params['tn_eps_max']/100)*(epoch/trainer_params['max_epoch'])  
        policy.policies[agents[0]].set_eps(epsilon)

    def test_fn(epoch, env_step):
        epsilon = trainer_params['ts_eps_max']#0.01#max(0.001, 0.1 - epoch * 0.001)
        policy.policies[agents[0]].set_eps(epsilon)
        
    def reward_metric(rews):       
        #print(rews)
        return rews.mean()#[:,0]
                           
    # ======== Step 5: Run the trainer =========
    result = offpolicy_trainer(
        policy=policy,
        train_collector=train_collector,
        test_collector=test_collector,        
        max_epoch=trainer_params['max_epoch'],
        step_per_epoch=trainer_params['step_per_epoch'],
        step_per_collect=trainer_params['step_per_collect'],
        episode_per_test= trainer_params['episode_per_test'],
        batch_size=trainer_params['batch_size'],
        train_fn=train_fn,
        test_fn=test_fn,
        stop_fn=stop_fn,
        save_best_fn=save_best_fn,
        update_per_step=trainer_params['update_per_step'],
        logger=logger,
        test_in_train=False,
        reward_metric=reward_metric,
        show_progress = True
        )

    # return result, policy.policies[agents[1]]
    print(f"\n==========Result==========\n{result}")
    print("\n(the trained policy can be accessed via policy.policies[agents[0]])")



.


Epoch #1: 6001it [00:57, 104.27it/s, agent0/loss=44.005, agent1/loss=48.470, agent2/loss=54.151, agent3/loss=52.552, agent4/loss=51.525, agent5/loss=51.641, agent6/loss=53.283, agent7/loss=53.875, env_step=6000, len=20, n/ep=20, n/st=500, rew=-61.19]                          


Epoch #1: test_reward: -122.731227 ± 0.000000, best_reward: -122.731227 ± 0.000000 in #1


Epoch #2: 6001it [00:56, 106.81it/s, agent0/loss=45.230, agent1/loss=46.569, agent2/loss=52.473, agent3/loss=53.451, agent4/loss=52.020, agent5/loss=50.733, agent6/loss=46.244, agent7/loss=48.068, env_step=12000, len=20, n/ep=20, n/st=500, rew=-65.33]                          


Epoch #2: test_reward: -122.568748 ± 0.000000, best_reward: -122.568748 ± 0.000000 in #2


Epoch #3: 6001it [00:52, 113.61it/s, agent0/loss=49.411, agent1/loss=55.326, agent2/loss=56.932, agent3/loss=57.767, agent4/loss=55.754, agent5/loss=55.356, agent6/loss=52.357, agent7/loss=58.205, env_step=18000, len=20, n/ep=20, n/st=500, rew=-65.92]                          


Epoch #3: test_reward: -111.153582 ± 0.000000, best_reward: -111.153582 ± 0.000000 in #3


Epoch #4: 6001it [00:53, 112.15it/s, agent0/loss=49.535, agent1/loss=44.727, agent2/loss=52.089, agent3/loss=53.266, agent4/loss=50.469, agent5/loss=54.346, agent6/loss=50.432, agent7/loss=51.138, env_step=24000, len=20, n/ep=20, n/st=500, rew=-58.80]                          


Epoch #4: test_reward: -113.329863 ± 0.000000, best_reward: -111.153582 ± 0.000000 in #3


Epoch #5: 6001it [00:53, 112.01it/s, agent0/loss=47.641, agent1/loss=50.662, agent2/loss=54.256, agent3/loss=58.042, agent4/loss=51.461, agent5/loss=52.365, agent6/loss=50.186, agent7/loss=51.987, env_step=30000, len=20, n/ep=20, n/st=500, rew=-44.82]                          


Epoch #5: test_reward: -99.981940 ± 0.000000, best_reward: -99.981940 ± 0.000000 in #5


Epoch #6: 6001it [00:53, 112.80it/s, agent0/loss=44.687, agent1/loss=48.225, agent2/loss=53.519, agent3/loss=58.533, agent4/loss=48.159, agent5/loss=50.163, agent6/loss=50.327, agent7/loss=51.512, env_step=36000, len=20, n/ep=20, n/st=500, rew=-63.36]                          


Epoch #6: test_reward: -102.386741 ± 0.000000, best_reward: -99.981940 ± 0.000000 in #5


Epoch #7: 6001it [00:56, 106.83it/s, agent0/loss=44.952, agent1/loss=47.633, agent2/loss=52.411, agent3/loss=55.221, agent4/loss=52.466, agent5/loss=48.846, agent6/loss=51.925, agent7/loss=49.230, env_step=42000, len=20, n/ep=20, n/st=500, rew=-57.92]                          


Epoch #7: test_reward: -122.514011 ± 0.000000, best_reward: -99.981940 ± 0.000000 in #5


Epoch #8: 6001it [01:04, 92.40it/s, agent0/loss=41.841, agent1/loss=48.439, agent2/loss=56.073, agent3/loss=55.842, agent4/loss=46.269, agent5/loss=51.161, agent6/loss=54.383, agent7/loss=49.591, env_step=48000, len=20, n/ep=20, n/st=500, rew=-65.95]                           


Epoch #8: test_reward: -122.355592 ± 0.000000, best_reward: -99.981940 ± 0.000000 in #5


Epoch #9: 6001it [00:58, 102.72it/s, agent0/loss=40.942, agent1/loss=50.257, agent2/loss=56.668, agent3/loss=52.830, agent4/loss=51.225, agent5/loss=51.018, agent6/loss=53.929, agent7/loss=52.039, env_step=54000, len=20, n/ep=20, n/st=500, rew=-59.65]                          


Epoch #9: test_reward: -111.803740 ± 0.000000, best_reward: -99.981940 ± 0.000000 in #5


Epoch #10: 6001it [00:54, 111.07it/s, agent0/loss=46.319, agent1/loss=44.039, agent2/loss=55.065, agent3/loss=57.394, agent4/loss=53.123, agent5/loss=49.781, agent6/loss=57.363, agent7/loss=50.822, env_step=60000, len=20, n/ep=20, n/st=500, rew=-52.14]                          


Epoch #10: test_reward: -100.524584 ± 0.000000, best_reward: -99.981940 ± 0.000000 in #5


Epoch #11: 6001it [00:54, 109.48it/s, agent0/loss=47.614, agent1/loss=46.677, agent2/loss=54.228, agent3/loss=57.915, agent4/loss=55.204, agent5/loss=52.375, agent6/loss=53.106, agent7/loss=52.966, env_step=66000, len=20, n/ep=20, n/st=500, rew=-50.62]                          


Epoch #11: test_reward: -102.659393 ± 0.000000, best_reward: -99.981940 ± 0.000000 in #5


Epoch #12: 6001it [00:56, 106.55it/s, agent0/loss=46.481, agent1/loss=43.388, agent2/loss=52.687, agent3/loss=54.851, agent4/loss=52.549, agent5/loss=57.449, agent6/loss=53.116, agent7/loss=49.814, env_step=72000, len=20, n/ep=20, n/st=500, rew=-49.22]                          


Epoch #12: test_reward: -112.109918 ± 0.000000, best_reward: -99.981940 ± 0.000000 in #5


Epoch #13: 6001it [00:56, 106.05it/s, agent0/loss=45.269, agent1/loss=46.166, agent2/loss=50.160, agent3/loss=56.463, agent4/loss=51.226, agent5/loss=46.671, agent6/loss=48.983, agent7/loss=50.193, env_step=78000, len=20, n/ep=20, n/st=500, rew=-49.97]                          


Epoch #13: test_reward: -91.159605 ± 0.000000, best_reward: -91.159605 ± 0.000000 in #13


Epoch #14: 6001it [00:56, 106.69it/s, agent0/loss=45.597, agent1/loss=46.099, agent2/loss=52.516, agent3/loss=53.531, agent4/loss=52.430, agent5/loss=51.809, agent6/loss=52.179, agent7/loss=50.895, env_step=84000, len=20, n/ep=20, n/st=500, rew=-52.97]                          


Epoch #14: test_reward: -94.365820 ± 0.000000, best_reward: -91.159605 ± 0.000000 in #13


Epoch #15: 6001it [00:57, 104.55it/s, agent0/loss=44.585, agent1/loss=47.039, agent2/loss=58.021, agent3/loss=60.403, agent4/loss=52.775, agent5/loss=47.663, agent6/loss=52.518, agent7/loss=52.514, env_step=90000, len=20, n/ep=20, n/st=500, rew=-51.63]                          


Epoch #15: test_reward: -106.168187 ± 0.000000, best_reward: -91.159605 ± 0.000000 in #13


Epoch #16: 6001it [00:58, 102.43it/s, agent0/loss=47.688, agent1/loss=45.293, agent2/loss=52.356, agent3/loss=51.931, agent4/loss=53.204, agent5/loss=52.388, agent6/loss=48.739, agent7/loss=50.901, env_step=96000, len=20, n/ep=20, n/st=500, rew=-36.67]                          


Epoch #16: test_reward: -102.219267 ± 0.000000, best_reward: -91.159605 ± 0.000000 in #13


Epoch #17: 6001it [01:01, 97.64it/s, agent0/loss=44.910, agent1/loss=46.573, agent2/loss=56.722, agent3/loss=52.270, agent4/loss=52.519, agent5/loss=51.494, agent6/loss=52.629, agent7/loss=46.386, env_step=102000, len=20, n/ep=20, n/st=500, rew=-51.06]                           


Epoch #17: test_reward: -73.350536 ± 0.000000, best_reward: -73.350536 ± 0.000000 in #17


Epoch #18: 6001it [01:04, 92.83it/s, agent0/loss=44.373, agent1/loss=43.526, agent2/loss=51.181, agent3/loss=54.851, agent4/loss=49.629, agent5/loss=48.653, agent6/loss=48.919, agent7/loss=52.269, env_step=108000, len=20, n/ep=20, n/st=500, rew=-47.32]                           


Epoch #18: test_reward: -67.762214 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #19: 6001it [01:06, 90.78it/s, agent0/loss=44.646, agent1/loss=48.938, agent2/loss=50.446, agent3/loss=53.762, agent4/loss=51.257, agent5/loss=48.735, agent6/loss=49.590, agent7/loss=51.519, env_step=114000, len=20, n/ep=20, n/st=500, rew=-34.25]                           


Epoch #19: test_reward: -67.803602 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #20: 6001it [01:06, 90.06it/s, agent0/loss=41.181, agent1/loss=48.876, agent2/loss=54.024, agent3/loss=56.195, agent4/loss=49.159, agent5/loss=48.959, agent6/loss=51.627, agent7/loss=47.547, env_step=120000, len=20, n/ep=20, n/st=500, rew=-47.30]                           


Epoch #20: test_reward: -92.401393 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #21: 6001it [01:06, 90.27it/s, agent0/loss=47.030, agent1/loss=46.907, agent2/loss=52.330, agent3/loss=54.376, agent4/loss=49.732, agent5/loss=52.831, agent6/loss=46.594, agent7/loss=52.192, env_step=126000, len=20, n/ep=20, n/st=500, rew=-49.46]                           


Epoch #21: test_reward: -102.178212 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #22: 6001it [01:08, 87.21it/s, agent0/loss=49.780, agent1/loss=47.162, agent2/loss=51.683, agent3/loss=55.048, agent4/loss=53.501, agent5/loss=47.064, agent6/loss=48.398, agent7/loss=54.077, env_step=132000, len=20, n/ep=20, n/st=500, rew=-65.34]                           


Epoch #22: test_reward: -101.567857 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #23: 6001it [01:06, 89.94it/s, agent0/loss=44.192, agent1/loss=42.804, agent2/loss=54.847, agent3/loss=51.754, agent4/loss=55.215, agent5/loss=47.701, agent6/loss=51.348, agent7/loss=49.540, env_step=138000, len=20, n/ep=20, n/st=500, rew=-49.59]                           


Epoch #23: test_reward: -85.517960 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #24: 6001it [01:16, 78.62it/s, agent0/loss=43.860, agent1/loss=44.385, agent2/loss=50.963, agent3/loss=52.210, agent4/loss=46.348, agent5/loss=49.277, agent6/loss=51.343, agent7/loss=55.603, env_step=144000, len=20, n/ep=20, n/st=500, rew=-45.93]                           


Epoch #24: test_reward: -90.826644 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #25: 6001it [01:20, 74.92it/s, agent0/loss=40.691, agent1/loss=39.578, agent2/loss=53.150, agent3/loss=54.879, agent4/loss=50.689, agent5/loss=54.551, agent6/loss=50.701, agent7/loss=52.510, env_step=150000, len=20, n/ep=20, n/st=500, rew=-53.00]                           


Epoch #25: test_reward: -90.959073 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #26: 6001it [01:17, 77.46it/s, agent0/loss=43.185, agent1/loss=45.009, agent2/loss=51.404, agent3/loss=57.435, agent4/loss=50.608, agent5/loss=50.412, agent6/loss=50.067, agent7/loss=56.817, env_step=156000, len=20, n/ep=20, n/st=500, rew=-52.85]                           


Epoch #26: test_reward: -101.722793 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #27: 6001it [01:20, 74.83it/s, agent0/loss=46.880, agent1/loss=43.701, agent2/loss=56.612, agent3/loss=52.855, agent4/loss=48.515, agent5/loss=49.923, agent6/loss=44.946, agent7/loss=51.620, env_step=162000, len=20, n/ep=20, n/st=500, rew=-54.63]                           


Epoch #27: test_reward: -113.900749 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #28: 6001it [01:22, 72.71it/s, agent0/loss=43.314, agent1/loss=44.781, agent2/loss=56.449, agent3/loss=51.167, agent4/loss=48.654, agent5/loss=54.273, agent6/loss=49.653, agent7/loss=53.822, env_step=168000, len=20, n/ep=20, n/st=500, rew=-33.44]                           


Epoch #28: test_reward: -102.445426 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #29: 6001it [01:25, 70.53it/s, agent0/loss=45.599, agent1/loss=45.482, agent2/loss=52.768, agent3/loss=59.788, agent4/loss=46.908, agent5/loss=52.710, agent6/loss=46.240, agent7/loss=54.593, env_step=174000, len=20, n/ep=20, n/st=500, rew=-55.19]                           


Epoch #29: test_reward: -84.445437 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #30: 6001it [01:19, 75.88it/s, agent0/loss=41.450, agent1/loss=45.666, agent2/loss=51.687, agent3/loss=53.780, agent4/loss=44.513, agent5/loss=49.693, agent6/loss=50.493, agent7/loss=47.766, env_step=180000, len=20, n/ep=20, n/st=500, rew=-42.44]                           


Epoch #30: test_reward: -104.371129 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #31: 6001it [01:14, 80.08it/s, agent0/loss=42.344, agent1/loss=41.150, agent2/loss=54.315, agent3/loss=54.785, agent4/loss=51.016, agent5/loss=51.330, agent6/loss=51.774, agent7/loss=49.521, env_step=186000, len=20, n/ep=20, n/st=500, rew=-44.30]                           


Epoch #31: test_reward: -88.334219 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #32: 6001it [01:13, 81.13it/s, agent0/loss=43.748, agent1/loss=43.929, agent2/loss=51.783, agent3/loss=54.761, agent4/loss=48.248, agent5/loss=47.242, agent6/loss=54.637, agent7/loss=50.496, env_step=192000, len=20, n/ep=20, n/st=500, rew=-52.94]                           


Epoch #32: test_reward: -89.628420 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #33: 6001it [01:15, 79.00it/s, agent0/loss=39.421, agent1/loss=43.091, agent2/loss=50.155, agent3/loss=53.375, agent4/loss=45.925, agent5/loss=45.855, agent6/loss=49.221, agent7/loss=52.472, env_step=198000, len=20, n/ep=20, n/st=500, rew=-47.73]                           


Epoch #33: test_reward: -102.342758 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #34: 6001it [01:19, 75.67it/s, agent0/loss=44.750, agent1/loss=43.406, agent2/loss=52.743, agent3/loss=52.081, agent4/loss=47.212, agent5/loss=47.764, agent6/loss=51.709, agent7/loss=51.642, env_step=204000, len=20, n/ep=20, n/st=500, rew=-53.61]                           


Epoch #34: test_reward: -102.182579 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #35: 6001it [01:17, 77.85it/s, agent0/loss=41.941, agent1/loss=42.030, agent2/loss=53.012, agent3/loss=49.484, agent4/loss=52.407, agent5/loss=49.459, agent6/loss=48.021, agent7/loss=49.713, env_step=210000, len=20, n/ep=20, n/st=500, rew=-55.53]                           


Epoch #35: test_reward: -101.153009 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #36: 6001it [01:15, 79.24it/s, agent0/loss=41.093, agent1/loss=45.472, agent2/loss=50.128, agent3/loss=51.419, agent4/loss=49.030, agent5/loss=47.873, agent6/loss=52.649, agent7/loss=50.067, env_step=216000, len=20, n/ep=20, n/st=500, rew=-40.65]                           


Epoch #36: test_reward: -90.367024 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #37: 6001it [01:16, 78.15it/s, agent0/loss=41.835, agent1/loss=42.500, agent2/loss=52.793, agent3/loss=53.582, agent4/loss=50.711, agent5/loss=50.717, agent6/loss=47.862, agent7/loss=49.405, env_step=222000, len=20, n/ep=20, n/st=500, rew=-70.68]                           


Epoch #37: test_reward: -113.624404 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #38: 6001it [01:20, 74.14it/s, agent0/loss=43.465, agent1/loss=43.865, agent2/loss=52.488, agent3/loss=51.728, agent4/loss=46.796, agent5/loss=47.505, agent6/loss=47.673, agent7/loss=50.884, env_step=228000, len=20, n/ep=20, n/st=500, rew=-66.70]                           


Epoch #38: test_reward: -109.116386 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #39: 6001it [01:22, 72.94it/s, agent0/loss=42.955, agent1/loss=36.723, agent2/loss=52.032, agent3/loss=51.849, agent4/loss=49.408, agent5/loss=48.027, agent6/loss=46.362, agent7/loss=50.744, env_step=234000, len=20, n/ep=20, n/st=500, rew=-56.14]                           


Epoch #39: test_reward: -113.949044 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #40: 6001it [01:27, 68.64it/s, agent0/loss=44.546, agent1/loss=43.099, agent2/loss=50.735, agent3/loss=55.517, agent4/loss=45.006, agent5/loss=49.418, agent6/loss=51.833, agent7/loss=45.281, env_step=240000, len=20, n/ep=20, n/st=500, rew=-56.59]                           


Epoch #40: test_reward: -113.568913 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #41: 6001it [01:41, 59.22it/s, agent0/loss=40.739, agent1/loss=42.295, agent2/loss=52.036, agent3/loss=52.975, agent4/loss=47.394, agent5/loss=51.062, agent6/loss=43.559, agent7/loss=50.752, env_step=246000, len=20, n/ep=20, n/st=500, rew=-69.33]                           


Epoch #41: test_reward: -122.683673 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #42: 6001it [01:37, 61.65it/s, agent0/loss=42.405, agent1/loss=42.783, agent2/loss=50.208, agent3/loss=54.205, agent4/loss=45.514, agent5/loss=47.840, agent6/loss=47.635, agent7/loss=49.187, env_step=252000, len=20, n/ep=20, n/st=500, rew=-81.38]                          


Epoch #42: test_reward: -122.008687 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #43: 6001it [01:39, 60.25it/s, agent0/loss=45.007, agent1/loss=47.963, agent2/loss=49.532, agent3/loss=54.324, agent4/loss=46.912, agent5/loss=44.553, agent6/loss=49.624, agent7/loss=49.378, env_step=258000, len=20, n/ep=20, n/st=500, rew=-69.43]                          


Epoch #43: test_reward: -122.519477 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #44: 6001it [01:39, 60.16it/s, agent0/loss=43.293, agent1/loss=47.155, agent2/loss=52.332, agent3/loss=54.366, agent4/loss=47.796, agent5/loss=46.811, agent6/loss=48.498, agent7/loss=55.498, env_step=264000, len=20, n/ep=20, n/st=500, rew=-71.81]                          


Epoch #44: test_reward: -122.635236 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #45: 6001it [01:37, 61.27it/s, agent0/loss=43.188, agent1/loss=44.461, agent2/loss=53.638, agent3/loss=54.508, agent4/loss=50.798, agent5/loss=44.846, agent6/loss=50.714, agent7/loss=50.646, env_step=270000, len=20, n/ep=20, n/st=500, rew=-71.64]                          


Epoch #45: test_reward: -122.500108 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #46: 6001it [01:35, 62.83it/s, agent0/loss=40.459, agent1/loss=42.872, agent2/loss=50.238, agent3/loss=51.742, agent4/loss=49.737, agent5/loss=44.985, agent6/loss=47.387, agent7/loss=46.162, env_step=276000, len=20, n/ep=20, n/st=500, rew=-63.16]                          


Epoch #46: test_reward: -112.762813 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #47: 6001it [01:31, 65.75it/s, agent0/loss=42.686, agent1/loss=45.137, agent2/loss=48.122, agent3/loss=52.488, agent4/loss=47.351, agent5/loss=44.689, agent6/loss=46.409, agent7/loss=49.371, env_step=282000, len=20, n/ep=20, n/st=500, rew=-68.94]                           


Epoch #47: test_reward: -113.689719 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #48: 6001it [01:29, 67.24it/s, agent0/loss=42.491, agent1/loss=39.274, agent2/loss=48.104, agent3/loss=48.980, agent4/loss=47.086, agent5/loss=48.140, agent6/loss=47.343, agent7/loss=51.732, env_step=288000, len=20, n/ep=20, n/st=500, rew=-70.41]                           


Epoch #48: test_reward: -113.487396 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #49: 6001it [01:26, 69.20it/s, agent0/loss=43.123, agent1/loss=42.283, agent2/loss=48.272, agent3/loss=53.457, agent4/loss=45.953, agent5/loss=51.768, agent6/loss=47.721, agent7/loss=50.184, env_step=294000, len=20, n/ep=20, n/st=500, rew=-67.08]                           


Epoch #49: test_reward: -113.995590 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #50: 6001it [01:30, 66.24it/s, agent0/loss=43.595, agent1/loss=40.073, agent2/loss=52.063, agent3/loss=52.549, agent4/loss=48.132, agent5/loss=49.516, agent6/loss=45.710, agent7/loss=52.230, env_step=300000, len=20, n/ep=20, n/st=500, rew=-71.96]                           


Epoch #50: test_reward: -122.439957 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #51: 6001it [01:24, 70.77it/s, agent0/loss=39.537, agent1/loss=40.700, agent2/loss=52.459, agent3/loss=51.977, agent4/loss=43.682, agent5/loss=43.395, agent6/loss=46.545, agent7/loss=47.839, env_step=306000, len=20, n/ep=20, n/st=500, rew=-71.01]                           


Epoch #51: test_reward: -122.222687 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #52: 6001it [01:23, 72.17it/s, agent0/loss=39.061, agent1/loss=39.907, agent2/loss=48.453, agent3/loss=50.496, agent4/loss=48.416, agent5/loss=46.641, agent6/loss=48.838, agent7/loss=49.444, env_step=312000, len=20, n/ep=20, n/st=500, rew=-67.92]                           


Epoch #52: test_reward: -102.557873 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #53: 6001it [01:24, 71.39it/s, agent0/loss=38.204, agent1/loss=46.217, agent2/loss=51.109, agent3/loss=52.009, agent4/loss=46.278, agent5/loss=43.250, agent6/loss=48.319, agent7/loss=47.300, env_step=318000, len=20, n/ep=20, n/st=500, rew=-68.65]                           


Epoch #53: test_reward: -113.763659 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #54: 6001it [01:22, 72.98it/s, agent0/loss=40.501, agent1/loss=41.576, agent2/loss=49.395, agent3/loss=54.620, agent4/loss=44.257, agent5/loss=44.803, agent6/loss=47.237, agent7/loss=46.455, env_step=324000, len=20, n/ep=20, n/st=500, rew=-60.77]                           


Epoch #54: test_reward: -112.932111 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #55: 6001it [01:24, 71.08it/s, agent0/loss=40.444, agent1/loss=38.834, agent2/loss=48.947, agent3/loss=51.749, agent4/loss=42.773, agent5/loss=45.779, agent6/loss=46.340, agent7/loss=51.357, env_step=330000, len=20, n/ep=20, n/st=500, rew=-62.14]                           


Epoch #55: test_reward: -102.321374 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #56: 6001it [01:22, 72.53it/s, agent0/loss=42.099, agent1/loss=40.957, agent2/loss=49.053, agent3/loss=49.265, agent4/loss=43.214, agent5/loss=44.664, agent6/loss=44.720, agent7/loss=50.344, env_step=336000, len=20, n/ep=20, n/st=500, rew=-60.99]                           


Epoch #56: test_reward: -102.312722 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #57: 6001it [01:22, 72.44it/s, agent0/loss=41.060, agent1/loss=43.335, agent2/loss=48.592, agent3/loss=52.385, agent4/loss=47.117, agent5/loss=48.352, agent6/loss=46.800, agent7/loss=47.413, env_step=342000, len=20, n/ep=20, n/st=500, rew=-61.42]                           


Epoch #57: test_reward: -102.196611 ± 0.000000, best_reward: -67.762214 ± 0.000000 in #18


Epoch #58:  25%|##5       | 1500/6000 [00:20<01:00, 74.55it/s, agent0/loss=42.089, agent1/loss=41.102, agent2/loss=47.246, agent3/loss=48.743, agent4/loss=45.764, agent5/loss=50.565, agent6/loss=42.312, agent7/loss=50.357, env_step=343000, len=20, n/ep=20, n/st=500, rew=-62.76]


KeyboardInterrupt: 

In [None]:

from typing import Optional, Tuple

import numpy as np
import torch
from tianshou.env import DummyVectorEnv
from tianshou.trainer import offpolicy_trainer
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import TensorboardLogger
import torch
#from Custom_Classes import CustomCollector

# Create a new instance of the policy with the same architecture as the saved policy
policy, optim, _ = _get_agents()
model_save_path = os.path.join("dqn_Custom", save_policy_name)        

# Load the saved checkpoint
policy_test = policy.policies['agent0']
policy_test.load_state_dict(torch.load(model_save_path ))

envs = DummyVectorEnv([_get_env for _ in range(1)])

envs.max_time_steps = 200
#policy_test.policies['agent0'].eval()
#policy.policies['agent0'].set_eps(0.9)

policy_test.eval()
policy_test.set_eps(0.00)


#collector = CustomCollector(policy.policies['agent0'], envs, exploration_noise=True)
#collector = CustomCollector(policy_test, envs, exploration_noise=False)
collector = CustomCollector(policy, envs, exploration_noise=True)

results = collector.collect(n_episode=10)
#collector.collect(n_episode=1, render=1 / 5000)






In [None]:
results

{'n/ep': 10,
 'n/st': 400,
 'rews': array([[146.60998543, 146.60998543, 146.60998543, 146.60998543,
         146.60998543, 146.60998543, 146.60998543, 146.60998543],
        [149.65674368, 149.65674368, 149.65674368, 149.65674368,
         149.65674368, 149.65674368, 149.65674368, 149.65674368],
        [150.12206002, 150.12206002, 150.12206002, 150.12206002,
         150.12206002, 150.12206002, 150.12206002, 150.12206002],
        [150.34100261, 150.34100261, 150.34100261, 150.34100261,
         150.34100261, 150.34100261, 150.34100261, 150.34100261],
        [136.37620211, 136.37620211, 136.37620211, 136.37620211,
         136.37620211, 136.37620211, 136.37620211, 136.37620211],
        [145.54718066, 145.54718066, 145.54718066, 145.54718066,
         145.54718066, 145.54718066, 145.54718066, 145.54718066],
        [144.85799617, 144.85799617, 144.85799617, 144.85799617,
         144.85799617, 144.85799617, 144.85799617, 144.85799617],
        [146.66782528, 146.66782528, 146.6678252

In [None]:
max(results['rews'])
print(np.mean(results['rews'][results['rews'] > -10]))


#create a function  to print a histogram of the results['rews']
import matplotlib.pyplot as plt
plt.hist(results['rews'][results['rews'] > -10], bins=100)
plt.show()


ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

In [None]:
from turtle import st
import torch
from tianshou.data import Batch

# load policy as in your original code
policy, optim, _ = _get_agents()
model_save_path = os.path.join("dqn_Custom", save_policy_name)        
policy_test = policy.policies['agent0']
state_saved = torch.load(model_save_path )
print(policy_test)
policy_test.load_state_dict(state_saved)
policy_test.eval()
policy_test.set_eps(0.00)

# initialize your environment
#env = DummyVectorEnv([_get_env for _ in range(1)])
env = MultiDroneEnv(None)
env.max_time_steps = 100

# simulate the interaction with the environment manually
for episode in range(1):  # simulate 10 episodes
    
    obs, _  = env.reset(seed=episode)         
    info         = env.get_initial_state()
    
    drones = info["drones"]
    tasks = info["tasks"]
        
    done = {0 : False}
    truncations = {0 : False}
    
    episodo_reward = 0
    #obs, reward, done, truncations, info = env.step(action)

    while not all(done.values()) and not all(truncations.values()):
        
        agent_id = "agent" + str(env.agent_selector._current_agent)
        # Create a Batch of observations
        obs_batch = Batch(obs=obs[agent_id], info=[{}])  # add empty info for each observation
        
        #print(obs_batch)
        # Forward the batch of observations through the policy to get the actions
        action = policy_test(obs_batch).act
        action = {agent_id : action[0]}
       
        obs, reward, done, truncations, info = env.step(action)
          
        episodo_reward += sum(reward.values())/env.n_agents

       

print(episodo_reward)
