In [1]:
import os
import datetime
from typing import Optional, Tuple
import json
import numpy as np
import torch

from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env_parallel import PettingZooParallelEnv

#from PettingZooParallelEnv import PettingZooParallelEnv

from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy, RainbowPolicy
from tianshou.trainer import OffpolicyTrainer
from torch.utils.tensorboard import SummaryWriter
from DNN_B_ACE import DNN_B_ACE
from GodotRLPettingZooWrapper import GodotRLPettingZooWrapper

from CollectorMA import CollectorMA
from MAParalellPolicy import MAParalellPolicy


####---------------------------#######
#Tianshou Adjustment
import wandb
os.environ["WANDB_NOTEBOOK_NAME"] = "Tianshow_Training_GoDot.ipybn"
from tianshou.utils import WandbLogger
from tianshou.utils.logger.base import LOG_DATA_TYPE
def new_write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
    data[step_type] = step
    wandb.log(data)   
WandbLogger.write = new_write 
####---------------------------#######


model  =  "DNN_B_ACE"#"SISL_Task_MultiHead" #"CNN_ATT_SISL" #"MultiHead_SISL" 
test_num  =  "_B_ACE01"
policyModel  =  "DQN"
name = model + test_num

train_env_num = 1
test_env_num = 1

now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = name + str(now)
log_path = os.path.join('./', "Logs", "dqn_sisl", log_name)

load_policy_name = f'policy_SISL_Task_MultiHead_Desk_NewExpCor231219-173711_44.pth'
save_policy_name = f'policy_{log_name}'
policy_path = "dqn_B_ACE"


model_load_path = os.path.join(policy_path, load_policy_name)  
model_save_path = os.path.join(policy_path, save_policy_name)        
os.makedirs(os.path.join(policy_path), exist_ok=True)  
os.makedirs(os.path.join(log_path), exist_ok=True)

Policy_Config = {
    "same_policy" : True,
    "load_model" : False,
    "freeze_CNN" : False     
                }

B_ACE_Config = {        
        'task': 'b_ace_v2',
        'env_path': 'BVR_AirCombat/bin/B_ACE_v2.exe',
        'show_window': True,
        'seed': 10,
        'port': 12500,
        'action_repeat': 20,
        'speedup': 2000,
        'num_allies': 1,
        'num_enemies': 1,
        'action_type': 'Low_Level_Discrete' ,
        'enemies_baseline': 'duck',
        'full_observation': 0,
        'actions_2d': 0
    }


max_cycles = B_ACE_Config["max_cycles"]
n_agents = 4#B_ACE_Config["n_pursuers"]

dqn_params = {"discount_factor": 0.98, 
              "estimation_step": 20, 
              "target_update_freq": 1000,#max_cycles * n_agents,
              "optminizer": "Adam",
              "lr": 0.00016 }

trainer_params = {"max_epoch": 500,
                  "step_per_epoch": 20000,#5 * (150 * n_agents),
                  "step_per_collect": 400,# * (10 * n_agents),
                  "episode_per_test": 30,
                  "batch_size" : 32 * n_agents,
                  "update_per_step": 1 / 50, #Only run after close a Collect (run many times as necessary to meet the value)
                  "tn_eps_max": 0.10,
                  "ts_eps_max": 0.01,
                  "warmup_size" : 1
                  }

runConfig = dqn_params
runConfig.update(Policy_Config)
runConfig.update(trainer_params) 
runConfig.update(B_ACE_Config)


def _get_agents(
    agent_learn: Optional[BasePolicy] = None,
    agent_opponent: Optional[BasePolicy] = None,
    optim: Optional[torch.optim.Optimizer] = None,
    policy_load_path = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
    
    env = _get_env()       
    agent_observation_space = env.observation_space
   
    action_shape = env.action_space
    
    device="cuda" if torch.cuda.is_available() else "cpu"  

    agents = []        
    
    if Policy_Config["same_policy"]:
        policies_number = 1
    else:
        policies_number = 4#len(env.agents)

    for _ in range(policies_number):      
        
        if model == "DNN_B_ACE":
            net = DNN_B_ACE(
                obs_shape=agent_observation_space,                
                action_shape=4,                
                device="cuda" if torch.cuda.is_available() else "cpu"
                
            ).to(device)        
                   
        if policyModel == "DQN":
            agent_learn = DQNPolicy(
                model=net,
                optim=optim,
                action_space = action_shape,
                discount_factor= dqn_params["discount_factor"],
                estimation_step=dqn_params["estimation_step"],
                target_update_freq=dqn_params["target_update_freq"],
                reward_normalization = False,
                clip_loss_grad = False 
            ) 
         
         
        if Policy_Config["load_model"] is True:
            # Load the saved checkpoint             
            agent_learn.load_state_dict(torch.load(model_load_path))
            print(f'Loaded-> {model_load_path}')
                   
        
        agents.append(agent_learn)

    if Policy_Config["same_policy"]:
        agents = [agents[0] for _ in range(len(env.agents))]
    else:
        for _ in range(len(env.agents) - policies_number):
            agents.append(agents[0])

    # policy = MultiAgentPolicyManager(policies = agents, env=env)  
    policy = MAParalellPolicy(policies = agents, env=env, device="cuda" if torch.cuda.is_available() else "cpu" )  
        
    return policy, optim, env.agents

def _get_env():
    """This function is needed to provide callables for DummyVectorEnv."""   
    
    # env_paralell = MultiUAVEnv()  
    # env = pursuit_v4.env()   
    env = GodotRLPettingZooWrapper( convert_action_space = False,
                                    device = 'cpu',
                                    **B_ACE_Config)
    
    env = PettingZooParallelEnv(env)  
    
    return env  
   

# print(json.dumps(runConfig, indent=4))


ModuleNotFoundError: No module named 'DNN_B_ACE'

In [2]:
if __name__ == "__main__":
                        
    torch.set_grad_enabled(True) 
   
    # ======== Step 1: Environment setup =========
    train_envs = DummyVectorEnv([_get_env for _ in range(train_env_num)])
    test_envs = DummyVectorEnv([_get_env for _ in range(test_env_num)]) 

    # seed
    seed = 100
    np.random.seed(seed)
    
    torch.manual_seed(seed)

    train_envs.seed(seed)
    test_envs.seed(seed)

    # ======== Step 2: Agent setup =========
    policy, optim, agents = _get_agents()    

    
    if False:
        # ======== Step 3: Collector setup =========
        train_collector = Collector(
            policy,
            train_envs,
            # VectorReplayBuffer(300_000, len(train_envs)),
            PrioritizedVectorReplayBuffer( 300_000, len(train_envs), alpha=0.6, beta=0.4) , 
            #ListReplayBuffer(100000)       
            # buffer = StateMemoryVectorReplayBuffer(
            #         300_000,
            #         len(train_envs),  # Assuming train_envs is your vectorized environment
            #         memory_size=10,                
            #     ),
            exploration_noise=True             
        )
        test_collector = Collector(policy, test_envs, exploration_noise=True)
        

    if True:
        agents_buffers_training = {agent : 
                        PrioritizedVectorReplayBuffer( 30_000, 
                                                        len(train_envs), 
                                                        alpha=0.6, 
                                                        beta=0.4) 
                                                        for agent in agents
                        }
        agents_buffers_test = {agent : 
                        PrioritizedVectorReplayBuffer( 30_000, 
                                                        len(train_envs), 
                                                        alpha=0.6, 
                                                        beta=0.4) 
                                                        for agent in agents
                        }
    
        # ======== Step 3: Collector setup =========
        train_collector = CollectorMA(
            policy,
            train_envs,
            agents_buffers_training,                        
            exploration_noise=True             
        )
        test_collector = CollectorMA(policy, test_envs, agents_buffers_test, exploration_noise=True)

    
        
    print("Buffer Warming Up ")    
    for i in range(trainer_params["warmup_size"]):#int(trainer_params['batch_size'] / (300 * 10 ) )):
        
         train_collector.collect(n_episode=train_env_num)#,random=True) #trainer_params['batch_size'] * train_env_num))
         #train_collector.collect(n_step=300 * 10)
         print(".", end="") 
    
    # len_buffer = len(train_collector.buffer) / (B_ACE_Config["max_cycles"] * SISL_Config["n_pursuers"])
    # print("\nBuffer Lenght: ", len_buffer ) 
    len_buffer = 0
    
    info = { "Buffer"  : "PriorizedReplayBuffer", " Warmup_ep" : len_buffer}
    
    # ======== tensorboard logging setup =========                       
    logger = WandbLogger(
        train_interval = runConfig["max_cycles"] ,
        test_interval = 1,#runConfig["max_cycles"] * runConfig["n_pursuers"],
        update_interval = runConfig["max_cycles"],
        save_interval = 1,
        write_flush = True,
        project = "B_ACE01",
        name = log_name,
        entity = None,
        run_id = log_name,
        config = runConfig,
        monitor_gym = True )
    
    writer = SummaryWriter(log_path)    
    writer.add_text("args", str(runConfig))    
    logger.load(writer)

    global_step_holder = [0] 
        
    # ======== Step 4: Callback functions setup =========
    def save_best_fn(policy):                
        
        if Policy_Config["same_policy"]:
            torch.save(policy.policies[agents[0]].state_dict(), model_save_path + "_" + str(global_step_holder[0]) + "_BestRew.pth")
            print("Best Saved Rew" , str(global_step_holder[0]))
        
        else:
            for n,agent in enumerate(agents):
                torch.save(policy.policies[agent].state_dict(), model_save_path + "_" + str(global_step_holder[0]) + "_" + agent + ".pth")
            
            print("Bests Saved Rew" , str(global_step_holder[0]))
        
    def save_test_best_fn(policy):                
        
        if Policy_Config["same_policy"]:
            torch.save(policy.policies[agents[0]].state_dict(), model_save_path + "_" + str(global_step_holder[0]) + "_BestLen.pth")
            print("Best Saved Length" , str(global_step_holder[0]))
        
        else:
            for n,agent in enumerate(agents):
                torch.save(policy.policies[agent].state_dict(), model_save_path + "_" + str(global_step_holder[0]) + "_" + agent + ".pth")
            
            print("Best Saved Length" , str(global_step_holder[0]))
        

    def stop_fn(mean_rewards):
        return mean_rewards >= 99999939.0

    def train_fn(epoch, env_step):
        epsilon = trainer_params['tn_eps_max'] - (trainer_params['tn_eps_max'] - trainer_params['tn_eps_max']/100)*(epoch/trainer_params['max_epoch'])          
        if Policy_Config["same_policy"]:
            policy.policies[agents[0]].set_eps(epsilon)
        else:
            for agent in agents:
                policy.policies[agent].set_eps(epsilon)
                
        
        # if env_step % 500 == 0:
            # logger.write("train/env_step", env_step, {"train/eps": eps})


    def test_fn(epoch, env_step):
               
        epsilon = trainer_params['ts_eps_max']#0.01#max(0.001, 0.1 - epoch * 0.001)
        if Policy_Config["same_policy"]:
            policy.policies[agents[0]].set_eps(epsilon)
        else:            
            for agent in agents:                             
                 policy.policies[agent].set_eps(epsilon)
                
        
        if global_step_holder[0] % 10 == 0:
            
            if Policy_Config["same_policy"]:
                torch.save(policy.policies[agents[0]].state_dict(), model_save_path + "_" + str(global_step_holder[0]) + "_Step.pth")
                print("Steps Policy Saved " , str(global_step_holder[0]))
            
            else:
                for n,agent in enumerate(agents):
                    torch.save(policy.policies[agent].state_dict(), model_save_path + "_" + str(global_step_holder[0]) + "_" + agent + "Step" + str(global_step_holder[0]) + ".pth")
                
                print("Steps Policy Saved " , str(global_step_holder[0]))

        
    def reward_metric(rews):       
                
        global_step_holder[0] +=1 
        return np.sum(rews, axis = 1)


    # # ======== Step 5: Run the trainer =========
    offPolicyTrainer = OffpolicyTrainer(
        policy=policy,
        train_collector=train_collector,
        test_collector=test_collector,        
        max_epoch=trainer_params['max_epoch'],
        step_per_epoch=trainer_params['step_per_epoch'],
        step_per_collect=trainer_params['step_per_collect'],        
        episode_per_test= trainer_params['episode_per_test'],
        batch_size=trainer_params['batch_size'],
        train_fn=train_fn,
        test_fn=test_fn,
        stop_fn=stop_fn,
        save_best_fn=save_best_fn,
        # save_test_best_fn=save_test_best_fn,
        update_per_step=trainer_params['update_per_step'],
        logger=logger,
        test_in_train=True,
        reward_metric=reward_metric,
        show_progress = True 
               
        )
    
    result = offPolicyTrainer.run()
    writer.close()
    # return result, policy.policies[agents[1]]
    print(f"\n==========Result==========\n{result}")
    print("\n(the trained policy can be accessed via policy.policies[agents[0]])")

waiting for remote GODOT connection on port 11008
connection established
action space {'desiredG_input': {'action_type': 'continuous', 'size': 1}, 'hdg_input': {'action_type': 'continuous', 'size': 1}, 'level_input': {'action_type': 'continuous', 'size': 1}, 'shoot_input': {'action_type': 'discrete', 'size': 2}}
observation space {'obs': {'size': [15], 'space': 'box'}}
waiting for remote GODOT connection on port 11008
connection established
action space {'desiredG_input': {'action_type': 'continuous', 'size': 1}, 'hdg_input': {'action_type': 'continuous', 'size': 1}, 'level_input': {'action_type': 'continuous', 'size': 1}, 'shoot_input': {'action_type': 'discrete', 'size': 2}}
observation space {'obs': {'size': [15], 'space': 'box'}}
waiting for remote GODOT connection on port 11008
connection established
action space {'desiredG_input': {'action_type': 'continuous', 'size': 1}, 'hdg_input': {'action_type': 'continuous', 'size': 1}, 'level_input': {'action_type': 'continuous', 'size': 1



({'agent_0': {'obs': [0, 0.184642740885417, 0.406399993896484, 0, 0.846153846153846, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]}, 'agent_1': {'obs': [0.0617333323160807, 0.184642740885417, 0.406399993896484, 0, 0.846153846153846, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]}}, {'agent_0': {}, 'agent_1': {}})
2
{'agent_0': {}, 'agent_1': {}}
({'agent_0': {'obs': [0, 0.184642740885417, 0.406399993896484, 0, 0.846153846153846, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]}, 'agent_1': {'obs': [0.0617333323160807, 0.184642740885417, 0.406399993896484, 0, 0.846153846153846, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]}}, {'agent_0': {}, 'agent_1': {}})
2
{'agent_0': {}, 'agent_1': {}}
Buffer Warming Up 
Batch(
    agent_0: 3,
    agent_1: 3,
)


IndexError: invalid index to scalar variable.