## Imports

---

In [1]:
import sys
import os
import hydra
import wandb
import sys
import hydra
import time

from hydra.core.hydra_config import HydraConfig
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig, OmegaConf

# Benchmarl & Project Imports
import benchmarl.models
from benchmarl.algorithms import *
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment
from benchmarl.hydra_config import (
    load_algorithm_config_from_hydra,
    load_experiment_config_from_hydra,
    load_task_config_from_hydra,
    load_model_config_from_hydra,
)

# Custom Callbacks
from het_control.callback import *
from het_control.environments.vmas import render_callback
from het_control.models.het_control_mlp_empirical import HetControlMlpEmpiricalConfig
from het_control.callbacks.esc_callback import ExtremumSeekingController
from het_control.callbacks.sndESLogger import TrajectorySNDLoggerCallback

import numpy as np
import torch
import matplotlib.pyplot as plt
import networkx as nx
import wandb
from tensordict import TensorDict
from typing import List
from benchmarl.experiment.callback import Callback



## SND Visualizers
---

In [2]:
class SNDHeatmapVisualizer:
    def __init__(self, key_name="Visuals/SND_Heatmap"):
        self.key_name = key_name

    def generate(self, snd_matrix, step_count):
        # --- Fix diagonal ---
        snd_matrix = snd_matrix.copy()
        np.fill_diagonal(snd_matrix, 0.0)

        # --- Enforce symmetry ---
        snd_matrix = (snd_matrix + snd_matrix.T) / 2.0

        n_agents = snd_matrix.shape[0]
        agent_labels = [f"Agent {i+1}" for i in range(n_agents)]
        
        # --- Compute SND = average pairwise distance ---
        iu = np.triu_indices(n_agents, k=1)     # upper triangle (i < j)
        snd_value = float(np.mean(snd_matrix[iu]))

        fig, ax = plt.subplots(figsize=(6, 5))

        im = ax.imshow(
            snd_matrix,
            cmap="viridis",
            interpolation="nearest",
            vmin=0, vmax=3
        )

        # --- Updated title with SND ---
        ax.set_title(f"SND: {snd_value:.3f}  –  Step {step_count}")

        ax.set_xticks(np.arange(n_agents))
        ax.set_yticks(np.arange(n_agents))
        ax.set_xticklabels(agent_labels)
        ax.set_yticklabels(agent_labels)
        plt.setp(ax.get_xticklabels(), rotation=30, ha="right")

        fig.colorbar(im, ax=ax, label="Distance")

        # Cell labels
        for i in range(n_agents):
            for j in range(n_agents):
                val = snd_matrix[i, j]
                text_color = "white" if val < 1.0 else "black"

                ax.text(
                    j, i, f"{val:.2f}",
                    ha="center", va="center",
                    color=text_color,
                    fontsize=9, fontweight="bold"
                )

        plt.tight_layout()
        img = wandb.Image(fig)
        plt.close(fig)
        return {self.key_name: img}


class SNDBarChartVisualizer:
    def __init__(self, key_name="Visuals/SND_BarChart"):
        self.key_name = key_name

    def generate(self, snd_matrix, step_count):
        n_agents = snd_matrix.shape[0]

        # --- Fix diagonal ---
        snd_matrix = snd_matrix.copy()
        np.fill_diagonal(snd_matrix, 0.0)

        # --- Enforce symmetry ---
        snd_matrix = (snd_matrix + snd_matrix.T) / 2.0

        # --- Create agent pairs i < j ---
        pairs = [(i, j) for i in range(n_agents) for j in range(i + 1, n_agents)]
        if not pairs:
            return {}

        pair_values = [float(snd_matrix[i, j]) for i, j in pairs]
        pair_labels = [f"A{i+1}-A{j+1}" for i, j in pairs]

        # --- Compute SND (mean of pairwise distances) ---
        snd_value = float(np.mean(pair_values))

        # --- Plot ---
        fig, ax = plt.subplots(figsize=(8, 5))
        bars = ax.bar(pair_labels, pair_values, color="teal")

        ax.set_title(f"SND: {snd_value:.3f}  –  Step {step_count}")
        ax.set_ylabel("Distance")
        ax.set_ylim(0, 3)
        ax.tick_params(axis="x", rotation=45)

        # Add value labels above bars
        ax.bar_label(bars, fmt="%.2f", padding=3)

        plt.tight_layout()
        img = wandb.Image(fig)
        plt.close(fig)
        return {self.key_name: img}


class SNDGraphVisualizer:
    def __init__(self, key_name="Visuals/SND_NetworkGraph"):
        self.key_name = key_name

    def generate(self, snd_matrix, step_count):
        n_agents = snd_matrix.shape[0]

        # --- Fix diagonal ---
        snd_matrix = snd_matrix.copy()
        np.fill_diagonal(snd_matrix, 0.0)

        # --- Enforce symmetry ---
        snd_matrix = (snd_matrix + snd_matrix.T) / 2.0

        # --- Create edges only for i < j ---
        pairs = [(i, j) for i in range(n_agents) for j in range(i + 1, n_agents)]
        if not pairs:
            return {}

        # Distances for each pair
        pair_values = [float(snd_matrix[i, j]) for i, j in pairs]

        # --- Compute SND (avg distance) ---
        snd_value = float(np.mean(pair_values))

        # Build graph
        fig = plt.figure(figsize=(7, 7))
        G = nx.Graph()

        # Add edges with weights
        for i, j in pairs:
            G.add_edge(i, j, weight=float(snd_matrix[i, j]))

        # Layout
        pos = nx.spring_layout(G, seed=42)

        # Edge weights for coloring
        weights = [G[u][v]['weight'] for u, v in G.edges()]

        # --- Draw Nodes ---
        nx.draw_networkx_nodes(
            G, pos, node_size=750, node_color='lightblue'
        )

        # Label nodes as Agent 1, Agent 2, ...
        label_mapping = {i: f"A{i+1}" for i in range(n_agents)}
        nx.draw_networkx_labels(
            G, pos, labels=label_mapping, font_size=12, font_weight='bold'
        )

        # --- Draw edges ---
        edges = nx.draw_networkx_edges(
            G, pos,
            edge_color=weights,
            edge_cmap=plt.cm.viridis,
            width=2,
            edge_vmin=0,
            edge_vmax=3
        )

        # --- Draw edge labels ---
        edge_labels = {(i, j): f"{snd_matrix[i, j]:.2f}" for i, j in pairs}

        nx.draw_networkx_edge_labels(
            G,
            pos,
            edge_labels=edge_labels,
            font_color='black',
            font_size=9,
            font_weight='bold'
        )

        # Colorbar
        plt.colorbar(edges, label='Distance')

        # Title with SND value
        plt.title(f"SND: {snd_value:.3f}  –  Step {step_count}", fontsize=14)
        plt.axis('off')

        img = wandb.Image(fig)
        plt.close(fig)
        return {self.key_name: img}


class SNDVisualizationManager:
    """
    Manages the individual visualizers.
    """
    def __init__(self):
        self.visualizers = [
            SNDHeatmapVisualizer(),
            SNDBarChartVisualizer(),
            SNDGraphVisualizer()
        ]

    def generate_all(self, snd_matrix, step_count):
        all_plots = {}
        for visualizer in self.visualizers:
            try:
                plots = visualizer.generate(snd_matrix, step_count)
                all_plots.update(plots)
            except Exception as e:
                print(f"Error generating {visualizer.__class__.__name__}: {e}")
        return all_plots


In [3]:
class SNDVisualizerCallback(Callback):
    """
    Computes the SND matrix and uses the Manager to log visualizations.
    """
    def __init__(self):
        super().__init__()
        self.control_group = None
        self.model = None
        # Initialize the manager that holds the 3 plot classes
        self.viz_manager = SNDVisualizationManager()

    def on_setup(self):
        """Auto-detects the agent group and initializes the model wrapper."""
        if not self.experiment.group_policies:
            print("\nWARNING: No group policies found. SND Visualizer disabled.\n")
            return

        self.control_group = list(self.experiment.group_policies.keys())[0]
        policy = self.experiment.group_policies[self.control_group]
        
        # Ensure 'get_het_model' is imported or available in this scope
        self.model = get_het_model(policy)

        if self.model is None:
             print(f"\nWARNING: Could not extract HetModel for group '{self.control_group}'. Visualizer disabled.\n")

    def _get_agent_actions_for_rollout(self, rollout):
        """Helper to run the forward pass and get actions for SND computation."""
        obs = rollout.get((self.control_group, "observation"))
        actions = []
        for i in range(self.model.n_agents):
            temp_td = TensorDict(
                {(self.control_group, "observation"): obs},
                batch_size=obs.shape[:-1]
            )
            action_td = self.model._forward(temp_td, agent_index=i, compute_estimate=False)
            actions.append(action_td.get(self.model.out_key))
        return actions

    def on_evaluation_end(self, rollouts: List[TensorDict]):
        """Runs at the end of evaluation to compute SND and log plots."""
        if self.model is None:
            return

        logs_to_push = {}
        first_rollout_snd_matrix = None

        with torch.no_grad():
            for i, r in enumerate(rollouts):
                # We only need the matrix from the first rollout for clean visualization
                if i > 0: 
                    break

                agent_actions = self._get_agent_actions_for_rollout(r)
                
                # Ensure 'compute_behavioral_distance' is imported/available
                pairwise_distances_tensor = compute_behavioral_distance(agent_actions, just_mean=False)
                
                if pairwise_distances_tensor.ndim > 2:
                    pairwise_distances_tensor = pairwise_distances_tensor.mean(dim=0)

                first_rollout_snd_matrix = pairwise_distances_tensor.cpu().numpy()

        # Generate and Log Visualizations via the Manager
        if first_rollout_snd_matrix is not None:
            visual_logs = self.viz_manager.generate_all(
                snd_matrix=first_rollout_snd_matrix, 
                step_count=self.experiment.n_iters_performed
            )
            logs_to_push.update(visual_logs)
            
            # Update the logger
            self.experiment.logger.log(logs_to_push, step=self.experiment.n_iters_performed)

## Setup

---

In [4]:
# 1. EXPERIMENT LOGIC

def setup(task_name):
    benchmarl.models.model_config_registry.update(
        {
            "hetcontrolmlpempirical": HetControlMlpEmpiricalConfig,
        }
    )
    if task_name == "vmas/navigation":
        # Set the render callback for the navigation case study
        VmasTask.render_callback = render_callback

def get_experiment(cfg: DictConfig) -> Experiment:
    hydra_choices = HydraConfig.get().runtime.choices
    task_name = hydra_choices.task
    algorithm_name = hydra_choices.algorithm

    setup(task_name)

    print(f"\nAlgorithm: {algorithm_name}, Task: {task_name}")
    # print("\nLoaded config:\n") # Optional: Commented out to reduce clutter
    # print(OmegaConf.to_yaml(cfg))

    algorithm_config = load_algorithm_config_from_hydra(cfg.algorithm)
    experiment_config = load_experiment_config_from_hydra(cfg.experiment)
    task_config = load_task_config_from_hydra(cfg.task, task_name)
    critic_model_config = load_model_config_from_hydra(cfg.critic_model)
    model_config = load_model_config_from_hydra(cfg.model)

    if isinstance(algorithm_config, (MappoConfig, IppoConfig, MasacConfig, IsacConfig)):
        model_config.probabilistic = True
        model_config.scale_mapping = algorithm_config.scale_mapping
        algorithm_config.scale_mapping = (
            "relu"  # The scaling of std_dev will be done in the model
        )
    else:
        model_config.probabilistic = False

    experiment = Experiment(
        task=task_config,
        algorithm_config=algorithm_config,
        model_config=model_config,
        critic_model_config=critic_model_config,
        seed=cfg.seed,
        config=experiment_config,
        callbacks=[
            SndCallback(),
            # ExtremumSeekingController(
            #             control_group="agents",
            #             initial_snd=0.3,
            #             dither_magnitude=0.1,
            #             dither_frequency_rad_s=1.0,
            #             integral_gain=-0.001,
            #             high_pass_cutoff_rad_s=1.0,
            #             low_pass_cutoff_rad_s=1.0,
            #             sampling_period=1.0
            # ),
            SNDVisualizerCallback(),
            # TrajectorySNDLoggerCallback(control_group="agents"),
            NormLoggerCallback(),
            ActionSpaceLoss(
                use_action_loss=cfg.use_action_loss, action_loss_lr=cfg.action_loss_lr
            ),
        ]
        + (
            [
                TagCurriculum(
                    cfg.simple_tag_freeze_policy_after_frames,
                    cfg.simple_tag_freeze_policy,
                )
            ]
            if task_name == "vmas/simple_tag"
            else []
        ),
    )
    return experiment

## Monkey Patch

In [5]:

# # --- STEP 1: Define Unique ID ---
# unique_id = f"AD2C_Eval_{int(time.time())}"

# # --- STEP 2: The "Safe" Monkey Patch ---
# # We store the original function inside the wandb module itself.
# # This ensures that even if you re-run the cell, we never lose the real function.

# if not hasattr(wandb, "_custom_orig_init"):
#     print("Saving original WandB init function...")
#     wandb._custom_orig_init = wandb.init

# def forced_wandb_init(*args, **kwargs):
#     print(f"\n--- INTERCEPTING WANDB INIT ---")
    
#     # Force the new ID and Name
#     kwargs['id'] = unique_id
#     kwargs['name'] = unique_id
    
#     # Force "New Run" behavior
#     kwargs['resume'] = "allow" 
#     kwargs['reinit'] = True
    
#     print(f"Forced ID: {unique_id}")
#     print(f"-------------------------------\n")
    
#     # We always call the SAVED original function, not the current one
#     return wandb._custom_orig_init(*args, **kwargs)

# # Apply the patch
# wandb.init = forced_wandb_init

# # --- STEP 3: Setup Environment ---
# os.environ["WANDB_MODE"] = "online"
# os.environ["WANDB_RUN_ID"] = unique_id
# os.environ["WANDB_NAME"] = unique_id
# os.environ["WANDB_INIT_TIMEOUT"] = "300"
# os.environ["WANDB_SILENT"] = "true"

# # Setup Paths
# ABS_CONFIG_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf"
# CONFIG_NAME = "navigation_ippo"
# CHECKPOINT_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/snd00.pt"

# GlobalHydra.instance().clear()

# sys.argv = [
#     "eval_script.py",
#     f"experiment.restore_file={CHECKPOINT_PATH}",
#     "experiment.evaluation_episodes=10",
#     "experiment.render=True",
#     "experiment.evaluation_deterministic_actions=True",
#     "experiment.save_folder=null",
#     "model.desired_snd=0.3",
#     f"+experiment.name={unique_id}",
#     f"+logger.id={unique_id}"
# ]

# @hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
# def eval_experiment(cfg: DictConfig) -> None:
#     if wandb.run is not None:
#         wandb.finish()

#     OmegaConf.set_struct(cfg, False)
#     cfg.logger.id = unique_id
    
#     print(f"Loading model from: {cfg.experiment.restore_file}")
    
#     experiment = get_experiment(cfg=cfg)
    
#     print("Model loaded. Starting Evaluation...")
#     experiment._evaluation_loop()
#     print("Evaluation Complete.")
    
#     experiment.close()
#     wandb.finish()

# if __name__ == "__main__":
#     try:
#         eval_experiment()
#     except SystemExit:
#         pass
#     except Exception as e:
#         print(f"An error occurred: {e}")
#         if wandb.run is not None:
#             wandb.finish()

## Proper init Fix

---

In [None]:
if not hasattr(wandb, "_custom_orig_init"):
    print("Saving original WandB init function...")
    wandb._custom_orig_init = wandb.init

def forced_wandb_init(*args, **kwargs):
    print(f"\n--- INTERCEPTING WANDB INIT ---")
    
    # Force the new ID and Name
    kwargs['id'] = unique_id
    kwargs['name'] = unique_id
    
    # Force "New Run" behavior
    kwargs['resume'] = "allow" 
    kwargs['reinit'] = True
    
    print(f"Forced ID: {unique_id}")
    print(f"-------------------------------\n")
    
    # We always call the SAVED original function, not the current one
    return wandb._custom_orig_init(*args, **kwargs)

# Apply the patch
wandb.init = forced_wandb_init

# --- STEP 1: Define Unique ID ---
unique_id = f"AD2C_Eval_{int(time.time())}"

# --- STEP 2: Force WandB to use this exact run ---
os.environ["WANDB_MODE"] = "online"
os.environ["WANDB_RUN_ID"] = unique_id
os.environ["WANDB_NAME"] = unique_id
os.environ["WANDB_INIT_TIMEOUT"] = "300"
os.environ["WANDB_SILENT"] = "true"

# Initialize WandB ONE TIME manually
# wandb.init(
#     id=unique_id,
#     name=unique_id,
#     resume="allow",
#     reinit=True
# )
# print(f"Initialized WandB with run id: {unique_id}")

# --- STEP 3: Hydra Setup ---
ABS_CONFIG_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf"
CONFIG_NAME = "navigation_ippo"
CHECKPOINT_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/snd06.pt"

GlobalHydra.instance().clear()

sys.argv = [
    "eval_script.py",
    f"experiment.restore_file={CHECKPOINT_PATH}",
    "experiment.evaluation_episodes=10",
    "experiment.render=True",
    "experiment.evaluation_deterministic_actions=True",
    "experiment.save_folder=null",
    "model.desired_snd=0.3",
    f"+experiment.name={unique_id}",
    f"+logger.id={unique_id}"
]

@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def eval_experiment(cfg: DictConfig) -> None:

    OmegaConf.set_struct(cfg, False)
    cfg.logger.id = unique_id

    print(f"Loading model from: {cfg.experiment.restore_file}")

    experiment = get_experiment(cfg=cfg)

    print("Model loaded. Starting Evaluation...")
    experiment._evaluation_loop()
    print("Evaluation Complete.")

    experiment.close()
    wandb.finish()

if __name__ == "__main__":
    try:
        eval_experiment()
    except SystemExit:
        pass
    except Exception as e:
        print(f"An error occurred: {e}")
        wandb.finish()


Saving original WandB init function...

--- INTERCEPTING WANDB INIT ---
Forced ID: AD2C_Eval_1764394288
-------------------------------

Initialized WandB with run id: AD2C_Eval_1764394288
Loading model from: /home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/snd06.pt

Algorithm: ippo, Task: vmas/navigation


Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.



--- INTERCEPTING WANDB INIT ---
Forced ID: AD2C_Eval_1764394288
-------------------------------

Model loaded. Starting Evaluation...




Evaluation Complete.
