## Imports

---

In [None]:
import sys
import os
import time
import pickle
import io

# Hydra
import hydra
from hydra.core.hydra_config import HydraConfig
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig, OmegaConf

# WandB / Logging
import wandb

# BenchMARL
import benchmarl.models
from benchmarl.algorithms import *
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment
from benchmarl.hydra_config import (
    load_algorithm_config_from_hydra,
    load_experiment_config_from_hydra,
    load_task_config_from_hydra,
    load_model_config_from_hydra,
)
from benchmarl.experiment.callback import Callback

# Het-Control
from het_control.callback import *
from het_control.environments.vmas import render_callback
from het_control.models.het_control_mlp_empirical import (
    HetControlMlpEmpiricalConfig,
    HetControlMlpEmpirical,
)
from het_control.callbacks.sndESLogger import TrajectorySNDLoggerCallback
from het_control.callbacks.utils import *
from het_control.snd import compute_behavioral_distance

# Scientific
import numpy as np
import torch
from tensordict import TensorDict, TensorDictBase
from typing import List, Dict, Any, Callable, Union

# Visualization
import matplotlib.pyplot as plt
import networkx as nx
from PIL import Image

## SND Visualization Plot
---

In [None]:
class SNDHeatmapVisualizer:
    def __init__(self, key_name="Visuals/SND_Heatmap"):
        self.key_name = key_name

    def generate(self, snd_matrix, step_count):
        # --- Fix diagonal ---
        snd_matrix = snd_matrix.copy()
        np.fill_diagonal(snd_matrix, 0.0)

        # --- Enforce symmetry ---
        snd_matrix = (snd_matrix + snd_matrix.T) / 2.0

        n_agents = snd_matrix.shape[0]
        agent_labels = [f"Agent {i+1}" for i in range(n_agents)]
        
        # --- Compute SND = average pairwise distance ---
        iu = np.triu_indices(n_agents, k=1)     # upper triangle (i < j)
        snd_value = float(np.mean(snd_matrix[iu]))

        fig, ax = plt.subplots(figsize=(6, 5))

        im = ax.imshow(
            snd_matrix,
            cmap="viridis",
            interpolation="nearest",
            vmin=0, vmax=3
        )

        # --- Updated title with SND ---
        ax.set_title(f"SND: {snd_value:.3f}  –  Step {step_count}")

        ax.set_xticks(np.arange(n_agents))
        ax.set_yticks(np.arange(n_agents))
        ax.set_xticklabels(agent_labels)
        ax.set_yticklabels(agent_labels)
        plt.setp(ax.get_xticklabels(), rotation=30, ha="right")

        fig.colorbar(im, ax=ax, label="Distance")

        # Cell labels
        for i in range(n_agents):
            for j in range(n_agents):
                val = snd_matrix[i, j]
                text_color = "white" if val < 1.0 else "black"

                ax.text(
                    j, i, f"{val:.2f}",
                    ha="center", va="center",
                    color=text_color,
                    fontsize=9, fontweight="bold"
                )

        plt.tight_layout()
        img = wandb.Image(fig)
        plt.close(fig)
        return {self.key_name: img}


class SNDBarChartVisualizer:
    def __init__(self, key_name="Visuals/SND_BarChart"):
        self.key_name = key_name

    def generate(self, snd_matrix, step_count):
        n_agents = snd_matrix.shape[0]

        # --- Fix diagonal ---
        snd_matrix = snd_matrix.copy()
        np.fill_diagonal(snd_matrix, 0.0)

        # --- Enforce symmetry ---
        snd_matrix = (snd_matrix + snd_matrix.T) / 2.0

        # --- Create agent pairs i < j ---
        pairs = [(i, j) for i in range(n_agents) for j in range(i + 1, n_agents)]
        if not pairs:
            return {}

        pair_values = [float(snd_matrix[i, j]) for i, j in pairs]
        pair_labels = [f"A{i+1}-A{j+1}" for i, j in pairs]

        # --- Compute SND (mean of pairwise distances) ---
        snd_value = float(np.mean(pair_values))

        # --- Plot ---
        fig, ax = plt.subplots(figsize=(8, 5))
        bars = ax.bar(pair_labels, pair_values, color="teal")

        ax.set_title(f"SND: {snd_value:.3f}  –  Step {step_count}")
        ax.set_ylabel("Distance")
        ax.set_ylim(0, 3)
        ax.tick_params(axis="x", rotation=45)

        # Add value labels above bars
        ax.bar_label(bars, fmt="%.2f", padding=3)

        plt.tight_layout()
        img = wandb.Image(fig)
        plt.close(fig)
        return {self.key_name: img}


class SNDGraphVisualizer:
    def __init__(self, key_name="Visuals/SND_NetworkGraph"):
        self.key_name = key_name

    def generate(self, snd_matrix, step_count):
        n_agents = snd_matrix.shape[0]

        # --- Fix diagonal ---
        snd_matrix = snd_matrix.copy()
        np.fill_diagonal(snd_matrix, 0.0)

        # --- Enforce symmetry ---
        snd_matrix = (snd_matrix + snd_matrix.T) / 2.0

        # --- Create edges only for i < j ---
        pairs = [(i, j) for i in range(n_agents) for j in range(i + 1, n_agents)]
        if not pairs:
            return {}

        # Distances for each pair
        pair_values = [float(snd_matrix[i, j]) for i, j in pairs]

        # --- Compute SND (avg distance) ---
        snd_value = float(np.mean(pair_values))

        # Build graph
        fig = plt.figure(figsize=(7, 7))
        G = nx.Graph()

        # Add edges with weights
        for i, j in pairs:
            G.add_edge(i, j, weight=float(snd_matrix[i, j]))

        # Layout
        pos = nx.spring_layout(G, seed=42)

        # Edge weights for coloring
        weights = [G[u][v]['weight'] for u, v in G.edges()]

        # --- Draw Nodes ---
        nx.draw_networkx_nodes(
            G, pos, node_size=750, node_color='lightblue'
        )

        # Label nodes as Agent 1, Agent 2, ...
        label_mapping = {i: f"A{i+1}" for i in range(n_agents)}
        nx.draw_networkx_labels(
            G, pos, labels=label_mapping, font_size=12, font_weight='bold'
        )

        # --- Draw edges ---
        edges = nx.draw_networkx_edges(
            G, pos,
            edge_color=weights,
            edge_cmap=plt.cm.viridis,
            width=2,
            edge_vmin=0,
            edge_vmax=3
        )

        # --- Draw edge labels ---
        edge_labels = {(i, j): f"{snd_matrix[i, j]:.2f}" for i, j in pairs}

        nx.draw_networkx_edge_labels(
            G,
            pos,
            edge_labels=edge_labels,
            font_color='black',
            font_size=9,
            font_weight='bold'
        )

        # Colorbar
        plt.colorbar(edges, label='Distance')

        # Title with SND value
        plt.title(f"SND: {snd_value:.3f}  –  Step {step_count}", fontsize=14)
        plt.axis('off')

        img = wandb.Image(fig)
        plt.close(fig)
        return {self.key_name: img}


class SNDVisualizationManager:
    """
    Manages the individual visualizers.
    """
    def __init__(self):
        self.visualizers = [
            SNDHeatmapVisualizer(),
            SNDBarChartVisualizer(),
            SNDGraphVisualizer()
        ]

    def generate_all(self, snd_matrix, step_count):
        all_plots = {}
        for visualizer in self.visualizers:
            try:
                plots = visualizer.generate(snd_matrix, step_count)
                all_plots.update(plots)
            except Exception as e:
                print(f"Error generating {visualizer.__class__.__name__}: {e}")
        return all_plots


In [None]:
class SNDVisualizerCallback(Callback):
    """
    Computes the SND matrix and uses the Manager to log visualizations.
    """
    def __init__(self):
        super().__init__()
        self.control_group = None
        self.model = None
        # Initialize the manager that holds the 3 plot classes
        self.viz_manager = SNDVisualizationManager()

    def on_setup(self):
        """Auto-detects the agent group and initializes the model wrapper."""
        if not self.experiment.group_policies:
            print("\nWARNING: No group policies found. SND Visualizer disabled.\n")
            return

        self.control_group = list(self.experiment.group_policies.keys())[0]
        policy = self.experiment.group_policies[self.control_group]
        
        # Ensure 'get_het_model' is imported or available in this scope
        self.model = get_het_model(policy)

        if self.model is None:
             print(f"\nWARNING: Could not extract HetModel for group '{self.control_group}'. Visualizer disabled.\n")

    def _get_agent_actions_for_rollout(self, rollout):
        """Helper to run the forward pass and get actions for SND computation."""
        obs = rollout.get((self.control_group, "observation"))
        actions = []
        for i in range(self.model.n_agents):
            temp_td = TensorDict(
                {(self.control_group, "observation"): obs},
                batch_size=obs.shape[:-1]
            )
            action_td = self.model._forward(temp_td, agent_index=i, compute_estimate=False)
            actions.append(action_td.get(self.model.out_key))
        return actions

    def on_evaluation_end(self, rollouts: List[TensorDict]):
        """Runs at the end of evaluation to compute SND and log plots."""
        if self.model is None:
            return

        logs_to_push = {}
        first_rollout_snd_matrix = None

        with torch.no_grad():
            for i, r in enumerate(rollouts):
                # We only need the matrix from the first rollout for clean visualization
                if i > 0: 
                    break

                agent_actions = self._get_agent_actions_for_rollout(r)
                
                # Ensure 'compute_behavioral_distance' is imported/available
                pairwise_distances_tensor = compute_behavioral_distance(agent_actions, just_mean=False)
                
                if pairwise_distances_tensor.ndim > 2:
                    pairwise_distances_tensor = pairwise_distances_tensor.mean(dim=0)

                first_rollout_snd_matrix = pairwise_distances_tensor.cpu().numpy()

        # Generate and Log Visualizations via the Manager
        if first_rollout_snd_matrix is not None:
            visual_logs = self.viz_manager.generate_all(
                snd_matrix=first_rollout_snd_matrix, 
                step_count=self.experiment.n_iters_performed
            )
            logs_to_push.update(visual_logs)
            
            # Update the logger
            self.experiment.logger.log(logs_to_push, step=self.experiment.n_iters_performed)

## ESC Controller
---

In [None]:
class esc:
    # cutoff_frequencies in rad/s
    def __init__(
        self,
        sampling_period,
        disturbance_frequency,
        disturbance_magnitude,
        integrator_gain,
        initial_search_value,
        high_pass_cutoff_frequency,
        low_pass_cutoff_frequency,
        use_adapter,
    ):
        self.dt = sampling_period  # in [s]
        self.disturbance_frequency = disturbance_frequency  # in rad/s
        self.disturbance_magnitude = disturbance_magnitude
        # negative for gradient descent
        self.integrator_gain = integrator_gain
        self.initial_search_value = initial_search_value
        # boolean, true or false (use the adapter or not)
        self.use_adapter = use_adapter

        self.high_pass_filter = High_pass_filter_first_order(
            sampling_period, high_pass_cutoff_frequency, 0, 0
        )
        self.low_pass_filter = Low_pass_filter_first_order(
            sampling_period, low_pass_cutoff_frequency, 0
        )
        # current phase of perturbation
        self.wt = 0

        self.min_setpoint = 0.0
        
        # integrator output
        self.integral = 0
        # estimated second moment
        self.m2 = 0
        self.b2 = 0.8
        # to prevent from dividing by zero
        self.epsilon = 1e-8

        return

    def update(self, cost):
        high_pass_output = self.high_pass_filter.apply(cost)
        low_pass_input = high_pass_output * np.sin(self.wt)
        low_pass_output = self.low_pass_filter.apply(low_pass_input)

        # if self.use_adapter:
        #     # Estimate the second moment (variance) of the gradient
        #     self.m2 = self.b2 * self.m2 + (1 - self.b2) * np.power(low_pass_output, 2)
        #     # Always normalize the gradient by its root mean square
        #     gradient = low_pass_output / (np.sqrt(self.m2) + self.epsilon)
        # else:
        #     gradient = low_pass_output

        self.m2 = self.b2 * self.m2 + (1 - self.b2) * np.power(low_pass_output, 2)
        gradient_mag = np.sqrt(self.m2)

        threshold = 0.2

        high_gain = -0.025  #0.1
        # low_gain =  -0.0015  #0.05

        if self.use_adapter:
            if gradient_mag > threshold:
                gain = high_gain
            else:
                gain = self.integrator_gain
        else:
            gain = self.integrator_gain


        self.integral += gain * low_pass_output * self.dt
        # setpoint = self.initial_search_value + self.integral
        
        # setpoint = max(setpoint, self.min_setpoint)

        setpoint_r = self.initial_search_value + self.integral

        # 2. Apply the clamp to get the actual setpoint
        setpoint = max(setpoint_r, self.min_setpoint)

        # 3. Correct the integrator state if clamping occurred
        if setpoint < self.min_setpoint:
            self.integral = self.min_setpoint - self.initial_search_value

        output = self.disturbance_magnitude * np.sin(self.wt) + setpoint

        # perturbation = self.disturbance_magnitude * np.sin(self.wt) 

        # update wt
        self.wt += self.disturbance_frequency * self.dt
        if self.wt > 2 * np.pi:
            self.wt -= 2 * np.pi

        return (
            output,
            high_pass_output,
            low_pass_output,
            gradient_mag,
            low_pass_output,
            setpoint,
        )

In [None]:
class ExtremumSeekingController(Callback):
    """
    Implements an Extremum Seeking Controller to optimize a performance metric by
    periodically adjusting the desired SND.
    """
    def __init__(
        self,
        control_group: str,
        initial_snd: float,
        
        # ESC parameters
        dither_magnitude: float, # Renamed from dither_amplitude for clarity
        dither_frequency_rad_s: float, # Explicitly state units
        integral_gain: float,
        high_pass_cutoff_rad_s: float,
        low_pass_cutoff_rad_s: float,
        use_adapter: bool = True,
        sampling_period: float = 1.0
        ):

        super().__init__()
        self.control_group = control_group
        
        self.initial_snd = initial_snd

        self.esc_params = {
            "sampling_period": sampling_period,
            "disturbance_frequency": dither_frequency_rad_s,
            "disturbance_magnitude": dither_magnitude,
            "integrator_gain": integral_gain,
            "initial_search_value": initial_snd,
            "high_pass_cutoff_frequency": high_pass_cutoff_rad_s,
            "low_pass_cutoff_frequency": low_pass_cutoff_rad_s,
            "use_adapter": use_adapter,
        }
        # Controller state variables
        self.model = None
        self.controller = None

    def on_setup(self):
        """Initializes the controller and logs hyperparameters."""
        hparams = {
            # "controller_type": "ExtremumSeeking_v2",
            "control_group": self.control_group,
            **self.esc_params
        }
        self.experiment.logger.log_hparams(**hparams)

        if self.control_group not in self.experiment.group_policies:
            print(f"\nWARNING: Controller group '{self.control_group}' not found. Disabling controller.\n")
            return

        policy = self.experiment.group_policies[self.control_group]
        self.model = get_het_model(policy)

        if isinstance(self.model, HetControlMlpEmpirical):
            print(f"\n✅ SUCCESS: Extremum Seeking Controller initialized for group '{self.control_group}'.")
            self.controller = esc(**self.esc_params)
            self.model.desired_snd[:] = float(self.initial_snd)
        else:
            print(f"\nWARNING: A compatible model was not found for group '{self.control_group}'. Disabling controller.\n")
            self.model = None

    def on_evaluation_end(self, rollouts: List[TensorDictBase]):
        if self.model is None or self.controller is None:
            return
        logs_to_push = {}
        
        # 1. Collect rewards + compute actual diversity for logging
        episode_rewards = []
        with torch.no_grad():
            for r in rollouts:
                reward_key = ('next', self.control_group, 'reward')
                total_reward = r.get(reward_key).sum().item() if reward_key in r.keys(include_nested=True) else 0
                episode_rewards.append(total_reward)

        if not episode_rewards:
            print("\nWARNING: No episode rewards found. Cannot update controller.\n")
            self.experiment.logger.log(logs_to_push, step=self.experiment.n_iters_performed)
            return

        reward_mean = np.mean(episode_rewards)
        cost = -reward_mean  # Assuming we want to maximize reward, so cost is negative reward
        
        # 2. Call the core ES function
        (
            uk, 
            hpf_out, 
            lpf_out, 
            m2_sqrt, 
            gradient, 
            setpoint
        ) = self.controller.update(cost)
        
        # 3. Update diversity parameter
        previous_snd = self.model.desired_snd.item()
        self.model.desired_snd[:] = torch.clamp(torch.tensor(uk), min=0.0)

        print(f"[ESC] Updated SND: {self.model.desired_snd.item()} "
              f"(Reward: {reward_mean:.3f}, Update Step: {uk - previous_snd:.4f})")

        # 4. Logging
        logs_to_push.update({
            "esc/mean_reward": reward_mean,
            "esc/cost": cost,
            "esc/diversity_output": uk,
            "esc/diversity_setpoint": setpoint,
            "esc/gradient_estimate": gradient,
            "esc/hpf_output": hpf_out,
            "esc/lpf_output": lpf_out,
            "esc/m2_sqrt": m2_sqrt,
            "esc/update_step": uk - previous_snd
        })
        self.experiment.logger.log(logs_to_push, step=self.experiment.n_iters_performed)


## Env Setup
---

In [None]:
# 1. EXPERIMENT LOGIC

def setup(task_name):
    benchmarl.models.model_config_registry.update(
        {
            "hetcontrolmlpempirical": HetControlMlpEmpiricalConfig,
        }
    )
    if task_name == "vmas/navigation":
        # Set the render callback for the navigation case study
        VmasTask.render_callback = render_callback

def get_experiment(cfg: DictConfig) -> Experiment:
    hydra_choices = HydraConfig.get().runtime.choices
    task_name = hydra_choices.task
    algorithm_name = hydra_choices.algorithm

    setup(task_name)

    print(f"\nAlgorithm: {algorithm_name}, Task: {task_name}")
    # print("\nLoaded config:\n") # Optional: Commented out to reduce clutter
    # print(OmegaConf.to_yaml(cfg))

    algorithm_config = load_algorithm_config_from_hydra(cfg.algorithm)
    experiment_config = load_experiment_config_from_hydra(cfg.experiment)
    task_config = load_task_config_from_hydra(cfg.task, task_name)
    critic_model_config = load_model_config_from_hydra(cfg.critic_model)
    model_config = load_model_config_from_hydra(cfg.model)

    if isinstance(algorithm_config, (MappoConfig, IppoConfig, MasacConfig, IsacConfig)):
        model_config.probabilistic = True
        model_config.scale_mapping = algorithm_config.scale_mapping
        algorithm_config.scale_mapping = (
            "relu"  # The scaling of std_dev will be done in the model
        )
    else:
        model_config.probabilistic = False

    experiment = Experiment(
        task=task_config,
        algorithm_config=algorithm_config,
        model_config=model_config,
        critic_model_config=critic_model_config,
        seed=cfg.seed,
        config=experiment_config,
        callbacks=[
            SndCallback(),
            ExtremumSeekingController(
                        control_group="agents",
                        initial_snd=0.3,
                        dither_magnitude=0.2,
                        dither_frequency_rad_s=1.0,
                        integral_gain=-0.01,
                        high_pass_cutoff_rad_s=1.0,
                        low_pass_cutoff_rad_s=1.0,
                        sampling_period=1.0
            ),
            SNDVisualizerCallback(),
            # TrajectorySNDLoggerCallback(control_group="agents"),
            NormLoggerCallback(),
            ActionSpaceLoss(
                use_action_loss=cfg.use_action_loss, action_loss_lr=cfg.action_loss_lr
            ),
        ]
        + (
            [
                TagCurriculum(
                    cfg.simple_tag_freeze_policy_after_frames,
                    cfg.simple_tag_freeze_policy,
                )
            ]
            if task_name == "vmas/simple_tag"
            else []
        ),
    )
    return experiment

## Runner Code
---

In [None]:
ABS_CONFIG_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf"
CONFIG_NAME = "navigation_ippo"  # Make sure 'navigation_ippo.yaml' exists in the folder above!
SAVE_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/model_checkpoints/navigation_ippo_esc/"

save_interval = 600000
desired_snd = 1.0
max_frame = 6000000

if not os.path.exists(SAVE_PATH):
    print(f"Creating missing directory: {SAVE_PATH}")
    os.makedirs(SAVE_PATH, exist_ok=True)

GlobalHydra.instance().clear()

sys.argv = [
    "dummy.py",
    f"model.desired_snd={desired_snd}",
    f"experiment.max_n_frames={max_frame}",
    f"experiment.checkpoint_interval={save_interval}",
    f"experiment.save_folder={SAVE_PATH}",
]

# 3. Define the Hydra wrapper
@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def hydra_experiment(cfg: DictConfig) -> None:
    print(f"Config loaded from: {ABS_CONFIG_PATH}")
    if wandb.run is not None:
        print("Finishing previous WandB run...")
        wandb.finish()
    
    print(f"Running with SND: {cfg.model.desired_snd}")
    
    experiment = get_experiment(cfg=cfg)
    experiment.run()
    wandb.finish()

# 4. Execute safely
if __name__ == "__main__":
    try:
        hydra_experiment()
    except SystemExit:
        print("Experiment finished successfully.")
    except Exception as e:
        print(f"An error occurred: {e}")

Config loaded from: /home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf
Running with SND: 1.0

Algorithm: ippo, Task: vmas/navigation


Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
[34m[1mwandb[0m: Currently logged in as: [33msvarp[0m ([33msvarp-university-of-massachusetts-lowell[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin






Experiment was closed gracefully




[34m[1mwandb[0m: [32m[41mERROR[0m Problem finishing run
Error executing job with overrides: ['model.desired_snd=1.0', 'experiment.max_n_frames=6000000', 'experiment.checkpoint_interval=600000', 'experiment.save_folder=/home/grad/doc/2027/spatel2/AD2C_testBed/model_checkpoints/navigation_ippo_esc/']


Experiment finished successfully.


Traceback (most recent call last):
  File "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/BenchMARL/benchmarl/experiment/experiment.py", line 519, in run
    self._collection_loop()
  File "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/BenchMARL/benchmarl/experiment/experiment.py", line 567, in _collection_loop
    training_tds.append(self._optimizer_loop(group))
  File "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/BenchMARL/benchmarl/experiment/experiment.py", line 658, in _optimizer_loop
    optimizer.step()
  File "/home/grad/doc/2027/spatel2/miniconda3/envs/ad2c/lib/python3.9/site-packages/torch/optim/optimizer.py", line 140, in wrapper
    out = func(*args, **kwargs)
  File "/home/grad/doc/2027/spatel2/miniconda3/envs/ad2c/lib/python3.9/site-packages/torch/optim/optimizer.py", line 23, in _use_grad
    ret = func(self, *args, **kwargs)
  File "/home/grad/doc/2027/spatel2/miniconda3/envs/ad2c/lib/python3.9/site-packages/torch/optim/adam.py", line 234, in step
    adam(params_with_g