# AD2C : Test Bed

This Jupter Notebook aims to Experiment with different Component of AD2C framework Like task, ESC, Callback and Loggers. This will also include advance experimental setup to test the trainied model. 


----


## Imports

In [7]:
import sys
import os
import hydra
import wandb
import sys
import hydra
import time

from hydra.core.hydra_config import HydraConfig
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig, OmegaConf

# Benchmarl & Project Imports
import benchmarl.models
from benchmarl.algorithms import *
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment
from benchmarl.hydra_config import (
    load_algorithm_config_from_hydra,
    load_experiment_config_from_hydra,
    load_task_config_from_hydra,
    load_model_config_from_hydra,
)

# Custom Callbacks
from het_control.callback import *
from het_control.environments.vmas import render_callback
from het_control.models.het_control_mlp_empirical import HetControlMlpEmpiricalConfig
from het_control.callbacks.esc_callback import ExtremumSeekingController
from het_control.callbacks.sndESLogger import TrajectorySNDLoggerCallback

import numpy as np
import torch
import matplotlib.pyplot as plt
import networkx as nx
import wandb
from tensordict import TensorDict
from typing import List
from benchmarl.experiment.callback import Callback

## Wandb Fix
---

In [8]:
unique_id = f"AD2C_Eval_{int(time.time())}"

if not hasattr(wandb, "_custom_orig_init"):
    print("Saving original WandB init function...")
    wandb._custom_orig_init = wandb.init

def forced_wandb_init(*args, **kwargs):
    print(f"\n--- INTERCEPTING WANDB INIT ---")
    
    # Force the new ID and Name
    kwargs['id'] = unique_id
    kwargs['name'] = unique_id
    
    # Force "New Run" behavior
    kwargs['resume'] = "allow" 
    kwargs['reinit'] = True
    
    print(f"Forced ID: {unique_id}")
    print(f"-------------------------------\n")
    
    # We always call the SAVED original function, not the current one
    return wandb._custom_orig_init(*args, **kwargs)

# Apply the patch
wandb.init = forced_wandb_init


## SND Ploting Callback
This allow us to display plots for the snd. from the eval run. 

---

In [9]:
import numpy as np
import matplotlib.pyplot as plt
import wandb
import torch
import networkx as nx  # Required for the Graph Visualizer

class SNDHeatmapVisualizer:
    def __init__(self, key_name="Visuals/SND_Heatmap"):
        self.key_name = key_name

    def generate(self, snd_matrix, step_count):
        # snd_matrix is now GUARANTEED to be a clean 2D Numpy array
        n_agents = snd_matrix.shape[0]
        agent_labels = [f"Agent {i+1}" for i in range(n_agents)]
        
        # Calculate SND value
        iu = np.triu_indices(n_agents, k=1)
        if len(iu[0]) > 0:
            snd_value = float(np.mean(snd_matrix[iu]))
        else:
            snd_value = 0.0

        fig, ax = plt.subplots(figsize=(6, 5))

        im = ax.imshow(
            snd_matrix,
            cmap="viridis",
            interpolation="nearest",
            vmin=0, vmax=3 
        )

        ax.set_title(f"SND: {snd_value:.3f}  –  Step {step_count}")

        ax.set_xticks(np.arange(n_agents))
        ax.set_yticks(np.arange(n_agents))
        ax.set_xticklabels(agent_labels)
        ax.set_yticklabels(agent_labels)
        plt.setp(ax.get_xticklabels(), rotation=30, ha="right")

        fig.colorbar(im, ax=ax, label="Distance")

        for i in range(n_agents):
            for j in range(n_agents):
                val = snd_matrix[i, j]
                # Dynamic text color for visibility
                text_color = "white" if val < 1.0 else "black"
                ax.text(
                    j, i, f"{val:.2f}",
                    ha="center", va="center",
                    color=text_color,
                    fontsize=9, fontweight="bold"
                )

        plt.tight_layout()
        img = wandb.Image(fig)
        plt.close(fig)
        return {self.key_name: img}


class SNDBarChartVisualizer:
    def __init__(self, key_name="Visuals/SND_BarChart"):
        self.key_name = key_name

    def generate(self, snd_matrix, step_count):
        n_agents = snd_matrix.shape[0]
        
        # Create pairs i < j
        pairs = [(i, j) for i in range(n_agents) for j in range(i + 1, n_agents)]
        if not pairs:
            return {}

        pair_values = [float(snd_matrix[i, j]) for i, j in pairs]
        pair_labels = [f"A{i+1}-A{j+1}" for i, j in pairs]

        snd_value = float(np.mean(pair_values))

        fig, ax = plt.subplots(figsize=(8, 5))
        bars = ax.bar(pair_labels, pair_values, color="teal")

        ax.set_title(f"SND: {snd_value:.3f}  –  Step {step_count}")
        ax.set_ylabel("Distance")
        ax.set_ylim(0, 3)
        ax.tick_params(axis="x", rotation=45)

        ax.bar_label(bars, fmt="%.2f", padding=3)

        plt.tight_layout()
        img = wandb.Image(fig)
        plt.close(fig)
        return {self.key_name: img}


class SNDGraphVisualizer:
    def __init__(self, key_name="Visuals/SND_NetworkGraph"):
        self.key_name = key_name

    def generate(self, snd_matrix, step_count):
        n_agents = snd_matrix.shape[0]

        pairs = [(i, j) for i in range(n_agents) for j in range(i + 1, n_agents)]
        if not pairs:
            return {}

        pair_values = [float(snd_matrix[i, j]) for i, j in pairs]
        snd_value = float(np.mean(pair_values))

        fig = plt.figure(figsize=(7, 7))
        G = nx.Graph()

        for i, j in pairs:
            G.add_edge(i, j, weight=float(snd_matrix[i, j]))

        pos = nx.spring_layout(G, seed=42)
        weights = [G[u][v]['weight'] for u, v in G.edges()]

        nx.draw_networkx_nodes(G, pos, node_size=750, node_color='lightblue')
        
        label_mapping = {i: f"A{i+1}" for i in range(n_agents)}
        nx.draw_networkx_labels(G, pos, labels=label_mapping, font_size=12, font_weight='bold')

        edges = nx.draw_networkx_edges(
            G, pos,
            edge_color=weights,
            edge_cmap=plt.cm.viridis,
            width=2,
            edge_vmin=0, edge_vmax=3
        )

        edge_labels = {(i, j): f"{snd_matrix[i, j]:.2f}" for i, j in pairs}
        nx.draw_networkx_edge_labels(
            G, pos, edge_labels=edge_labels,
            font_color='black', font_size=9, font_weight='bold'
        )

        plt.colorbar(edges, label='Distance')
        plt.title(f"SND: {snd_value:.3f}  –  Step {step_count}", fontsize=14)
        plt.axis('off')

        img = wandb.Image(fig)
        plt.close(fig)
        return {self.key_name: img}


class SNDVisualizationManager:
    """
    Manages the individual visualizers and handles ALL data cleaning centrally.
    """
    def __init__(self):
        self.visualizers = [
            SNDHeatmapVisualizer(),
            SNDBarChartVisualizer(),
            SNDGraphVisualizer()
        ]

    def _prepare_matrix(self, snd_matrix):
        """
        Robustly converts and reshapes matrix.
        Fixes crash by Symmetrizing (Broadcasting) BEFORE accessing diagonals.
        """
        # 1. Convert to Numpy
        if hasattr(snd_matrix, "detach"):
            snd_matrix = snd_matrix.detach().cpu().numpy()
        elif not isinstance(snd_matrix, np.ndarray):
            snd_matrix = np.array(snd_matrix)

        # 2. "Peel" dimensions until we hit 2D
        # This turns (1, 2, 2) -> (2, 2) and (1, 2, 1) -> (2, 1)
        while snd_matrix.ndim > 2:
            snd_matrix = snd_matrix[0]

        # 3. Handle 1D edge case (if squeeze happened upstream)
        if snd_matrix.ndim == 1:
            # Try to reshape to square, or expand dims
            size = snd_matrix.shape[0]
            n_agents = int(np.sqrt(size))
            if n_agents * n_agents == size:
                snd_matrix = snd_matrix.reshape(n_agents, n_agents)
            else:
                # Treat as column vector (N, 1)
                snd_matrix = snd_matrix[:, None]

        # 4. Create copy
        snd_matrix = snd_matrix.copy()

        # 5. FIX: Enforce Symmetry FIRST
        # If input is (2, 1), this line broadcasts it: (2, 1) + (1, 2) = (2, 2)
        # This automatically "expands" the missing dimension.
        snd_matrix = (snd_matrix + snd_matrix.T) / 2.0

        # 6. NOW set diagonals (Safe because matrix is guaranteed square now)
        n = snd_matrix.shape[0]
        if n > 0:
            for i in range(n):
                snd_matrix[i, i] = 0.0
        
        return snd_matrix

    def generate_all(self, snd_matrix, step_count):
        # Clean the matrix ONCE here
        clean_matrix = self._prepare_matrix(snd_matrix)
        
        all_plots = {}
        for visualizer in self.visualizers:
            try:
                # Pass the clean matrix to all visualizers
                plots = visualizer.generate(clean_matrix, step_count)
                all_plots.update(plots)
            except Exception as e:
                print(f"Error generating {visualizer.__class__.__name__}: {e}")
                # Optional: Print shape to help debug if it fails again
                print(f"Failed Matrix Shape: {clean_matrix.shape}")
        return all_plots

In [10]:
class SNDVisualizerCallback(Callback):
    """
    Computes the SND matrix and uses the Manager to log visualizations.
    """
    def __init__(self):
        super().__init__()
        self.control_group = None
        self.model = None
        # Initialize the manager that holds the 3 plot classes
        self.viz_manager = SNDVisualizationManager()

    def on_setup(self):
        """Auto-detects the agent group and initializes the model wrapper."""
        if not self.experiment.group_policies:
            print("\nWARNING: No group policies found. SND Visualizer disabled.\n")
            return

        self.control_group = list(self.experiment.group_policies.keys())[0]
        policy = self.experiment.group_policies[self.control_group]
        
        # Ensure 'get_het_model' is imported or available in this scope
        self.model = get_het_model(policy)

        if self.model is None:
             print(f"\nWARNING: Could not extract HetModel for group '{self.control_group}'. Visualizer disabled.\n")

    def _get_agent_actions_for_rollout(self, rollout):
        """Helper to run the forward pass and get actions for SND computation."""
        obs = rollout.get((self.control_group, "observation"))
        actions = []
        for i in range(self.model.n_agents):
            temp_td = TensorDict(
                {(self.control_group, "observation"): obs},
                batch_size=obs.shape[:-1]
            )
            action_td = self.model._forward(temp_td, agent_index=i, compute_estimate=False)
            actions.append(action_td.get(self.model.out_key))
        return actions

    def on_evaluation_end(self, rollouts: List[TensorDict]):
        """Runs at the end of evaluation to compute SND and log plots."""
        if self.model is None:
            return

        logs_to_push = {}
        first_rollout_snd_matrix = None

        with torch.no_grad():
            for i, r in enumerate(rollouts):
                # We only need the matrix from the first rollout for clean visualization
                if i > 0: 
                    break

                agent_actions = self._get_agent_actions_for_rollout(r)
                
                # Ensure 'compute_behavioral_distance' is imported/available
                pairwise_distances_tensor = compute_behavioral_distance(agent_actions, just_mean=False)
                
                if pairwise_distances_tensor.ndim > 2:
                    pairwise_distances_tensor = pairwise_distances_tensor.mean(dim=0)

                first_rollout_snd_matrix = pairwise_distances_tensor.cpu().numpy()

        # Generate and Log Visualizations via the Manager
        if first_rollout_snd_matrix is not None:
            visual_logs = self.viz_manager.generate_all(
                snd_matrix=first_rollout_snd_matrix, 
                step_count=self.experiment.n_iters_performed
            )
            logs_to_push.update(visual_logs)
            
            # Update the logger
            self.experiment.logger.log(logs_to_push, step=self.experiment.n_iters_performed)

## Env Setup

This code block compiles all the different section of MARl env together to Run the experiment

---

In [11]:
# 1. EXPERIMENT LOGIC

def setup(task_name):
    benchmarl.models.model_config_registry.update(
        {
            "hetcontrolmlpempirical": HetControlMlpEmpiricalConfig,
        }
    )
    if task_name == "vmas/navigation":
        # Set the render callback for the navigation case study
        VmasTask.render_callback = render_callback

def get_experiment(cfg: DictConfig) -> Experiment:
    hydra_choices = HydraConfig.get().runtime.choices
    task_name = hydra_choices.task
    algorithm_name = hydra_choices.algorithm

    setup(task_name)

    print(f"\nAlgorithm: {algorithm_name}, Task: {task_name}")
    # print("\nLoaded config:\n") # Optional: Commented out to reduce clutter
    # print(OmegaConf.to_yaml(cfg))

    algorithm_config = load_algorithm_config_from_hydra(cfg.algorithm)
    experiment_config = load_experiment_config_from_hydra(cfg.experiment)
    task_config = load_task_config_from_hydra(cfg.task, task_name)
    critic_model_config = load_model_config_from_hydra(cfg.critic_model)
    model_config = load_model_config_from_hydra(cfg.model)

    if isinstance(algorithm_config, (MappoConfig, IppoConfig, MasacConfig, IsacConfig)):
        model_config.probabilistic = True
        model_config.scale_mapping = algorithm_config.scale_mapping
        algorithm_config.scale_mapping = (
            "relu"  # The scaling of std_dev will be done in the model
        )
    else:
        model_config.probabilistic = False

    experiment = Experiment(
        task=task_config,
        algorithm_config=algorithm_config,
        model_config=model_config,
        critic_model_config=critic_model_config,
        seed=cfg.seed,
        config=experiment_config,
        callbacks=[
            SndCallback(),
            ExtremumSeekingController(
                        control_group="agents",
                        initial_snd=0.0,
                        dither_magnitude=0.2,
                        dither_frequency_rad_s=1.0,
                        integral_gain=-0.1,
                        high_pass_cutoff_rad_s=1.0,
                        low_pass_cutoff_rad_s=1.0,
                        sampling_period=1.0
            ),
            SNDVisualizerCallback(),
            # TrajectorySNDLoggerCallback(control_group="agents"),
            NormLoggerCallback(),
            ActionSpaceLoss(
                use_action_loss=cfg.use_action_loss, action_loss_lr=cfg.action_loss_lr
            ),
        ]
        + (
            [
                TagCurriculum(
                    cfg.simple_tag_freeze_policy_after_frames,
                    cfg.simple_tag_freeze_policy,
                )
            ]
            if task_name == "vmas/simple_tag"
            else []
        ),
    )
    return experiment

## Training Code

Trains the model for 100 episodes. 

In [None]:
ABS_CONFIG_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf"
CONFIG_NAME = "navigation_ippo"  # Make sure 'navigation_ippo.yaml' exists in the folder above!
SAVE_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/model_checkpoints/navigation_ippo_esc/"

save_interval = 600000
desired_snd = 1.0
max_frame = 6000000

if not os.path.exists(SAVE_PATH):
    print(f"Creating missing directory: {SAVE_PATH}")
    os.makedirs(SAVE_PATH, exist_ok=True)

GlobalHydra.instance().clear()

sys.argv = [
    "dummy.py",
    f"model.desired_snd={desired_snd}",
    f"experiment.max_n_frames={max_frame}",
    f"experiment.checkpoint_interval={save_interval}",
    f"experiment.save_folder={SAVE_PATH}", 
    f"task.n_agents=2",
]

# 3. Define the Hydra wrapper
@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def hydra_experiment(cfg: DictConfig) -> None:
    print(f"Config loaded from: {ABS_CONFIG_PATH}")
    if wandb.run is not None:
        print("Finishing previous WandB run...")
        wandb.finish()
    
    print(f"Running with SND: {cfg.model.desired_snd}")
    
    experiment = get_experiment(cfg=cfg)
    experiment.run()
    wandb.finish()

# 4. Execute safely
if __name__ == "__main__":
    try:
        hydra_experiment()
    except SystemExit:
        print("Experiment finished successfully.")
    except Exception as e:
        print(f"An error occurred: {e}")

Config loaded from: /home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf
Running with SND: 1.0

Algorithm: ippo, Task: vmas/navigation


Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.



--- INTERCEPTING WANDB INIT ---
Forced ID: AD2C_Eval_1764920995
-------------------------------



[34m[1mwandb[0m: Currently logged in as: [33msvarp[0m ([33msvarp-university-of-massachusetts-lowell[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



✅ SUCCESS: Extremum Seeking Controller initialized for group 'agents'.






[ESC] Updated SND: 0.0 (Reward: -0.326, Update Step: 0.0000)




[ESC] Updated SND: 0.1951899230480194 (Reward: 0.541, Update Step: 0.1952)




[ESC] Updated SND: 0.22066403925418854 (Reward: 1.273, Update Step: 0.0255)




[ESC] Updated SND: 0.0723969042301178 (Reward: 1.609, Update Step: -0.1483)




[ESC] Updated SND: 0.0 (Reward: 1.928, Update Step: -0.1819)




[ESC] Updated SND: 0.0 (Reward: 1.968, Update Step: -0.1623)




[ESC] Updated SND: 0.0 (Reward: 2.020, Update Step: -0.0324)




[ESC] Updated SND: 0.15251342952251434 (Reward: 1.973, Update Step: 0.1525)




[ESC] Updated SND: 0.22088252007961273 (Reward: 2.041, Update Step: 0.0684)




[ESC] Updated SND: 0.10564800351858139 (Reward: 1.991, Update Step: -0.1152)




[ESC] Updated SND: 0.0 (Reward: 1.966, Update Step: -0.1904)




[ESC] Updated SND: 0.0 (Reward: 2.011, Update Step: -0.1770)




[ESC] Updated SND: 0.0 (Reward: 2.031, Update Step: -0.0854)




[ESC] Updated SND: 0.10660509020090103 (Reward: 2.082, Update Step: 0.1066)




[ESC] Updated SND: 0.2202632874250412 (Reward: 2.045, Update Step: 0.1137)




[ESC] Updated SND: 0.13654768466949463 (Reward: 1.485, Update Step: -0.0837)




[ESC] Updated SND: 0.0 (Reward: 1.555, Update Step: -0.1920)




[ESC] Updated SND: 0.0 (Reward: 0.668, Update Step: -0.1541)




[ESC] Updated SND: 0.0 (Reward: 1.193, Update Step: -0.1057)




[ESC] Updated SND: 0.07923769950866699 (Reward: 1.498, Update Step: 0.0792)




[ESC] Updated SND: 0.25632768869400024 (Reward: 1.963, Update Step: 0.1771)




[ESC] Updated SND: 0.2606731653213501 (Reward: 2.066, Update Step: 0.0043)




[ESC] Updated SND: 0.09885510057210922 (Reward: 1.774, Update Step: -0.1618)




[ESC] Updated SND: 0.0 (Reward: 0.914, Update Step: -0.1586)




[ESC] Updated SND: 0.0 (Reward: -0.956, Update Step: -0.0475)




[ESC] Updated SND: 0.11584093421697617 (Reward: -0.134, Update Step: 0.1158)




[ESC] Updated SND: 0.3147627115249634 (Reward: 1.917, Update Step: 0.1989)




[ESC] Updated SND: 0.3691081404685974 (Reward: 2.041, Update Step: 0.0543)




[ESC] Updated SND: 0.23821188509464264 (Reward: 1.931, Update Step: -0.1309)




[ESC] Updated SND: 0.05403723940253258 (Reward: 1.812, Update Step: -0.1842)




[ESC] Updated SND: 0.012667101807892323 (Reward: -0.328, Update Step: -0.0414)




[ESC] Updated SND: 0.14055241644382477 (Reward: -0.180, Update Step: 0.1279)




[ESC] Updated SND: 0.34460437297821045 (Reward: 1.539, Update Step: 0.2041)




[ESC] Updated SND: 0.4438360929489136 (Reward: 1.482, Update Step: 0.0992)




[ESC] Updated SND: 0.34742459654808044 (Reward: 0.299, Update Step: -0.0964)




[ESC] Updated SND: 0.1622738242149353 (Reward: -0.937, Update Step: -0.1852)




[ESC] Updated SND: 0.06800756603479385 (Reward: -1.954, Update Step: -0.0943)




[ESC] Updated SND: 0.14809729158878326 (Reward: -1.979, Update Step: 0.0801)




[ESC] Updated SND: 0.33957669138908386 (Reward: -1.909, Update Step: 0.1915)




[ESC] Updated SND: 0.46964630484580994 (Reward: -2.335, Update Step: 0.1301)




[ESC] Updated SND: 0.42532601952552795 (Reward: -2.095, Update Step: -0.0443)




[ESC] Updated SND: 0.24393372237682343 (Reward: -1.866, Update Step: -0.1814)




[ESC] Updated SND: 0.09253598004579544 (Reward: -1.996, Update Step: -0.1514)




[ESC] Updated SND: 0.10886208713054657 (Reward: -1.902, Update Step: 0.0163)




[ESC] Updated SND: 0.27859896421432495 (Reward: -1.549, Update Step: 0.1697)




[ESC] Updated SND: 0.44621071219444275 (Reward: -1.560, Update Step: 0.1676)




[ESC] Updated SND: 0.45150062441825867 (Reward: -1.765, Update Step: 0.0053)




[ESC] Updated SND: 0.29043635725975037 (Reward: -2.405, Update Step: -0.1611)




[ESC] Updated SND: 0.09493562579154968 (Reward: -1.706, Update Step: -0.1955)




[ESC] Updated SND: 0.06079040467739105 (Reward: -2.092, Update Step: -0.0341)




[ESC] Updated SND: 0.20122337341308594 (Reward: -2.112, Update Step: 0.1404)




[ESC] Updated SND: 0.39650246500968933 (Reward: -1.798, Update Step: 0.1953)




[ESC] Updated SND: 0.48031091690063477 (Reward: -1.476, Update Step: 0.0838)




[ESC] Updated SND: 0.37953293323516846 (Reward: -1.026, Update Step: -0.1008)




[ESC] Updated SND: 0.19677622616291046 (Reward: -1.299, Update Step: -0.1828)




[ESC] Updated SND: 0.1244945228099823 (Reward: -1.579, Update Step: -0.0723)




[ESC] Updated SND: 0.22408370673656464 (Reward: -1.391, Update Step: 0.0996)




[ESC] Updated SND: 0.41793954372406006 (Reward: -1.373, Update Step: 0.1939)




[ESC] Updated SND: 0.5582385659217834 (Reward: -0.717, Update Step: 0.1403)




[ESC] Updated SND: 0.5057383179664612 (Reward: -0.640, Update Step: -0.0525)




[ESC] Updated SND: 0.32581478357315063 (Reward: -0.856, Update Step: -0.1799)




[ESC] Updated SND: 0.2026582807302475 (Reward: -0.965, Update Step: -0.1232)




[ESC] Updated SND: 0.2697395384311676 (Reward: -1.505, Update Step: 0.0671)




[ESC] Updated SND: 0.46220090985298157 (Reward: -0.859, Update Step: 0.1925)




[ESC] Updated SND: 0.6326430439949036 (Reward: -0.601, Update Step: 0.1704)




[ESC] Updated SND: 0.6285800933837891 (Reward: -0.529, Update Step: -0.0041)




[ESC] Updated SND: 0.4629945755004883 (Reward: -0.355, Update Step: -0.1656)




[ESC] Updated SND: 0.29280349612236023 (Reward: -0.262, Update Step: -0.1702)




[ESC] Updated SND: 0.3013450503349304 (Reward: -0.813, Update Step: 0.0085)




[ESC] Updated SND: 0.465116411447525 (Reward: -0.831, Update Step: 0.1638)




[ESC] Updated SND: 0.6527587175369263 (Reward: -0.547, Update Step: 0.1876)




[ESC] Updated SND: 0.7017248868942261 (Reward: 0.639, Update Step: 0.0490)




[ESC] Updated SND: 0.5664432048797607 (Reward: -0.086, Update Step: -0.1353)




[ESC] Updated SND: 0.3790142834186554 (Reward: 0.414, Update Step: -0.1874)




[ESC] Updated SND: 0.31891942024230957 (Reward: 0.083, Update Step: -0.0601)




[ESC] Updated SND: 0.44456011056900024 (Reward: -0.086, Update Step: 0.1256)




[ESC] Updated SND: 0.6511949896812439 (Reward: 0.562, Update Step: 0.2066)




[ESC] Updated SND: 0.7451424598693848 (Reward: 0.923, Update Step: 0.0939)




[ESC] Updated SND: 0.6606975793838501 (Reward: 0.831, Update Step: -0.0844)




[ESC] Updated SND: 0.47815582156181335 (Reward: 0.564, Update Step: -0.1825)




## Eval Run
---
Single Step eval. from the check point

In [None]:

wandb.init = forced_wandb_init

# --- STEP 1: Define Unique ID ---
unique_id = f"AD2C_Eval_{int(time.time())}"

# --- STEP 2: Force WandB to use this exact run ---
os.environ["WANDB_MODE"] = "online"
os.environ["WANDB_RUN_ID"] = unique_id
os.environ["WANDB_NAME"] = unique_id
os.environ["WANDB_INIT_TIMEOUT"] = "300"
os.environ["WANDB_SILENT"] = "true"

# --- STEP 3: Hydra Setup ---
ABS_CONFIG_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf"
CONFIG_NAME = "navigation_ippo"
CHECKPOINT_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/snd06.pt"

GlobalHydra.instance().clear()

sys.argv = [
    "eval_script.py",
    f"experiment.restore_file={CHECKPOINT_PATH}",
    "experiment.evaluation_episodes=10",
    "experiment.render=True",
    "experiment.evaluation_deterministic_actions=True",
    "experiment.save_folder=null",
    "model.desired_snd=0.6",
    f"+experiment.name={unique_id}",
    f"+logger.id={unique_id}"
]

@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def eval_experiment(cfg: DictConfig) -> None:

    OmegaConf.set_struct(cfg, False)
    cfg.logger.id = unique_id

    print(f"Loading model from: {cfg.experiment.restore_file}")

    experiment = get_experiment(cfg=cfg)

    print("Model loaded. Starting Evaluation...")
    experiment._evaluation_loop()
    print("Evaluation Complete.")

    experiment.close()
    wandb.finish()

if __name__ == "__main__":
    try:
        eval_experiment()
    except SystemExit:
        pass
    except Exception as e:
        print(f"An error occurred: {e}")
        wandb.finish()

Saving original WandB init function...

--- INTERCEPTING WANDB INIT ---
Forced ID: AD2C_Eval_1764394529
-------------------------------

Initialized WandB with run id: AD2C_Eval_1764394529
Loading model from: /home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/snd06.pt

Algorithm: ippo, Task: vmas/navigation





--- INTERCEPTING WANDB INIT ---
Forced ID: AD2C_Eval_1764394529
-------------------------------

Model loaded. Starting Evaluation...




Evaluation Complete.


In [None]:
# !wandb login --relogin b3c2b62655aa322e8ab1d1ab07287749ce03ff8

## Testing Env with different task config
---

Changing the task config to check the adaptibility of learned policy. 

In [None]:
wandb.init = forced_wandb_init

# --- STEP 1: Define Unique ID ---
unique_id = f"AD2C_Eval_{int(time.time())}"

# --- STEP 2: Force WandB to use this exact run ---
os.environ["WANDB_MODE"] = "online"
os.environ["WANDB_RUN_ID"] = unique_id
os.environ["WANDB_NAME"] = unique_id
os.environ["WANDB_INIT_TIMEOUT"] = "300"
os.environ["WANDB_SILENT"] = "true"

# --- STEP 3: Hydra Setup ---
ABS_CONFIG_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf"
CONFIG_NAME = "navigation_ippo"
CHECKPOINT_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/snd06.pt"

GlobalHydra.instance().clear()

sys.argv = [
    "eval_script.py",
    f"experiment.restore_file={CHECKPOINT_PATH}",
    "experiment.evaluation_episodes=10",
    "experiment.render=True",
    "experiment.evaluation_deterministic_actions=True",
    "experiment.save_folder=null",
    "model.desired_snd=0.6",

    "task.agents_with_same_goal=1" ,

    f"+experiment.name={unique_id}",
    f"+logger.id={unique_id}"
]

@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def eval_experiment(cfg: DictConfig) -> None:

    OmegaConf.set_struct(cfg, False)
    cfg.logger.id = unique_id

    print(f"Loading model from: {cfg.experiment.restore_file}")

    experiment = get_experiment(cfg=cfg)

    print("Model loaded. Starting Evaluation...")
    experiment._evaluation_loop()
    print("Evaluation Complete.")

    experiment.close()
    wandb.finish()

if __name__ == "__main__":
    try:
        eval_experiment()
    except SystemExit:
        pass
    except Exception as e:
        print(f"An error occurred: {e}")
        wandb.finish()


Loading model from: /home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/snd06.pt

Algorithm: ippo, Task: vmas/navigation

--- INTERCEPTING WANDB INIT ---
Forced ID: AD2C_Eval_1764394624
-------------------------------





Model loaded. Starting Evaluation...




Evaluation Complete.


## Transfer Lerning
---

Finetune the policy to make it work with other task that it could not. 

In [None]:
wandb.init = forced_wandb_init

# --- STEP 1: Define Unique ID ---
unique_id = f"AD2C_Eval_{int(time.time())}"

# --- STEP 2: Force WandB to use this exact run ---
os.environ["WANDB_MODE"] = "online"
os.environ["WANDB_RUN_ID"] = unique_id
os.environ["WANDB_NAME"] = unique_id
os.environ["WANDB_INIT_TIMEOUT"] = "300"
os.environ["WANDB_SILENT"] = "true"


ABS_CONFIG_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf"
CONFIG_NAME = "navigation_ippo"
CHECKPOINT_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/snd06.pt"

# ==========================================
# 2. RUN LOGIC
# ==========================================
new_max_frames = 18000000 
desired_snd = 0.6

GlobalHydra.instance().clear()

sys.argv = [
    "run_script.py",
    f"model.desired_snd={desired_snd}",
    f"experiment.restore_file={CHECKPOINT_PATH}",
    f"experiment.max_n_frames={new_max_frames}",
    
    # --- TASK CONFIGURATION ---
    "task.agents_with_same_goal=3", 
    "experiment.save_folder=null"
]

@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def hydra_experiment(cfg: DictConfig) -> None:
    print(f"Resuming with SND: {cfg.model.desired_snd}")
    print(f"Agents sharing a goal: {cfg.task.agents_with_same_goal}")
    
    experiment = get_experiment(cfg=cfg)
    experiment.run()

if __name__ == "__main__":
    try:
        hydra_experiment()
    except SystemExit:
        print("Experiment finished successfully.")
    except Exception as e:
        print(f"An error occurred: {e}")

Resuming with SND: 0.6
Agents sharing a goal: 3

Algorithm: ippo, Task: vmas/navigation





--- INTERCEPTING WANDB INIT ---
Forced ID: AD2C_Eval_1764406651
-------------------------------


✅ SUCCESS: Extremum Seeking Controller initialized for group 'agents'.




[ESC] Updated SND: 0.6000000238418579 (Reward: 1.148, Update Step: -0.0000)




[ESC] Updated SND: 0.7793900966644287 (Reward: 1.078, Update Step: 0.1794)




[ESC] Updated SND: 0.7636690139770508 (Reward: 0.103, Update Step: -0.0157)




[ESC] Updated SND: 0.5872311592102051 (Reward: -1.629, Update Step: -0.1764)




[ESC] Updated SND: 0.4164064824581146 (Reward: -2.317, Update Step: -0.1708)




[ESC] Updated SND: 0.3906967043876648 (Reward: -3.000, Update Step: -0.0257)




[ESC] Updated SND: 0.5327726602554321 (Reward: -2.880, Update Step: 0.1421)




[ESC] Updated SND: 0.7285916209220886 (Reward: -1.888, Update Step: 0.1958)




[ESC] Updated SND: 0.8129545450210571 (Reward: -0.775, Update Step: 0.0844)




[ESC] Updated SND: 0.7093358635902405 (Reward: -0.038, Update Step: -0.1036)




[ESC] Updated SND: 0.5182251334190369 (Reward: 0.298, Update Step: -0.1911)




[ESC] Updated SND: 0.43275371193885803 (Reward: -0.487, Update Step: -0.0855)




[ESC] Updated SND: 0.5322818160057068 (Reward: -1.146, Update Step: 0.0995)




[ESC] Updated SND: 0.7303656339645386 (Reward: 0.086, Update Step: 0.1981)




[ESC] Updated SND: 0.8521392941474915 (Reward: 0.268, Update Step: 0.1218)




[ESC] Updated SND: 0.7840874195098877 (Reward: -0.310, Update Step: -0.0681)




[ESC] Updated SND: 0.5924826264381409 (Reward: 1.137, Update Step: -0.1916)




[ESC] Updated SND: 0.45903846621513367 (Reward: 0.432, Update Step: -0.1334)




[ESC] Updated SND: 0.5045048594474792 (Reward: 0.152, Update Step: 0.0455)




[ESC] Updated SND: 0.6864054203033447 (Reward: 0.581, Update Step: 0.1819)




[ESC] Updated SND: 0.8442695736885071 (Reward: 0.550, Update Step: 0.1579)




[ESC] Updated SND: 0.8119148015975952 (Reward: -0.014, Update Step: -0.0324)




[ESC] Updated SND: 0.6361719369888306 (Reward: 1.110, Update Step: -0.1757)




[ESC] Updated SND: 0.4565834403038025 (Reward: 1.066, Update Step: -0.1796)




[ESC] Updated SND: 0.4717179536819458 (Reward: 0.152, Update Step: 0.0151)




[ESC] Updated SND: 0.6342374086380005 (Reward: 0.796, Update Step: 0.1625)




[ESC] Updated SND: 0.8322603106498718 (Reward: 1.175, Update Step: 0.1980)




[ESC] Updated SND: 0.8815199732780457 (Reward: 1.094, Update Step: 0.0493)




[ESC] Updated SND: 0.7547338008880615 (Reward: 1.630, Update Step: -0.1268)




[ESC] Updated SND: 0.5608043670654297 (Reward: 1.829, Update Step: -0.1939)




[ESC] Updated SND: 0.5025118589401245 (Reward: 1.479, Update Step: -0.0583)




[ESC] Updated SND: 0.6242393255233765 (Reward: 1.406, Update Step: 0.1217)




[ESC] Updated SND: 0.826641321182251 (Reward: 1.863, Update Step: 0.2024)




[ESC] Updated SND: 0.9225767254829407 (Reward: 1.776, Update Step: 0.0959)




[ESC] Updated SND: 0.8142543435096741 (Reward: 1.022, Update Step: -0.1083)




[ESC] Updated SND: 0.6128775477409363 (Reward: 1.529, Update Step: -0.2014)




[ESC] Updated SND: 0.4798089861869812 (Reward: 1.842, Update Step: -0.1331)




[ESC] Updated SND: 0.5379781126976013 (Reward: 1.856, Update Step: 0.0582)




[ESC] Updated SND: 0.7289766073226929 (Reward: 2.387, Update Step: 0.1910)




[ESC] Updated SND: 0.8694860339164734 (Reward: 2.339, Update Step: 0.1405)




[ESC] Updated SND: 0.7968385219573975 (Reward: 1.287, Update Step: -0.0726)




[ESC] Updated SND: 0.6056327223777771 (Reward: 1.595, Update Step: -0.1912)




[ESC] Updated SND: 0.4523775577545166 (Reward: 1.548, Update Step: -0.1533)




[ESC] Updated SND: 0.47998046875 (Reward: 1.245, Update Step: 0.0276)




[ESC] Updated SND: 0.6542676091194153 (Reward: 1.986, Update Step: 0.1743)




[ESC] Updated SND: 0.8471461534500122 (Reward: 2.461, Update Step: 0.1929)




[ESC] Updated SND: 0.8333594799041748 (Reward: 1.348, Update Step: -0.0138)




[ESC] Updated SND: 0.6689465641975403 (Reward: 1.651, Update Step: -0.1644)




[ESC] Updated SND: 0.48499351739883423 (Reward: 1.721, Update Step: -0.1840)




[ESC] Updated SND: 0.4598637521266937 (Reward: 1.348, Update Step: -0.0251)




[ESC] Updated SND: 0.5999032258987427 (Reward: 1.703, Update Step: 0.1400)




[ESC] Updated SND: 0.8019872307777405 (Reward: 2.152, Update Step: 0.2021)




[ESC] Updated SND: 0.8428605198860168 (Reward: 1.299, Update Step: 0.0409)




[ESC] Updated SND: 0.7137242555618286 (Reward: 1.360, Update Step: -0.1291)




[ESC] Updated SND: 0.5191119909286499 (Reward: 1.399, Update Step: -0.1946)




[ESC] Updated SND: 0.44026631116867065 (Reward: 0.484, Update Step: -0.0788)




[ESC] Updated SND: 0.5427061915397644 (Reward: 0.179, Update Step: 0.1024)




[ESC] Updated SND: 0.7429813742637634 (Reward: 1.739, Update Step: 0.2003)




[ESC] Updated SND: 0.8550353646278381 (Reward: 1.048, Update Step: 0.1121)




[ESC] Updated SND: 0.7873067259788513 (Reward: 1.610, Update Step: -0.0677)




[ESC] Updated SND: 0.6014970541000366 (Reward: 1.656, Update Step: -0.1858)




[ESC] Updated SND: 0.47899019718170166 (Reward: 1.370, Update Step: -0.1225)




[ESC] Updated SND: 0.5475239157676697 (Reward: 0.814, Update Step: 0.0685)




[ESC] Updated SND: 0.7425951957702637 (Reward: 1.768, Update Step: 0.1951)




[ESC] Updated SND: 0.901756227016449 (Reward: 2.279, Update Step: 0.1592)




[ESC] Updated SND: 0.8901187777519226 (Reward: 2.466, Update Step: -0.0116)




[ESC] Updated SND: 0.722015917301178 (Reward: 2.329, Update Step: -0.1681)




[ESC] Updated SND: 0.5616021156311035 (Reward: 1.834, Update Step: -0.1604)




[ESC] Updated SND: 0.559990644454956 (Reward: 1.481, Update Step: -0.0016)




[ESC] Updated SND: 0.718810498714447 (Reward: 1.931, Update Step: 0.1588)




[ESC] Updated SND: 0.9004638195037842 (Reward: 1.860, Update Step: 0.1817)




[ESC] Updated SND: 0.9599202275276184 (Reward: 2.416, Update Step: 0.0595)




[ESC] Updated SND: 0.8274074792861938 (Reward: 2.050, Update Step: -0.1325)




[ESC] Updated SND: 0.6495016813278198 (Reward: 1.911, Update Step: -0.1779)




[ESC] Updated SND: 0.6167809367179871 (Reward: 1.353, Update Step: -0.0327)




[ESC] Updated SND: 0.7486892938613892 (Reward: 1.452, Update Step: 0.1319)




[ESC] Updated SND: 0.9555116891860962 (Reward: 1.969, Update Step: 0.2068)




[ESC] Updated SND: 1.0626249313354492 (Reward: 2.153, Update Step: 0.1071)




[ESC] Updated SND: 0.9537951946258545 (Reward: 1.150, Update Step: -0.1088)




[ESC] Updated SND: 0.7591032981872559 (Reward: 1.380, Update Step: -0.1947)




[ESC] Updated SND: 0.6607710123062134 (Reward: 1.098, Update Step: -0.0983)




[ESC] Updated SND: 0.7376139760017395 (Reward: 1.208, Update Step: 0.0768)




[ESC] Updated SND: 0.9331179857254028 (Reward: 1.614, Update Step: 0.1955)




[ESC] Updated SND: 1.0469722747802734 (Reward: 0.995, Update Step: 0.1139)




[ESC] Updated SND: 1.0097177028656006 (Reward: 1.677, Update Step: -0.0373)




[ESC] Updated SND: 0.8298658728599548 (Reward: 1.718, Update Step: -0.1799)




[ESC] Updated SND: 0.6760863065719604 (Reward: 1.776, Update Step: -0.1538)




[ESC] Updated SND: 0.704460084438324 (Reward: 1.455, Update Step: 0.0284)




[ESC] Updated SND: 0.8789025545120239 (Reward: 1.574, Update Step: 0.1744)




[ESC] Updated SND: 1.0373082160949707 (Reward: 1.354, Update Step: 0.1584)




[ESC] Updated SND: 1.0675793886184692 (Reward: 2.112, Update Step: 0.0303)




[ESC] Updated SND: 0.9175883531570435 (Reward: 1.652, Update Step: -0.1500)




[ESC] Updated SND: 0.7381042838096619 (Reward: 1.887, Update Step: -0.1795)




[ESC] Updated SND: 0.700613260269165 (Reward: 1.906, Update Step: -0.0375)




[ESC] Updated SND: 0.838908314704895 (Reward: 1.973, Update Step: 0.1383)




[ESC] Updated SND: 1.0259952545166016 (Reward: 2.020, Update Step: 0.1871)




[ESC] Updated SND: 1.0302292108535767 (Reward: -3.406, Update Step: 0.0042)




[ESC] Updated SND: 0.8756323456764221 (Reward: -4.922, Update Step: -0.1546)




[ESC] Updated SND: 0.6705653071403503 (Reward: -3.480, Update Step: -0.2051)




[ESC] Updated SND: 0.5546724796295166 (Reward: -1.179, Update Step: -0.1159)


mean return = -1.5155255794525146: 100%|██████████| 300/300 [1:09:40<00:00, 41.80s/it]
