In [1]:
import os
import datetime
from typing import Optional, Tuple

import gym
import numpy as np
import torch
from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy
from tianshou.trainer import offpolicy_trainer
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import TensorboardLogger

from TaskAllocation.RL_Policies.Custom_Classes import CustomNet
from TaskAllocation.RL_Policies.Custom_Classes import CustomCollector
from TaskAllocation.RL_Policies.Custom_Classes import CustomParallelToAECWrapper

#from CustomClass_multi_head import CustomNet
from TaskAllocation.RL_Policies.Custom_Classes_simplified import CustomNetSimple
#from Custom_Classes_simplified import CustomCollectorSimple
#from Custom_Classes_simplified import CustomParallelToAECWrapperSimple

from TaskAllocation.RL_Policies.CustomClasses_Transformer_Reduced import CustomNetReduced
from TaskAllocation.RL_Policies.CustomClass_MultiHead_Transformer import CustomNetMultiHead
import importlib

from mUAV_TA.DroneEnv import MultiDroneEnv
#from tianshou_DQN import train
model = "CustomNetMultiHead" # "CustomNet" or "CustomNetSimple" or "CustomNetReduced" or "CustomNetMultiHead"
test_num = "_Eval_TBTA_02_simplified_UCF_mask01_seed0All"

train_env_num = 10
test_env_num = 10

name = model + test_num

load_policy_name = f'policy_CustomNetMultiHead_Eval_TBTA_02_simplified_UCF1_new_rew_updR.pth'
save_policy_name = f'policy_{name}.pth'
policy_path = "dqn_Custom"
load_model = False

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = name + str(now)

log_path = os.path.join('./', "Logs", "dqn", log_name)

dqn_params = {"discount_factor": 0.98, 
              "estimation_step": 5, 
              "target_update_freq": 100,
              "optminizer": "Adam",
              "lr": 1e-4  }

trainer_params = {"max_epoch": 1000,
                  "step_per_epoch": 250 * train_env_num,
                  "step_per_collect": 20,
                  "episode_per_test": 1,
                  "batch_size" : 32,
                  "update_per_step": 1/25,
                  "tn_eps_max": 0.85,
                  "ts_eps_max": 0.0,
                  }

Run_Data = f'{name}\n\
        Loaded_Model: {load_policy_name if load_model == True else "no"} \n\
        log_path: {log_path} \n\
        train/test_env_num: {train_env_num} / {test_env_num} \n\
        model: {model} \n\
        dqn_params: {dqn_params} \n\
        trainer_params: {trainer_params} \n\
        obs: Task Info -> Dist / Quality for own drone \
            Agents_info -> Post_next / Time_next / Type \
            Scene:  agents= F1:2, F2:2, R1:3, R2:3,     \
                 tasks= Att: 8 , Rec:22'

model_load_path = os.path.join(policy_path, load_policy_name)  
model_save_path = os.path.join(policy_path, save_policy_name)        
os.makedirs(os.path.join(policy_path), exist_ok=True)  
os.makedirs(os.path.join(log_path), exist_ok=True)


def _get_agents(
    agent_learn: Optional[BasePolicy] = None,
    agent_opponent: Optional[BasePolicy] = None,
    optim: Optional[torch.optim.Optimizer] = None,
    policy_load_path = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
    
    env = _get_env()
    agent_name = env.agents[0]  # Get the name of the first agent

    #print(env.observation_space )
    agent_observation_space = env.observation_space # assuming 'agent0' is a valid agent name
    state_shape_agent_position = agent_observation_space["agent_position"].shape[0]
    state_shape_agent_state = agent_observation_space["agent_state"].shape[0]
    state_shape_agent_type = agent_observation_space["agent_type"].shape[0]
    state_shape_next_free_time = agent_observation_space["next_free_time"].shape[0]
    state_shape_position_after_last_task = agent_observation_space["position_after_last_task"].shape[0]       
    #state_shape_agent_relay_area = agent_observation_space["agent_relay_area"].shape[0]
        
    state_shape_agent = (state_shape_agent_position + state_shape_agent_state +
                     state_shape_agent_type+ state_shape_next_free_time + state_shape_position_after_last_task #+                     
                     #state_shape_agent_relay_area
                     )                 
    

    state_shape_task = 30 * 3 #env.observation_space["tasks_info"].shape[0]
                  
    action_shape = env.action_space[agent_name].shape[0]
    #action_shape = env.action_space[agent_name].n
               
    if agent_learn is None:
        # model
        if model == "CustomNet":        
            net = CustomNet(
                state_shape_agent=state_shape_agent,
                state_shape_task=state_shape_task,
                action_shape=action_shape,
                hidden_sizes=[128,128],
                device="cuda" if torch.cuda.is_available() else "cpu",
            ).to("cuda" if torch.cuda.is_available() else "cpu")
        
        if model == "CustomNetSimple":
            net = CustomNetSimple(            
                state_shape_agent=state_shape_agent,
                state_shape_task=state_shape_task,
                action_shape=action_shape,
                hidden_sizes=[128,128],
                device="cuda" if torch.cuda.is_available() else "cpu",
            ).to("cuda" if torch.cuda.is_available() else "cpu")
        
        if model == "CustomNetReduced":
            net = CustomNetReduced(            
                state_shape_agent=state_shape_agent,
                state_shape_task=state_shape_task,
                action_shape=action_shape,
                hidden_sizes=[128,128],
                device="cuda" if torch.cuda.is_available() else "cpu",
            ).to("cuda" if torch.cuda.is_available() else "cpu")
        
        if model == "CustomNetMultiHead":
            net = CustomNetMultiHead(
                state_shape_agent=state_shape_agent,
                state_shape_task=state_shape_task,
                action_shape=action_shape,
                hidden_sizes=[128,128],
                device="cuda" if torch.cuda.is_available() else "cpu",
            ).to("cuda" if torch.cuda.is_available() else "cpu")

    
        if optim is None:
            optim = torch.optim.Adam(net.parameters(), lr=dqn_params["lr"])
    
        agent_learn = DQNPolicy(
            model=net,
            optim=optim,
            discount_factor= dqn_params["discount_factor"],
            estimation_step=dqn_params["estimation_step"],
            target_update_freq=dqn_params["target_update_freq"],
        )  
        
        if load_model == True:
            # Load the saved checkpoint             
            agent_learn.load_state_dict(torch.load(model_load_path))
            print(f'Loaded-> {model_load_path}')
            
        
        agents = [agent_learn for _ in range(len(env.agents))]
        
    policy = MultiAgentPolicyManager(agents, env)    
        
    return policy, optim, env.agents


def _get_env():
    """This function is needed to provide callables for DummyVectorEnv."""   
    env_paralell = MultiDroneEnv()
    #env = parallel_to_aec_wrapper(env_paralell)    
    env = CustomParallelToAECWrapper(env_paralell)
    
    return PettingZooEnv(env)

print(Run_Data)

CustomNetMultiHead_Eval_TBTA_02_simplified_UCF_mask01_seed0All
        Loaded_Model: no 
        log_path: ./Logs\dqn\CustomNetMultiHead_Eval_TBTA_02_simplified_UCF_mask01_seed0All230915-112008 
        train/test_env_num: 10 / 10 
        model: CustomNetMultiHead 
        dqn_params: {'discount_factor': 0.98, 'estimation_step': 5, 'target_update_freq': 100, 'optminizer': 'Adam', 'lr': 0.0001} 
        trainer_params: {'max_epoch': 1000, 'step_per_epoch': 2500, 'step_per_collect': 20, 'episode_per_test': 1, 'batch_size': 32, 'update_per_step': 0.04, 'tn_eps_max': 0.85, 'ts_eps_max': 0.0} 
        obs: Task Info -> Dist / Quality for own drone             Agents_info -> Post_next / Time_next / Type             Scene:  agents= F1:2, F2:2, R1:3, R2:3,                      tasks= Att: 8 , Rec:22


In [2]:
if __name__ == "__main__":
                        
    torch.set_grad_enabled(True) 
    # ======== Step 1: Environment setup =========
    train_envs = DummyVectorEnv([_get_env for _ in range(train_env_num)])
    test_envs = DummyVectorEnv([_get_env for _ in range(test_env_num)]) 

    # seed
    seed = 1
    np.random.seed(seed)
    
    torch.manual_seed(seed)
    train_envs.seed(seed)
    test_envs.seed(seed)

    # ======== Step 2: Agent setup =========
    policy, optim, agents = _get_agents()
    

    # ======== Step 3: Collector setup =========
    train_collector = CustomCollector(
        policy,
        train_envs,
        #VectorReplayBuffer(100_000, len(train_envs)),
        PrioritizedVectorReplayBuffer( 100_000, len(train_envs), alpha=0.6, beta=0.4) ,       
        exploration_noise=True             
    )
    test_collector = CustomCollector(policy, test_envs, exploration_noise=True)
     
    #train_collector.collect(n_step=trainer_params['batch_size'] * train_env_num)
    train_collector.collect(n_episode=10)
    #test_collector.collect(n_episode=2 )
    #test_collector.collect(n_step=trainer_params['batch size'] * train_env_num)
    
    # ======== tensorboard logging setup =========
    #         
    writer = SummaryWriter(log_path)
    writer.add_text(name, str(Run_Data))
    logger = TensorboardLogger(writer)
        
    # ======== Step 4: Callback functions setup =========
    def save_best_fn(policy):                
        print("Best Saved")
        torch.save(policy.policies[agents[0]].state_dict(), model_save_path)

    def stop_fn(mean_rewards):
        return mean_rewards >= 9939.0

    def train_fn(epoch, env_step):
        epsilon = trainer_params['tn_eps_max'] - (trainer_params['tn_eps_max'] - trainer_params['tn_eps_max']/100)*(epoch/trainer_params['max_epoch'])  
        policy.policies[agents[0]].set_eps(epsilon)

    def test_fn(epoch, env_step):
        epsilon = trainer_params['ts_eps_max']#0.01#max(0.001, 0.1 - epoch * 0.001)
        policy.policies[agents[0]].set_eps(epsilon)
        
    def reward_metric(rews):       
        #print(rews)
        return rews.mean()#[:,0]
                           
    # ======== Step 5: Run the trainer =========
    result = offpolicy_trainer(
        policy=policy,
        train_collector=train_collector,
        test_collector=test_collector,        
        max_epoch=trainer_params['max_epoch'],
        step_per_epoch=trainer_params['step_per_epoch'],
        step_per_collect=trainer_params['step_per_collect'],
        episode_per_test= trainer_params['episode_per_test'],
        batch_size=trainer_params['batch_size'],
        train_fn=train_fn,
        test_fn=test_fn,
        stop_fn=stop_fn,
        save_best_fn=save_best_fn,
        update_per_step=trainer_params['update_per_step'],
        logger=logger,
        test_in_train=False,
        reward_metric=reward_metric,
        show_progress = True          
        )

    # return result, policy.policies[agents[1]]
    print(f"\n==========Result==========\n{result}")
    print("\n(the trained policy can be accessed via policy.policies[agents[0]])")



Best Saved


Epoch #1: 2501it [00:25, 97.85it/s, agent0/loss=417.506, agent1/loss=373.098, agent2/loss=471.721, agent3/loss=331.841, agent4/loss=258.785, agent5/loss=285.056, agent6/loss=271.582, agent7/loss=228.886, agent8/loss=216.837, agent9/loss=291.650, env_step=2500, len=50, n/ep=10, n/st=20, rew=124.65]                           


Best Saved
Epoch #1: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #2: 2501it [00:26, 94.17it/s, agent0/loss=375.417, agent1/loss=303.714, agent2/loss=387.532, agent3/loss=322.924, agent4/loss=233.661, agent5/loss=251.161, agent6/loss=211.441, agent7/loss=156.173, agent8/loss=183.473, agent9/loss=187.043, env_step=5000, len=50, n/ep=10, n/st=20, rew=130.45]                           


Epoch #2: test_reward: 171.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #3: 2501it [00:25, 96.82it/s, agent0/loss=355.877, agent1/loss=310.226, agent3/loss=292.760, agent5/loss=228.061, agent6/loss=243.793, agent7/loss=147.065, agent8/loss=190.982, agent9/loss=155.933, env_step=7500, len=50, n/ep=10, n/st=20, rew=129.45]                                                                     


Epoch #3: test_reward: 178.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #4: 2501it [00:25, 97.19it/s, agent0/loss=306.719, agent1/loss=288.721, agent2/loss=354.972, agent3/loss=295.086, agent4/loss=263.616, agent5/loss=205.578, agent6/loss=208.707, agent7/loss=162.531, agent8/loss=164.855, agent9/loss=162.124, env_step=10000, len=50, n/ep=10, n/st=20, rew=129.65]                          


Epoch #4: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #5: 2501it [00:26, 96.18it/s, agent0/loss=292.150, agent1/loss=321.369, agent2/loss=365.657, agent3/loss=312.415, agent4/loss=242.300, agent5/loss=206.563, agent6/loss=192.250, agent7/loss=183.088, agent8/loss=158.472, agent9/loss=151.118, env_step=12500, len=50, n/ep=10, n/st=20, rew=131.40]                           


Epoch #5: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #6: 2501it [00:25, 96.95it/s, agent0/loss=290.651, agent1/loss=351.040, agent2/loss=315.815, agent3/loss=294.568, agent4/loss=213.765, agent5/loss=220.859, agent6/loss=188.190, agent7/loss=160.909, agent8/loss=179.511, agent9/loss=195.776, env_step=15000, len=50, n/ep=10, n/st=20, rew=119.45]                           


Epoch #6: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #7: 2501it [00:25, 97.34it/s, agent0/loss=324.812, agent1/loss=268.491, agent2/loss=278.987, agent4/loss=227.590, agent5/loss=215.703, agent6/loss=222.071, agent7/loss=194.296, agent8/loss=143.941, agent9/loss=203.409, env_step=17500, len=50, n/ep=10, n/st=20, rew=120.85]                                                


Epoch #7: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #8: 2501it [00:25, 97.63it/s, agent0/loss=335.559, agent1/loss=269.076, agent2/loss=255.567, agent3/loss=287.629, agent4/loss=221.696, agent5/loss=242.184, agent6/loss=187.150, agent7/loss=162.961, agent8/loss=153.234, agent9/loss=188.034, env_step=20000, len=50, n/ep=10, n/st=20, rew=129.55]                           


Epoch #8: test_reward: 178.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #9: 2501it [00:25, 98.24it/s, agent0/loss=265.818, agent1/loss=336.490, agent2/loss=298.191, agent3/loss=258.956, agent4/loss=250.403, agent5/loss=209.058, agent6/loss=192.833, agent7/loss=175.554, agent8/loss=138.391, agent9/loss=194.672, env_step=22500, len=50, n/ep=10, n/st=20, rew=119.75]                           


Epoch #9: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #10: 2501it [00:25, 96.52it/s, agent0/loss=334.730, agent1/loss=304.283, agent2/loss=336.221, agent3/loss=333.799, agent4/loss=238.331, agent5/loss=203.181, agent6/loss=190.910, agent7/loss=174.257, agent8/loss=186.789, agent9/loss=167.554, env_step=25000, len=50, n/ep=10, n/st=20, rew=131.45]                           


Epoch #10: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #11: 2501it [00:25, 97.46it/s, agent0/loss=306.260, agent1/loss=294.430, agent2/loss=327.952, agent3/loss=278.398, agent4/loss=262.150, agent5/loss=188.899, agent6/loss=216.669, agent7/loss=179.533, agent8/loss=197.427, agent9/loss=193.360, env_step=27500, len=50, n/ep=10, n/st=20, rew=129.90]                          


Epoch #11: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #12: 2501it [00:26, 95.01it/s, agent0/loss=292.344, agent1/loss=299.057, agent2/loss=319.800, agent3/loss=296.016, agent5/loss=213.239, agent6/loss=216.816, agent7/loss=148.879, agent8/loss=188.242, agent9/loss=182.572, env_step=30000, len=50, n/ep=10, n/st=20, rew=123.80]                                               


Epoch #12: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #13: 2501it [00:26, 96.16it/s, agent0/loss=279.837, agent1/loss=258.550, agent2/loss=321.798, agent3/loss=265.726, agent4/loss=219.625, agent5/loss=241.862, agent6/loss=189.645, agent7/loss=168.761, agent8/loss=180.560, agent9/loss=203.979, env_step=32500, len=50, n/ep=10, n/st=20, rew=129.25]                           


Epoch #13: test_reward: 178.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #14: 2501it [00:25, 96.74it/s, agent0/loss=302.774, agent1/loss=298.980, agent2/loss=321.387, agent3/loss=306.369, agent4/loss=253.795, agent5/loss=250.208, agent6/loss=206.361, agent7/loss=161.192, agent8/loss=192.379, agent9/loss=175.556, env_step=35000, len=50, n/ep=10, n/st=20, rew=132.10]                          


Epoch #14: test_reward: 178.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #15: 2501it [00:26, 95.79it/s, agent0/loss=288.426, agent1/loss=268.768, agent2/loss=328.627, agent3/loss=261.210, agent4/loss=220.821, agent5/loss=243.814, agent6/loss=234.697, agent7/loss=158.927, agent8/loss=173.163, agent9/loss=183.073, env_step=37500, len=50, n/ep=10, n/st=20, rew=123.05]                           


Epoch #15: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #16: 2501it [00:25, 96.43it/s, agent0/loss=302.495, agent1/loss=306.410, agent2/loss=346.393, agent3/loss=263.068, agent4/loss=251.472, agent5/loss=176.898, agent6/loss=210.685, agent7/loss=148.721, agent8/loss=198.846, agent9/loss=162.321, env_step=40000, len=50, n/ep=10, n/st=20, rew=123.45]                          


Epoch #16: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #17: 2501it [00:33, 74.18it/s, agent0/loss=326.562, agent1/loss=286.830, agent2/loss=308.606, agent3/loss=251.101, agent4/loss=240.564, agent5/loss=231.447, agent6/loss=194.348, agent7/loss=162.189, agent8/loss=170.978, agent9/loss=167.030, env_step=42500, len=50, n/ep=10, n/st=20, rew=119.75]                          


Epoch #17: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #18: 2501it [00:26, 93.87it/s, agent0/loss=310.983, agent1/loss=294.639, agent2/loss=316.898, agent3/loss=276.640, agent4/loss=241.680, agent5/loss=197.317, agent6/loss=174.742, agent8/loss=202.808, agent9/loss=174.926, env_step=45000, len=50, n/ep=10, n/st=20, rew=137.40]                                                


Epoch #18: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #19: 2501it [00:24, 100.82it/s, agent0/loss=300.691, agent1/loss=273.878, agent2/loss=333.881, agent3/loss=311.997, agent4/loss=212.543, agent5/loss=237.911, agent6/loss=218.433, agent7/loss=175.147, agent8/loss=185.780, agent9/loss=183.532, env_step=47500, len=50, n/ep=10, n/st=20, rew=121.30]                         


Epoch #19: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #20: 2501it [00:25, 98.04it/s, agent0/loss=292.698, agent1/loss=280.069, agent2/loss=332.158, agent4/loss=233.349, agent5/loss=217.139, agent6/loss=227.668, agent8/loss=164.166, agent9/loss=191.784, env_step=50000, len=50, n/ep=10, n/st=20, rew=129.20]                                                                    


Epoch #20: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #21: 2501it [00:26, 93.66it/s, agent0/loss=307.996, agent1/loss=302.629, agent2/loss=320.332, agent3/loss=272.884, agent4/loss=210.390, agent5/loss=192.428, agent6/loss=214.994, agent7/loss=170.376, agent8/loss=180.742, agent9/loss=203.504, env_step=52500, len=50, n/ep=10, n/st=20, rew=129.90]                          


Epoch #21: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #22: 2501it [00:27, 92.01it/s, agent0/loss=317.117, agent1/loss=279.147, agent2/loss=303.981, agent3/loss=257.219, agent4/loss=233.617, agent5/loss=242.314, agent6/loss=199.425, agent7/loss=159.269, agent8/loss=161.747, agent9/loss=209.379, env_step=55000, len=50, n/ep=10, n/st=20, rew=140.35]                          


Epoch #22: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #23: 2501it [00:26, 93.88it/s, agent0/loss=346.299, agent1/loss=280.530, agent2/loss=313.836, agent3/loss=260.613, agent4/loss=231.653, agent5/loss=250.558, agent6/loss=182.435, agent7/loss=156.408, agent8/loss=163.535, agent9/loss=166.251, env_step=57500, len=50, n/ep=10, n/st=20, rew=127.65]                          


Epoch #23: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #24: 2501it [00:27, 90.25it/s, agent0/loss=279.075, agent1/loss=314.546, agent2/loss=330.236, agent3/loss=273.303, agent4/loss=242.439, agent5/loss=234.815, agent6/loss=205.747, agent7/loss=170.814, agent8/loss=165.332, agent9/loss=179.075, env_step=60000, len=50, n/ep=10, n/st=20, rew=127.55]                          


Epoch #24: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #25: 2501it [00:28, 89.06it/s, agent1/loss=307.620, agent2/loss=316.537, agent3/loss=312.568, agent4/loss=214.778, agent5/loss=225.706, agent6/loss=203.245, agent7/loss=156.737, agent8/loss=153.553, agent9/loss=168.319, env_step=62500, len=50, n/ep=10, n/st=20, rew=124.75]                                               


Epoch #25: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #26: 2501it [00:26, 95.06it/s, agent0/loss=287.276, agent1/loss=291.683, agent2/loss=325.264, agent3/loss=290.950, agent4/loss=205.411, agent5/loss=216.765, agent6/loss=177.068, agent7/loss=174.495, agent8/loss=170.630, env_step=65000, len=50, n/ep=10, n/st=20, rew=121.05]                                               


Epoch #26: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #27: 2501it [00:26, 94.17it/s, agent0/loss=309.077, agent1/loss=330.318, agent2/loss=267.690, agent3/loss=299.747, agent5/loss=197.439, agent6/loss=239.299, agent7/loss=170.448, agent8/loss=158.004, agent9/loss=207.314, env_step=67500, len=50, n/ep=10, n/st=20, rew=128.55]                                               


Epoch #27: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #28: 2501it [00:27, 91.68it/s, agent0/loss=344.823, agent1/loss=244.570, agent2/loss=312.697, agent3/loss=279.002, agent4/loss=263.306, agent5/loss=216.365, agent6/loss=194.697, agent7/loss=163.609, agent8/loss=165.350, agent9/loss=227.235, env_step=70000, len=50, n/ep=10, n/st=20, rew=140.65]                          


Epoch #28: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #29: 2501it [00:27, 90.49it/s, agent0/loss=298.226, agent1/loss=275.676, agent2/loss=312.465, agent3/loss=295.339, agent4/loss=235.302, agent5/loss=208.419, agent6/loss=198.158, agent7/loss=160.231, agent8/loss=186.116, agent9/loss=197.580, env_step=72500, len=50, n/ep=10, n/st=20, rew=138.30]                          


Epoch #29: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #30: 2501it [00:27, 91.62it/s, agent0/loss=321.338, agent1/loss=285.784, agent2/loss=323.969, agent3/loss=303.493, agent4/loss=224.371, agent5/loss=202.512, agent6/loss=184.035, agent7/loss=190.416, agent8/loss=166.101, env_step=75000, len=50, n/ep=10, n/st=20, rew=123.85]                                               


Epoch #30: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #31: 2501it [00:26, 94.09it/s, agent0/loss=347.199, agent1/loss=307.842, agent2/loss=319.264, agent3/loss=275.331, agent4/loss=238.220, agent5/loss=192.232, agent6/loss=228.339, agent7/loss=166.561, agent8/loss=142.611, agent9/loss=168.492, env_step=77500, len=50, n/ep=10, n/st=20, rew=127.00]                          


Epoch #31: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #32: 2501it [00:26, 93.08it/s, agent0/loss=314.382, agent1/loss=285.903, agent2/loss=323.425, agent3/loss=259.907, agent5/loss=229.742, agent6/loss=201.197, agent7/loss=168.305, agent8/loss=186.400, agent9/loss=199.391, env_step=80000, len=50, n/ep=10, n/st=20, rew=147.85]                                               


Epoch #32: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #33: 2501it [00:26, 93.48it/s, agent0/loss=316.189, agent1/loss=330.835, agent2/loss=348.731, agent3/loss=270.586, agent4/loss=254.148, agent5/loss=220.492, agent6/loss=182.112, agent7/loss=167.851, agent8/loss=202.122, agent9/loss=233.750, env_step=82500, len=50, n/ep=10, n/st=20, rew=132.40]                          


Epoch #33: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #34: 2501it [00:26, 92.95it/s, agent0/loss=279.945, agent1/loss=293.953, agent2/loss=328.591, agent3/loss=241.120, agent4/loss=220.460, agent5/loss=238.969, agent6/loss=193.427, agent7/loss=141.324, agent8/loss=212.721, agent9/loss=179.663, env_step=85000, len=50, n/ep=10, n/st=20, rew=133.00]                          


Epoch #34: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #35: 2501it [00:27, 92.48it/s, agent0/loss=286.020, agent1/loss=278.468, agent2/loss=305.507, agent3/loss=310.749, agent4/loss=232.562, agent5/loss=202.144, agent7/loss=166.613, agent8/loss=172.223, agent9/loss=187.738, env_step=87500, len=50, n/ep=10, n/st=20, rew=117.65]                                               


Epoch #35: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #36: 2501it [00:26, 93.37it/s, agent0/loss=321.890, agent1/loss=291.230, agent2/loss=278.743, agent3/loss=267.652, agent4/loss=216.351, agent5/loss=186.896, agent6/loss=203.673, agent7/loss=170.350, agent8/loss=175.924, agent9/loss=213.891, env_step=90000, len=50, n/ep=10, n/st=20, rew=136.55]                          


Epoch #36: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #37: 2501it [00:31, 80.25it/s, agent0/loss=262.218, agent1/loss=270.854, agent2/loss=291.875, agent3/loss=238.901, agent4/loss=251.010, agent5/loss=236.898, agent6/loss=210.598, agent7/loss=165.182, agent8/loss=168.741, agent9/loss=201.521, env_step=92500, len=50, n/ep=10, n/st=20, rew=131.65]                          


Epoch #37: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #38: 2501it [00:33, 74.67it/s, agent0/loss=266.692, agent2/loss=289.587, agent3/loss=254.507, agent4/loss=281.885, agent5/loss=234.803, agent6/loss=208.525, agent7/loss=178.074, agent8/loss=184.938, agent9/loss=206.286, env_step=95000, len=50, n/ep=10, n/st=20, rew=134.40]                                               


Epoch #38: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #39: 2501it [00:31, 79.68it/s, agent0/loss=292.200, agent1/loss=254.341, agent2/loss=307.186, agent3/loss=291.060, agent4/loss=234.870, agent5/loss=221.243, agent6/loss=189.876, agent7/loss=179.378, agent8/loss=197.130, agent9/loss=151.580, env_step=97500, len=50, n/ep=10, n/st=20, rew=133.45]                          


Epoch #39: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #40: 2501it [00:28, 86.38it/s, agent0/loss=312.082, agent1/loss=282.167, agent2/loss=259.800, agent3/loss=317.560, agent4/loss=231.462, agent5/loss=209.304, agent6/loss=215.407, agent7/loss=160.469, agent8/loss=183.468, agent9/loss=213.945, env_step=100000, len=50, n/ep=10, n/st=20, rew=132.35]                          


Epoch #40: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #41: 2501it [00:25, 97.30it/s, agent1/loss=301.464, agent2/loss=302.065, agent3/loss=247.236, agent4/loss=244.221, agent5/loss=202.660, agent6/loss=245.265, agent7/loss=155.041, agent8/loss=135.648, agent9/loss=169.203, env_step=102500, len=50, n/ep=10, n/st=20, rew=126.70]                                                


Epoch #41: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #42: 2501it [00:26, 95.61it/s, agent0/loss=268.121, agent1/loss=276.616, agent2/loss=329.330, agent3/loss=272.602, agent4/loss=237.715, agent5/loss=208.494, agent6/loss=220.229, agent7/loss=181.553, agent8/loss=160.522, agent9/loss=180.430, env_step=105000, len=50, n/ep=10, n/st=20, rew=132.10]                           


Epoch #42: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #43: 2501it [00:24, 102.12it/s, agent0/loss=292.185, agent1/loss=294.978, agent2/loss=324.129, agent3/loss=276.206, agent4/loss=230.394, agent5/loss=212.137, agent6/loss=212.421, agent7/loss=184.997, agent8/loss=210.425, agent9/loss=196.244, env_step=107500, len=50, n/ep=10, n/st=20, rew=126.40]                          


Epoch #43: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #44: 2501it [00:26, 93.44it/s, agent0/loss=290.197, agent1/loss=339.638, agent2/loss=306.395, agent3/loss=229.450, agent4/loss=238.354, agent5/loss=240.514, agent6/loss=203.391, agent7/loss=175.689, agent8/loss=176.905, agent9/loss=174.896, env_step=110000, len=50, n/ep=10, n/st=20, rew=134.10]                          


Epoch #44: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #45: 2501it [00:25, 99.22it/s, agent0/loss=312.329, agent1/loss=277.399, agent2/loss=307.014, agent3/loss=306.218, agent4/loss=237.380, agent6/loss=208.140, agent7/loss=198.171, agent8/loss=177.247, agent9/loss=166.998, env_step=112500, len=50, n/ep=10, n/st=20, rew=131.90]                                                


Epoch #45: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #46: 2501it [00:25, 99.62it/s, agent0/loss=313.633, agent1/loss=269.011, agent2/loss=318.349, agent4/loss=263.798, agent5/loss=201.032, agent7/loss=155.675, agent8/loss=193.950, env_step=115000, len=50, n/ep=10, n/st=20, rew=132.25]                                                                                          


Epoch #46: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #47: 2501it [00:24, 102.49it/s, agent0/loss=329.599, agent1/loss=304.350, agent2/loss=332.752, agent4/loss=242.942, agent5/loss=211.380, agent6/loss=200.698, agent7/loss=176.900, agent9/loss=220.485, env_step=117500, len=50, n/ep=10, n/st=20, rew=144.80]                                                                    


Epoch #47: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #48: 2501it [00:24, 101.11it/s, agent0/loss=318.454, agent1/loss=300.936, agent2/loss=306.291, agent3/loss=283.871, agent4/loss=260.479, agent5/loss=216.311, agent6/loss=231.276, agent7/loss=174.063, agent8/loss=184.184, agent9/loss=229.167, env_step=120000, len=50, n/ep=10, n/st=20, rew=130.05]                          


Epoch #48: test_reward: 185.000000 ± 0.000000, best_reward: 185.000000 ± 0.000000 in #1


Epoch #49:  38%|###7      | 940/2500 [00:10<00:17, 91.67it/s, agent0/loss=305.855, agent1/loss=270.213, agent2/loss=296.074, agent3/loss=266.801, agent4/loss=272.704, agent5/loss=227.276, agent6/loss=199.069, agent7/loss=187.088, agent9/loss=207.807, env_step=120920, len=50, n/ep=0, n/st=20, rew=137.35]                       


KeyboardInterrupt: 

In [None]:
from typing import Optional, Tuple
import os
import numpy as np
import torch
from tianshou.env import DummyVectorEnv
from tianshou.trainer import offpolicy_trainer
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import TensorboardLogger
import torch

import mUAV_TA.MultiDroneEnvUtils as utils
#from Custom_Classes import CustomCollector

def _get_env_eval():
    """This function is needed to provide callables for DummyVectorEnv."""
    case =  {'case' : 0, 'F1':2, 'F2': 2, "R1" : 3, 'R2' : 3, "Att" : 8, "Rec" : 22}

    config = utils.DroneEnvOptions(     
            render_mode = 'human', 
            render_speed = -1,
            simulation_frame_rate = 1 / 60,
            simulator_module = "Internal", 
            max_time_steps = 300,
            action_mode= "TaskAssign",
            agents= {"F1" : case['F1'], "F2" : case['F2'], "R1" : case['R1'], "R2" : case['R2']},                 
            tasks= { "Att" : case['Att'], "Rec" : case['Rec']},
            random_init_pos = True,
            num_obstacles = 0,
            hidden_obstacles = False,
            fail_rate = 0,
            info = "No Info" )
   
    
    env_paralell = MultiDroneEnv()
    #env = parallel_to_aec_wrapper(env_paralell)    
    env = CustomParallelToAECWrapper(env_paralell)
    
    return PettingZooEnv(env)


# Create a new instance of the policy with the same architecture as the saved policy
name = 'policy_CustomNetMultiHead_Eval_TBTA_02_simplified_UCF_mask01_seed0All.pth' 
load_policy_name = f'policy_{name}'


log_path = os.path.join('./', "Logs", "dqn", name)

policy, optim, _ = _get_agents()
model_save_path = os.path.join("dqn_Custom", save_policy_name)        

# Load the saved checkpoint
policy_test = policy.policies['agent0']
policy_test.load_state_dict(torch.load(model_save_path ))

envs = DummyVectorEnv([_get_env_eval for _ in range(1)])
policy_test.eval()
policy_test.set_eps(0.00)

#collector = CustomCollector(policy.policies['agent0'], envs, exploration_noise=True)
#collector = CustomCollector(policy_test, envs, exploration_noise=False)
collector = CustomCollector(policy, envs, exploration_noise=True)

#results = collector.collect(n_episode=1)
results = collector.collect(n_episode=5)#, gym_reset_kwargs={'seed' :2})
results

In [None]:
results['rews']
print(np.mean(results['rews'][results['rews'] > -10]))


#create a function  to print a histogram of the results['rews']
import matplotlib.pyplot as plt
plt.hist(results['rews'][results['rews'] > -10], bins=100)
plt.show()


In [None]:
from turtle import st
import torch
from tianshou.data import Batch

# load policy as in your original code
policy, optim, _ = _get_agents()
model_save_path = os.path.join("dqn_Custom", save_policy_name)        
policy_test = policy.policies['agent0']
state_saved = torch.load(model_save_path )
#print(policy_test)
policy_test.load_state_dict(state_saved)
policy_test.eval()
policy_test.set_eps(0.00)

# initialize your environment
#env = DummyVectorEnv([_get_env for _ in range(1)])
env = MultiDroneEnv(None)
env.max_time_steps = 50

# simulate the interaction with the environment manually
for i in range(10):
    for episode in range(1):  # simulate 10 episodes
        
        #env.render_speed = 1
        obs, _  = env.reset(seed=episode)         
        info         = env.get_initial_state()
        
        drones = info["drones"]
        tasks = info["tasks"]
            
        done = {0 : False}
        truncations = {0 : False}
        
        episodo_reward = 0
        #obs, reward, done, truncations, info = env.step(action)

        while not all(done.values()) and not all(truncations.values()):
            
            agent_id = "agent" + str(env.agent_selector._current_agent)
            # Create a Batch of observations
            obs_batch = Batch(obs=[obs[agent_id]], info=[{}])  # add empty info for each observation
            
            #print(obs_batch)
            # Forward the batch of observations through the policy to get the actions
            action = policy_test(obs_batch).act
            action = {agent_id : action[0]}
        
            obs, reward, done, truncations, info = env.step(action)
            
            episodo_reward += sum(reward.values())/env.n_agents

        

    print(episodo_reward)
