In [None]:
from logging import config
import os
import datetime
from typing import Optional, Tuple
import json

os.environ["WANDB_NOTEBOOK_NAME"] = ".\Tianshow_Centralized_Training"

import numpy as np
import torch
from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
from tianshou.env import DummyVectorEnv

from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy, RainbowPolicy
from tianshou.trainer import OffpolicyTrainer
from torch.utils.tensorboard import SummaryWriter

from LOTZ.LOTZ_env import LeadingOnesTrailingZerosEnv

# from TaskAllocation.RL_Policies.MultiHead_SISL import MultiHead_SISL
from TaskAllocation.RL_Policies.DNN_LOTZ import DNN_LOTZ
from TaskAllocation.RL_Policies.MultiHead_LOTZ import MultiHead_LOTZ
from TaskAllocation.RL_Policies.ATT_LOTZ import ATT_LOTZ
# from TaskAllocation.RL_Policies.CNN_SISL import CNN_SISL

from TaskAllocation.RL_Policies.Custom_Classes import CustomNet
from TaskAllocation.RL_Policies.Custom_Classes import CustomCollector
from TaskAllocation.RL_Policies.Custom_Classes import CustomParallelToAECWrapper


#----------------------------------#
from tianshou.env.pettingzoo_env import PettingZooEnv
from typing import Any
from gymnasium import spaces

def _reset(self, *args: Any, **kwargs: Any) -> tuple[dict, dict]:
        self.env.reset(*args, **kwargs)

        observation, reward, terminated, truncated, info = self.env.last()

        if isinstance(observation, dict) and "action_mask" in observation:
            observation_dict = {
                "agent_id": self.env.agent_selection,
                "obs": observation["observation"],
                "mask": [obm == 1 for obm in observation["action_mask"]],
            }
        else:
            if isinstance(self.action_space, spaces.Discrete):
                observation_dict = {
                    "agent_id": self.env.agent_selection,
                    "obs": observation,
                    "mask": [True] * self.env.action_space(self.env.agent_selection).n,
                }
            else:
                observation_dict = {
                    "agent_id": self.env.agent_selection,
                    "obs": observation,
                }

        return observation_dict, info
    
PettingZooEnv.reset = _reset
    
# --- Add specific modification to tianshou -----#
import wandb
from tianshou.utils import WandbLogger
from tianshou.utils.logger.base import LOG_DATA_TYPE
def new_write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
    data[step_type] = step
    wandb.log(data)
WandbLogger.write = new_write 

#from tianshou_DQN import train
project = "LOTZ_Eval"
model  =  "MultiHead_LOTZ" #"MultiHead_SISL" 
test_num  =  "_NOV01"
policyModel  =  "DQN"

train_env_num = 20
test_env_num = 20

name = model + test_num

load_policy_name = f'policy_MultiHead_LOTZ_NOV01.pth'
save_policy_name = f'policy_{name}'
policy_path = "policy_LOTZ"

same_policy = True
load_model = False

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = name + str(now)

log_path = os.path.join('./', "Logs", "dqn_sisl", log_name)

LOTZ_Config = {
    "string_length": 128,
    "n_agents": 2,
    "seed": 0,
    "m_steps": 256,
    "sp": 0
}

max_cycles = LOTZ_Config["m_steps"]
n_agents = 2

dqn_params = {"discount_factor": 0.95, 
              "estimation_step": 3, 
              "target_update_freq": 100 * max_cycles,
              "optminizer": "Adam",
              "lr": 1e-3 }

trainer_params = {"max_epoch": 500,
                  "step_per_epoch": 250 * max_cycles,
                  "step_per_collect": max_cycles * 50,
                  "episode_per_test": 20,
                  "batch_size" : max_cycles * 10 ,
                  "update_per_step": 1 / (max_cycles * 5), #Only run after close a Collect (run many times as necessary to meet the value)
                  "tn_eps_max": 0.15,
                  "ts_eps_max": 0.0,
                  }

runConfig = dqn_params
runConfig.update(trainer_params) 
runConfig.update(LOTZ_Config)

model_load_path = os.path.join(policy_path, load_policy_name)  
model_save_path = os.path.join(policy_path, save_policy_name)        
os.makedirs(os.path.join(policy_path), exist_ok=True)  
os.makedirs(os.path.join(log_path), exist_ok=True)

def _get_agents(
    agent_learn: Optional[BasePolicy] = None,
    agent_opponent: Optional[BasePolicy] = None,
    optim: Optional[torch.optim.Optimizer] = None,
    policy_load_path = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
    
    env = _get_env()       
    agent_observation_space = env.observation_space
   
    action_shape = env.action_space

    print("Action_Shape: ", action_shape)
    print("agent_observation_space: ", agent_observation_space)
    
    device="cpu" #"cuda" if torch.cuda.is_available() else "cpu"          
    
    if agent_learn is None:      
        
        if model == "MultiHead_LOTZ":
            net = MultiHead_LOTZ(
                obs_shape=agent_observation_space,                
                action_shape=action_shape, 
                max_len = LOTZ_Config["string_length"],
                device=device
                
            ).to(device)

        if model == "DNN_LOTZ":
            net = DNN_LOTZ(
                obs_shape=agent_observation_space,                
                action_shape=action_shape,                
                device=device
                
            ).to(device)

        if model == "ATT_LOTZ":
            net = ATT_LOTZ(
                obs_shape=agent_observation_space,                
                action_shape=action_shape,                
                device=device
                
            ).to(device)
           

        if optim is None:
            optim = torch.optim.Adam(net.parameters(), lr=dqn_params["lr"], weight_decay=0.0, amsgrad= False )                
    
        if policyModel == "DQN":
            agent_learn = DQNPolicy(
                model=net,
                optim=optim,
                action_space = action_shape,
                discount_factor= dqn_params["discount_factor"],
                estimation_step=dqn_params["estimation_step"],
                target_update_freq=dqn_params["target_update_freq"],
                reward_normalization = False,
                clip_loss_grad = False 
            ) 
                     

        if policyModel == "Rainbow":
            agent_learn = RainbowPolicy(
                model=net.to(device),
                optim=optim,
                action_space = action_shape,
                num_atoms= 5,
                discount_factor= dqn_params["discount_factor"],
                estimation_step=dqn_params["estimation_step"],
                target_update_freq=dqn_params["target_update_freq"],
            ) 
         
 
        if load_model is True:
            # Load the saved checkpoint             
            agent_learn.load_state_dict(torch.load(model_load_path))
            print(f'Loaded-> {model_load_path}')
                   
        #print(env.agents)
        #agents = [agent_learn for _ in range(len(env.agents))]
        
        agents = [agent_learn for _ in range(len(env.agents))]

        
    policy = MultiAgentPolicyManager(policies = agents, env=env)  
        
    return policy, optim, env.agents

def _get_env():
    """This function is needed to provide callables for DummyVectorEnv."""   
    # env_paralell = MultiUAVEnv()  
    # env = pursuit_v4.env()
    env = LeadingOnesTrailingZerosEnv(
        string_length= LOTZ_Config["string_length"], 
        n_agents=LOTZ_Config["n_agents"], 
        seed=LOTZ_Config["seed"], 
        m_steps = LOTZ_Config["m_steps"], 
        sp = LOTZ_Config["sp"] )
    
    #env = parallel_to_aec_wrapper(env_paralell)    
    # env = CustomParallelToAECWrapper(env_paralell)
    
    return PettingZooEnv(env)
    # return env

print(json.dumps(runConfig, indent=4))

In [None]:
if __name__ == "__main__":
                        
    torch.set_grad_enabled(True) 
   
    # ======== Step 1: Environment setup =========
    train_envs = DummyVectorEnv([_get_env for _ in range(train_env_num)])
    test_envs = DummyVectorEnv([_get_env for _ in range(test_env_num)]) 

    # seed
    seed = 0
    np.random.seed(seed)
    
    torch.manual_seed(seed)

    train_envs.seed(seed)
    test_envs.seed(seed)

    # ======== Step 2: Agent setup =========
    policy, optim, agents = _get_agents()    

    agentsBuffer = PrioritizedVectorReplayBuffer( 300_000, len(train_envs), alpha=0.6, beta=0.4)  
    # ======== Step 3: Collector setup =========
    train_collector = Collector(
        policy,
        train_envs,
        # VectorReplayBuffer(300_000, len(train_envs)),
        agentsBuffer,
        #ListReplayBuffer(100000)       
        exploration_noise=True             
    )
    test_collector = Collector(policy, test_envs, exploration_noise=False)
     
    print("Buffer Warming Up ")    
    for i in range(1):#int(trainer_params['batch_size'] / (300 * 10 ) )):
        
        train_collector.collect(n_episode=train_env_num)#,random=True) #trainer_params['batch_size'] * train_env_num))
        #train_collector.collect(n_step=300 * 10)
        print(".", end="") 
    
    len_buffer = len(train_collector.buffer) #/ (SISL_Config["max_cycles"] * SISL_Config["n_pursuers"])
    print("\nBuffer Lenght: ", len_buffer ) 
    
    info = { "Buffer"  : "ReplayBuffer", " Warmup_ep" : len_buffer}
    # ======== tensorboard logging setup =========                       
    logger = WandbLogger(
        train_interval = runConfig["m_steps"] * runConfig["n_agents"] ,
        test_interval = 1,#runConfig["max_cycles"] * runConfig["n_pursuers"],
        update_interval = runConfig["m_steps"],
        save_interval = 1,
        write_flush = True,
        project = project,
        name = log_name,
        entity = None,
        run_id = log_name,
        config = runConfig,
        monitor_gym = True )
    
    writer = SummaryWriter(log_path)    
    writer.add_text("args", str(runConfig))    
    logger.load(writer)
    
    # ======== Step 4: Callback functions setup =========
    def save_best_fn(policy):                
        
        torch.save(policy.policies[agents[0]].state_dict(), model_save_path + ".pth")
        print("Best Saved")
        

    def stop_fn(mean_rewards):
        return mean_rewards >= 99999939.0

    def train_fn(epoch, env_step):
        epsilon = trainer_params['tn_eps_max'] - (trainer_params['tn_eps_max'] - trainer_params['tn_eps_max']/100)*(epoch/trainer_params['max_epoch'])          
        if same_policy:
            policy.policies[agents[0]].set_eps(epsilon)
        else:
            policy.policies['R_agent0'].set_eps(epsilon)
            policy.policies['F_agent0'].set_eps(epsilon)
        
        # if env_step % 500 == 0:
            # logger.write("train/env_step", env_step, {"train/eps": eps})


    def test_fn(epoch, env_step):
        epsilon = trainer_params['ts_eps_max']#0.01#max(0.001, 0.1 - epoch * 0.001)
        if same_policy:
            policy.policies[agents[0]].set_eps(epsilon)
        else:
            policy.policies['R_agent0'].set_eps(epsilon)
            policy.policies['F_agent0'].set_eps(epsilon)

        
    def reward_metric(rews):       
        #print(rews)
        return rews#[:, 1]

    # # ======== Step 5: Run the trainer =========
    offPolicyTrainer = OffpolicyTrainer(
        policy=policy,
        train_collector=train_collector,
        test_collector=test_collector,  
        buffer= agentsBuffer,      
        max_epoch=trainer_params['max_epoch'],
        step_per_epoch=trainer_params['step_per_epoch'],
        step_per_collect=trainer_params['step_per_collect'],        
        episode_per_test= trainer_params['episode_per_test'],
        batch_size=trainer_params['batch_size'],
        train_fn=train_fn,
        test_fn=test_fn,
        stop_fn=stop_fn,
        save_best_fn=save_best_fn,
        update_per_step=trainer_params['update_per_step'],
        logger=logger,
        test_in_train=False,
        reward_metric=reward_metric,
        show_progress = True 
               
        )
    
    result = offPolicyTrainer.run()
    writer.close()
    # return result, policy.policies[agents[1]]
    print(f"\n==========Result==========\n{result}")
    print("\n(the trained policy can be accessed via policy.policies[agents[0]])")

In [None]:
model = "MultiHead_LOTZ"
policyModel = "DQN"
# Create a new instance of the policy with the same architecture as the saved policy
name = 'policy_MultiHead_LOTZ_NOV01.pth' 

# policy, optim, _ = get_agents()
# model_load_path = os.path.join("policy_LOTZ", name)        

# Load the saved checkpoint
policy_test = policy.policies['agent0']
#policy_test.load_state_dict(torch.load(model_load_path ))

envs = DummyVectorEnv([_get_env for _ in range(1)])
#policy_test.eval()
policy_test.set_eps(0.00)

#collector = CustomCollector(policy.policies['agent0'], envs, exploration_noise=True)
#collector = CustomCollector(policy_test, envs, exploration_noise=False)
collector = CustomCollector(policy, envs, exploration_noise=False)

#results = collector.collect(n_episode=1)
results = collector.collect(n_episode=10)# render=0.01,)#, gym_reset_kwargs={'seed' :2})
results

In [None]:
results['rews']
print(np.mean(results['rews'][results['rews'] > -10]))


#create a function  to print a histogram of the results['rews']
import matplotlib.pyplot as plt
plt.hist(results['rews'][results['rews'] > -10], bins=100)
plt.show()
