In [2]:
from email import utils
import os
import datetime
from typing import Optional, Tuple
import json


os.environ["WANDB_NOTEBOOK_NAME"] = "Tianshow_Centralized_Training"

import numpy as np
import torch
from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy, RainbowPolicy
from tianshou.trainer import OffpolicyTrainer
from torch.utils.tensorboard import SummaryWriter

# from pettingzoo.sisl import pursuit_v4
# from pettingzoo.mpe import simple_spread_v3
import Mods.TaskSpreadEnv as TaskSpreadEnv

from TaskAllocation.RL_Policies.DNN_Spread import DNN_Spread
from TaskAllocation.RL_Policies.MPE_Task_MultiHead import MPE_Task_MultiHead

#import Mods.TaskPursuitEnv as TaskPursuitEnv
import Mods.ActionLoggerWrapper as ActionLoggerWrapper

from TaskAllocation.RL_Policies.Custom_Classes import CustomNet
from TaskAllocation.RL_Policies.Custom_Classes import CustomCollector
from TaskAllocation.RL_Policies.Custom_Classes import CustomParallelToAECWrapper

# Add specific modification to tianshou
import wandb
from tianshou.utils import WandbLogger
from tianshou.utils.logger.base import LOG_DATA_TYPE

def new_write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
    data[step_type] = step
    wandb.log(data)
    
WandbLogger.write = new_write 

from pettingzoo.utils import wrappers
import gym

class ActionLoggerWrapper(gym.Wrapper):
    def __init__(self, env):
        super(ActionLoggerWrapper, self).__init__(env)
        self.actions = []

    def step(self, action):
        self.actions.append(action)
        return self.env.step(action)

    def reset(self, **kwargs):      
        if self.actions:
            # Convert all actions to numpy arrays and standardize their shapes
            formatted_actions = [np.array(a).flatten() for a in self.actions]
            flattened_actions = np.concatenate(formatted_actions)

            try:
                # Compute the histogram
                hist_data, bin_edges = np.histogram(flattened_actions, bins='auto')

                # Log the actions as a histogram to wandb
                wandb.log({"actions_histogram": wandb.Histogram(np_histogram=(hist_data, bin_edges))})
            except Exception as e:
                pass#print("Error in logging histogram:", e)

            self.actions = []
        return self.env.reset(**kwargs)


#from tianshou_DQN import train
model  =  "MPE_Task_MultiHead" #"DNN_Spread" #"CNN_ATT_SISL" #"MultiHead_SISL" 
test_num  =  "_Desk_01_8feat"
policyModel  =  "DQN"

train_env_num = 10
test_env_num = 10

name = model + test_num

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = name + str(now)
log_path = os.path.join('./', "Logs", "dqn_sisl", log_name)

#policy
load_policy_name = f'policy_MPE_Task_MultiHead_Desk_01_8feat240112-174145_29_BestRew.pth'
save_policy_name = f'policy_{log_name}'
policy_path = "dqn_Spread"

Policy_Config = {
    "same_policy" : True,
    "load_model" : True,
    "freeze_CNN" : False     
                }

Spread_Config = {
    "N": 3,                      # Default = 3
    "local_ratio": 0.5,          # Default = 0.5
    "max_cycles": 25,            # Default = 25
    "continuous_actions": False, # Default = False
    "render_mode": None          # Default = None 
}

max_cycles = Spread_Config["max_cycles"]
n_agents = Spread_Config["N"]

dqn_params = {"discount_factor": 0.99, 
              "estimation_step": 5, 
              "target_update_freq": max_cycles * n_agents * 20,
              "optminizer": "Adam",
              "lr": 0.00001 }

trainer_params = {"max_epoch": 1000,
                  "step_per_epoch": 250 * (max_cycles * n_agents),
                  "step_per_collect": 20 * (max_cycles * n_agents),
                  "episode_per_test": 50,
                  "batch_size" : 32 * n_agents,
                  "update_per_step": 1 / 100, #Only run after close a Collect (run many times as necessary to meet the value)
                  "tn_eps_max": 0.15,
                  "ts_eps_max": 0.0,
                  "warmup_size" : 10
                  }


runConfig = dqn_params
runConfig.update(Policy_Config)
runConfig.update(trainer_params) 
runConfig.update(Spread_Config)

model_load_path = os.path.join(policy_path, load_policy_name)  
model_save_path = os.path.join(policy_path, save_policy_name)        
os.makedirs(os.path.join(policy_path), exist_ok=True)  
os.makedirs(os.path.join(log_path), exist_ok=True)

def _get_agents(
    agent_learn: Optional[BasePolicy] = None,
    agent_opponent: Optional[BasePolicy] = None,
    optim: Optional[torch.optim.Optimizer] = None,
    policy_load_path = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
    
    env = _get_env()       
    agent_observation_space = env.observation_space.shape
   
    action_shape = env.action_space
    
    device="cuda" if torch.cuda.is_available() else "cpu"  

    agents = []        
    
    if Policy_Config["same_policy"]:
        policies_number = 1
    else:
        policies_number = 4#len(env.agents)

    for _ in range(policies_number):                   

        if model == "DNN_Spread":
            net = DNN_Spread(
                obs_shape=agent_observation_space[0],                
                action_shape=5,                
                device="cuda" if torch.cuda.is_available() else "cpu"
                
            ).to(device)

        if model == "MPE_Task_MultiHead":
            net = MPE_Task_MultiHead(                
                num_tasks=Spread_Config['N'] * 2 + 5,
                num_features_per_task = 2,#6 + 2 + 1,                
                device="cuda" if torch.cuda.is_available() else "cpu"
                
            ).to(device)

        optim = torch.optim.Adam(net.parameters(), lr=dqn_params["lr"], weight_decay=0.0, amsgrad= True )                

        if policyModel == "DQN":
            agent_learn = DQNPolicy(
                model=net,
                optim=optim,
                action_space = action_shape,
                discount_factor= dqn_params["discount_factor"],
                estimation_step=dqn_params["estimation_step"],
                target_update_freq=dqn_params["target_update_freq"],
                reward_normalization = False,
                clip_loss_grad = False 
            ) 

        if Policy_Config["load_model"] is True:
            # Load the saved checkpoint             
            agent_learn.load_state_dict(torch.load(model_load_path))
            print(f'Loaded-> {model_load_path}')
                   
        #print(env.agents)
        #agents = [agent_learn for _ in range(len(env.agents))]
        
        agents.append(agent_learn)

    if Policy_Config["same_policy"]:
        agents = [agents[0] for _ in range(len(env.agents))]
    else:
        for _ in range(len(env.agents) - policies_number):
            agents.append(agents[0])

    policy = MultiAgentPolicyManager(policies = agents, env=env)  
        
    return policy, optim, env.agents

def _get_env(test=False):
    """This function is needed to provide callables for DummyVectorEnv."""   
    # env_paralell = MultiUAVEnv()  
    # env = pursuit_v4.env()    
    env = TaskSpreadEnv.env(
        max_cycles=Spread_Config["max_cycles"],
        local_ratio=Spread_Config["local_ratio"],
        N=Spread_Config["N"],
        continuous_actions=Spread_Config["continuous_actions"],
        render_mode=" human" #Spread_Config["render_mode"]
    )    
    
    #env = parallel_to_aec_wrapper(env_paralell)    
    # env = CustomParallelToAECWrapper(env_paralell)
    env = ActionLoggerWrapper(env)
    env = PettingZooEnv(env) 
    
    return  env

# print(json.dumps(runConfig, indent=4))


In [4]:
def _get_envT():
    """This function is needed to provide callables for DummyVectorEnv."""   
    # env_paralell = MultiUAVEnv()  
    # env = pursuit_v4.env()

    env = TaskSpreadEnv.env(
        max_cycles=Spread_Config["max_cycles"],
        local_ratio=Spread_Config["local_ratio"],
        N=Spread_Config["N"],
        continuous_actions=Spread_Config["continuous_actions"],
        render_mode=None#"human"#Spread_Config["render_mode"]
    )    
           
    #env = parallel_to_aec_wrapper(env_paralell)    
    # env = CustomParallelToAECWrapper(env_paralell)
    
    return PettingZooEnv(env)   


policy, optim, agents = _get_agents()
test_env_num = 1
 # ======== Step 1: Environment setup =========

test_envs = DummyVectorEnv([_get_envT for _ in range(test_env_num)]) 

# seed
seed = 15
np.random.seed(seed)

torch.manual_seed(seed)
test_envs.seed(seed)

episodes =  50
render  = False
# Load the saved checkpoint
for agent in agents:    
    
    # if Policy_Config["same_policy"]:
    #     model_path = os.path.join("dqn_SISL", name + ".pth")                            
    # else:
    #     model_path = os.path.join("dqn_SISL", name + agent + ".pth") 

    policy.policies[agent].set_eps(0.00)
    # policy.policies[agent].load_state_dict(torch.load(model_load_path))
    policy.policies[agent].eval()
    
# envs = DummyVectorEnv([_get_env for _ in range(1)])

collector = CustomCollector(policy, test_envs, exploration_noise=False)

results = collector.collect(n_episode=episodes, render=1.0 if render else None)#0.02)#, gym_reset_kwargs={'seed' :2})

print("FinalRew: ", np.sum(results['rews'], axis = 1))
# print("Finished: ", results['lens'] , " Steps")
print("Mean: " , np.mean(np.sum(results['rews'], axis =1)))
print("Std: " , np.std (np.sum(results['rews'], axis =1)))
print("Max: " , np.max(np.sum(results['rews'], axis =1)))
print("Min: " , np.min(np.sum(results['rews'], axis =1)))
# print("Mean Len: " , np.mean(results['lens']))

Loaded-> dqn_Spread\policy_MPE_Task_MultiHead_Desk_01_8feat240112-174145_29_BestRew.pth
FinalRew:  [-120.96800264  -66.89787138  -40.83468072  -16.86045109  -43.71990931
  -47.85819943 -124.58582785  -15.49845184  -28.53288109  -82.68382379
  -41.92962945  -72.48818375  -98.56185421 -105.710163    -31.59049657
  -78.30097248  -22.07308414  -72.55776464  -31.85664005  -36.98912038
  -52.14023537 -146.18971412 -103.29077554  -54.71550599  -45.33200616
 -139.00046628  -18.53452249  -36.20646953  -47.8976879  -172.46940857
  -71.72084027  -70.38217559  -72.01529576  -28.8688919   -65.74283937
  -17.35790224 -154.82250092 -118.10552011  -26.33621204  -29.62611086
  -21.3348159   -33.28609972 -112.7049803   -26.25415881 -226.93321669
  -25.0019167  -231.50854852  -98.15801548  -66.19929329  -64.47540274]
Mean:  -71.14219073953582
Std:  51.12397680585824
Max:  -15.498451844208702
Min:  -231.50854851500605
