In [None]:
import argparse
import sys
import time
import random
import traceback
from collections import deque
from pprint import pprint
import wandb
import numpy as np

In [None]:
import torch.optim as optim
from mlagents_envs.environment import UnityEnvironment, ActionTuple
from mlagents_envs.side_channel.environment_parameters_channel import EnvironmentParametersChannel

In [None]:
from gymnasium import spaces 
from stable_baselines3.common.buffers import ReplayBuffer

from training_utils import *
from testing_utils import *

In [None]:
import argparse
import sys

def parse_args(default_config_path="./config/train_new_obs.yaml"):
    """
    Parse arguments from CLI or notebook.
    - In notebook: usa il default se non passato
    - In CLI: permette override dei parametri nel config
    """
    # --- Gestione notebook: evita crash su ipykernel args ---
    argv = sys.argv[1:]
    # Se siamo in notebook o non Ã¨ passato il config_path, inseriamo il default
    if len(argv) == 0 or "--f=" in " ".join(argv):
        argv = [default_config_path]

    # --- Pre-parser per leggere il config_path ---
    pre_parser = argparse.ArgumentParser(add_help=False)
    pre_parser.add_argument(
        "config_path",
        type=str,
        nargs="?",
        default=default_config_path,
        help="Main config file path"
    )
    initial_args, remaining_argv = pre_parser.parse_known_args(argv)
    CONFIG_PATH = initial_args.config_path
    print(f"Config path: {CONFIG_PATH}")

    # --- Legge parametri dal file di config ---
    file_config_dict = parse_config_file(CONFIG_PATH)

    # --- Parser principale ---
    parser = argparse.ArgumentParser(description="Training Script")
    parser.add_argument(
        "config_path",
        type=str,
        nargs="?",
        default=CONFIG_PATH,
        help="Main config file path"
    )

    # Aggiunge parametri dal config file, con tipi corretti
    for key, value in file_config_dict.items():
        if isinstance(value, bool):
            parser.add_argument(f"--{key}", type=str2bool, default=value)
        elif value is None:
            parser.add_argument(f"--{key}", type=str, default=value)
        else:
            parser.add_argument(f"--{key}", type=type(value), default=value)

    # --- Parse finale con remaining_argv per ignorare args extra Jupyter ---
    args, unknown = parser.parse_known_args(remaining_argv)
    if unknown:
        print("Ignored unknown args:", unknown)
    return args


# Testing Function

In [None]:
def test(env, 
         args, 
         env_info,

         actor,
         
         BEHAVIOUR_NAME,
         STATE_SIZE,
         DEVICE
        ):

    testing_stats = {
        "time/python_time": RunningMean(),
        "time/unity_time": RunningMean(),
    }

    best_reward = -float('inf')

    test_data = {}
    
    episodic_stats = {}
    success_stats = {}
    failure_stats = {}
    
    start_time = time.time()
    unity_end_time = -1
    unity_start_time = -1

    global_step = 0
    print(f'[{global_step}/{args.total_timesteps}] Starting Training')


    obs = collect_data_after_step(env, BEHAVIOUR_NAME, STATE_SIZE)
    
    while global_step < args.total_timesteps:

        # actions for each agent in the environment
        # dim = (naagents, action_space)
        for id in obs:
            agent_obs = obs[id]
            
            # terminated agents are not considered
            if agent_obs[3]:
                continue
            
            action, _, _ = actor.get_action(torch.Tensor([obs[id][0]]).to(DEVICE))
            action = action[0].detach().cpu().numpy()
            
            # memorize the action taken for the next step
            agent_obs[2] = action
            
            # the first dimention of the action is the "number of agent"
            # Always 1 if "set_action_for_agent" is used
            a = ActionTuple(continuous=np.array([action]))
            env.set_action_for_agent(BEHAVIOUR_NAME, id, a)
        
        # --- ENVIRONMENT STEP ---
        unity_start_time = time.time()
        if unity_end_time > 0 and global_step > args.learning_starts:
            testing_stats['time/python_time'].update(unity_start_time - unity_end_time)
        
        env.step()
        unity_end_time = time.time()
        if global_step > args.learning_starts:
            testing_stats['time/unity_time'].update(unity_end_time - unity_start_time)

        next_obs = collect_data_after_step(env, BEHAVIOUR_NAME, STATE_SIZE)
        
        while env_info.stop_msg_queue:
                msg = env_info.stop_msg_queue.pop()
                
                if global_step >= args.learning_starts:
                    update_stats_from_message(episodic_stats, success_stats, failure_stats, msg, args.metrics_smoothing)        
                    if episodic_stats['ep_count'] % args.metrics_log_interval == 0:
                        print_update(global_step, args.total_timesteps, start_time, episodic_stats)
                        
                        
        # crucial step, easy to overlook, update the previous observation
        obs = next_obs
                
        # Step counter
        global_step += 1
        
    return testing_stats, episodic_stats, success_stats, failure_stats, test_data

# Start Testing Code

In [None]:
args = parse_args()
agent_config = parse_config_file(args.agent_config_path)
obstacles_config = parse_config_file(args.obstacles_config_path)
other_config = parse_config_file(args.other_config_path)

args.seed = random.randint(0, 2**16)
# args.name = generate_funny_name()

print('Training with the following parameters:')
pprint(vars(args))

print('agent_config:')
pprint(agent_config)

print('obstacles_config:')
pprint(obstacles_config)

print('other_config:')
pprint(other_config)

In [None]:
if torch.cuda.is_available() and args.cuda >= 0:
    # F-string per inserire l'indice: diventa "cuda:2"
    device_str = f"cuda:{args.cuda}"
else:
    device_str = "cpu"

DEVICE = torch.device(device_str)
print(f"Using device: {DEVICE}")

In [None]:
# seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
print(f'Seed: {args.seed}')

In [None]:
# Create the channel
env_info = CustomChannel()
param_channel = EnvironmentParametersChannel()

print('Applying Unity settings from config...')
apply_unity_settings(param_channel, agent_config, 'ag_')
apply_unity_settings(param_channel, obstacles_config, 'obs_')

if args.test_lib:
    print('Testing Ended')
    exit(0)

# env setup
print(f'Starting Unity Environment from build: {args.build_path}')
# args.build_path
env = UnityEnvironment(args.build_path, 
                       seed=args.seed, 
                       side_channels=[env_info, param_channel], 
                       no_graphics=args.headless,
                       worker_id=args.worker_id)
print('Unity Environment connected.')

In [None]:
print('Resetting environment...')
env.reset()

In [None]:
run_name = f"{args.exp_name}_{int(time.time()) - args.base_time}"
args.run_name = run_name
print(f"Run name: {run_name}")

# start training
save_path = './models/' + run_name
os.makedirs(save_path, exist_ok=True)
print('saving to path:', save_path)

In [None]:
BEHAVIOUR_NAME = other_config['behavior_name'] + '?team=' + other_config['team']

RAY_PER_DIRECTION = other_config['rays_per_direction']
RAYCAST_MIN = other_config['rays_min_observation']
RAYCAST_MAX = other_config['rays_max_observation']
RAYCAST_SIZE = 2*RAY_PER_DIRECTION + 1

STATE_SIZE = other_config['state_observation_size'] - 1
STATE_MIN = other_config['state_min_observation']
STATE_MAX = other_config['state_max_observation']

ACTION_SIZE = other_config['action_size']
ACTION_MIN = other_config['min_action']
ACTION_MAX = other_config['max_action']

TOTAL_STATE_SIZE = (STATE_SIZE + RAYCAST_SIZE)*args.input_stack

In [None]:
print('Creating and loading actor and critic networks...')

# ===== Actor =====
actor = OldDenseActor(
    TOTAL_STATE_SIZE,
    ACTION_SIZE,
    ACTION_MIN,
    ACTION_MAX,
    args.actor_network_layers
).to(DEVICE)

# ===== Q Ensemble =====
qf_ensemble = [
    OldDenseSoftQNetwork(
        TOTAL_STATE_SIZE,
        ACTION_SIZE,
        args.q_network_layers
    ).to(DEVICE)
    for _ in range(args.q_ensemble_n)
]

qf_ensemble_target = [
    OldDenseSoftQNetwork(
        TOTAL_STATE_SIZE,
        ACTION_SIZE,
        args.q_network_layers
    ).to(DEVICE)
    for _ in range(args.q_ensemble_n)
]

# ===== Load saved weights =====
load_models(actor, qf_ensemble, qf_ensemble_target, save_path, suffix='_best')

# ===== Optimizers (dopo il load) =====
actor_optimizer = optim.Adam(actor.parameters(), lr=args.policy_lr)

par = []
for q in qf_ensemble:
    par += list(q.parameters())

qf_optimizer = torch.optim.Adam(
    par,
    lr=args.q_lr
)

# ===== Obs stack =====
obs_stack = DenseStackedObservations(
    args.input_stack,
    STATE_SIZE + RAYCAST_SIZE,
    args.n_envs
)


In [None]:
testing_stats, episodic_stats, success_stats, failure_stats, test_data = test(env, 
                                                                           args, 
                                                                           
                                                                           env_info, 
                                                                           actor, 
                                                                           
                                                                           BEHAVIOUR_NAME, 
                                                                           STATE_SIZE, 
                                                                           DEVICE)

In [21]:
# Save dataset to JSON if accumulation is enabled
if CONFIG_DICT['accumulate_data']: 
    
    # Recursive helper to convert all numbers into float (JSON safe)
    def convert_all_to_float(obj):
        if isinstance(obj, dict):
            return {k: convert_all_to_float(v) for k, v in obj.items()}
        elif isinstance(obj, (list, tuple)):
            return [convert_all_to_float(item) for item in obj]
        elif isinstance(obj, (np.floating, Decimal)):
            return float(obj)
        else:
            return obj
        
    # Save dataset with timestamp in filename
    with open(f'./results/test_{int(time.time()) - 1751796000}.json', 'w+') as file:
        file.write(json.dumps(convert_all_to_float(dataset)))


# Close Environment

In [22]:
# close the environment
env.close()