In [16]:
import time
import random
from collections import deque
from pprint import pprint

import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from gymnasium import spaces
from stable_baselines3.common.buffers import DictReplayBuffer
from mlagents_envs.environment import UnityEnvironment, ActionTuple

from utils_policy_train import *

In [2]:
CONFIG = './train_config/standard.yaml'

# Args

In [3]:
args = parse_args_from_file(CONFIG)
args.seed = random.randint(0, 2**16)
# args.name = generate_funny_name()

pprint(vars(args))

{'actor_network_layers': [128, 128, 128],
 'alpha': 1.0,
 'alpha_lr': 0.0004,
 'autotune': True,
 'batch_size': 256,
 'bootstrap_batch_proportion': 0.8,
 'buffer_size': 120000,
 'cuda': True,
 'env_id': 'std',
 'exp_name': 'base+wp',
 'gamma': 0.995,
 'learning_starts': 1000,
 'loss_log_interval': 100,
 'metrics_log_interval': 300,
 'metrics_smoothing': 0.985,
 'noise_clip': 0.5,
 'policy_frequency': 4,
 'policy_lr': 0.0004,
 'q_ensemble_n': 5,
 'q_lr': 0.0004,
 'q_network_layers': [128, 128, 128],
 'seed': 59409,
 'target_network_frequency': 1,
 'tau': 0.005,
 'torch_deterministic': True,
 'total_timesteps': 50000,
 'update_per_step': 1}


# Seeding

In [4]:

# seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic

device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
print(f"Using device: {device}")

Using device: cuda


# Start Environment

In [5]:

# Create the channel
env_info = CustomChannel()

# env setup
env = UnityEnvironment(None, seed=args.seed, side_channels=[env_info])

In [6]:
env.reset()

# Environment Variables and Log

In [7]:

run_name = f"{args.exp_name}_{int(time.time()) - 1751796000}"
args.full_name = run_name

# writer to track performance
writer = SummaryWriter(f"ens_train/{run_name}")
writer.add_text(
    "Algorithm Hyperparameters",
    "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
for dict in env_info.settings:
    writer.add_text(
        dict,
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in env_info.settings[dict].items()])),
    )

In [8]:
# select the correct agent's behavoir (works if there is only one behavoir)
behavior_names = list(env.behavior_specs.keys())
if len(behavior_names) > 1:
    raise Exception("Multiple behaviors found.")
BEHAVIOUR_NAME = behavior_names[0]

In [11]:
env_info.settings

{'agent_settings': {'step_after_goal': 10,
  'max_step': 1000,
  'max_movement_speed': 3.0,
  'max_turn_speed': 160.0,
  'agent_spawn_offset': 0.15000000596046448,
  'move_smooth_time': 0.10000000149011612,
  'goal_reward': 10.0,
  'wall_hit_penalty': 0.0,
  'wall_hit_speed_weight': 0.0,
  'progress_reward': 0.05000000074505806,
  'stagnation_penalty': -0.019999999552965164,
  'ema_range_penalty': 5.0,
  'ema_smoothing': 0.010999999940395355},
 'ray_sensor_settings': {'sensor_name': 'RayPerceptionSensor',
  'rays_per_direction': 8,
  'max_ray_degrees': 75.0,
  'sphere_cast_radius': 0.0,
  'ray_length': 8.0,
  'observation_stacks': 4,
  'alternating_ray_order': False,
  'use_batched_raycasts': True,
  'min_observation': 0.0,
  'max_observation': 1.0,
  'ignore_last_ray': False},
 'behavior_parameters_settings': {'behavior_name': 'turtlebot?team=0',
  'observation_size': 7,
  'stacked_vector': 4,
  'min_observation': -256.0,
  'max_observation': 256.0,
  'continuous_actions': 2,
  'min_a

In [12]:
# BEHAVIOUR_NAME = env_info.settings['behavoir_parameters_settings']['behavior_name']

RAY_STACK = env_info.settings['ray_sensor_settings']['observation_stacks']
RAY_PER_DIRECTION = env_info.settings['ray_sensor_settings']['rays_per_direction']
RAYCAST_MIN = env_info.settings['ray_sensor_settings']['min_observation']
RAYCAST_MAX = env_info.settings['ray_sensor_settings']['max_observation']
DELETE_LAST_RAY = env_info.settings['ray_sensor_settings']['ignore_last_ray']


STATE_STACK = env_info.settings['behavior_parameters_settings']['stacked_vector']
STATE_SIZE = env_info.settings['behavior_parameters_settings']['observation_size']
STATE_MIN = env_info.settings['behavior_parameters_settings']['min_observation']
STATE_MAX = env_info.settings['behavior_parameters_settings']['max_observation']

ACTION_SIZE = env_info.settings['behavior_parameters_settings']['continuous_actions']
ACTION_MIN = env_info.settings['behavior_parameters_settings']['min_action']
ACTION_MAX = env_info.settings['behavior_parameters_settings']['max_action']

if DELETE_LAST_RAY:
    RAYCAST_SHAPE = (RAY_STACK, 2*RAY_PER_DIRECTION) 
else:
    RAYCAST_SHAPE = (RAY_STACK, 2*RAY_PER_DIRECTION + 1) 
STATE_SHAPE = (STATE_SIZE*STATE_STACK, )
ACTION_SHAPE = (ACTION_SIZE, )

In [13]:

# creating the training networks
actor = DenseActor(RAYCAST_SHAPE, STATE_SHAPE[0], ACTION_SHAPE[0], ACTION_MIN, ACTION_MAX, args.actor_network_layers).to(device)
actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr)

qf_ensemble = [DenseSoftQNetwork(RAYCAST_SHAPE, STATE_SHAPE[0], ACTION_SHAPE[0], args.q_network_layers).to(device) for _ in range(args.q_ensemble_n)]
qf_ensemble_target = [DenseSoftQNetwork(RAYCAST_SHAPE, STATE_SHAPE[0], ACTION_SHAPE[0], args.q_network_layers).to(device) for _ in range(args.q_ensemble_n)]
for q_t, q in zip(qf_ensemble_target, qf_ensemble):
    q_t.load_state_dict(q.state_dict())

par = []
for q in qf_ensemble:
    par += list(q.parameters())
qf_optimizer = torch.optim.Adam(
    par,
    lr=args.q_lr
)

# Replay Buffer

In [17]:

# definition of the gym spaces for the replay buffer
observation_space = spaces.Dict({
    "raycast": spaces.Box(low=RAYCAST_MIN, high=RAYCAST_MAX, shape=RAYCAST_SHAPE, dtype=np.float32),
    "state": spaces.Box(low=STATE_MIN, high=STATE_MAX, shape=STATE_SHAPE, dtype=np.float32),
})
action_space = spaces.Box(low=ACTION_MIN, high=ACTION_MAX, shape=ACTION_SHAPE, dtype=np.float32)

# initialization of the tailored replay buffer
rb = DictReplayBuffer(
    args.buffer_size,
    observation_space=observation_space,
    action_space=action_space,
    device=device,
    handle_timeout_termination=True,
    optimize_memory_usage=False
)

# start algorithm

In [18]:
# Automatic entropy tuning
if args.autotune:
    target_entropy = -torch.prod(torch.Tensor(ACTION_SHAPE).to(device)).item()
    log_alpha = torch.zeros(1, requires_grad=True, device=device)
    alpha = log_alpha.exp().clamp(min=1e-4).item()
    a_optimizer = optim.Adam([log_alpha], lr=args.alpha_lr)
else:
    alpha = args.alpha

In [21]:
# savepath and variables to the model checkpointing
save_path = './new_models/' + run_name

os.makedirs(save_path, exist_ok=True)
best_reward = -float('inf')

# to keep track of the episode length
episodic_stats = None
initial_movements = {}  # agent_id -> movement
obs = collect_data_after_step(env, env_info)

In [23]:
# start training
print('Start Training')
start_time = time.time()
global_step = 0
while global_step < args.total_timesteps:

    # actions for each agent in the environment
    # dim = (naagents, action_space)
    for id in obs:
        agent_obs = obs[id]
        
        # terminated agents are not considered
        if agent_obs[4]:
            continue
        
        # algo logic
        if global_step < args.learning_starts * 2:
            # change this to use the handcrafted starting policy or a previously trained policy
            
            action = get_initial_action(id, initial_movements)
            # action, _, _ = old_actor.get_action(torch.Tensor([obs[id][0]]), 
            #                                 torch.Tensor([obs[id][1]]),
            #                                 0.5)
            # action = action[0].detach().numpy()
        else:
            # training policy
            action, _, _, _ = actor.get_action(torch.Tensor([obs[id][0]]).to(device), 
                                            torch.Tensor([obs[id][1]]).to(device))
            action = action[0].detach().cpu().numpy()
        
        # memorize the action taken for the next step
        agent_obs[3] = action
        
        # the first dimention of the action is the "number of agent"
        # Always 1 if "set_action_for_agent" is used
        a = ActionTuple(continuous=np.array([action]))
        env.set_action_for_agent(BEHAVIOUR_NAME, id, a)
    
    # environment step
    env.step()
    next_obs = collect_data_after_step(env, env_info)
         
    # asynchronous reception of other info from ended episode
    while env_info.msg_queue:
        msg = env_info.msg_queue.pop()
        
        if global_step >= args.learning_starts:
            if episodic_stats == None:
                episodic_stats = {
                    "length": msg["length"],
                    "reward": msg["reward"],
                    "success": msg["success"],
                    "collisions": msg["collisions"],
                }
            else:
                for s in episodic_stats:
                    episodic_stats[s] = episodic_stats[s]*args.metrics_smoothing + (1 - args.metrics_smoothing)*msg[s]
        
    # save data to reply buffer; handle `terminal_observation`
    for id in obs:
        prev_agent_obs = obs[id]
        # consider every agent that in the previous step was not terminated
        # in this way are excluded the agents that are already considered before and don't have a 
        # couple prev_obs - next_obs and a reward
        if prev_agent_obs[4] or id not in next_obs:
            continue
            
        next_agent_obs = next_obs[id]
        
        # add the data to the replay buffer
        rb.add(obs = {'raycast': prev_agent_obs[0], 'state': prev_agent_obs[1]}, 
               next_obs = {'raycast': next_agent_obs[0], 'state': next_agent_obs[1]},
               action = np.array(prev_agent_obs[3]), 
               reward = next_agent_obs[2], 
               done = next_agent_obs[4],
               infos = [{}])
        
    # crucial step, easy to overlook, update the previous observation
    obs = next_obs
    
    # Training loop
    for _ in range(args.update_per_step):

        # Log episodic stats periodically
        if episodic_stats is not None and global_step % args.metrics_log_interval == 0:
            print_text = f"[{global_step}/{args.total_timesteps}] "
            for s in episodic_stats:
                writer.add_scalar("episodic_stats/" + s, episodic_stats[s], global_step)
                print_text += f"|{s}: {episodic_stats[s]:.5f}"
            print_text += f'| SPS: {int(global_step / (time.time() - start_time))}'
            print(print_text)

        # Save best models based on reward
        if episodic_stats is not None and episodic_stats["reward"] > best_reward:
            best_reward = episodic_stats["reward"]
            torch.save(actor.state_dict(), os.path.join(save_path, 'actor_best.pth'))
            for i, qf in enumerate(qf_ensemble):
                torch.save(qf.state_dict(), os.path.join(save_path, f'qf{i+1}_best.pth'))
            for i, qft in enumerate(qf_ensemble_target):
                torch.save(qft.state_dict(), os.path.join(save_path, f'qf{i+1}_target_best.pth'))

        # Start learning after a warm-up phase
        if global_step > args.learning_starts:

            # Sample a batch from replay buffer
            data = rb.sample(args.batch_size)

            with torch.no_grad():
                # Compute target action with exploration noise
                next_action, next_log_pi, _, _ = actor.get_action(
                    data.next_observations['raycast'], 
                    data.next_observations['state']
                )

                if args.noise_clip > 0:
                    noise = torch.randn_like(next_action) * args.noise_clip
                    next_action = torch.clamp(next_action + noise, -1, 1)

                # Compute target Q-value (min over ensemble)
                target_q_values = []
                for q_target in qf_ensemble_target:
                    q_val = q_target(
                        data.next_observations['raycast'], 
                        data.next_observations['state'], 
                        next_action
                    )
                    target_q_values.append(q_val)
                stacked_target_q = torch.stack(target_q_values)
                min_qf_next_target = stacked_target_q.min(dim=0).values - alpha * next_log_pi
                next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * min_qf_next_target.view(-1)

            # Q-function updates (with bootstrapping)
            q_losses = []
            q_vals = []
            batch_size = int(data.actions.shape[0] * args.bootstrap_batch_proportion)
            for q in qf_ensemble:
                # Bootstrap indices
                indices = torch.randint(0, batch_size, (batch_size,), device=data.actions.device)
                
                obs_raycast = data.observations['raycast'][indices]
                obs_state = data.observations['state'][indices]
                actions = data.actions[indices]
                target = next_q_value[indices]

                # Compute Q loss
                q_val = q(obs_raycast, obs_state, actions).view(-1)
                loss = F.mse_loss(q_val, target)
                q_losses.append(loss)
                q_vals.append(q_val)
                
            total_q_loss = torch.stack(q_losses).mean()
            qf_optimizer.zero_grad()
            total_q_loss.backward()
            qf_optimizer.step()
            
            # Track Q-value statistics
            q_std = torch.stack(q_vals).std(dim=0).mean().item()
            q_vals = torch.stack(q_vals).mean()
            
            # Delayed policy (actor) update
            if global_step % args.policy_frequency == 0:
                for _ in range(args.policy_frequency):
                    pi, log_pi, _, _ = actor.get_action(data.observations['raycast'], data.observations['state'])
                    actor_entropy = - (log_pi.exp() * log_pi).sum(dim=-1).mean()

                    q_pi_vals = [q(data.observations['raycast'], data.observations['state'], pi) for q in qf_ensemble]
                    min_qf_pi = torch.min(torch.stack(q_pi_vals), dim=0).values.view(-1)

                    actor_loss = ((alpha * log_pi) - min_qf_pi).mean()

                    actor_optimizer.zero_grad()
                    actor_loss.backward()
                    actor_optimizer.step()

                    # Automatic entropy tuning (if enabled)
                    if args.autotune:
                        with torch.no_grad():
                            _, log_pi, _, _ = actor.get_action(data.observations['raycast'], data.observations['state'])
                        alpha_loss = (-log_alpha * (log_pi + target_entropy)).mean()

                        a_optimizer.zero_grad()
                        alpha_loss.backward()
                        a_optimizer.step()
                        alpha = log_alpha.exp().item()

            # Soft update target Q-networks
            if global_step % args.target_network_frequency == 0:
                for q, q_t in zip(qf_ensemble, qf_ensemble_target):
                    for param, target_param in zip(q.parameters(), q_t.parameters()):
                        target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)

            # Log training losses and stats
            if global_step % args.loss_log_interval == 0:
                for i in range(len(qf_ensemble)):
                    writer.add_scalar(f"q_loss/qf{i+1}", q_losses[i].item(), global_step)

                writer.add_scalar("loss/qf_mean", torch.stack(q_losses).mean().item(), global_step)
                writer.add_scalar("loss/actor", actor_loss.item(), global_step)
                
                writer.add_scalar("stats/qf_val_mean ", q_vals, global_step)
                writer.add_scalar("stats/qf_val_var ", q_std, global_step)
                
                writer.add_scalar("stats/policy_entropy", actor_entropy.item(), global_step)
                writer.add_scalar("stats/SPS", int(global_step / (time.time() - start_time)), global_step)
                if args.autotune:
                    writer.add_scalar("loss/alpha", alpha, global_step)
                    writer.add_scalar("loss/alpha_loss", alpha_loss.item(), global_step)

        elif global_step == args.learning_starts:
            print("Start Learning")

        # Step counter
        global_step += 1

Start Training
Start Learning
[1200/50000] |length: 1001.00000|reward: -19.47134|success: 0.00000|collisions: 2.97023| SPS: 10
[1500/50000] |length: 1001.00000|reward: -18.69947|success: 0.00000|collisions: 3.14229| SPS: 10
[1800/50000] |length: 1001.00000|reward: -18.66213|success: 0.00000|collisions: 3.13532| SPS: 10


  action, _, _, _ = actor.get_action(torch.Tensor([obs[id][0]]).to(device),


[2100/50000] |length: 1001.00000|reward: -18.55758|success: 0.00000|collisions: 3.09548| SPS: 10
[2400/50000] |length: 1001.00000|reward: -18.48240|success: 0.00000|collisions: 3.14854| SPS: 10
[2700/50000] |length: 1001.00000|reward: -18.37394|success: 0.00000|collisions: 3.16728| SPS: 10
[3000/50000] |length: 1001.00000|reward: -18.12937|success: 0.00000|collisions: 3.30433| SPS: 10
[3300/50000] |length: 1001.00000|reward: -17.96807|success: 0.00000|collisions: 3.32542| SPS: 10
[3600/50000] |length: 1001.00000|reward: -17.94221|success: 0.00000|collisions: 3.39455| SPS: 10
[3900/50000] |length: 993.22850|reward: -17.39589|success: 0.01455|collisions: 3.37553| SPS: 10
[4200/50000] |length: 993.57300|reward: -17.27667|success: 0.01391|collisions: 3.47575| SPS: 10
[4500/50000] |length: 994.21686|reward: -17.06456|success: 0.01270|collisions: 3.69544| SPS: 10
[4800/50000] |length: 980.31978|reward: -16.53815|success: 0.02696|collisions: 3.85774| SPS: 10
[5100/50000] |length: 982.11258|re

KeyboardInterrupt: 

# Close Environment

In [24]:
# close environment
env.close()

In [25]:
# save trained networks, actor and critics
torch.save(actor.state_dict(), os.path.join(save_path, 'actor_final.pth'))
for i, qf in enumerate(qf_ensemble):
    torch.save(qf.state_dict(), os.path.join(save_path, f'qf{i+1}_final.pth'))
for i, qft in enumerate(qf_ensemble_target):
    torch.save(qft.state_dict(), os.path.join(save_path, f'qf{i+1}_target_final.pth'))