In [1]:
from email import utils
import os
import datetime
from typing import Optional, Tuple
import json


os.environ["WANDB_NOTEBOOK_NAME"] = "Tianshow_Centralized_Training"

import numpy as np
import torch
from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.env.pettingzoo_env_parallel import PettingZooParallelEnv

from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy, RainbowPolicy
from tianshou.trainer import OffpolicyTrainer
from torch.utils.tensorboard import SummaryWriter

# from pettingzoo.sisl import pursuit_v4
from pettingzoo.mpe import simple_spread_v3
#import Mods.TaskSpreadEnv as TaskSpreadEnv

from TaskAllocation.RL_Policies.DNN_Spread import DNN_Spread
from TaskAllocation.RL_Policies.MPE_Task_MultiHead import MPE_Task_MultiHead

#import Mods.TaskPursuitEnv as TaskPursuitEnv
import Mods.ActionLoggerWrapper as ActionLoggerWrapper
import Mods.VDNPolicy as VDNPolicy
import Mods.PettingZooParallelEnv2 as PettingZooParallelEnv2
import Mods.CollectorMA as CollectorMA

from TaskAllocation.RL_Policies.Custom_Classes import CustomNet
from TaskAllocation.RL_Policies.Custom_Classes import CustomCollector
from TaskAllocation.RL_Policies.Custom_Classes import CustomParallelToAECWrapper

# Add specific modification to tianshou
import wandb
from tianshou.utils import WandbLogger
from tianshou.utils.logger.base import LOG_DATA_TYPE

def new_write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
    data[step_type] = step
    wandb.log(data)
    
WandbLogger.write = new_write 

from pettingzoo.utils import wrappers
import gym

class ActionLoggerWrapper(gym.Wrapper):
    def __init__(self, env):
        super(ActionLoggerWrapper, self).__init__(env)
        self.actions = []

    def step(self, action):
        self.actions.append(action)
        return self.env.step(action)

    def reset(self, **kwargs):      
        if self.actions:
            # Convert all actions to numpy arrays and standardize their shapes
            formatted_actions = [np.array(a).flatten() for a in self.actions]
            flattened_actions = np.concatenate(formatted_actions)

            try:
                # Compute the histogram
                hist_data, bin_edges = np.histogram(flattened_actions, bins='auto')

                # Log the actions as a histogram to wandb
                wandb.log({"actions_histogram": wandb.Histogram(np_histogram=(hist_data, bin_edges))})
            except Exception as e:
                pass#print("Error in logging histogram:", e)

            self.actions = []
        return self.env.reset(**kwargs)


#from tianshou_DQN import train
model  =  "VDN_Spread" #"DNN_Spread"#"MPE_Task_MultiHead" # #"CNN_ATT_SISL" #"MultiHead_SISL" 
test_num  =  "_Desk_01_8feat"
policyModel  =  "VDN"

train_env_num = 10
test_env_num = 10

name = model + test_num

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = name + str(now)
log_path = os.path.join('./', "Logs", "dqn_sisl", log_name)

#policy
load_policy_name = f'policy_MPE_Task_MultiHead_Desk_01_8feat240112-145703_29_BestRew.pth'
save_policy_name = f'policy_{log_name}'
policy_path = "vdn_Spread"

Policy_Config = {
    "same_policy" : False,
    "load_model" : False,
    "freeze_CNN" : False     
                }

Spread_Config = {
    "N": 3,                      # Default = 3
    "local_ratio": 0.5,          # Default = 0.5
    "max_cycles": 25,            # Default = 25
    "continuous_actions": False, # Default = False
    "render_mode": None          # Default = None 
}

max_cycles = Spread_Config["max_cycles"]
n_agents = Spread_Config["N"]

dqn_params = {"discount_factor": 0.90, 
              "estimation_step": 5, 
              "target_update_freq": 12000,
              "optminizer": "Adam",
              "lr": 0.00005 }

trainer_params = {"max_epoch": 500,
                  "step_per_epoch": 24000,
                  "step_per_collect": 6000,
                  "episode_per_test": 50,
                  "batch_size" : 128,
                  "update_per_step": 1 / 1000, #Only run after close a Collect (run many times as necessary to meet the value)
                  "tn_eps_max": 0.8,
                  "ts_eps_max": 0.01,
                  "warmup_size" : 10
                  }


runConfig = dqn_params
runConfig.update(Policy_Config)
runConfig.update(trainer_params) 
runConfig.update(Spread_Config)

model_load_path = os.path.join(policy_path, load_policy_name)  
model_save_path = os.path.join(policy_path, save_policy_name)        
os.makedirs(os.path.join(policy_path), exist_ok=True)  
os.makedirs(os.path.join(log_path), exist_ok=True)

def _get_agents(
    agent_learn: Optional[BasePolicy] = None,
    agent_opponent: Optional[BasePolicy] = None,
    optim: Optional[torch.optim.Optimizer] = None,
    policy_load_path = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
    
    env = _get_env()       
    agent_observation_space = env.observation_space.shape
   
    action_shape = env.action_space
    
    device="cuda" if torch.cuda.is_available() else "cpu"  

    agents = []        
    
    if Policy_Config["same_policy"]:
        policies_number = 1
    else:
        policies_number = 3#len(env.agents)

    for _ in range(policies_number):                   

        if model == "DNN_Spread":
            net = DNN_Spread(
                obs_shape=agent_observation_space[0],                
                action_shape=5,                
                device="cuda" if torch.cuda.is_available() else "cpu"
                
            ).to(device)

        if model == "VDN_Spread":
            net = DNN_Spread(
                obs_shape=agent_observation_space[0],                
                action_shape=5,                
                device="cuda" if torch.cuda.is_available() else "cpu"
                
            ).to(device)

        if model == "MPE_Task_MultiHead":
            net = MPE_Task_MultiHead(                
                num_tasks=Spread_Config['N'] * 2 + 5,
                num_features_per_task = 2,#6 + 2 + 1,                
                device="cuda" if torch.cuda.is_available() else "cpu"
                
            ).to(device)

        optim = torch.optim.Adam(net.parameters(), lr=dqn_params["lr"], weight_decay=0.0, amsgrad= True )                

        if policyModel == "DQN":
            agent_learn = DQNPolicy(
                model=net,
                optim=optim,
                action_space = action_shape,
                discount_factor= dqn_params["discount_factor"],
                estimation_step=dqn_params["estimation_step"],
                target_update_freq=dqn_params["target_update_freq"],
                reward_normalization = False,
                clip_loss_grad = False 
            ) 
        
        if policyModel == "VDN":
            agent_learn = DQNPolicy(
                model=net,
                optim=optim,
                action_space = action_shape,
                discount_factor= dqn_params["discount_factor"],
                estimation_step=dqn_params["estimation_step"],
                target_update_freq=dqn_params["target_update_freq"],
                reward_normalization = False,
                clip_loss_grad = False,                
            ) 

        if Policy_Config["load_model"] is True:
            # Load the saved checkpoint             
            agent_learn.load_state_dict(torch.load(model_load_path))
            print(f'Loaded-> {model_load_path}')
                   
        #print(env.agents)
        #agents = [agent_learn for _ in range(len(env.agents))]
        
        agents.append(agent_learn)

    if Policy_Config["same_policy"]:
        agents = [agents[0] for _ in range(len(env.agents))]
    else:
        for _ in range(len(env.agents) - policies_number):
            agents.append(agents[0])

    policy = VDNPolicy.VDNMAPolicy(policies = agents, env=env, device="cuda" if torch.cuda.is_available() else "cpu" )  
        
    return policy, optim, env.agents

def _get_env(test=False):
    """This function is needed to provide callables for DummyVectorEnv."""   
    # env_paralell = MultiUAVEnv()  
    #env = pursuit_v4.env()    
    #env = TaskSpreadEnv.env(
    env = simple_spread_v3.parallel_env(
    # env = simple_spread_v3.env(
        max_cycles=Spread_Config["max_cycles"],
        local_ratio=Spread_Config["local_ratio"],
        N=Spread_Config["N"],
        continuous_actions=Spread_Config["continuous_actions"],
        render_mode=" human" #Spread_Config["render_mode"]
    )    
    
    #env = parallel_to_aec_wrapper(env_paralell)    
    # env = CustomParallelToAECWrapper(env_paralell)
    # env = ActionLoggerWrapper(env)
    #env = PettingZooEnv(env) 
    env = PettingZooParallelEnv(env)
       
    return  env

# print(json.dumps(runConfig, indent=4))


In [9]:
def _get_envT():
    """This function is needed to provide callables for DummyVectorEnv."""   
    # env_paralell = MultiUAVEnv()  
    # env = pursuit_v4.env()

    env =  simple_spread_v3.parallel_env(
        max_cycles=Spread_Config["max_cycles"],
        local_ratio=Spread_Config["local_ratio"],
        N=Spread_Config["N"],
        continuous_actions=Spread_Config["continuous_actions"],
        render_mode=None#"human"#Spread_Config["render_mode"]
    )    
           
    #env = parallel_to_aec_wrapper(env_paralell)    
    # env = CustomParallelToAECWrapper(env_paralell)
    
    return PettingZooParallelEnv(env)   


policy, optim, agents = _get_agents()
test_env_num = 1
 # ======== Step 1: Environment setup =========

test_envs = DummyVectorEnv([_get_envT for _ in range(test_env_num)]) 

# seed
seed = 0
np.random.seed(seed)

torch.manual_seed(seed)
test_envs.seed(seed)

episodes =  50
render  = False

policy_name = "policy_VDN_Spread_Desk_01_8feat240119-232748_41_" 
# Load the saved checkpoint
for agent in agents:    
    
    if Policy_Config["same_policy"]:
         model_path = os.path.join("vdn_SISL", name + ".pth")                            
    else:
         model_path = os.path.join("vdn_SISL", name + agent + ".pth") 

    policy.policies[agent].set_eps(0.01)
    # policy.policies[agent].load_state_dict(torch.load(model_load_path))
    policy.policies[agent].eval()
    
# envs = DummyVectorEnv([_get_env for _ in range(1)])
    
agents_buffers_test = {agent :PrioritizedVectorReplayBuffer( 30_000,
                                                         test_env_num, 
                                                          alpha=0.6, 
                                                          beta=0.4) 
                                                          for agent in agents
                         }

collector = CollectorMA.CollectorMA(policy, test_envs, agents_buffers_test, exploration_noise=True)

results = collector.collect(n_episode=episodes, render=0.0 if render else None)#0.02)#, gym_reset_kwargs={'seed' :2})

print("Mean: ", results['rew'])
print("Std:  " , np.std (results['rews']))
print("Max:  " , np.max(results['rews']))
print("Min:  " , np.min(results['rews']))
# print("Mean Len: " , np.mean(results['lens']))

Mean:  -60.50102659959429
Std:   12.191940338544022
Max:   -35.676358312353955
Min:   -91.80732706641545
