In [1]:
import os
import datetime
from typing import Optional, Tuple
import json


os.environ["WANDB_NOTEBOOK_NAME"] = "Tianshow_Centralized_Training"

import numpy as np
import torch
from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy, RainbowPolicy
from tianshou.trainer import OffpolicyTrainer
from torch.utils.tensorboard import SummaryWriter

from pettingzoo.sisl import pursuit_v4

from TaskAllocation.RL_Policies.MultiHead_SISL import MultiHead_SISL
from TaskAllocation.RL_Policies.DNN_SISL import DNN_SISL
from TaskAllocation.RL_Policies.CNN_SISL import CNN_SISL
from TaskAllocation.RL_Policies.CNN_ATT_SISL import CNN_ATT_SISL
from TaskAllocation.RL_Policies.SISL_Task_MultiHead import SISL_Task_MultiHead


from Mods.MemoryBuffer import StateMemoryVectorReplayBuffer
from Mods.MemoryBuffer import MemoryOffpolicyTrainer
import Mods.MemPursuitEnv as MemPursuitEnv
from Mods.OffPolicyTrainerMod import OffPolicyTrainerMod

import Mods.TaskPursuitEnv as TaskPursuitEnv

from TaskAllocation.RL_Policies.Custom_Classes import CustomNet
from TaskAllocation.RL_Policies.Custom_Classes import CustomCollector
from TaskAllocation.RL_Policies.Custom_Classes import CustomParallelToAECWrapper

# Add specific modification to tianshou
import wandb
from tianshou.utils import WandbLogger
from tianshou.utils.logger.base import LOG_DATA_TYPE

def new_write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
    data[step_type] = step
    wandb.log(data)
    
WandbLogger.write = new_write 

#from tianshou_DQN import train
model  =  "SISL_Task_MultiHead"#"SISL_Task_MultiHead" #"CNN_ATT_SISL" #"MultiHead_SISL" 
test_num  =  "_Desk_Task02"
policyModel  =  "DQN"

train_env_num = 10
test_env_num = 30

name = model + test_num

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = name + str(now)
log_path = os.path.join('./', "Logs", "dqn_sisl", log_name)

#policy
load_policy_name = f'policy_SISL_Task_MultiHead_Desk_NewExpCor231219-173711_44.pth'
save_policy_name = f'policy_{log_name}'
policy_path = "dqn_SISL"

Policy_Config = {
    "same_policy" : True,
    "load_model" : False,
    "freeze_CNN" : False     
                }

SISL_Config = {
    "max_cycles": 500,         # default: 500
    "x_size": 16,              # default: 16
    "y_size": 16,              # default: 16
    "shared_reward": False,    # default: True
    "n_evaders": 30,           # default: 30
    "n_pursuers": 8,           # default: 10
    "obs_range": 15,            # default: 7
    "n_catch": 2,              # default: 2
    "freeze_evaders": False,   # default: False
    "tag_reward": 0.01,        # default: 0.01
    "catch_reward": 5.0,       # default: 5.0
    "urgency_reward": -0.1,    # default: -0.1
    "surround": True,          # default: True
    "constraint_window": 1.0,  # default: 1.0
    ###---- Additional Config ----###
    # "att_memory" : False,
    # "max_tasks" : 10  
}

max_cycles = SISL_Config["max_cycles"]
n_agents = SISL_Config["n_pursuers"]

dqn_params = {"discount_factor": 0.98, 
              "estimation_step": 20, 
              "target_update_freq": 1000,#max_cycles * n_agents,
              "optminizer": "Adam",
              "lr": 0.00016 }

trainer_params = {"max_epoch": 500,
                  "step_per_epoch": 20000,#5 * (150 * n_agents),
                  "step_per_collect": 400,# * (10 * n_agents),
                  "episode_per_test": 30,
                  "batch_size" : 32 * n_agents,
                  "update_per_step": 1 / 50, #Only run after close a Collect (run many times as necessary to meet the value)
                  "tn_eps_max": 0.10,
                  "ts_eps_max": 0.01,
                  "warmup_size" : 1
                  }


runConfig = dqn_params
runConfig.update(Policy_Config)
runConfig.update(trainer_params) 
runConfig.update(SISL_Config)

model_load_path = os.path.join(policy_path, load_policy_name)  
model_save_path = os.path.join(policy_path, save_policy_name)        
os.makedirs(os.path.join(policy_path), exist_ok=True)  
os.makedirs(os.path.join(log_path), exist_ok=True)

def _get_agents(
    agent_learn: Optional[BasePolicy] = None,
    agent_opponent: Optional[BasePolicy] = None,
    optim: Optional[torch.optim.Optimizer] = None,
    policy_load_path = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
    
    env = _get_env()       
    agent_observation_space = env.observation_space
   
    action_shape = env.action_space
    
    device="cuda" if torch.cuda.is_available() else "cpu"  

    agents = []        
    
    if Policy_Config["same_policy"]:
        policies_number = 1
    else:
        policies_number = 4#len(env.agents)

    for _ in range(policies_number):      
        
        if model == "MultiHead_SISL":
            net = MultiHead_SISL(
                obs_shape=agent_observation_space,                
                num_tasks=5,
                hidden_sizes = 32,                
                device="cuda" if torch.cuda.is_available() else "cpu"
                
            ).to(device)
        
        if model == "SISL_Task_MultiHead":
            net = SISL_Task_MultiHead(                
                num_tasks=20,
                num_features_per_task = 9,                
                device="cuda" if torch.cuda.is_available() else "cpu"
                
            ).to(device)

        if model == "DNN_SISL":
            net = DNN_SISL(
                obs_shape=agent_observation_space,                
                action_shape=5,                
                device="cuda" if torch.cuda.is_available() else "cpu"
                
            ).to(device)

        if model == "CNN_SISL":
            net = CNN_SISL(
                obs_shape=agent_observation_space.shape,                
                action_shape=5,                
                device="cuda" if torch.cuda.is_available() else "cpu"
                
            ).to(device)            
        
        if model == "CNN_ATT_SISL":
            net = CNN_ATT_SISL(
                obs_shape=agent_observation_space.shape,                
                action_shape=5,                
                device="cuda" if torch.cuda.is_available() else "cpu"
                
            ).to(device)           
        
        
    
        if Policy_Config["freeze_CNN"]:                
                net.freeze_conv_layers()  # Freeze the convolutional layers

                optim = torch.optim.Adam(
                    list(net.policy_fn.parameters()) + list(net.value_fn.parameters()), 
                    lr=dqn_params["lr"]
                )
        else:
            optim = torch.optim.Adam(net.parameters(), lr=dqn_params["lr"], weight_decay=0.0, amsgrad= True )                



        if policyModel == "DQN":
            agent_learn = DQNPolicy(
                model=net,
                optim=optim,
                action_space = action_shape,
                discount_factor= dqn_params["discount_factor"],
                estimation_step=dqn_params["estimation_step"],
                target_update_freq=dqn_params["target_update_freq"],
                reward_normalization = False,
                clip_loss_grad = False 
            ) 
         
 
        if Policy_Config["load_model"] is True:
            # Load the saved checkpoint             
            agent_learn.load_state_dict(torch.load(model_load_path))
            print(f'Loaded-> {model_load_path}')
                   
        #print(env.agents)
        #agents = [agent_learn for _ in range(len(env.agents))]
        
        agents.append(agent_learn)

    if Policy_Config["same_policy"]:
        agents = [agents[0] for _ in range(len(env.agents))]
    else:
        for _ in range(len(env.agents) - policies_number):
            agents.append(agents[0])

    policy = MultiAgentPolicyManager(policies = agents, env=env)  
        
    return policy, optim, env.agents

def _get_env(test=False):
    """This function is needed to provide callables for DummyVectorEnv."""   
    # env_paralell = MultiUAVEnv()  
    # env = pursuit_v4.env()
    if not test:
        #env = TaskPursuitEnv.env(
        env =  pursuit_v4.env(
                max_cycles=SISL_Config["max_cycles"],
                x_size=SISL_Config["x_size"],
                y_size=SISL_Config["y_size"],
                shared_reward=SISL_Config["shared_reward"],
                n_evaders=SISL_Config["n_evaders"],
                n_pursuers=SISL_Config["n_pursuers"],
                obs_range=SISL_Config["obs_range"],
                n_catch=SISL_Config["n_catch"],
                freeze_evaders=SISL_Config["freeze_evaders"],
                tag_reward=SISL_Config["tag_reward"],
                catch_reward=SISL_Config["catch_reward"],
                urgency_reward=SISL_Config["urgency_reward"],
                surround=SISL_Config["surround"],
                constraint_window=SISL_Config["constraint_window"],
                # att_memory = SISL_Config["att_memory"],
                #render_mode= "human"#True
                render_mode= None#"human"#True
            )
    else:
        #env = TaskPursuitEnv.env(
            env =  pursuit_v4.env(
                max_cycles=SISL_Config["max_cycles"],
                x_size=SISL_Config["x_size"],
                y_size=SISL_Config["y_size"],
                shared_reward=SISL_Config["shared_reward"],
                n_evaders=SISL_Config["n_evaders"],
                n_pursuers=SISL_Config["n_pursuers"],
                obs_range=SISL_Config["obs_range"],
                n_catch=SISL_Config["n_catch"],
                freeze_evaders=SISL_Config["freeze_evaders"],
                tag_reward=SISL_Config["tag_reward"],
                catch_reward=SISL_Config["catch_reward"],
                urgency_reward=SISL_Config["urgency_reward"],
                surround=SISL_Config["surround"],
                constraint_window=SISL_Config["constraint_window"],
                # att_memory = SISL_Config["att_memory"],
                #render_mode= "human"#True
                render_mode= None#"human"#True
            )

    
    #env = parallel_to_aec_wrapper(env_paralell)    
    # env = CustomParallelToAECWrapper(env_paralell)
    
    return PettingZooEnv(env)   

# print(json.dumps(runConfig, indent=4))


In [2]:


def _get_envT():
    """This function is needed to provide callables for DummyVectorEnv."""   
    # env_paralell = MultiUAVEnv()  
    # env = pursuit_v4.env()

    env =  TaskPursuitEnv.env(
                max_cycles=SISL_Config["max_cycles"],
                x_size=SISL_Config["x_size"],
                y_size=SISL_Config["y_size"],
                shared_reward=SISL_Config["shared_reward"],
                n_evaders=SISL_Config["n_evaders"],
                n_pursuers=SISL_Config["n_pursuers"],
                obs_range=12,#[5,5],#SISL_Config["obs_range"],
                n_catch=SISL_Config["n_catch"],
                freeze_evaders=SISL_Config["freeze_evaders"],
                tag_reward=SISL_Config["tag_reward"],
                catch_reward=SISL_Config["catch_reward"],
                urgency_reward=SISL_Config["urgency_reward"],
                surround=SISL_Config["surround"],
                constraint_window=SISL_Config["constraint_window"],
                # att_memory = SISL_Config["att_memory"],
                #render_mode= "human"#True
                render_mode= None#"human"#True
            ) 
           
    #env = parallel_to_aec_wrapper(env_paralell)    
    # env = CustomParallelToAECWrapper(env_paralell)
    return PettingZooEnv(env)
    # return PettingZooParallelEnv(env)   


policy, optim, agents = _get_agents()
test_env_num = 10
 # ======== Step 1: Environment setup =========

test_envs = DummyVectorEnv([_get_envT for _ in range(test_env_num)]) 

# seed
seed = 1
np.random.seed(seed)

torch.manual_seed(seed)
test_envs.seed(seed)

episodes =  50
render  = False

policy_name = "policy_CNN_SISL_Desk_CNN02240128-083000_3571_BestRew.pth" 
policy_name = "policy_SISL_Task_MultiHead_Desk_NewExpFix_noActHist240109-173308_410_BestRew.pth"
#policy_name = "policy_CNN_SISL_Desk_CNN02240128-083000_2618_BestRew.pth"
# Load the saved checkpoint
for agent in agents:    
    
    if Policy_Config["same_policy"]:
         model_path = os.path.join("dqn_SISL", policy_name)                            
    else:
         model_path = os.path.join("vdn_SISL", policy_name) 

    policy.policies[agent].set_eps(0.01)
    policy.policies[agent].load_state_dict(torch.load(model_path))
    policy.policies[agent].eval()
    
# envs = DummyVectorEnv([_get_env for _ in range(1)])
test_collector = Collector(policy, test_envs, exploration_noise=True)

results = test_collector.collect(n_episode=episodes, render=0.1 if render else None)#0.02)#, gym_reset_kwargs={'seed' :2})

print("Mean: ", np.mean(np.sum(results['rews'], axis = 1)))
print("Std:  " , np.std (np.sum(results['rews'], axis = 1)))
print("Max:  " , np.max(np.sum(results['rews'], axis = 1)))
print("Min:  " , np.min(np.sum(results['rews'], axis = 1)))
# print("Mean Len: " , np.mean(results['lens']))

RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`