## Imports

---

In [1]:
import sys
import os
import time
import pickle
import io

# Hydra
import hydra
from hydra.core.hydra_config import HydraConfig
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig, OmegaConf

# WandB / Logging
import wandb

# BenchMARL
import benchmarl.models
from benchmarl.algorithms import *
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment
from benchmarl.hydra_config import (
    load_algorithm_config_from_hydra,
    load_experiment_config_from_hydra,
    load_task_config_from_hydra,
    load_model_config_from_hydra,
)
from benchmarl.experiment.callback import Callback

# Het-Control
from het_control.callback import *
from het_control.environments.vmas import render_callback
from het_control.models.het_control_mlp_empirical import (
    HetControlMlpEmpiricalConfig,
    HetControlMlpEmpirical,
)
from het_control.callbacks.sndESLogger import TrajectorySNDLoggerCallback
from het_control.callbacks.utils import *
from het_control.snd import compute_behavioral_distance

# Scientific
import numpy as np
import torch
from tensordict import TensorDict, TensorDictBase
from typing import List, Dict, Any, Callable, Union

# Visualization
import matplotlib.pyplot as plt
import networkx as nx
from PIL import Image



In [2]:
unique_id = f"AD2C_Eval_{int(time.time())}"

if not hasattr(wandb, "_custom_orig_init"):
    print("Saving original WandB init function...")
    wandb._custom_orig_init = wandb.init

def forced_wandb_init(*args, **kwargs):
    print(f"\n--- INTERCEPTING WANDB INIT ---")
    
    # Force the new ID and Name
    kwargs['id'] = unique_id
    kwargs['name'] = unique_id
    
    # Force "New Run" behavior
    kwargs['resume'] = "allow" 
    kwargs['reinit'] = True
    
    print(f"Forced ID: {unique_id}")
    print(f"-------------------------------\n")
    
    # We always call the SAVED original function, not the current one
    return wandb._custom_orig_init(*args, **kwargs)

# Apply the patch
wandb.init = forced_wandb_init

Saving original WandB init function...


## SND Visualization Plot
---

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import wandb
import torch
import networkx as nx  # Required for the Graph Visualizer

class SNDHeatmapVisualizer:
    def __init__(self, key_name="Visuals/SND_Heatmap"):
        self.key_name = key_name

    def generate(self, snd_matrix, step_count):
        # snd_matrix is now GUARANTEED to be a clean 2D Numpy array
        n_agents = snd_matrix.shape[0]
        agent_labels = [f"Agent {i+1}" for i in range(n_agents)]
        
        # Calculate SND value
        iu = np.triu_indices(n_agents, k=1)
        if len(iu[0]) > 0:
            snd_value = float(np.mean(snd_matrix[iu]))
        else:
            snd_value = 0.0

        fig, ax = plt.subplots(figsize=(6, 5))

        im = ax.imshow(
            snd_matrix,
            cmap="viridis",
            interpolation="nearest",
            vmin=0, vmax=3 
        )

        ax.set_title(f"SND: {snd_value:.3f}  –  Step {step_count}")

        ax.set_xticks(np.arange(n_agents))
        ax.set_yticks(np.arange(n_agents))
        ax.set_xticklabels(agent_labels)
        ax.set_yticklabels(agent_labels)
        plt.setp(ax.get_xticklabels(), rotation=30, ha="right")

        fig.colorbar(im, ax=ax, label="Distance")

        for i in range(n_agents):
            for j in range(n_agents):
                val = snd_matrix[i, j]
                # Dynamic text color for visibility
                text_color = "white" if val < 1.0 else "black"
                ax.text(
                    j, i, f"{val:.2f}",
                    ha="center", va="center",
                    color=text_color,
                    fontsize=9, fontweight="bold"
                )

        plt.tight_layout()
        img = wandb.Image(fig)
        plt.close(fig)
        return {self.key_name: img}


class SNDBarChartVisualizer:
    def __init__(self, key_name="Visuals/SND_BarChart"):
        self.key_name = key_name

    def generate(self, snd_matrix, step_count):
        n_agents = snd_matrix.shape[0]
        
        # Create pairs i < j
        pairs = [(i, j) for i in range(n_agents) for j in range(i + 1, n_agents)]
        if not pairs:
            return {}

        pair_values = [float(snd_matrix[i, j]) for i, j in pairs]
        pair_labels = [f"A{i+1}-A{j+1}" for i, j in pairs]

        snd_value = float(np.mean(pair_values))

        fig, ax = plt.subplots(figsize=(8, 5))
        bars = ax.bar(pair_labels, pair_values, color="teal")

        ax.set_title(f"SND: {snd_value:.3f}  –  Step {step_count}")
        ax.set_ylabel("Distance")
        ax.set_ylim(0, 3)
        ax.tick_params(axis="x", rotation=45)

        ax.bar_label(bars, fmt="%.2f", padding=3)

        plt.tight_layout()
        img = wandb.Image(fig)
        plt.close(fig)
        return {self.key_name: img}


class SNDGraphVisualizer:
    def __init__(self, key_name="Visuals/SND_NetworkGraph"):
        self.key_name = key_name

    def generate(self, snd_matrix, step_count):
        n_agents = snd_matrix.shape[0]

        pairs = [(i, j) for i in range(n_agents) for j in range(i + 1, n_agents)]
        if not pairs:
            return {}

        pair_values = [float(snd_matrix[i, j]) for i, j in pairs]
        snd_value = float(np.mean(pair_values))

        fig = plt.figure(figsize=(7, 7))
        G = nx.Graph()

        for i, j in pairs:
            G.add_edge(i, j, weight=float(snd_matrix[i, j]))

        pos = nx.spring_layout(G, seed=42)
        weights = [G[u][v]['weight'] for u, v in G.edges()]

        nx.draw_networkx_nodes(G, pos, node_size=750, node_color='lightblue')
        
        label_mapping = {i: f"A{i+1}" for i in range(n_agents)}
        nx.draw_networkx_labels(G, pos, labels=label_mapping, font_size=12, font_weight='bold')

        edges = nx.draw_networkx_edges(
            G, pos,
            edge_color=weights,
            edge_cmap=plt.cm.viridis,
            width=2,
            edge_vmin=0, edge_vmax=3
        )

        edge_labels = {(i, j): f"{snd_matrix[i, j]:.2f}" for i, j in pairs}
        nx.draw_networkx_edge_labels(
            G, pos, edge_labels=edge_labels,
            font_color='black', font_size=9, font_weight='bold'
        )

        plt.colorbar(edges, label='Distance')
        plt.title(f"SND: {snd_value:.3f}  –  Step {step_count}", fontsize=14)
        plt.axis('off')

        img = wandb.Image(fig)
        plt.close(fig)
        return {self.key_name: img}


class SNDVisualizationManager:
    """
    Manages the individual visualizers and handles ALL data cleaning centrally.
    """
    def __init__(self):
        self.visualizers = [
            SNDHeatmapVisualizer(),
            SNDBarChartVisualizer(),
            SNDGraphVisualizer()
        ]

    def _prepare_matrix(self, snd_matrix):
        """
        Robustly converts and reshapes matrix.
        Fixes crash by Symmetrizing (Broadcasting) BEFORE accessing diagonals.
        """
        # 1. Convert to Numpy
        if hasattr(snd_matrix, "detach"):
            snd_matrix = snd_matrix.detach().cpu().numpy()
        elif not isinstance(snd_matrix, np.ndarray):
            snd_matrix = np.array(snd_matrix)

        # 2. "Peel" dimensions until we hit 2D
        # This turns (1, 2, 2) -> (2, 2) and (1, 2, 1) -> (2, 1)
        while snd_matrix.ndim > 2:
            snd_matrix = snd_matrix[0]

        # 3. Handle 1D edge case (if squeeze happened upstream)
        if snd_matrix.ndim == 1:
            # Try to reshape to square, or expand dims
            size = snd_matrix.shape[0]
            n_agents = int(np.sqrt(size))
            if n_agents * n_agents == size:
                snd_matrix = snd_matrix.reshape(n_agents, n_agents)
            else:
                # Treat as column vector (N, 1)
                snd_matrix = snd_matrix[:, None]

        # 4. Create copy
        snd_matrix = snd_matrix.copy()

        # 5. FIX: Enforce Symmetry FIRST
        # If input is (2, 1), this line broadcasts it: (2, 1) + (1, 2) = (2, 2)
        # This automatically "expands" the missing dimension.
        snd_matrix = (snd_matrix + snd_matrix.T) / 2.0

        # 6. NOW set diagonals (Safe because matrix is guaranteed square now)
        n = snd_matrix.shape[0]
        if n > 0:
            for i in range(n):
                snd_matrix[i, i] = 0.0
        
        return snd_matrix

    def generate_all(self, snd_matrix, step_count):
        # Clean the matrix ONCE here
        clean_matrix = self._prepare_matrix(snd_matrix)
        
        all_plots = {}
        for visualizer in self.visualizers:
            try:
                # Pass the clean matrix to all visualizers
                plots = visualizer.generate(clean_matrix, step_count)
                all_plots.update(plots)
            except Exception as e:
                print(f"Error generating {visualizer.__class__.__name__}: {e}")
                # Optional: Print shape to help debug if it fails again
                print(f"Failed Matrix Shape: {clean_matrix.shape}")
        return all_plots

In [4]:
class SNDVisualizerCallback(Callback):
    """
    Computes the SND matrix and uses the Manager to log visualizations.
    """
    def __init__(self):
        super().__init__()
        self.control_group = None
        self.model = None
        # Initialize the manager that holds the 3 plot classes
        self.viz_manager = SNDVisualizationManager()

    def on_setup(self):
        """Auto-detects the agent group and initializes the model wrapper."""
        if not self.experiment.group_policies:
            print("\nWARNING: No group policies found. SND Visualizer disabled.\n")
            return

        self.control_group = list(self.experiment.group_policies.keys())[0]
        policy = self.experiment.group_policies[self.control_group]
        
        # Ensure 'get_het_model' is imported or available in this scope
        self.model = get_het_model(policy)

        if self.model is None:
             print(f"\nWARNING: Could not extract HetModel for group '{self.control_group}'. Visualizer disabled.\n")

    def _get_agent_actions_for_rollout(self, rollout):
        """Helper to run the forward pass and get actions for SND computation."""
        obs = rollout.get((self.control_group, "observation"))
        actions = []
        for i in range(self.model.n_agents):
            temp_td = TensorDict(
                {(self.control_group, "observation"): obs},
                batch_size=obs.shape[:-1]
            )
            action_td = self.model._forward(temp_td, agent_index=i, compute_estimate=False)
            actions.append(action_td.get(self.model.out_key))
        return actions

    def on_evaluation_end(self, rollouts: List[TensorDict]):
        """Runs at the end of evaluation to compute SND and log plots."""
        if self.model is None:
            return

        logs_to_push = {}
        first_rollout_snd_matrix = None

        with torch.no_grad():
            for i, r in enumerate(rollouts):
                # We only need the matrix from the first rollout for clean visualization
                if i > 0: 
                    break

                agent_actions = self._get_agent_actions_for_rollout(r)
                
                # Ensure 'compute_behavioral_distance' is imported/available
                pairwise_distances_tensor = compute_behavioral_distance(agent_actions, just_mean=False)
                
                if pairwise_distances_tensor.ndim > 2:
                    pairwise_distances_tensor = pairwise_distances_tensor.mean(dim=0)

                first_rollout_snd_matrix = pairwise_distances_tensor.cpu().numpy()

        # Generate and Log Visualizations via the Manager
        if first_rollout_snd_matrix is not None:
            visual_logs = self.viz_manager.generate_all(
                snd_matrix=first_rollout_snd_matrix, 
                step_count=self.experiment.n_iters_performed
            )
            logs_to_push.update(visual_logs)
            
            # Update the logger
            self.experiment.logger.log(logs_to_push, step=self.experiment.n_iters_performed)

## ESC Controller
---

In [5]:
import numpy as np
import torch
from typing import Dict, Any, List, Tuple
from benchmarl.experiment.callback import Callback
from het_control.models.het_control_mlp_empirical import HetControlMlpEmpirical
from tensordict import TensorDictBase

# --- 1. Signal Processing Components ---

class FirstOrderLPF:
    """Standard First-Order Low Pass Filter."""
    def __init__(self, sampling_period: float, cutoff_freq: float, initial_value: float = 0.0):
        self.alpha = np.exp(-sampling_period * cutoff_freq)
        self.prev_val = float(initial_value)

    def apply(self, input_val: float) -> float:
        output = self.alpha * self.prev_val + (1 - self.alpha) * input_val
        self.prev_val = output
        return output

class FirstOrderHPF:
    """First-Order High Pass Filter using Tustin approximation."""
    def __init__(self, sampling_period: float, cutoff_freq: float, initial_input: float = 0.0, initial_output: float = 0.0):
        dt = sampling_period
        wc = cutoff_freq
        self.a1 = dt * wc + 2.0
        self.b1 = dt * wc - 2.0
        
        self.u_prev = float(initial_input)
        self.y_prev = float(initial_output)

    def apply(self, input_val: float) -> float:
        output = (1.0 / self.a1) * (-self.b1 * self.y_prev + 2.0 * (input_val - self.u_prev))
        self.u_prev = input_val
        self.y_prev = output
        return output

class PhaseGenerator:
    """
    Handles the dither signal generation (sin(wt)).
    Manages the phase 'wt' and keeps it within [0, 2pi].
    """
    def __init__(self, frequency: float, magnitude: float, sampling_period: float):
        self.freq = frequency        # rad/s
        self.mag = magnitude
        self.dt = sampling_period
        self.wt = 0.0

    def step(self) -> Tuple[float, float]:
        """Updates phase and returns (carrier_signal, dither_value)."""
        carrier = np.sin(self.wt)
        dither = self.mag * carrier
        
        # Update phase
        self.wt += self.freq * self.dt
        if self.wt > 2 * np.pi:
            self.wt -= 2 * np.pi
            
        return carrier, dither

class GradientEstimator:
    """
    Demodulates the signal (Cost * sin(wt)) and applies Low Pass Filtering.
    Also handles the 'Adapter' logic (RMS normalization).
    """
    def __init__(self, lpf_cutoff: float, sampling_period: float, use_adapter: bool = True):
        self.lpf = FirstOrderLPF(sampling_period, lpf_cutoff)
        self.use_adapter = use_adapter
        
        # Adaptation state
        self.m2 = 0.0       # Second moment estimate
        self.b2 = 0.8       # Forgetting factor for variance
        self.epsilon = 1e-8
        self.grad_mag = 0.0
        self.gradient_raw = 0.0

    def estimate(self, high_passed_cost: float, carrier_signal: float) -> float:
        # 1. Demodulate
        demodulated = high_passed_cost * carrier_signal
        
        # 2. Filter to get DC component (gradient)
        gradient_raw = self.lpf.apply(demodulated)
        
        # 3. Adaptation (Normalization)
        self.m2 = self.b2 * self.m2 + (1 - self.b2) * (gradient_raw ** 2)
        self.grad_mag = np.sqrt(self.m2)
        
        if self.use_adapter:
            # Normalize gradient by its RMS value
            return gradient_raw / (self.grad_mag + self.epsilon)
        
        return gradient_raw

class Integrator:
    """
    Integrates the estimated gradient to update the parameter theta.
    Handles Gain Scheduling and Saturation (Min/Max limits).
    """
    def __init__(self, base_gain: float, initial_value: float, min_val: float, dt: float):
        self.gain = base_gain
        self.integral = 0.0
        self.initial_value = initial_value
        self.min_val = min_val
        self.dt = dt
        
        # Gain scheduling parameters
        self.high_gain = -0.025
        self.threshold = 0.2
        self.use_gain_scheduling = False # Can be enabled via setter

    def set_gain_scheduling(self, enabled: bool):
        self.use_gain_scheduling = enabled

    def step(self, gradient: float, gradient_magnitude: float = 0.0) -> float:
        # 1. Determine Gain
        current_gain = self.gain
        if self.use_gain_scheduling:
            if gradient_magnitude > self.threshold:
                current_gain = self.high_gain
        
        # 2. Integrate: theta = integral(gain * grad)
        self.integral += current_gain * gradient * self.dt
        
        # 3. Calculate Raw Setpoint
        setpoint_raw = self.initial_value + self.integral
        
        # 4. Clamp/Saturate
        setpoint = max(setpoint_raw, self.min_val)
        
        # 5. Anti-windup: if clamped, correct the integral to match the clamp
        if setpoint < self.min_val:
            self.integral = self.min_val - self.initial_value
            
        return setpoint

In [6]:
class ESController:
    """
    Orchestrates the ESC components.
    Inputs: Cost (J)
    Outputs: Control Signal (theta + dither), and debug logs.
    """
    def __init__(self, 
                 sampling_period: float,
                 perturb_freq: float,
                 perturb_mag: float,
                 integrator_gain: float,
                 initial_val: float,
                 hpf_cutoff: float,
                 lpf_cutoff: float,
                 use_adapter: bool):

        # 1. Initialize Components
        self.hpf = FirstOrderHPF(sampling_period, hpf_cutoff)
        
        self.perturbation = PhaseGenerator(perturb_freq, perturb_mag, sampling_period)
        
        self.grad_estimator = GradientEstimator(lpf_cutoff, sampling_period, use_adapter)
        
        self.integrator = Integrator(integrator_gain, initial_val, min_val=0.0, dt=sampling_period)
        # Enable the gain scheduling logic you had in your original code
        self.integrator.set_gain_scheduling(use_adapter) 

    def update(self, cost: float) -> Dict[str, float]:
        """
        Run one step of the ESC loop.
        Returns a dictionary containing the output and intermediate values for logging.
        """
        # A. High Pass Filter the Cost
        y_hpf = self.hpf.apply(cost)

        # B. Get Perturbation (Carrier and Dither)
        carrier, dither = self.perturbation.step()

        # C. Estimate Gradient (Demodulate -> LPF -> Adapt)
        gradient = self.grad_estimator.estimate(y_hpf, carrier)
        grad_mag = self.grad_estimator.grad_mag # Retrieve internal state for scheduling

        # D. Update Setpoint (Integrator)
        theta_hat = self.integrator.step(gradient, grad_mag)

        # E. Final Output
        output_signal = theta_hat + dither

        return {
            "output": output_signal,
            "theta_hat": theta_hat,
            "dither": dither,
            "gradient": gradient,
            "cost_hpf": y_hpf,
            "grad_mag": grad_mag,
            "lpf_output": self.grad_estimator.gradient_raw,
        }

In [7]:
class ESCallback(Callback):
    """
    BenchMARL Callback that wraps the modular ESController.
    """
    def __init__(
        self,
        control_group: str,
        initial_snd: float,
        dither_magnitude: float,
        dither_frequency_rad_s: float,
        integral_gain: float,
        high_pass_cutoff_rad_s: float,
        low_pass_cutoff_rad_s: float,
        use_adapter: bool = True,
        sampling_period: float = 1.0
    ):
        super().__init__()
        self.control_group = control_group
        self.initial_snd = initial_snd
        
        # Save params to initialize controller later
        self.esc_config = {
            "sampling_period": sampling_period,
            "perturb_freq": dither_frequency_rad_s,
            "perturb_mag": dither_magnitude,
            "integrator_gain": integral_gain,
            "initial_val": initial_snd,
            "hpf_cutoff": high_pass_cutoff_rad_s,
            "lpf_cutoff": low_pass_cutoff_rad_s,
            "use_adapter": use_adapter
        }
        
        self.model = None
        self.controller = None

    def on_setup(self):
        """Initialize controller and validate model presence."""
        # self.experiment.logger.log_hparams({
        #     "control_group": self.control_group,
        #     **self.esc_config
        # })

        if self.control_group not in self.experiment.group_policies:
            print(f"WARNING: Group '{self.control_group}' not found.")
            return

        policy = self.experiment.group_policies[self.control_group]
        # Assuming get_het_model is defined in your utils
        self.model = get_het_model(policy) 

        if isinstance(self.model, HetControlMlpEmpirical):
            print(f"✅ ESC Initialized for '{self.control_group}'")
            self.controller = ESController(**self.esc_config)
            self.model.desired_snd[:] = float(self.initial_snd)
        else:
            print(f"WARNING: Compatible model not found for '{self.control_group}'.")

    def on_evaluation_end(self, rollouts: List[TensorDictBase]):
        if self.model is None or self.controller is None:
            return

        # 1. Calculate Cost (Negative Reward)
        episode_rewards = []
        with torch.no_grad():
            for r in rollouts:
                reward_key = ('next', self.control_group, 'reward')
                if reward_key in r.keys(include_nested=True):
                    episode_rewards.append(r.get(reward_key).sum().item())
                else:
                    episode_rewards.append(0.0)

        if not episode_rewards:
            return

        reward_mean = np.mean(episode_rewards)
        reward_mean = np.clip(reward_mean, -4.0, 4.0)
        cost = -reward_mean 

        # 2. Update Controller
        results = self.controller.update(cost)
        
        # 3. Apply new parameter to PyTorch Model
        new_snd = results["output"]
        previous_snd = self.model.desired_snd.item()
        
        # Ensure we don't pass negative values to the model even with dither
        final_snd_val = max(0.0, new_snd) 
        self.model.desired_snd[:] = float(final_snd_val)

        print(f"[ESC] Updated SND: {final_snd_val:.4f} (Reward: {reward_mean:.3f})")

        # 4. Log Metrics
        logs = {
            "esc/mean_reward": reward_mean,
            "esc/cost": cost,
            "esc/diversity_output": results["output"],
            "esc/diversity_setpoint": results["theta_hat"],
            "esc/gradient_estimate": results["gradient"],
            "esc/hpf_output": results["cost_hpf"],
            "esc/lpf_output": results["lpf_output"],
            "esc/m2_sqrt": results["grad_mag"],
            "esc/update_step": results["output"] - previous_snd
        }
        self.experiment.logger.log(logs, step=self.experiment.n_iters_performed)

## Env Setup
---

In [8]:
# 1. EXPERIMENT LOGIC

def setup(task_name):
    benchmarl.models.model_config_registry.update(
        {
            "hetcontrolmlpempirical": HetControlMlpEmpiricalConfig,
        }
    )
    if task_name == "vmas/navigation":
        # Set the render callback for the navigation case study
        VmasTask.render_callback = render_callback

def get_experiment(cfg: DictConfig) -> Experiment:
    hydra_choices = HydraConfig.get().runtime.choices
    task_name = hydra_choices.task
    algorithm_name = hydra_choices.algorithm

    setup(task_name)

    print(f"\nAlgorithm: {algorithm_name}, Task: {task_name}")
    # print("\nLoaded config:\n") # Optional: Commented out to reduce clutter
    # print(OmegaConf.to_yaml(cfg))

    algorithm_config = load_algorithm_config_from_hydra(cfg.algorithm)
    experiment_config = load_experiment_config_from_hydra(cfg.experiment)
    task_config = load_task_config_from_hydra(cfg.task, task_name)
    critic_model_config = load_model_config_from_hydra(cfg.critic_model)
    model_config = load_model_config_from_hydra(cfg.model)

    if isinstance(algorithm_config, (MappoConfig, IppoConfig, MasacConfig, IsacConfig)):
        model_config.probabilistic = True
        model_config.scale_mapping = algorithm_config.scale_mapping
        algorithm_config.scale_mapping = (
            "relu"  # The scaling of std_dev will be done in the model
        )
    else:
        model_config.probabilistic = False

    experiment = Experiment(
        task=task_config,
        algorithm_config=algorithm_config,
        model_config=model_config,
        critic_model_config=critic_model_config,
        seed=cfg.seed,
        config=experiment_config,
        callbacks=[
            SndCallback(),
            ESCallback(
                control_group="agents",
                initial_snd=0.0,
                dither_magnitude=0.25,            
                dither_frequency_rad_s=1.0,
                integral_gain=-0.01,             
                high_pass_cutoff_rad_s=1.0,
                low_pass_cutoff_rad_s=1.0,
                sampling_period=1.0,
                use_adapter=True,               
            ),
            SNDVisualizerCallback(),
            # TrajectorySNDLoggerCallback(control_group="agents"),
            NormLoggerCallback(),
            ActionSpaceLoss(
                use_action_loss=cfg.use_action_loss, action_loss_lr=cfg.action_loss_lr
            ),
        ]
        + (
            [
                TagCurriculum(
                    cfg.simple_tag_freeze_policy_after_frames,
                    cfg.simple_tag_freeze_policy,
                )
            ]
            if task_name == "vmas/simple_tag"
            else []
        ),
    )
    return experiment

## Runner Code
---

In [None]:
ABS_CONFIG_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf"
CONFIG_NAME = "navigation_ippo"  # Make sure 'navigation_ippo.yaml' exists in the folder above!
SAVE_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/model_checkpoints/navigation_ippo_esc/"

save_interval = 600000
desired_snd = 0.0
max_frame = 12000000

if not os.path.exists(SAVE_PATH):
    print(f"Creating missing directory: {SAVE_PATH}")
    os.makedirs(SAVE_PATH, exist_ok=True)

GlobalHydra.instance().clear()

sys.argv = [
    "dummy.py",
    f"model.desired_snd={desired_snd}",
    f"experiment.max_n_frames={max_frame}",
    f"experiment.checkpoint_interval={save_interval}",
    f"experiment.save_folder={SAVE_PATH}",
    f"task.agents_with_same_goal=2",
    f"task.n_agents=2",
]

# 3. Define the Hydra wrapper
@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def hydra_experiment(cfg: DictConfig) -> None:
    print(f"Config loaded from: {ABS_CONFIG_PATH}")
    if wandb.run is not None:
        print("Finishing previous WandB run...")
        wandb.finish()
    
    print(f"Running with SND: {cfg.model.desired_snd}")
    
    experiment = get_experiment(cfg=cfg)
    experiment.run()
    wandb.finish()

# 4. Execute safely
if __name__ == "__main__":
    try:
        hydra_experiment()
    except SystemExit:
        print("Experiment finished successfully.")
    except Exception as e:
        print(f"An error occurred: {e}")

Config loaded from: /home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf
Running with SND: 0.0

Algorithm: ippo, Task: vmas/navigation


Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.



--- INTERCEPTING WANDB INIT ---
Forced ID: AD2C_Eval_1764922658
-------------------------------



[34m[1mwandb[0m: Currently logged in as: [33msvarp[0m ([33msvarp-university-of-massachusetts-lowell[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


✅ ESC Initialized for 'agents'






[ESC] Updated SND: 0.0000 (Reward: -0.326)




[ESC] Updated SND: 0.2327 (Reward: 0.541)




[ESC] Updated SND: 0.2994 (Reward: 1.253)




[ESC] Updated SND: 0.1294 (Reward: 1.430)




[ESC] Updated SND: 0.0000 (Reward: -0.202)




[ESC] Updated SND: 0.0000 (Reward: 0.447)




[ESC] Updated SND: 0.0750 (Reward: 1.122)




[ESC] Updated SND: 0.3332 (Reward: 1.823)




[ESC] Updated SND: 0.4277 (Reward: 1.548)




[ESC] Updated SND: 0.2889 (Reward: 1.576)




[ESC] Updated SND: 0.0520 (Reward: 1.464)




[ESC] Updated SND: 0.0000 (Reward: -2.495)




[ESC] Updated SND: 0.1365 (Reward: -2.405)




[ESC] Updated SND: 0.3956 (Reward: -0.416)




[ESC] Updated SND: 0.5695 (Reward: 1.014)




[ESC] Updated SND: 0.4990 (Reward: 0.600)




[ESC] Updated SND: 0.2686 (Reward: 0.935)




[ESC] Updated SND: 0.1066 (Reward: 0.510)




[ESC] Updated SND: 0.1705 (Reward: -0.024)




[ESC] Updated SND: 0.4024 (Reward: 0.837)




[ESC] Updated SND: 0.6147 (Reward: 1.602)




[ESC] Updated SND: 0.5958 (Reward: 0.840)




[ESC] Updated SND: 0.3845 (Reward: 1.232)




[ESC] Updated SND: 0.1635 (Reward: 1.623)




[ESC] Updated SND: 0.1404 (Reward: 1.587)




[ESC] Updated SND: 0.3297 (Reward: 1.691)




[ESC] Updated SND: 0.5530 (Reward: 1.684)


