In [1]:
import os
import datetime
from typing import Optional, Tuple
import json


os.environ["WANDB_NOTEBOOK_NAME"] = "Tianshow_Centralized_Training"

import numpy as np
import torch
from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy, RainbowPolicy
from tianshou.trainer import OffpolicyTrainer
from torch.utils.tensorboard import SummaryWriter

from pettingzoo.sisl import pursuit_v4

from TaskAllocation.RL_Policies.MultiHead_SISL import MultiHead_SISL
from TaskAllocation.RL_Policies.DNN_SISL import DNN_SISL
from TaskAllocation.RL_Policies.CNN_SISL import CNN_SISL
from TaskAllocation.RL_Policies.CNN_ATT_SISL import CNN_ATT_SISL
from TaskAllocation.RL_Policies.SISL_Task_MultiHead import SISL_Task_MultiHead


from Mods.MemoryBuffer import StateMemoryVectorReplayBuffer
from Mods.MemoryBuffer import MemoryOffpolicyTrainer
import Mods.MemPursuitEnv as MemPursuitEnv
from Mods.OffPolicyTrainerMod import OffPolicyTrainerMod

import Mods.TaskPursuitEnv as TaskPursuitEnv

from TaskAllocation.RL_Policies.Custom_Classes import CustomNet
from TaskAllocation.RL_Policies.Custom_Classes import CustomCollector
from TaskAllocation.RL_Policies.Custom_Classes import CustomParallelToAECWrapper

# Add specific modification to tianshou
import wandb
from tianshou.utils import WandbLogger
from tianshou.utils.logger.base import LOG_DATA_TYPE

def new_write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
    data[step_type] = step
    wandb.log(data)
    
WandbLogger.write = new_write 

#from tianshou_DQN import train
model  =  "CNN_SISL"#"SISL_Task_MultiHead" #"CNN_ATT_SISL" #"MultiHead_SISL" 
test_num  =  "_Desk_CNN_FV5"
policyModel  =  "DQN"

train_env_num = 10
test_env_num = 30

name = model + test_num

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = name + str(now)
log_path = os.path.join('./', "Logs", "dqn_sisl", log_name)

#policy
load_policy_name = f'policy_SISL_Task_MultiHead_Desk_NewExpCor231219-173711_44.pth'
save_policy_name = f'policy_{log_name}'
policy_path = "dqn_SISL"

Policy_Config = {
    "same_policy" : True,
    "load_model" : False,
    "freeze_CNN" : False     
                }

SISL_Config = {
    "max_cycles": 500,         # default: 500
    "x_size": 16,              # default: 16
    "y_size": 16,              # default: 16
    "shared_reward": False,    # default: True
    "n_evaders": 30,           # default: 30
    "n_pursuers": 8,           # default: 10
    "obs_range": 9,            # default: 7
    "n_catch": 2,              # default: 2
    "freeze_evaders": False,   # default: False
    "tag_reward": 0.01,        # default: 0.01
    "catch_reward": 5.0,       # default: 5.0
    "urgency_reward": -0.1,    # default: -0.1
    "surround": True,          # default: True
    "constraint_window": 1.0,  # default: 1.0
    ###---- Additional Config ----###
    # "att_memory" : False,
    # "max_tasks" : 10  
}

max_cycles = SISL_Config["max_cycles"]
n_agents = SISL_Config["n_pursuers"]

dqn_params = {"discount_factor": 0.98, 
              "estimation_step": 20, 
              "target_update_freq": 1000,#max_cycles * n_agents,
              "optminizer": "Adam",
              "lr": 0.00016 }

trainer_params = {"max_epoch": 500,
                  "step_per_epoch": 20000,#5 * (150 * n_agents),
                  "step_per_collect": 400,# * (10 * n_agents),
                  "episode_per_test": 30,
                  "batch_size" : 32 * n_agents,
                  "update_per_step": 1 / 50, #Only run after close a Collect (run many times as necessary to meet the value)
                  "tn_eps_max": 0.10,
                  "ts_eps_max": 0.01,
                  "warmup_size" : 1
                  }


runConfig = dqn_params
runConfig.update(Policy_Config)
runConfig.update(trainer_params) 
runConfig.update(SISL_Config)

model_load_path = os.path.join(policy_path, load_policy_name)  
model_save_path = os.path.join(policy_path, save_policy_name)        
os.makedirs(os.path.join(policy_path), exist_ok=True)  
os.makedirs(os.path.join(log_path), exist_ok=True)

def _get_agents(
    agent_learn: Optional[BasePolicy] = None,
    agent_opponent: Optional[BasePolicy] = None,
    optim: Optional[torch.optim.Optimizer] = None,
    policy_load_path = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
    
    env = _get_env()       
    agent_observation_space = env.observation_space
   
    action_shape = env.action_space
    
    device="cuda" if torch.cuda.is_available() else "cpu"  

    agents = []        
    
    if Policy_Config["same_policy"]:
        policies_number = 1
    else:
        policies_number = 4#len(env.agents)

    for _ in range(policies_number):      
        
        if model == "MultiHead_SISL":
            net = MultiHead_SISL(
                obs_shape=agent_observation_space,                
                num_tasks=5,
                hidden_sizes = 32,                
                device="cuda" if torch.cuda.is_available() else "cpu"
                
            ).to(device)
        
        if model == "SISL_Task_MultiHead":
            net = SISL_Task_MultiHead(                
                num_tasks=20,
                num_features_per_task = 9,                
                device="cuda" if torch.cuda.is_available() else "cpu"
                
            ).to(device)

        if model == "DNN_SISL":
            net = DNN_SISL(
                obs_shape=agent_observation_space,                
                action_shape=5,                
                device="cuda" if torch.cuda.is_available() else "cpu"
                
            ).to(device)

        if model == "CNN_SISL":
            net = CNN_SISL(
                obs_shape=agent_observation_space.shape,                
                action_shape=5,                
                device="cuda" if torch.cuda.is_available() else "cpu"
                
            ).to(device)            
        
        if model == "CNN_ATT_SISL":
            net = CNN_ATT_SISL(
                obs_shape=agent_observation_space.shape,                
                action_shape=5,                
                device="cuda" if torch.cuda.is_available() else "cpu"
                
            ).to(device)           
        
        
    
        if Policy_Config["freeze_CNN"]:                
                net.freeze_conv_layers()  # Freeze the convolutional layers

                optim = torch.optim.Adam(
                    list(net.policy_fn.parameters()) + list(net.value_fn.parameters()), 
                    lr=dqn_params["lr"]
                )
        else:
            optim = torch.optim.Adam(net.parameters(), lr=dqn_params["lr"], weight_decay=0.0, amsgrad= True )                



        if policyModel == "DQN":
            agent_learn = DQNPolicy(
                model=net,
                optim=optim,
                action_space = action_shape,
                discount_factor= dqn_params["discount_factor"],
                estimation_step=dqn_params["estimation_step"],
                target_update_freq=dqn_params["target_update_freq"],
                reward_normalization = False,
                clip_loss_grad = False 
            ) 
         
 
        if Policy_Config["load_model"] is True:
            # Load the saved checkpoint             
            agent_learn.load_state_dict(torch.load(model_load_path))
            print(f'Loaded-> {model_load_path}')
                   
        #print(env.agents)
        #agents = [agent_learn for _ in range(len(env.agents))]
        
        agents.append(agent_learn)

    if Policy_Config["same_policy"]:
        agents = [agents[0] for _ in range(len(env.agents))]
    else:
        for _ in range(len(env.agents) - policies_number):
            agents.append(agents[0])

    policy = MultiAgentPolicyManager(policies = agents, env=env)  
        
    return policy, optim, env.agents

def _get_env(test=False):
    """This function is needed to provide callables for DummyVectorEnv."""   
    # env_paralell = MultiUAVEnv()  
    # env = pursuit_v4.env()
    if not test:
        #env = TaskPursuitEnv.env(
        env =  pursuit_v4.env(
                max_cycles=SISL_Config["max_cycles"],
                x_size=SISL_Config["x_size"],
                y_size=SISL_Config["y_size"],
                shared_reward=SISL_Config["shared_reward"],
                n_evaders=SISL_Config["n_evaders"],
                n_pursuers=SISL_Config["n_pursuers"],
                obs_range=SISL_Config["obs_range"],
                n_catch=SISL_Config["n_catch"],
                freeze_evaders=SISL_Config["freeze_evaders"],
                tag_reward=SISL_Config["tag_reward"],
                catch_reward=SISL_Config["catch_reward"],
                urgency_reward=SISL_Config["urgency_reward"],
                surround=SISL_Config["surround"],
                constraint_window=SISL_Config["constraint_window"],
                # att_memory = SISL_Config["att_memory"],
                #render_mode= "human"#True
                render_mode= None#"human"#True
            )
    else:
        #env = TaskPursuitEnv.env(
            env =  pursuit_v4.env(
                max_cycles=SISL_Config["max_cycles"],
                x_size=SISL_Config["x_size"],
                y_size=SISL_Config["y_size"],
                shared_reward=SISL_Config["shared_reward"],
                n_evaders=SISL_Config["n_evaders"],
                n_pursuers=SISL_Config["n_pursuers"],
                obs_range=SISL_Config["obs_range"],
                n_catch=SISL_Config["n_catch"],
                freeze_evaders=SISL_Config["freeze_evaders"],
                tag_reward=SISL_Config["tag_reward"],
                catch_reward=SISL_Config["catch_reward"],
                urgency_reward=SISL_Config["urgency_reward"],
                surround=SISL_Config["surround"],
                constraint_window=SISL_Config["constraint_window"],
                # att_memory = SISL_Config["att_memory"],
                #render_mode= "human"#True
                render_mode= None#"human"#True
            )

    
    #env = parallel_to_aec_wrapper(env_paralell)    
    # env = CustomParallelToAECWrapper(env_paralell)
    
    return PettingZooEnv(env)   

# print(json.dumps(runConfig, indent=4))


In [2]:
if __name__ == "__main__":
                        
    torch.set_grad_enabled(True) 
   
    # ======== Step 1: Environment setup =========
    train_envs = DummyVectorEnv([_get_env for _ in range(train_env_num)])
    test_envs = DummyVectorEnv([_get_env for _ in range(test_env_num)]) 

    # seed
    seed = 100
    np.random.seed(seed)
    
    torch.manual_seed(seed)

    train_envs.seed(seed)
    test_envs.seed(seed)

    # ======== Step 2: Agent setup =========
    policy, optim, agents = _get_agents()    

    # ======== Step 3: Collector setup =========
    train_collector = Collector(
        policy,
        train_envs,
        # VectorReplayBuffer(300_000, len(train_envs)),
        PrioritizedVectorReplayBuffer( 300_000, len(train_envs), alpha=0.6, beta=0.4) , 
        #ListReplayBuffer(100000)       
        # buffer = StateMemoryVectorReplayBuffer(
        #         300_000,
        #         len(train_envs),  # Assuming train_envs is your vectorized environment
        #         memory_size=10,                
        #     ),
        exploration_noise=True             
    )
    test_collector = Collector(policy, test_envs, exploration_noise=True)
     
    print("Buffer Warming Up ")    
    # for i in range(trainer_params["warmup_size"]):#int(trainer_params['batch_size'] / (300 * 10 ) )):
        
    #     train_collector.collect(n_episode=train_env_num)#,random=True) #trainer_params['batch_size'] * train_env_num))
    #     #train_collector.collect(n_step=300 * 10)
    #     print(".", end="") 
    
    len_buffer = len(train_collector.buffer) / (SISL_Config["max_cycles"] * SISL_Config["n_pursuers"])
    print("\nBuffer Lenght: ", len_buffer ) 
    
    info = { "Buffer"  : "PriorizedReplayBuffer", " Warmup_ep" : len_buffer}
    # ======== tensorboard logging setup =========                       
    logger = WandbLogger(
        train_interval = runConfig["max_cycles"] * runConfig["n_pursuers"] ,
        test_interval = 1,#runConfig["max_cycles"] * runConfig["n_pursuers"],
        update_interval = runConfig["max_cycles"],
        save_interval = 1,
        write_flush = True,
        project = "SISL_Eval01",
        name = log_name,
        entity = None,
        run_id = log_name,
        config = runConfig,
        monitor_gym = True )
    
    writer = SummaryWriter(log_path)    
    writer.add_text("args", str(runConfig))    
    logger.load(writer)

    
    global_step_holder = [0] 
    
    
    # ======== Step 4: Callback functions setup =========
    def save_best_fn(policy):                
        
        if Policy_Config["same_policy"]:
            torch.save(policy.policies[agents[0]].state_dict(), model_save_path + "_" + str(global_step_holder[0]) + "_BestRew.pth")
            print("Best Saved Rew" , str(global_step_holder[0]))
        
        else:
            for n,agent in enumerate(agents):
                torch.save(policy.policies[agent].state_dict(), model_save_path + "_" + str(global_step_holder[0]) + "_" + agent + ".pth")
            
            print("Bests Saved Rew" , str(global_step_holder[0]))
        
    def save_test_best_fn(policy):                
        
        if Policy_Config["same_policy"]:
            torch.save(policy.policies[agents[0]].state_dict(), model_save_path + "_" + str(global_step_holder[0]) + "_BestLen.pth")
            print("Best Saved Length" , str(global_step_holder[0]))
        
        else:
            for n,agent in enumerate(agents):
                torch.save(policy.policies[agent].state_dict(), model_save_path + "_" + str(global_step_holder[0]) + "_" + agent + ".pth")
            
            print("Best Saved Length" , str(global_step_holder[0]))
        

    def stop_fn(mean_rewards):
        return mean_rewards >= 99999939.0

    def train_fn(epoch, env_step):
        epsilon = trainer_params['tn_eps_max'] - (trainer_params['tn_eps_max'] - trainer_params['tn_eps_max']/100)*(epoch/trainer_params['max_epoch'])          
        if Policy_Config["same_policy"]:
            policy.policies[agents[0]].set_eps(epsilon)
        else:
            for agent in agents:
                policy.policies[agent].set_eps(epsilon)
                
        
        # if env_step % 500 == 0:
            # logger.write("train/env_step", env_step, {"train/eps": eps})


    def test_fn(epoch, env_step):
               
        epsilon = trainer_params['ts_eps_max']#0.01#max(0.001, 0.1 - epoch * 0.001)
        if Policy_Config["same_policy"]:
            policy.policies[agents[0]].set_eps(epsilon)
        else:            
            for agent in agents:                             
                 policy.policies[agent].set_eps(epsilon)
                
        
        if global_step_holder[0] % 10 == 0:
            
            if Policy_Config["same_policy"]:
                torch.save(policy.policies[agents[0]].state_dict(), model_save_path + "_" + str(global_step_holder[0]) + "_Step.pth")
                print("Steps Policy Saved " , str(global_step_holder[0]))
            
            else:
                for n,agent in enumerate(agents):
                    torch.save(policy.policies[agent].state_dict(), model_save_path + "_" + str(global_step_holder[0]) + "_" + agent + "Step" + str(global_step_holder[0]) + ".pth")
                
                print("Steps Policy Saved " , str(global_step_holder[0]))

        
    def reward_metric(rews):       
                
        global_step_holder[0] +=1 
        return np.sum(rews, axis = 1)


    # # ======== Step 5: Run the trainer =========
    offPolicyTrainer = OffpolicyTrainer(
        policy=policy,
        train_collector=train_collector,
        test_collector=test_collector,        
        max_epoch=trainer_params['max_epoch'],
        step_per_epoch=trainer_params['step_per_epoch'],
        step_per_collect=trainer_params['step_per_collect'],        
        episode_per_test= trainer_params['episode_per_test'],
        batch_size=trainer_params['batch_size'],
        train_fn=train_fn,
        test_fn=test_fn,
        stop_fn=stop_fn,
        save_best_fn=save_best_fn,
        # save_test_best_fn=save_test_best_fn,
        update_per_step=trainer_params['update_per_step'],
        logger=logger,
        test_in_train=True,
        reward_metric=reward_metric,
        show_progress = True 
               
        )
    
    result = offPolicyTrainer.run()
    writer.close()
    # return result, policy.policies[agents[1]]
    print(f"\n==========Result==========\n{result}")
    print("\n(the trained policy can be accessed via policy.policies[agents[0]])")



Buffer Warming Up 

Buffer Lenght:  0.0


[34m[1mwandb[0m: Currently logged in as: [33mandrekuros[0m. Use [1m`wandb login --relogin`[0m to force relogin
  from IPython.core.display import HTML, display  # type: ignore


Steps Policy Saved  0


RuntimeError: mat1 and mat2 shapes cannot be multiplied (30x2304 and 256x64)

In [None]:
torch.save(policy.policies[agents[0]].state_dict(), model_save_path + "_" + str(global_step_holder[0]) + ".pth")
print("Steps Policy Saved " , str(global_step_holder[0]))
            

In [None]:
def _get_envT():
    """This function is needed to provide callables for DummyVectorEnv."""   
    # env_paralell = MultiUAVEnv()  
    # env = pursuit_v4.env()

    env = TaskPursuitEnv.env(
            max_cycles=SISL_Config["max_cycles"],
            x_size=SISL_Config["x_size"],
            y_size=SISL_Config["y_size"],
            shared_reward=SISL_Config["shared_reward"],
            n_evaders=SISL_Config["n_evaders"],
            n_pursuers=SISL_Config["n_pursuers"],
            obs_range=SISL_Config["obs_range"],
            n_catch=SISL_Config["n_catch"],
            freeze_evaders=SISL_Config["freeze_evaders"],
            tag_reward=SISL_Config["tag_reward"],
            catch_reward=SISL_Config["catch_reward"],
            urgency_reward=SISL_Config["urgency_reward"],
            surround=SISL_Config["surround"],
            constraint_window=SISL_Config["constraint_window"],
            # att_memory = SISL_Config["att_memory"],
            #render_mode= "human"#True
            render_mode= None#"html"#"human" #"human"#True
    )
           
    #env = parallel_to_aec_wrapper(env_paralell)    
    # env = CustomParallelToAECWrapper(env_paralell)
    
    return PettingZooEnv(env)   


policy, optim, agents = _get_agents()
test_env_num = 1
 # ======== Step 1: Environment setup =========

test_envs = DummyVectorEnv([_get_envT for _ in range(test_env_num)]) 

# seed
seed = 100
np.random.seed(seed)

torch.manual_seed(seed)
test_envs.seed(seed)

episodes =  1
render  = False
# Load the saved checkpoint
for agent in agents:    
    
    # if Policy_Config["same_policy"]:
    #     model_path = os.path.join("dqn_SISL", name + ".pth")                            
    # else:
    #     model_path = os.path.join("dqn_SISL", name + agent + ".pth") 

    policy.policies[agent].set_eps(0.00)
    # policy.policies[agent].load_state_dict(torch.load(model_load_path))
    policy.policies[agent].eval()
    
# envs = DummyVectorEnv([_get_env for _ in range(1)])

collector = CustomCollector(policy, test_envs, exploration_noise=False)

results = collector.collect(n_episode=episodes, render=0.02 if render else None)#0.02)#, gym_reset_kwargs={'seed' :2})

print("FinalRew: ", np.sum(results['rews'], axis = 1))
print("Finished: ", results['lens'] , " Steps")