In [None]:
import json
import random
import time
from datetime import timedelta
from decimal import Decimal

import numpy as np
import torch

from mlagents_envs.base_env import ActionTuple
from mlagents_envs.environment import UnityEnvironment

from utils_policy_train import *
from utils_testing import *
from utils_uf_methods import *

# UF Methods

In [None]:
# UE methods
uf_methods = {}

In [None]:
# Probabilistic Transition Model

def prob_world_score(ray_obs, state_obs, 
                     action,
                     prob_model, 
                     input_mean, input_std):
    
    # If no action is provided, return neutral score
    if action is None:
        return 0.0
    
    # Convert inputs to tensors
    ray_obs = torch.tensor(ray_obs)
    state_obs = torch.tensor(state_obs)
    action = torch.tensor(action)
    
    # Concatenate all observations and action into a single vector
    obs_concat = torch.cat([ray_obs.flatten(), state_obs, action]).unsqueeze(0)
    
    # Select only the relevant input features (ray stacks, state subset, last 2 features)
    x = torch.cat([obs_concat[:, 17:17*4], 
                   obs_concat[:, 17*4 + 7: 17*4 + 7*4], 
                   obs_concat[:, -2:]], dim=1)
    
    # Normalize input with training statistics
    x = (x - input_mean) / input_std
    
    # Forward pass: extract predictive variance only (no gradient)
    with torch.no_grad():
        _, var = prob_model(x)
    var = var[0].sum().detach()
    
    # Return scalar uncertainty score
    return float(var)


# Load pre-trained probabilistic world model
prob_method = torch.load('./u_e_test/prob_world_method.pth', weights_only=False)
prob_world = ProbabilisticWorldModel(**prob_method['model_args'])
prob_world.load_state_dict(prob_method['model_parameters'])
prob_world.eval()

# Register probabilistic world model method inside UF methods
uf_methods['prob_world_model'] = lambda ray_obs, state_obs, action: prob_world_score(
    ray_obs, state_obs, action, 
    prob_world,
    prob_method['input_mean'], prob_method['input_std']
)


In [None]:
# Monte Carlo Dropout World Model

def mcd_world_score(ray_obs, state_obs, 
                    action,
                    mcd_model, 
                    input_mean, input_std):
    
    # If no action is provided, return neutral score
    if action is None:
        return 0.0
    
    # Convert inputs to tensors
    ray_obs = torch.tensor(ray_obs)
    state_obs = torch.tensor(state_obs)
    action = torch.tensor(action)
    
    # Concatenate observations and action into a single vector
    obs_concat = torch.cat([ray_obs.flatten(), state_obs, action]).unsqueeze(0)
    
    # Select the relevant subset of features (ray stacks, part of state, last 2 dims)
    x = torch.cat([obs_concat[:, 17:17*4], 
                   obs_concat[:, 17*4 + 7: 17*4 + 7*4], 
                   obs_concat[:, -2:]], dim=1)
    
    # Normalize input with training statistics
    x = (x - input_mean) / input_std
    
    # Forward pass with MC Dropout: compute variance across n_samples stochastic runs
    with torch.no_grad():
        _, var, _ = mcd_model.predict(x, n_samples=20)
    var = var[0].sum().detach()
    
    # Return scalar uncertainty score
    return float(var)


# Load pre-trained Monte Carlo Dropout world model
mcd_method = torch.load('./u_e_test/mcd_world_method.pth', weights_only=False)
mcd_world = MCDropoutWorldModel(**mcd_method['model_args'])
mcd_world.load_state_dict(mcd_method['model_parameters'])
mcd_world.eval()

# Register MCD world model method inside UF methods
uf_methods['mcd_world_model'] = lambda ray_obs, state_obs, action: mcd_world_score(
    ray_obs, state_obs, action,
    mcd_world,
    mcd_method['input_mean'], mcd_method['input_std']
)


In [None]:
# Q-network Ensemble

def qnet_ensemble_score(ray_obs, state_obs, 
                        action,
                        qnet_ens):

    # Convert inputs to tensors and add batch dimension
    ray_obs = torch.tensor([ray_obs])
    state_obs = torch.tensor([state_obs])
    action = torch.tensor([action])
    
    with torch.no_grad():                                          
        # Compute Q-values for each model in the ensemble
        q_vals = torch.stack([
            q(ray_obs, state_obs, action) for q in qnet_ens
        ]) 

    # Compute variance across ensemble predictions (disagreement = uncertainty)
    var = torch.var(q_vals.flatten()).detach()
    return float(var)


# Load pre-trained Q-network ensemble (5 members)
qnet_method = torch.load('./u_e_test/qnet_method.pth', weights_only=False)
qnet_ensemble = [DenseSoftQNetwork(**qnet_method['model_args']) for _ in range(5)]

# Load parameters for each ensemble member
for i, q in enumerate(qnet_ensemble):
    q.load_state_dict(qnet_method['model_parameters'][i])

# Set all networks to evaluation mode
for q in qnet_ensemble:
    q.eval()

# Register Q-ensemble method inside UF methods
uf_methods['qnet_ensemble'] = lambda ray_obs, state_obs, action: qnet_ensemble_score(
    ray_obs, state_obs, action,
    qnet_ensemble
)


In [None]:
# Random Network Distillation

def rnd_score(ray_obs, state_obs, 
              action,
              source_model, predictor_model, 
              input_mean, input_std):
    
    # Convert inputs to tensors and flatten ray observations
    ray_obs = torch.tensor(ray_obs, dtype=torch.float32).flatten()
    state_obs = torch.tensor(state_obs, dtype=torch.float32)
    x = torch.cat([ray_obs, state_obs]).unsqueeze(0)

    # Normalize with training statistics
    x = (x - input_mean) / input_std
    
    with torch.no_grad():
        # Predictor tries to match the fixed random source
        pred = predictor_model(x)
        target = source_model(x)
    
    # Compute squared error (MSE) as novelty signal
    diff = (pred - target) ** 2
    diff = diff[0].sum() * 100  # scaled score
    
    # Return scalar uncertainty score
    return float(diff)


# Load pre-trained RND models (source and predictor networks)
rnd_method = torch.load('./u_e_test/rnd_method.pth', weights_only=False)
rnd_source = RNDNetwork(**rnd_method['model_args'])
rnd_predictor = RNDNetwork(**rnd_method['model_args'])

# Load parameters for both networks
rnd_source.load_state_dict(rnd_method['model_parameters'][0])
rnd_predictor.load_state_dict(rnd_method['model_parameters'][1])

# Set models to evaluation mode
rnd_source.eval()
rnd_predictor.eval()

# Register RND method inside UF methods
uf_methods['rnd'] = lambda ray_obs, state_obs, action: rnd_score(
    ray_obs, state_obs, action,
    rnd_source, rnd_predictor,
    rnd_method['input_mean'], rnd_method['input_std']
)


In [None]:
# Random baseline
def random_score(ray_obs, state_obs, action):
    return random.randint(0, 100)


# Register random baseline method inside UF methods
uf_methods['random'] = random_score

In [4]:
qnet_method['percentiles']

[(1.0, 0.0043255239725112915),
 (10.0, 0.016104964539408684),
 (20.0, 0.026210537180304527),
 (30.0, 0.03698762133717537),
 (40.0, 0.0494244359433651),
 (50.0, 0.06499352306127548),
 (60.0, 0.08575601130723953),
 (65.0, 0.09925977140665054),
 (70.0, 0.11619290709495544),
 (75.0, 0.13798460364341736),
 (80.0, 0.1697603464126587),
 (85.0, 0.21788763999938965),
 (90.0, 0.32117959856987),
 (95.0, 0.6011334657669067),
 (99.0, 1.9817389249801636)]

In [5]:
mcd_method['percentiles']

[(1.0, 0.26237156987190247),
 (10.0, 0.36038389801979065),
 (20.0, 0.43789416551589966),
 (30.0, 0.5111830234527588),
 (40.0, 0.5836204290390015),
 (50.0, 0.6581688523292542),
 (60.0, 0.7393625974655151),
 (65.0, 0.7873311638832092),
 (70.0, 0.8432287573814392),
 (75.0, 0.9085058569908142),
 (80.0, 0.9850521087646484),
 (85.0, 1.091325044631958),
 (90.0, 1.2421854734420776),
 (95.0, 1.5098358392715454),
 (99.0, 2.2583086490631104)]

In [6]:
rnd_method['percentiles']

[(1.0, 0.13611337542533875),
 (10.0, 0.24117809534072876),
 (20.0, 0.32548123598098755),
 (30.0, 0.40168583393096924),
 (40.0, 0.4800463616847992),
 (50.0, 0.5671817660331726),
 (60.0, 0.6673864722251892),
 (65.0, 0.7246496677398682),
 (70.0, 0.791548490524292),
 (75.0, 0.8710367679595947),
 (80.0, 0.9697459936141968),
 (85.0, 1.1040456295013428),
 (90.0, 1.3025723695755005),
 (95.0, 1.670838475227356),
 (99.0, 2.764556884765625)]

In [7]:
prob_method['percentiles']

[(1.0, 0.07047063857316971),
 (10.0, 0.3038383424282074),
 (20.0, 0.5405747294425964),
 (30.0, 0.7670212388038635),
 (40.0, 1.011854648590088),
 (50.0, 1.281114935874939),
 (60.0, 1.6087371110916138),
 (65.0, 1.818676471710205),
 (70.0, 2.0656540393829346),
 (75.0, 2.4065101146698),
 (80.0, 2.8838393688201904),
 (85.0, 3.568255662918091),
 (90.0, 4.703351020812988),
 (95.0, 7.431720733642578),
 (99.0, 15.704105377197266)]

# Testing Function

In [None]:
def test(CONFIG_DICT, 
         env, env_info, env_debug,
         filter_methods, 
         actor, device):
    
    # Precompute ray angles for CBF checks
    angoli_radianti_precalcolati = generate_angles_rad(
        env_info.settings['ray_sensor_settings']['rays_per_direction'],
        env_info.settings['ray_sensor_settings']['max_ray_degrees']
    )

    current_episode = 1
    cumulative_obs = {}          # per-agent memory (obs, action, uncertainty info)
    running_episodes = {}        # active episodes data
    terminated_episodes = []     # finished episodes
    stats = []                   # episode statistics
    dataset = []                 # collected dataset
        
    while current_episode <= CONFIG_DICT['tot_episodes']:

        env.step()
        obs = collect_data_after_step(env, env_info)
        
        for id in obs:
            agent_obs = obs[id]

            # Handle terminated agents
            if agent_obs[4] == 1:
                if id in cumulative_obs:
                    # Remove agent from active lists and finalize episode
                    del cumulative_obs[id]
                    terminated_episodes.append(running_episodes[id])
                    del running_episodes[id]
                else:
                    # Agent killed very early
                    terminated_episodes.append([])
                    assert id not in running_episodes and id not in cumulative_obs
                    
            else:
                actual_ray_obs = agent_obs[0]
                actual_state_obs = agent_obs[1]
                    
                # Initialize new agent entry
                if id not in cumulative_obs:
                    cumulative_obs[id] = [
                        CONFIG_DICT['decision_frame_period'], # steps until next decision
                        None,   # last ray obs
                        None,   # last state obs
                        None,   # last action taken
                        0.0,    # last uncertainty estimate
                        True,   # last UF activation
                    ]
                    
                # Time to decide an action
                if cumulative_obs[id][0] >= CONFIG_DICT['decision_frame_period']:
                    cumulative_obs[id][0] = 0
                    
                    # Update ray observations with frame stacking
                    if cumulative_obs[id][1] is None:
                        cumulative_obs[id][1] = actual_ray_obs
                        cumulative_ray_obs = actual_ray_obs
                    else:
                        cumulative_ray_obs = cumulative_obs[id][1][1:, :] 
                        cumulative_ray_obs = np.concatenate([cumulative_ray_obs, actual_ray_obs[-1:, :]])

                    # Update state observations with temporal stacking
                    if cumulative_obs[id][2] is None:
                        cumulative_obs[id][2] = actual_state_obs
                        cumulative_state_obs = actual_state_obs
                    else:
                        cumulative_state_obs = cumulative_obs[id][2][env_info.settings['behavior_parameters_settings']['observation_size']:] 
                        cumulative_state_obs = np.concatenate([cumulative_state_obs, actual_state_obs[-env_info.settings['behavior_parameters_settings']['observation_size']:]])
                    
                    # Policy action from actor
                    action, _, _, _ = actor.get_action(
                        torch.Tensor([cumulative_ray_obs]).to(device), 
                        torch.Tensor([cumulative_state_obs]).to(device),
                        CONFIG_DICT['var_scale']
                    )
                    action = action[0].detach().cpu().numpy()
                    
                    # Uncertainty filter (optional)
                    if CONFIG_DICT['uncertainty_filter']['enabled']: 
                        uncertanty_estimate = filter_methods[CONFIG_DICT['uncertainty_filter']['method']](
                            cumulative_ray_obs, 
                            cumulative_state_obs, 
                            action
                        )
                        cumulative_obs[id][4] = uncertanty_estimate
                        cumulative_obs[id][5] = uncertanty_estimate > CONFIG_DICT['uncertainty_filter']['threshold']
                    
                    # Update agent memory
                    cumulative_obs[id][1] = cumulative_ray_obs
                    cumulative_obs[id][2] = cumulative_state_obs
                    cumulative_obs[id][3] = action
                    
                    # Start new episode if not already tracked
                    if id not in running_episodes:
                        running_episodes[id] = []
                    running_episodes[id].append({
                        'ray': cumulative_ray_obs,
                        'state': cumulative_state_obs,
                        'u_e': cumulative_obs[id][4],
                        'uf_activation': cumulative_obs[id][5],
                        'action': action,
                        'inner_steps': []
                    })

                # Use last predicted action by default
                policy_action = cumulative_obs[id][3] 
                
                # Control Barrier Function (CBF) correction
                cbf_action = np.zeros(2)
                if CONFIG_DICT['cbf']['enabled']:
                    cbf_action = CBF_from_obs(
                        actual_ray_obs[-1], policy_action, env_info,
                        CONFIG_DICT['cbf']['d_safe'],
                        CONFIG_DICT['cbf']['alpha'],
                        CONFIG_DICT['cbf']['d_safe_mul'],
                        angoli_radianti_precalcolati
                    )
                    # Ensure minimum forward velocity
                    if policy_action[0] > CONFIG_DICT['cbf']['min_forward']:
                        cbf_action[0] = max(CONFIG_DICT['cbf']['min_forward'], cbf_action[0])
                    else:
                        cbf_action[0] = max(policy_action[0], cbf_action[0])
                
                # Check if CBF activated
                cbf_activation = CONFIG_DICT['cbf']['enabled'] and np.linalg.norm(cbf_action - policy_action) > 0.0001
                running_episodes[id][-1]['inner_steps'].append([np.linalg.norm(cbf_action - policy_action), cbf_activation])
                
                # Final action selection (UF + CBF logic)
                final_action = policy_action
                if cumulative_obs[id][5] and cbf_activation:
                    final_action = cbf_action
                
                # Debug visualization (optional)
                if CONFIG_DICT['send_debug_action']:
                    env_debug.send_agent_action_debug(
                        final_action[0], final_action[1],
                        policy_action[0], policy_action[1], 
                        cbf_activation, 
                        cbf_action[0], cbf_action[1],
                        cumulative_obs[id][5],
                        CONFIG_DICT['uncertainty_filter']['threshold'],
                        cumulative_obs[id][4]
                    ) 
                                                          
                # Apply final action to environment
                a = ActionTuple(continuous=np.array([final_action]))
                env.set_action_for_agent(
                    env_info.settings['behavior_parameters_settings']['behavior_name'], id, a
                )
                
                # Increment frame counter
                cumulative_obs[id][0] += CONFIG_DICT['frame_per_step']
        
        # Handle finished episodes
        if len(env_info.msg_queue) == len(terminated_episodes) and len(terminated_episodes) > 0:
            if len(terminated_episodes) == 1:
                t_msg = env_info.msg_queue.pop() 
                t_episode = terminated_episodes.pop()
                
                if not t_episode:
                    print(current_episode, '- agent killed too early, step', t_msg['length'])
                else:
                    stats.append(extract_stats(t_episode, t_msg, CONFIG_DICT))
                    
                    if current_episode % CONFIG_DICT['print_interval'] == 0:
                        print_stats_light(stats, CONFIG_DICT['tot_episodes'])
                        
                    current_episode += 1
                    
                    # Save data if required
                    if CONFIG_DICT['accumulate_data']:
                        dataset.append([
                            list(element['ray'].flatten()) + list(element['state']) + list(element['action'])
                            for element in t_episode
                        ])
                        
            else: 
                # Too many overlapping terminations → reset
                print(current_episode, '- sovrapposition, deleting', len(terminated_episodes), 'episodes')
                terminated_episodes = []
                env_info.msg_queue = []
                
        # Safety check: queue should not grow indefinitely
        if len(env_info.msg_queue) > CONFIG_DICT['message_queue_len_error'] or len(terminated_episodes) > CONFIG_DICT['message_queue_len_error']:
            print('ERRORE')
            raise AssertionError('Unexpected queue growth')

    return stats, dataset


# Start Testing Code

In [None]:
# DEBUG DICT
CONFIG_DICT = {

    'test_name': 'Test',
    
    'send_debug_action': True,  # Solo 1 agente supportato
    'accumulate_data': False,
    'save_stats':False,
    'print_interval':25,
    
    'message_queue_len_error':10,
    'cuda': True,
    'tot_episodes': 10000,
    
    'decision_frame_period': 5,
    'frame_per_step': 1,
    
    'var_scale': 2,
    
    'cbf': {
        'd_safe': 1.25, 
        'd_safe_mul': 2, 
        'alpha': 5, 
        'min_forward': 0.05, 
        'enabled':True
    },
    
    'uncertainty_filter': {
        'method': 'mcd_world_model',
        'enabled': True,
        'threshold': 0.9085058569908142
    },
}

if CONFIG_DICT['decision_frame_period'] % CONFIG_DICT['frame_per_step'] != 0:
    print("ATTENZIONE ESPLODERA' TUTTO!!")
CONFIG_DICT['run_name'] = 'base_2179199' # 'base+wp_2183943'



In [None]:
'''
#TESTING DICT

CONFIG_DICT = {
    'test_name': 'base',
    
    'send_debug_action': False,  # Solo 1 agente supportato
    'accumulate_data': False,
    'save_stats': True,
    'print_interval':25,    
    'message_queue_len_error':10,
    'cuda': True,

    'tot_episodes': 1000,
    
    'decision_frame_period': 5,
    'frame_per_step': 1,
    
    'var_scale': 0.9,
    
    
    #'cbf': {'d_safe': 1.25, 'd_safe_mul': 2,  'alpha': 5, 'min_forward': 0.05, 'enabled':True},
    
    #'uncertainty_filter': {'method': 'mcd_world_model','enabled': True,'threshold': 0.8526}
}

if CONFIG_DICT['decision_frame_period'] % CONFIG_DICT['frame_per_step'] != 0:
    print("ATTENZIONE ESPLODERA' TUTTO!!")
CONFIG_DICT['run_name'] = 'base_2179199' # 'base+wp_2183943'
'''

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [10]:
# Create the channel
env_info = CustomChannel()
env_debug = DebugSideChannel()

# env setup
env = UnityEnvironment(None, seed=random.randint(-100000, 100000), side_channels=[env_info, env_debug])
env.reset()

In [None]:
# path to the saved models
path = './new_models/' + CONFIG_DICT['run_name']

actor = DenseActor((env_info.settings['ray_sensor_settings']['observation_stacks'],
                    2*env_info.settings['ray_sensor_settings']['rays_per_direction'] + 1), 
                env_info.settings['behavior_parameters_settings']['observation_size']*env_info.settings['behavior_parameters_settings']['stacked_vector'], 
                env_info.settings['behavior_parameters_settings']['continuous_actions'], 
                env_info.settings['behavior_parameters_settings']['min_action'], 
                env_info.settings['behavior_parameters_settings']['max_action'], 
                [128,128,128]).to(device)
actor.load_state_dict(torch.load(os.path.join(path, 'actor_best.pth')))
actor.eval()

In [None]:
cbf_conf = [{'d_safe': 0, 'd_safe_mul': 0, 'alpha': 0, 'min_forward': 0, 'enabled':False},
            {'d_safe': 1.25, 'd_safe_mul': 2, 'alpha': 5, 'min_forward': 0.05, 'enabled':True}]

In [None]:
# DEBUG RUN
env_info.reset()
env.reset()

stats, dataset = test(CONFIG_DICT,
                    env, env_info, env_debug,
                    uf_methods,
                    actor, device)

In [None]:
'''
# TESTING RUN

start_from = 'SCW_rnd_20pctl_cbf1'
start = True

uf_thresh = [prob_method['percentiles'], mcd_method['percentiles'], qnet_method['percentiles'], rnd_method['percentiles'],
             random_percentiles] 
uf_names = ['prob_world_model', 'mcd_world_model', 'qnet_ensemble', 'rnd','random']

for j, cbf_c in enumerate(cbf_conf):
    
    if j != 1:
        continue
    
    for i, uf_name in enumerate(uf_names):
        
        if i != 4:
            continue
        
        save_dir = f'./results/MO_{uf_name}'
        base_name = f'MO_{uf_name}'
        unc_prob = []
        percentuali = []
        
        for perc, val in uf_thresh[i]:
            unc_prob.append({'method': uf_name, 'enabled': True, 'threshold': float(val)})
            percentuali.append(perc)
        
        for i, unc_c in enumerate(unc_prob):
            
            CONFIG_DICT['test_name'] = f"{base_name}_{int(percentuali[i])}pctl_cbf{j}"
            
            CONFIG_DICT['uncertainty_filter'] = unc_c
            CONFIG_DICT['cbf'] = cbf_c

            print('Starting', CONFIG_DICT['test_name'], '--', CONFIG_DICT['uncertainty_filter']['threshold'])

            if CONFIG_DICT['test_name'] == start_from:
                start = True
            if start:
                
                done = False
                start_time = time.time()
                while not done:
                    env_info.reset()
                    env.reset()
                
                    try:
                        stats, dataset = test(CONFIG_DICT,
                                            env, env_info, env_debug,
                                            uf_methods,
                                            actor, device)
                        done = True
                    except AssertionError:
                        continue

                duration = time.time() - start_time
                duration_str = str(timedelta(seconds=duration))[:-3]
                
                if CONFIG_DICT['save_stats']:
                    save_stats(stats, env_info.settings, CONFIG_DICT,
                            CONFIG_DICT['test_name'] + f'_{int(time.time()) - 1751796000}',
                            save_dir,
                            duration=duration_str)'''

In [None]:
# Save dataset to JSON if accumulation is enabled
if CONFIG_DICT['accumulate_data']: 
    
    # Recursive helper to convert all numbers into float (JSON safe)
    def convert_all_to_float(obj):
        if isinstance(obj, dict):
            return {k: convert_all_to_float(v) for k, v in obj.items()}
        elif isinstance(obj, (list, tuple)):
            return [convert_all_to_float(item) for item in obj]
        elif isinstance(obj, (np.floating, Decimal)):
            return float(obj)
        else:
            return obj
        
    # Save dataset with timestamp in filename
    with open(f'./results/test_{int(time.time()) - 1751796000}.json', 'w+') as file:
        file.write(json.dumps(convert_all_to_float(dataset)))


# Close Environment

In [None]:
# close the environment
env.close()