## Imports

---

In [None]:
import sys
import os
import time
import pickle
import io

# Hydra
import hydra
from hydra.core.hydra_config import HydraConfig
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig, OmegaConf

# WandB / Logging
import wandb

# BenchMARL
import benchmarl.models
from benchmarl.algorithms import *
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment
from benchmarl.hydra_config import (
    load_algorithm_config_from_hydra,
    load_experiment_config_from_hydra,
    load_task_config_from_hydra,
    load_model_config_from_hydra,
)
from benchmarl.experiment.callback import Callback

# Het-Control
from het_control.callback import *
from het_control.environments.vmas import render_callback
from het_control.models.het_control_mlp_empirical import (
    HetControlMlpEmpiricalConfig,
    HetControlMlpEmpirical,
)
from het_control.callbacks.sndESLogger import TrajectorySNDLoggerCallback
from het_control.callbacks.utils import *
from het_control.snd import compute_behavioral_distance

# Scientific
import numpy as np
import torch
from tensordict import TensorDict, TensorDictBase
from typing import List, Dict, Any, Callable, Union

# Visualization
import matplotlib.pyplot as plt
import networkx as nx
from PIL import Image



## SND Visualization Plot
---

In [None]:
class SNDHeatmapVisualizer:
    def __init__(self, key_name="Visuals/SND_Heatmap"):
        self.key_name = key_name

    def generate(self, snd_matrix, step_count):
        # --- Fix diagonal ---
        snd_matrix = snd_matrix.copy()
        np.fill_diagonal(snd_matrix, 0.0)

        # --- Enforce symmetry ---
        snd_matrix = (snd_matrix + snd_matrix.T) / 2.0

        n_agents = snd_matrix.shape[0]
        agent_labels = [f"Agent {i+1}" for i in range(n_agents)]
        
        # --- Compute SND = average pairwise distance ---
        iu = np.triu_indices(n_agents, k=1)     # upper triangle (i < j)
        snd_value = float(np.mean(snd_matrix[iu]))

        fig, ax = plt.subplots(figsize=(6, 5))

        im = ax.imshow(
            snd_matrix,
            cmap="viridis",
            interpolation="nearest",
            vmin=0, vmax=3
        )

        # --- Updated title with SND ---
        ax.set_title(f"SND: {snd_value:.3f}  –  Step {step_count}")

        ax.set_xticks(np.arange(n_agents))
        ax.set_yticks(np.arange(n_agents))
        ax.set_xticklabels(agent_labels)
        ax.set_yticklabels(agent_labels)
        plt.setp(ax.get_xticklabels(), rotation=30, ha="right")

        fig.colorbar(im, ax=ax, label="Distance")

        # Cell labels
        for i in range(n_agents):
            for j in range(n_agents):
                val = snd_matrix[i, j]
                text_color = "white" if val < 1.0 else "black"

                ax.text(
                    j, i, f"{val:.2f}",
                    ha="center", va="center",
                    color=text_color,
                    fontsize=9, fontweight="bold"
                )

        plt.tight_layout()
        img = wandb.Image(fig)
        plt.close(fig)
        return {self.key_name: img}


class SNDBarChartVisualizer:
    def __init__(self, key_name="Visuals/SND_BarChart"):
        self.key_name = key_name

    def generate(self, snd_matrix, step_count):
        n_agents = snd_matrix.shape[0]

        # --- Fix diagonal ---
        snd_matrix = snd_matrix.copy()
        np.fill_diagonal(snd_matrix, 0.0)

        # --- Enforce symmetry ---
        snd_matrix = (snd_matrix + snd_matrix.T) / 2.0

        # --- Create agent pairs i < j ---
        pairs = [(i, j) for i in range(n_agents) for j in range(i + 1, n_agents)]
        if not pairs:
            return {}

        pair_values = [float(snd_matrix[i, j]) for i, j in pairs]
        pair_labels = [f"A{i+1}-A{j+1}" for i, j in pairs]

        # --- Compute SND (mean of pairwise distances) ---
        snd_value = float(np.mean(pair_values))

        # --- Plot ---
        fig, ax = plt.subplots(figsize=(8, 5))
        bars = ax.bar(pair_labels, pair_values, color="teal")

        ax.set_title(f"SND: {snd_value:.3f}  –  Step {step_count}")
        ax.set_ylabel("Distance")
        ax.set_ylim(0, 3)
        ax.tick_params(axis="x", rotation=45)

        # Add value labels above bars
        ax.bar_label(bars, fmt="%.2f", padding=3)

        plt.tight_layout()
        img = wandb.Image(fig)
        plt.close(fig)
        return {self.key_name: img}


class SNDGraphVisualizer:
    def __init__(self, key_name="Visuals/SND_NetworkGraph"):
        self.key_name = key_name

    def generate(self, snd_matrix, step_count):
        n_agents = snd_matrix.shape[0]

        # --- Fix diagonal ---
        snd_matrix = snd_matrix.copy()
        np.fill_diagonal(snd_matrix, 0.0)

        # --- Enforce symmetry ---
        snd_matrix = (snd_matrix + snd_matrix.T) / 2.0

        # --- Create edges only for i < j ---
        pairs = [(i, j) for i in range(n_agents) for j in range(i + 1, n_agents)]
        if not pairs:
            return {}

        # Distances for each pair
        pair_values = [float(snd_matrix[i, j]) for i, j in pairs]

        # --- Compute SND (avg distance) ---
        snd_value = float(np.mean(pair_values))

        # Build graph
        fig = plt.figure(figsize=(7, 7))
        G = nx.Graph()

        # Add edges with weights
        for i, j in pairs:
            G.add_edge(i, j, weight=float(snd_matrix[i, j]))

        # Layout
        pos = nx.spring_layout(G, seed=42)

        # Edge weights for coloring
        weights = [G[u][v]['weight'] for u, v in G.edges()]

        # --- Draw Nodes ---
        nx.draw_networkx_nodes(
            G, pos, node_size=750, node_color='lightblue'
        )

        # Label nodes as Agent 1, Agent 2, ...
        label_mapping = {i: f"A{i+1}" for i in range(n_agents)}
        nx.draw_networkx_labels(
            G, pos, labels=label_mapping, font_size=12, font_weight='bold'
        )

        # --- Draw edges ---
        edges = nx.draw_networkx_edges(
            G, pos,
            edge_color=weights,
            edge_cmap=plt.cm.viridis,
            width=2,
            edge_vmin=0,
            edge_vmax=3
        )

        # --- Draw edge labels ---
        edge_labels = {(i, j): f"{snd_matrix[i, j]:.2f}" for i, j in pairs}

        nx.draw_networkx_edge_labels(
            G,
            pos,
            edge_labels=edge_labels,
            font_color='black',
            font_size=9,
            font_weight='bold'
        )

        # Colorbar
        plt.colorbar(edges, label='Distance')

        # Title with SND value
        plt.title(f"SND: {snd_value:.3f}  –  Step {step_count}", fontsize=14)
        plt.axis('off')

        img = wandb.Image(fig)
        plt.close(fig)
        return {self.key_name: img}


class SNDVisualizationManager:
    """
    Manages the individual visualizers.
    """
    def __init__(self):
        self.visualizers = [
            SNDHeatmapVisualizer(),
            SNDBarChartVisualizer(),
            SNDGraphVisualizer()
        ]

    def generate_all(self, snd_matrix, step_count):
        all_plots = {}
        for visualizer in self.visualizers:
            try:
                plots = visualizer.generate(snd_matrix, step_count)
                all_plots.update(plots)
            except Exception as e:
                print(f"Error generating {visualizer.__class__.__name__}: {e}")
        return all_plots


In [None]:
class SNDVisualizerCallback(Callback):
    """
    Computes the SND matrix and uses the Manager to log visualizations.
    """
    def __init__(self):
        super().__init__()
        self.control_group = None
        self.model = None
        # Initialize the manager that holds the 3 plot classes
        self.viz_manager = SNDVisualizationManager()

    def on_setup(self):
        """Auto-detects the agent group and initializes the model wrapper."""
        if not self.experiment.group_policies:
            print("\nWARNING: No group policies found. SND Visualizer disabled.\n")
            return

        self.control_group = list(self.experiment.group_policies.keys())[0]
        policy = self.experiment.group_policies[self.control_group]
        
        # Ensure 'get_het_model' is imported or available in this scope
        self.model = get_het_model(policy)

        if self.model is None:
             print(f"\nWARNING: Could not extract HetModel for group '{self.control_group}'. Visualizer disabled.\n")

    def _get_agent_actions_for_rollout(self, rollout):
        """Helper to run the forward pass and get actions for SND computation."""
        obs = rollout.get((self.control_group, "observation"))
        actions = []
        for i in range(self.model.n_agents):
            temp_td = TensorDict(
                {(self.control_group, "observation"): obs},
                batch_size=obs.shape[:-1]
            )
            action_td = self.model._forward(temp_td, agent_index=i, compute_estimate=False)
            actions.append(action_td.get(self.model.out_key))
        return actions

    def on_evaluation_end(self, rollouts: List[TensorDict]):
        """Runs at the end of evaluation to compute SND and log plots."""
        if self.model is None:
            return

        logs_to_push = {}
        first_rollout_snd_matrix = None

        with torch.no_grad():
            for i, r in enumerate(rollouts):
                # We only need the matrix from the first rollout for clean visualization
                if i > 0: 
                    break

                agent_actions = self._get_agent_actions_for_rollout(r)
                
                # Ensure 'compute_behavioral_distance' is imported/available
                pairwise_distances_tensor = compute_behavioral_distance(agent_actions, just_mean=False)
                
                if pairwise_distances_tensor.ndim > 2:
                    pairwise_distances_tensor = pairwise_distances_tensor.mean(dim=0)

                first_rollout_snd_matrix = pairwise_distances_tensor.cpu().numpy()

        # Generate and Log Visualizations via the Manager
        if first_rollout_snd_matrix is not None:
            visual_logs = self.viz_manager.generate_all(
                snd_matrix=first_rollout_snd_matrix, 
                step_count=self.experiment.n_iters_performed
            )
            logs_to_push.update(visual_logs)
            
            # Update the logger
            self.experiment.logger.log(logs_to_push, step=self.experiment.n_iters_performed)

## ESC Controller
---

In [None]:
import numpy as np
from collections import deque

class ESCStabilityMonitor:
    """
    Tracks ESC internal signals across steps and determines stability state.
    """

    def __init__(self, window=20):
        self.grad_hist = deque(maxlen=window)
        self.snd_hist = deque(maxlen=window)
        self.hpf_hist = deque(maxlen=window)
        self.lpf_hist = deque(maxlen=window)

    # ------------------------------------------------------------------
    def update(self, snd, grad, hpf, lpf):
        self.snd_hist.append(float(snd))
        self.grad_hist.append(float(grad))
        self.hpf_hist.append(float(hpf))
        self.lpf_hist.append(float(lpf))

        return self.compute_stability()

    # ------------------------------------------------------------------
    def compute_stability(self):
        """
        Outputs:
            stability_score ∈ [0,1]
            stability_flag ∈ {"stable","oscillating","diverging","stalled","dither_spike"}
            osc_strength = mean abs 2nd diff of SND
        """

        if len(self.snd_hist) < 5:
            return 1.0, "stable", 0.0

        snd = np.array(self.snd_hist)
        grad = np.array(self.grad_hist)
        hpf = np.array(self.hpf_hist)
        lpf = np.array(self.lpf_hist)

        # -------------------------
        # 1. Oscillation Detection
        # -------------------------
        snd_second_diff = np.abs(np.diff(snd, n=2)).mean()
        osc_strength = snd_second_diff

        if osc_strength > 0.8:
            return 0.3, "oscillating", osc_strength

        # -------------------------
        # 2. Divergence Detection
        # -------------------------
        if snd[-1] > 2.5 * np.median(snd):
            return 0.1, "diverging", osc_strength

        # -------------------------
        # 3. Gradient stall
        # -------------------------
        if np.mean(np.abs(grad[-5:])) < 1e-4:
            if osc_strength > 0.3:
                return 0.4, "stalled", osc_strength

        # -------------------------
        # 4. Sudden spikes in filters
        # -------------------------
        if np.abs(hpf[-1]) > 5 * np.std(hpf) or np.abs(lpf[-1]) > 5 * np.std(lpf):
            return 0.5, "dither_spike", osc_strength

        # -------------------------
        # Otherwise stable
        # -------------------------
        return 1.0, "stable", osc_strength


In [None]:
import numpy as np

class ExtremumSeekingCore:
    """
    Clean, stable ESC implementation returning EXACT variables expected
    by your logging format.
    """

    def __init__(
        self,
        sampling_period: float,
        disturbance_frequency: float,
        disturbance_magnitude: float,
        integrator_gain: float,
        initial_search_value: float,
        high_pass_cutoff_frequency: float,
        low_pass_cutoff_frequency: float,
        use_adapter: bool = True,
    ):
        self.T = sampling_period
        self.omega = disturbance_frequency
        self.a = disturbance_magnitude
        self.k_i = integrator_gain
        self.u = initial_search_value
        self.use_adapter = use_adapter

        # Filters
        self.hpf_alpha = np.exp(-high_pass_cutoff_frequency * self.T)
        self.lpf_alpha = np.exp(-low_pass_cutoff_frequency * self.T)

        # States
        self.prev_cost = 0.0
        self.hpf_prev = 0.0
        self.lpf_prev = 0.0
        self.phase = 0.0

    # ------------------------------------------------------------------
    def update(self, cost: float):
        """
        Returns the EXACT SIX VARIABLES required by the callback:
            uk, hpf_out, lpf_out, m2_sqrt, gradient, setpoint
        """

        setpoint = float(self.u)  # preserve previous SND

        # --------------------------------------------------
        # High-pass filter
        hpf_out = (cost - self.prev_cost) + self.hpf_alpha * self.hpf_prev
        self.prev_cost = cost
        self.hpf_prev = hpf_out

        # --------------------------------------------------
        # Dither signal
        self.phase += self.omega * self.T
        m = self.a * np.sin(self.phase)

        # --------------------------------------------------
        # Raw gradient estimator
        gradient_raw = hpf_out * m

        # Low-pass filter gradient
        lpf_out = (1 - self.lpf_alpha) * gradient_raw + self.lpf_alpha * self.lpf_prev
        self.lpf_prev = lpf_out

        gradient = lpf_out  # naming consistency with your log format

        # --------------------------------------------------
        # "m2_sqrt" = amplitude of the dither for logging
        m2_sqrt = abs(m)

        # --------------------------------------------------
        # Update control variable (minimize cost)
        self.u = self.u - self.k_i * gradient
        self.u = float(np.maximum(self.u, 0.0))  # never negative

        uk = float(self.u)

        return uk, hpf_out, lpf_out, m2_sqrt, gradient, setpoint


In [None]:
class ExtremumSeekingController(Callback):
    """
    Updated callback adding:
    - Stability monitor
    - Max-reward-based adaptive integrator gain schedule
    - NO CHANGES to required WandB key names
    """

    def __init__(
        self,
        control_group: str,
        initial_snd: float,
        dither_magnitude: float,
        dither_frequency_rad_s: float,
        integral_gain: float,
        high_pass_cutoff_rad_s: float,
        low_pass_cutoff_rad_s: float,
        use_adapter: bool = True,
        sampling_period: float = 1.0,

        # --- NEW ESC tuning knobs ---
        stability_window: int = 20,
        amplitude_threshold: float = 0.15,
        signflip_threshold: int = 6,

        adaptive_gain_alpha: float = 0.7,
        gain_min: float = 0.1,
        gain_max: float = 3.0,
    ):
        super().__init__()

        self.control_group = control_group
        self.initial_snd = initial_snd

        # save original gain
        self.base_gain = integral_gain

        # adaptive gain params
        self.adaptive_gain_alpha = adaptive_gain_alpha
        self.gain_min = gain_min
        self.gain_max = gain_max

        self.max_reward = -np.inf  # ← NEW

        # ESC config
        self.esc_params = {
            "sampling_period": sampling_period,
            "disturbance_frequency": dither_frequency_rad_s,
            "disturbance_magnitude": dither_magnitude,
            "integrator_gain": integral_gain,
            "initial_search_value": initial_snd,
            "high_pass_cutoff_frequency": high_pass_cutoff_rad_s,
            "low_pass_cutoff_frequency": low_pass_cutoff_rad_s,
            "use_adapter": use_adapter,
        }

        self.controller = None
        self.model = None

        # stability monitor
        self.snd_history = []
        self.window = stability_window
        self.amp_thresh = amplitude_threshold
        self.signflip_thresh = signflip_threshold
        self.prev_val = None
        self.prev_deriv_sign = 0
        self.sign_flips = 0

    # ----------------------------------------------------------------------
    def _update_adaptive_gain(self, reward_mean):
        """Adaptive ESC integrator gain based on max reward tracking."""

        # update max reward
        if reward_mean > self.max_reward:
            self.max_reward = reward_mean

        # avoid division by zero or negative scenarios
        if self.max_reward <= 1e-6:
            return self.base_gain

        # ratio of current to best-so-far reward (0 → 1)
        reward_ratio = np.clip(reward_mean / self.max_reward, 0.0, 1.0)

        # adaptive gain schedule
        gain = self.base_gain * (1 - self.adaptive_gain_alpha * (1 - reward_ratio))

        # clamp for safety
        gain = float(np.clip(gain, self.gain_min * self.base_gain,
                                   self.gain_max * self.base_gain))

        # update ESC core
        self.controller.integrator_gain = gain

        return gain

    # ----------------------------------------------------------------------
    def _update_stability_monitor(self, val):
        """Online ESC oscillation detector."""
        self.snd_history.append(val)
        if len(self.snd_history) > self.window:
            self.snd_history.pop(0)

        if self.prev_val is not None:
            deriv = val - self.prev_val
            sign = 1 if deriv > 0 else -1 if deriv < 0 else 0
            if self.prev_deriv_sign != 0 and sign != self.prev_deriv_sign:
                self.sign_flips += 1
            self.prev_deriv_sign = sign

        self.prev_val = val

        if len(self.snd_history) < self.window:
            return 0

        peak_to_peak = max(self.snd_history) - min(self.snd_history)
        amp_flag = peak_to_peak > self.amp_thresh
        flip_flag = self.sign_flips >= self.signflip_thresh

        return 1 if (amp_flag or flip_flag) else 0

    # ----------------------------------------------------------------------
    def on_setup(self):

        self.experiment.logger.log_hparams(**self.esc_params)

        if self.control_group not in self.experiment.group_policies:
            print(f"WARNING: ESC group '{self.control_group}' not found.")
            return

        policy = self.experiment.group_policies[self.control_group]
        self.model = get_het_model(policy)

        if isinstance(self.model, HetControlMlpEmpirical):
            print(f"✅ ESC attached to group '{self.control_group}'.")
            self.controller = ExtremumSeekingCore(**self.esc_params)
            self.model.desired_snd[:] = float(self.initial_snd)
        else:
            print(f"WARNING: ESC disabled for '{self.control_group}'.")
            self.model = None

    # ----------------------------------------------------------------------
    def on_evaluation_end(self, rollouts):
        if self.model is None or self.controller is None:
            return

        # extract rewards
        episode_rewards = []
        key = ("next", self.control_group, "reward")

        for r in rollouts:
            if key in r.keys(include_nested=True):
                episode_rewards.append(r.get(key).sum().item())

        if len(episode_rewards) == 0:
            print("WARNING: No rewards found this eval.")
            return

        reward_mean = float(np.mean(episode_rewards))
        cost = -reward_mean

        # -------------------------------------------------------
        # NEW: Adaptive gain based on reward performance
        adaptive_gain = self._update_adaptive_gain(reward_mean)

        # -------------------------------------------------------
        # ESC update
        (
            uk,
            hpf_out,
            lpf_out,
            m2_sqrt,
            gradient,
            setpoint
        ) = self.controller.update(cost)

        previous_snd = self.model.desired_snd.item()
        self.model.desired_snd[:] = uk

        # NEW: Stability monitor
        stability_flag = self._update_stability_monitor(uk)

        # -------------------------------------------------------
        # REQUIRED: Original WandB keys — DO NOT MODIFY
        logs_to_push = {
            "esc/mean_reward": reward_mean,
            "esc/cost": cost,
            "esc/diversity_output": uk,
            "esc/diversity_setpoint": setpoint,
            "esc/gradient_estimate": gradient,
            "esc/hpf_output": hpf_out,
            "esc/lpf_output": lpf_out,
            "esc/m2_sqrt": m2_sqrt,
            "esc/update_step": uk - previous_snd,
        }

        # -------------------------------------------------------
        # SAFE NEW KEYS
        logs_to_push.update({
            "esc/stability_flag": stability_flag,
            "esc/max_reward_seen": self.max_reward,
            "esc/adaptive_gain": adaptive_gain,
        })

        self.experiment.logger.log(
            logs_to_push,
            step=self.experiment.n_iters_performed
        )

        print(
            f"[ESC] SND {previous_snd:.4f} → {uk:.4f} | "
            f"grad {gradient:.4f} | reward {reward_mean:.3f} | "
            f"gain {adaptive_gain:.4f} | stable={1 - stability_flag}"
        )


## Env Setup
---

In [None]:
# 1. EXPERIMENT LOGIC

def setup(task_name):
    benchmarl.models.model_config_registry.update(
        {
            "hetcontrolmlpempirical": HetControlMlpEmpiricalConfig,
        }
    )
    if task_name == "vmas/navigation":
        # Set the render callback for the navigation case study
        VmasTask.render_callback = render_callback

def get_experiment(cfg: DictConfig) -> Experiment:
    hydra_choices = HydraConfig.get().runtime.choices
    task_name = hydra_choices.task
    algorithm_name = hydra_choices.algorithm

    setup(task_name)

    print(f"\nAlgorithm: {algorithm_name}, Task: {task_name}")
    # print("\nLoaded config:\n") # Optional: Commented out to reduce clutter
    # print(OmegaConf.to_yaml(cfg))

    algorithm_config = load_algorithm_config_from_hydra(cfg.algorithm)
    experiment_config = load_experiment_config_from_hydra(cfg.experiment)
    task_config = load_task_config_from_hydra(cfg.task, task_name)
    critic_model_config = load_model_config_from_hydra(cfg.critic_model)
    model_config = load_model_config_from_hydra(cfg.model)

    if isinstance(algorithm_config, (MappoConfig, IppoConfig, MasacConfig, IsacConfig)):
        model_config.probabilistic = True
        model_config.scale_mapping = algorithm_config.scale_mapping
        algorithm_config.scale_mapping = (
            "relu"  # The scaling of std_dev will be done in the model
        )
    else:
        model_config.probabilistic = False

    experiment = Experiment(
        task=task_config,
        algorithm_config=algorithm_config,
        model_config=model_config,
        critic_model_config=critic_model_config,
        seed=cfg.seed,
        config=experiment_config,
        callbacks=[
            SndCallback(),
            ExtremumSeekingController(
                        control_group="agents",
                        initial_snd=0.0,
                        dither_magnitude=0.2,
                        dither_frequency_rad_s=1.0,
                        integral_gain=-0.1,
                        high_pass_cutoff_rad_s=1.0,
                        low_pass_cutoff_rad_s=1.0,
                        sampling_period=1.0
            ),
            SNDVisualizerCallback(),
            # TrajectorySNDLoggerCallback(control_group="agents"),
            NormLoggerCallback(),
            ActionSpaceLoss(
                use_action_loss=cfg.use_action_loss, action_loss_lr=cfg.action_loss_lr
            ),
        ]
        + (
            [
                TagCurriculum(
                    cfg.simple_tag_freeze_policy_after_frames,
                    cfg.simple_tag_freeze_policy,
                )
            ]
            if task_name == "vmas/simple_tag"
            else []
        ),
    )
    return experiment

## Runner Code
---

In [None]:
ABS_CONFIG_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf"
CONFIG_NAME = "navigation_ippo"  # Make sure 'navigation_ippo.yaml' exists in the folder above!
SAVE_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/model_checkpoints/navigation_ippo_esc/"

save_interval = 600000
desired_snd = 1.0
max_frame = 6000000

if not os.path.exists(SAVE_PATH):
    print(f"Creating missing directory: {SAVE_PATH}")
    os.makedirs(SAVE_PATH, exist_ok=True)

GlobalHydra.instance().clear()

sys.argv = [
    "dummy.py",
    f"model.desired_snd={desired_snd}",
    f"experiment.max_n_frames={max_frame}",
    f"experiment.checkpoint_interval={save_interval}",
    f"experiment.save_folder={SAVE_PATH}",
]

# 3. Define the Hydra wrapper
@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def hydra_experiment(cfg: DictConfig) -> None:
    print(f"Config loaded from: {ABS_CONFIG_PATH}")
    if wandb.run is not None:
        print("Finishing previous WandB run...")
        wandb.finish()
    
    print(f"Running with SND: {cfg.model.desired_snd}")
    
    experiment = get_experiment(cfg=cfg)
    experiment.run()
    wandb.finish()

# 4. Execute safely
if __name__ == "__main__":
    try:
        hydra_experiment()
    except SystemExit:
        print("Experiment finished successfully.")
    except Exception as e:
        print(f"An error occurred: {e}")

Config loaded from: /home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf
Running with SND: 1.0

Algorithm: ippo, Task: vmas/navigation


Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
[34m[1mwandb[0m: Currently logged in as: [33msvarp[0m ([33msvarp-university-of-massachusetts-lowell[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


✅ ESC attached to group 'agents'.






[ESC] SND 0.0000 → 0.0000 (grad -0.0412, reward 0.387, stable=1)




[ESC] SND 0.0000 → 0.0000 (grad -0.1049, reward 1.026, stable=1)




[ESC] SND 0.0000 → 0.0000 (grad -0.0513, reward 1.449, stable=1)




[ESC] SND 0.0000 → 0.0025 (grad 0.0247, reward 1.644, stable=1)




[ESC] SND 0.0025 → 0.0070 (grad 0.0457, reward 1.778, stable=1)




[ESC] SND 0.0070 → 0.0096 (grad 0.0252, reward 1.904, stable=1)




[ESC] SND 0.0096 → 0.0102 (grad 0.0062, reward 1.854, stable=1)




[ESC] SND 0.0102 → 0.0087 (grad -0.0152, reward 1.980, stable=1)




[ESC] SND 0.0087 → 0.0086 (grad -0.0009, reward 1.839, stable=1)




[ESC] SND 0.0086 → 0.0105 (grad 0.0196, reward 2.162, stable=1)




[ESC] SND 0.0105 → 0.0107 (grad 0.0013, reward 2.008, stable=1)




[ESC] SND 0.0107 → 0.0112 (grad 0.0057, reward 2.102, stable=1)




[ESC] SND 0.0112 → 0.0116 (grad 0.0032, reward 2.053, stable=1)


