# AD2C : Test Bed

This Jupter Notebook aims to Experiment with different Component of AD2C framework Like task, ESC, Callback and Loggers. This will also include advance experimental setup to test the trainied model. 


----


## Imports

In [5]:
import sys
import os
import hydra
import wandb
import sys
import hydra
from hydra.core.hydra_config import HydraConfig
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig, OmegaConf

# Benchmarl & Project Imports
import benchmarl.models
from benchmarl.algorithms import *
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment
from benchmarl.hydra_config import (
    load_algorithm_config_from_hydra,
    load_experiment_config_from_hydra,
    load_task_config_from_hydra,
    load_model_config_from_hydra,
)

# Custom Callbacks
from het_control.callback import *
from het_control.environments.vmas import render_callback
from het_control.models.het_control_mlp_empirical import HetControlMlpEmpiricalConfig
from het_control.callbacks.esc_callback import ExtremumSeekingController
from het_control.callbacks.sndESLogger import TrajectorySNDLoggerCallback

import numpy as np
import torch
import matplotlib.pyplot as plt
import networkx as nx
import wandb
from tensordict import TensorDict
from typing import List
from benchmarl.experiment.callback import Callback

## SND Ploting Callback
This allow us to display plots for the snd. from the eval run. 

---

In [6]:
def generate_snd_visualizations(snd_matrix, n_agents, step_count):
    """
    Generates 3 matplotlib figures (Heatmap, Bar, Graph) for a given SND matrix.
    Returns a dictionary of wandb.Image objects.
    Assumes snd_matrix is an N x N Symmetric Distance Matrix.
    """
    plots = {}
    
    # Define pairs (Upper triangle only, since matrix is symmetric: 1-2 is same as 2-1)
    pairs = [(i, j) for i in range(n_agents) for j in range(i + 1, n_agents)]
    
    # ==========================================
    # 1. HEATMAP with Cell Values
    # ==========================================
    fig_heat, ax_heat = plt.subplots(figsize=(6, 5))
    im = ax_heat.imshow(snd_matrix, cmap='viridis', interpolation='nearest', vmin=0, vmax=3)
    
    ax_heat.set_title(f'SND Matrix (Heatmap) - Step {step_count}')
    ax_heat.set_xlabel('Agent Index')
    ax_heat.set_ylabel('Agent Index')
    
    # Set ticks to be integers (Agent 0, Agent 1...)
    ax_heat.set_xticks(np.arange(n_agents))
    ax_heat.set_yticks(np.arange(n_agents))
    
    fig_heat.colorbar(im, ax=ax_heat, label='Distance')
    
    # Loop over data dimensions and create text annotations.
    for i in range(n_agents):
        for j in range(n_agents):
            val = snd_matrix[i, j]
            
            # Text color logic: White for dark background (low values), Black for light (high values)
            # Viridis: Low values (purple) -> White text. High values (yellow) -> Black text.
            # Scale is 0-3. Midpoint roughly 1.5.
            text_color = "white" if val < 1.0 else "black"
            
            # Print value centered in the cell
            ax_heat.text(j, i, f"{val:.2f}",
                         ha="center", va="center", color=text_color, 
                         fontsize=8, fontweight='bold')

    plt.tight_layout()
    plots["Visuals/SND_Heatmap"] = wandb.Image(fig_heat)
    plt.close(fig_heat)

    # ==========================================
    # 2. BAR CHART (Pairwise Values)
    # ==========================================
    if len(pairs) > 0:
        # Extract values for the unique pairs (upper triangle)
        pair_values = [snd_matrix[p[0], p[1]] for p in pairs]
        pair_labels = [f"A{p[0]}-A{p[1]}" for p in pairs]
        
        fig_bar, ax_bar = plt.subplots(figsize=(8, 5))
        bars = ax_bar.bar(pair_labels, pair_values, color='teal')
        
        ax_bar.set_title(f'Pairwise Distances - Step {step_count}')
        ax_bar.set_ylabel('Distance')
        ax_bar.set_ylim(0, 3) 
        ax_bar.tick_params(axis='x', rotation=45)
        
        # Add values on top of bars
        ax_bar.bar_label(bars, fmt='%.2f', padding=3)

        plt.tight_layout()
        plots["Visuals/SND_BarChart"] = wandb.Image(fig_bar)
        plt.close(fig_bar)

    # ==========================================
    # 3. NETWORK GRAPH (Topology)
    # ==========================================
    if len(pairs) > 0:
        fig_graph = plt.figure(figsize=(7, 7))
        G = nx.Graph()
        
        # Add edges for unique pairs
        for u, v in pairs:
            dist = snd_matrix[u, v]
            G.add_edge(u, v, weight=dist)

        pos = nx.spring_layout(G, seed=42)
        weights = [G[u][v]['weight'] for u, v in G.edges()]
        
        # Draw Nodes
        nx.draw_networkx_nodes(G, pos, node_size=600, node_color='lightblue')
        nx.draw_networkx_labels(G, pos, font_weight='bold')
        
        # Draw Edges
        edges = nx.draw_networkx_edges(G, pos, 
                               edge_color=weights, 
                               edge_cmap=plt.cm.viridis, 
                               width=2,
                               edge_vmin=0,
                               edge_vmax=3)
        
        # Draw Edge Labels (The distance values on the lines)
        edge_labels = {
            (u, v): f"{d['weight']:.2f}" 
            for u, v, d in G.edges(data=True)
        }
        nx.draw_networkx_edge_labels(
            G, pos, 
            edge_labels=edge_labels, 
            font_color='black', 
            font_size=8,
            font_weight='bold'
        )
        
        plt.colorbar(edges, label='Distance')
        plt.title(f'Interaction Graph - Step {step_count}')
        plt.axis('off')
        plots["Visuals/SND_NetworkGraph"] = wandb.Image(fig_graph)
        plt.close(fig_graph)

    return plots

In [7]:
class SNDVisualizerCallback(Callback):
    """
    A visualization-only callback that computes the SND (Behavioral Distance) matrix
    at evaluation time and logs Heatmap, Bar Chart, and Graph visualizations to WandB.
    """
    def __init__(self):
        super().__init__()
        self.control_group = None
        self.model = None

    def on_setup(self):
        """Auto-detects the agent group and initializes the model wrapper."""
        if not self.experiment.group_policies:
            print("\nWARNING: No group policies found. SND Visualizer disabled.\n")
            return

        # Auto-detect: Simply grab the first available control group
        self.control_group = list(self.experiment.group_policies.keys())[0]
        
        policy = self.experiment.group_policies[self.control_group]
        
        # We assume 'get_het_model' is available in your scope
        self.model = get_het_model(policy)

        if self.model is None:
             print(f"\nWARNING: Could not extract HetModel for group '{self.control_group}'. Visualizer disabled.\n")

    def _get_agent_actions_for_rollout(self, rollout):
        """Helper to run the forward pass and get actions for SND computation."""
        obs = rollout.get((self.control_group, "observation"))
        actions = []
        for i in range(self.model.n_agents):
            temp_td = TensorDict(
                {(self.control_group, "observation"): obs},
                batch_size=obs.shape[:-1]
            )
            # Ensure _forward exists and returns a TensorDict with the output key
            action_td = self.model._forward(temp_td, agent_index=i, compute_estimate=False)
            actions.append(action_td.get(self.model.out_key))
        return actions

    def on_evaluation_end(self, rollouts: List[TensorDict]):
        """Runs at the end of evaluation to compute SND and log plots."""
        if self.model is None:
            return

        logs_to_push = {}
        first_rollout_snd_matrix = None

        with torch.no_grad():
            for i, r in enumerate(rollouts):
                # We only need the matrix from the first rollout for clean visualization
                if i > 0: 
                    break

                agent_actions = self._get_agent_actions_for_rollout(r)
                
                # Compute behavioral distance (Assumed to be available in scope)
                # Must return N x N matrix
                pairwise_distances_tensor = compute_behavioral_distance(agent_actions, just_mean=False)
                
                # If the function returns (Time x N x N), average over Time to get (N x N)
                if pairwise_distances_tensor.ndim > 2:
                    pairwise_distances_tensor = pairwise_distances_tensor.mean(dim=0)

                first_rollout_snd_matrix = pairwise_distances_tensor.cpu().numpy()

        # Generate and Log Visualizations
        if first_rollout_snd_matrix is not None:
            visual_logs = generate_snd_visualizations(
                snd_matrix=first_rollout_snd_matrix, 
                n_agents=self.model.n_agents,
                step_count=self.experiment.n_iters_performed
            )
            logs_to_push.update(visual_logs)
            
            # Update the logger
            self.experiment.logger.log(logs_to_push, step=self.experiment.n_iters_performed)

## Env Setup

This code block compiles all the different section of MARl env together to Run the experiment

---

In [12]:
# Env clean up for W&B

try:
    if wandb.run is not None:
        print("⚠️ Closing lingering W&B run from previous execution...")
        wandb.finish()
except Exception as e:
    print(f"W&B Cleanup Warning: {e}")

# Optional: Reset W&B settings to ensure it spawns a fresh process
os.environ["WANDB_START_METHOD"] = "thread"

In [9]:
def setup(task_name):
    benchmarl.models.model_config_registry.update(
        {
            "hetcontrolmlpempirical": HetControlMlpEmpiricalConfig,
        }
    )
    if task_name == "vmas/navigation":
        # Set the render callback for the navigation case study
        VmasTask.render_callback = render_callback

def get_experiment(cfg: DictConfig) -> Experiment:
    hydra_choices = HydraConfig.get().runtime.choices
    task_name = hydra_choices.task
    algorithm_name = hydra_choices.algorithm

    setup(task_name)

    print(f"\nAlgorithm: {algorithm_name}, Task: {task_name}")
    # print("\nLoaded config:\n") # Optional: Commented out to reduce clutter
    # print(OmegaConf.to_yaml(cfg))

    algorithm_config = load_algorithm_config_from_hydra(cfg.algorithm)
    experiment_config = load_experiment_config_from_hydra(cfg.experiment)
    task_config = load_task_config_from_hydra(cfg.task, task_name)
    critic_model_config = load_model_config_from_hydra(cfg.critic_model)
    model_config = load_model_config_from_hydra(cfg.model)

    if isinstance(algorithm_config, (MappoConfig, IppoConfig, MasacConfig, IsacConfig)):
        model_config.probabilistic = True
        model_config.scale_mapping = algorithm_config.scale_mapping
        algorithm_config.scale_mapping = (
            "relu"  # The scaling of std_dev will be done in the model
        )
    else:
        model_config.probabilistic = False

    experiment = Experiment(
        task=task_config,
        algorithm_config=algorithm_config,
        model_config=model_config,
        critic_model_config=critic_model_config,
        seed=cfg.seed,
        config=experiment_config,
        callbacks=[
            SndCallback(),
            SNDVisualizerCallback(),
            # ExtremumSeekingController(
            #             control_group="agents",
            #             # initial_snd=0.0,
            #             dither_magnitude=0.1,
            #             dither_frequency_rad_s=1.0,
            #             integral_gain=-0.01,
            #             high_pass_cutoff_rad_s=1.0,
            #             low_pass_cutoff_rad_s=1.0,
            #             sampling_period=1.0
            # ),
            # TrajectorySNDLoggerCallback(control_group="agents"),
            NormLoggerCallback(),
            ActionSpaceLoss(
                use_action_loss=cfg.use_action_loss, action_loss_lr=cfg.action_loss_lr
            ),
        ]
        + (
            [
                TagCurriculum(
                    cfg.simple_tag_freeze_policy_after_frames,
                    cfg.simple_tag_freeze_policy,
                )
            ]
            if task_name == "vmas/simple_tag"
            else []
        ),
    )
    return experiment

## Training Code

Trains the model for 200 episodes. 

In [None]:

ABS_CONFIG_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf"
CONFIG_NAME = "navigation_ippo"  # Make sure 'navigation_ippo.yaml' exists in the folder above!
SAVE_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/model_checkpoints/navigation_ippo_esc/"

save_interval = 600000
desired_snd = 0.0
max_frame = 6000000

if not os.path.exists(SAVE_PATH):
    print(f"Creating missing directory: {SAVE_PATH}")
    os.makedirs(SAVE_PATH, exist_ok=True)

GlobalHydra.instance().clear()

sys.argv = [
    "dummy.py",
    f"model.desired_snd={desired_snd}",
    f"experiment.max_n_frames={max_frame}",
    f"experiment.checkpoint_interval={save_interval}",
    f"experiment.save_folder={SAVE_PATH}", # FIXED: Removed space after '='
]

# 3. Define the Hydra wrapper
@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def hydra_experiment(cfg: DictConfig) -> None:
    print(f"Config loaded from: {ABS_CONFIG_PATH}")
        
    print(f"Running with SND: {cfg.model.desired_snd}")
    
    experiment = get_experiment(cfg=cfg)
    experiment.run()

# 4. Execute safely
if __name__ == "__main__":
    try:
        hydra_experiment()
    except SystemExit:
        print("Experiment finished successfully.")
    except Exception as e:
        print(f"An error occurred: {e}")

Config loaded from: /home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf
Running with SND: 0.0

Algorithm: ippo, Task: vmas/navigation






## Eval Run

Single Step eval. from the check point

In [13]:

# CONFIGURATION
ABS_CONFIG_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf"
CONFIG_NAME = "navigation_ippo"

# Your checkpoint path
CHECKPOINT_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/model2.pt"

# ==========================================
# EVALUATION LOGIC
# ==========================================
GlobalHydra.instance().clear()

sys.argv = [
    "eval_script.py",
    f"experiment.restore_file={CHECKPOINT_PATH}",
    "experiment.evaluation_episodes=10",
    "experiment.render=True",
    "experiment.evaluation_deterministic_actions=True",
    "experiment.save_folder=null",
    "model.desired_snd=0.3" 
]

@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def eval_experiment(cfg: DictConfig) -> None:
    print(f"Loading model from: {cfg.experiment.restore_file}")
    print(f"Initializing model with dummy SND: {cfg.model.desired_snd}")
    
    experiment = get_experiment(cfg=cfg)
    
    print("Model loaded. Starting Evaluation...")
    
    experiment._evaluation_loop()
    
    print("Evaluation Complete.")
    experiment.close()

if __name__ == "__main__":
    try:
        eval_experiment()
    except SystemExit:
        pass
    except Exception as e:
        print(f"An error occurred: {e}")

Loading model from: /home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/model2.pt
Initializing model with dummy SND: 0.3

Algorithm: ippo, Task: vmas/navigation


Model loaded. Starting Evaluation...






Evaluation Complete.


0,1
eval/agents/reward/episode_reward_max,▁
eval/agents/reward/episode_reward_mean,▁
eval/agents/reward/episode_reward_min,▁
eval/agents/snd,▁
eval/reward/episode_len_mean,▁
eval/reward/episode_reward_max,▁
eval/reward/episode_reward_mean,▁
eval/reward/episode_reward_min,▁
timers/evaluation_time,▁

0,1
collection/agents/estimated_snd,238.9796
collection/agents/info/agent_collisions,0
collection/agents/info/final_rew,0
collection/agents/info/pos_rew,0.00484
collection/agents/logits,0.05074
collection/agents/observation,0.00852
collection/agents/out_loc_norm,0
collection/agents/reward/episode_reward_max,2.16559
collection/agents/reward/episode_reward_mean,0.48618
collection/agents/reward/episode_reward_min,-1.10746


In [6]:
# Different Testing Environment Encoding

In [9]:
import sys
import hydra
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig

# ==========================================
# CONFIGURATION
# ==========================================
ABS_CONFIG_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf"
CONFIG_NAME = "navigation_ippo"
CHECKPOINT_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/model_checkpoints/navigation_ippo_esc/ippo_navigation_hetcontrolmlpempirical__8abb18cb_25_11_21-18_17_10/checkpoints/checkpoint_12000000.pt"

# ==========================================
# EVALUATION LOGIC
# ==========================================
GlobalHydra.instance().clear()

sys.argv = [
    "eval_script.py",
    f"experiment.restore_file={CHECKPOINT_PATH}",
    "experiment.evaluation_episodes=10",
    "experiment.render=True",
    "experiment.evaluation_deterministic_actions=True",
    "experiment.save_folder=null",
    "model.desired_snd=0.3",
    
    # --- THE CHANGE ---
    # Assuming you have 3 agents total. 
    # Setting this to 3 means all 3 agents go to the SAME goal (1 goal total).
    "task.agents_with_same_goal=1" 
]

@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def eval_experiment(cfg: DictConfig) -> None:
    print(f"Loading model from: {cfg.experiment.restore_file}")
    
    # Verification print
    print(f"Evaluation Setup: {cfg.task.n_agents} Agents, {cfg.task.agents_with_same_goal} per goal.")
    
    experiment = get_experiment(cfg=cfg)
    print("Starting Evaluation...")
    experiment._evaluation_loop()
    print("Evaluation Complete.")
    experiment.close()

if __name__ == "__main__":
    try:
        eval_experiment()
    except SystemExit:
        pass
    except Exception as e:
        print(f"An error occurred: {e}")

Loading model from: /home/grad/doc/2027/spatel2/AD2C_testBed/model_checkpoints/navigation_ippo_esc/ippo_navigation_hetcontrolmlpempirical__8abb18cb_25_11_21-18_17_10/checkpoints/checkpoint_12000000.pt
Evaluation Setup: 3 Agents, 1 per goal.

Algorithm: ippo, Task: vmas/navigation


Starting Evaluation...




Evaluation Complete.




0,1
eval/agents/reward/episode_reward_max,▁
eval/agents/reward/episode_reward_mean,▁
eval/agents/reward/episode_reward_min,▁
eval/agents/snd,▁
eval/reward/episode_len_mean,▁
eval/reward/episode_reward_max,▁
eval/reward/episode_reward_mean,▁
eval/reward/episode_reward_min,▁
timers/evaluation_time,▁

0,1
collection/agents/estimated_snd,203.30853
collection/agents/info/agent_collisions,0
collection/agents/info/final_rew,0
collection/agents/info/pos_rew,0.01827
collection/agents/logits,0.02486
collection/agents/observation,-0.00358
collection/agents/out_loc_norm,0.0
collection/agents/reward/episode_reward_max,2.48916
collection/agents/reward/episode_reward_mean,1.03933
collection/agents/reward/episode_reward_min,0.22199


In [10]:
# Train from another checkpoint with different Task. 

In [None]:

# ==========================================
# 1. CONFIGURATION
# ==========================================
ABS_CONFIG_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf"
CONFIG_NAME = "navigation_ippo"
CHECKPOINT_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/model_checkpoints/navigation_ippo_esc/ippo_navigation_hetcontrolmlpempirical__8abb18cb_25_11_21-18_17_10/checkpoints/checkpoint_12000000.pt"

# ==========================================
# 2. RUN LOGIC
# ==========================================
new_max_frames = 15000000 
desired_snd = 1.0

GlobalHydra.instance().clear()

sys.argv = [
    "run_script.py",
    f"model.desired_snd={desired_snd}",
    f"experiment.restore_file={CHECKPOINT_PATH}",
    f"experiment.max_n_frames={new_max_frames}",
    
    # --- TASK CONFIGURATION ---
    "task.agents_with_same_goal=1", 
    "experiment.save_folder=null"
]

@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def hydra_experiment(cfg: DictConfig) -> None:
    print(f"Resuming with SND: {cfg.model.desired_snd}")
    print(f"Agents sharing a goal: {cfg.task.agents_with_same_goal}")
    
    experiment = get_experiment(cfg=cfg)
    experiment.run()

if __name__ == "__main__":
    try:
        hydra_experiment()
    except SystemExit:
        print("Experiment finished successfully.")
    except Exception as e:
        print(f"An error occurred: {e}")



Resuming with SND: 1.0
Agents sharing a goal: 1

Algorithm: ippo, Task: vmas/navigation


[34m[1mwandb[0m: Currently logged in as: [33msvarp[0m ([33msvarp-university-of-massachusetts-lowell[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



✅ SUCCESS: Extremum Seeking Controller initialized for group 'agents'.

SUCCESS: Logger initialized for HetControlMlpEscSnd on group 'agents'.
Experiment finished successfully.


Error executing job with overrides: ['model.desired_snd=1.0', 'experiment.restore_file=/home/grad/doc/2027/spatel2/AD2C_testBed/model_checkpoints/navigation_ippo_esc/ippo_navigation_hetcontrolmlpempirical__e3667b5c_25_11_21-15_00_13/checkpoints/checkpoint_12000000.pt', 'experiment.max_n_frames=15000000', 'task.agents_with_same_goal=1', 'experiment.save_folder=null']
Traceback (most recent call last):
  File "/tmp/ipykernel_2439253/1461699133.py", line 32, in hydra_experiment
    experiment = get_experiment(cfg=cfg)
  File "/tmp/ipykernel_2439253/1144324724.py", line 39, in get_experiment
    experiment = Experiment(
  File "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/BenchMARL/benchmarl/experiment/experiment.py", line 332, in __init__
    self._load_experiment()
  File "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/BenchMARL/benchmarl/experiment/experiment.py", line 792, in _load_experiment
    loaded_dict: OrderedDict = torch.load(self.config.restore_file)
  File "/home/grad/doc/202