# AD2C : Test Bed

This Jupter Notebook aims to Experiment with different Component of AD2C framework Like task, ESC, Callback and Loggers. This will also include advance experimental setup to test the trainied model. 


----


## Imports

In [1]:
import sys
import os
import hydra
import wandb
import sys
import hydra
import time

from hydra.core.hydra_config import HydraConfig
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig, OmegaConf

# Benchmarl & Project Imports
import benchmarl.models
from benchmarl.algorithms import *
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment
from benchmarl.hydra_config import (
    load_algorithm_config_from_hydra,
    load_experiment_config_from_hydra,
    load_task_config_from_hydra,
    load_model_config_from_hydra,
)

# Custom Callbacks
from het_control.callback import *
from het_control.environments.vmas import render_callback
from het_control.models.het_control_mlp_empirical import HetControlMlpEmpiricalConfig
from het_control.callbacks.esc_callback import ExtremumSeekingController
from het_control.callbacks.sndESLogger import TrajectorySNDLoggerCallback

import numpy as np
import torch
import matplotlib.pyplot as plt
import networkx as nx
import wandb
from tensordict import TensorDict
from typing import List
from benchmarl.experiment.callback import Callback



## Wandb Fix
---

In [2]:
if not hasattr(wandb, "_custom_orig_init"):
    print("Saving original WandB init function...")
    wandb._custom_orig_init = wandb.init

def forced_wandb_init(*args, **kwargs):
    print(f"\n--- INTERCEPTING WANDB INIT ---")
    
    # Force the new ID and Name
    kwargs['id'] = unique_id
    kwargs['name'] = unique_id
    
    # Force "New Run" behavior
    kwargs['resume'] = "allow" 
    kwargs['reinit'] = True
    
    print(f"Forced ID: {unique_id}")
    print(f"-------------------------------\n")
    
    # We always call the SAVED original function, not the current one
    return wandb._custom_orig_init(*args, **kwargs)

# Apply the patch
wandb.init = forced_wandb_init


Saving original WandB init function...


## SND Ploting Callback
This allow us to display plots for the snd. from the eval run. 

---

In [5]:
class SNDHeatmapVisualizer:
    def __init__(self, key_name="Visuals/SND_Heatmap"):
        self.key_name = key_name

    def generate(self, snd_matrix, step_count):
        # --- Fix diagonal ---
        snd_matrix = snd_matrix.copy()
        np.fill_diagonal(snd_matrix, 0.0)

        # --- Enforce symmetry ---
        snd_matrix = (snd_matrix + snd_matrix.T) / 2.0

        n_agents = snd_matrix.shape[0]
        agent_labels = [f"Agent {i+1}" for i in range(n_agents)]
        
        # --- Compute SND = average pairwise distance ---
        iu = np.triu_indices(n_agents, k=1)     # upper triangle (i < j)
        snd_value = float(np.mean(snd_matrix[iu]))

        fig, ax = plt.subplots(figsize=(6, 5))

        im = ax.imshow(
            snd_matrix,
            cmap="viridis",
            interpolation="nearest",
            vmin=0, vmax=3
        )

        # --- Updated title with SND ---
        ax.set_title(f"SND: {snd_value:.3f}  –  Step {step_count}")

        ax.set_xticks(np.arange(n_agents))
        ax.set_yticks(np.arange(n_agents))
        ax.set_xticklabels(agent_labels)
        ax.set_yticklabels(agent_labels)
        plt.setp(ax.get_xticklabels(), rotation=30, ha="right")

        fig.colorbar(im, ax=ax, label="Distance")

        # Cell labels
        for i in range(n_agents):
            for j in range(n_agents):
                val = snd_matrix[i, j]
                text_color = "white" if val < 1.0 else "black"

                ax.text(
                    j, i, f"{val:.2f}",
                    ha="center", va="center",
                    color=text_color,
                    fontsize=9, fontweight="bold"
                )

        plt.tight_layout()
        img = wandb.Image(fig)
        plt.close(fig)
        return {self.key_name: img}


class SNDBarChartVisualizer:
    def __init__(self, key_name="Visuals/SND_BarChart"):
        self.key_name = key_name

    def generate(self, snd_matrix, step_count):
        n_agents = snd_matrix.shape[0]

        # --- Fix diagonal ---
        snd_matrix = snd_matrix.copy()
        np.fill_diagonal(snd_matrix, 0.0)

        # --- Enforce symmetry ---
        snd_matrix = (snd_matrix + snd_matrix.T) / 2.0

        # --- Create agent pairs i < j ---
        pairs = [(i, j) for i in range(n_agents) for j in range(i + 1, n_agents)]
        if not pairs:
            return {}

        pair_values = [float(snd_matrix[i, j]) for i, j in pairs]
        pair_labels = [f"A{i+1}-A{j+1}" for i, j in pairs]

        # --- Compute SND (mean of pairwise distances) ---
        snd_value = float(np.mean(pair_values))

        # --- Plot ---
        fig, ax = plt.subplots(figsize=(8, 5))
        bars = ax.bar(pair_labels, pair_values, color="teal")

        ax.set_title(f"SND: {snd_value:.3f}  –  Step {step_count}")
        ax.set_ylabel("Distance")
        ax.set_ylim(0, 3)
        ax.tick_params(axis="x", rotation=45)

        # Add value labels above bars
        ax.bar_label(bars, fmt="%.2f", padding=3)

        plt.tight_layout()
        img = wandb.Image(fig)
        plt.close(fig)
        return {self.key_name: img}


class SNDGraphVisualizer:
    def __init__(self, key_name="Visuals/SND_NetworkGraph"):
        self.key_name = key_name

    def generate(self, snd_matrix, step_count):
        n_agents = snd_matrix.shape[0]

        # --- Fix diagonal ---
        snd_matrix = snd_matrix.copy()
        np.fill_diagonal(snd_matrix, 0.0)

        # --- Enforce symmetry ---
        snd_matrix = (snd_matrix + snd_matrix.T) / 2.0

        # --- Create edges only for i < j ---
        pairs = [(i, j) for i in range(n_agents) for j in range(i + 1, n_agents)]
        if not pairs:
            return {}

        # Distances for each pair
        pair_values = [float(snd_matrix[i, j]) for i, j in pairs]

        # --- Compute SND (avg distance) ---
        snd_value = float(np.mean(pair_values))

        # Build graph
        fig = plt.figure(figsize=(7, 7))
        G = nx.Graph()

        # Add edges with weights
        for i, j in pairs:
            G.add_edge(i, j, weight=float(snd_matrix[i, j]))

        # Layout
        pos = nx.spring_layout(G, seed=42)

        # Edge weights for coloring
        weights = [G[u][v]['weight'] for u, v in G.edges()]

        # --- Draw Nodes ---
        nx.draw_networkx_nodes(
            G, pos, node_size=750, node_color='lightblue'
        )

        # Label nodes as Agent 1, Agent 2, ...
        label_mapping = {i: f"A{i+1}" for i in range(n_agents)}
        nx.draw_networkx_labels(
            G, pos, labels=label_mapping, font_size=12, font_weight='bold'
        )

        # --- Draw edges ---
        edges = nx.draw_networkx_edges(
            G, pos,
            edge_color=weights,
            edge_cmap=plt.cm.viridis,
            width=2,
            edge_vmin=0,
            edge_vmax=3
        )

        # --- Draw edge labels ---
        edge_labels = {(i, j): f"{snd_matrix[i, j]:.2f}" for i, j in pairs}

        nx.draw_networkx_edge_labels(
            G,
            pos,
            edge_labels=edge_labels,
            font_color='black',
            font_size=9,
            font_weight='bold'
        )

        # Colorbar
        plt.colorbar(edges, label='Distance')

        # Title with SND value
        plt.title(f"SND: {snd_value:.3f}  –  Step {step_count}", fontsize=14)
        plt.axis('off')

        img = wandb.Image(fig)
        plt.close(fig)
        return {self.key_name: img}


class SNDVisualizationManager:
    """
    Manages the individual visualizers.
    """
    def __init__(self):
        self.visualizers = [
            SNDHeatmapVisualizer(),
            SNDBarChartVisualizer(),
            SNDGraphVisualizer()
        ]

    def generate_all(self, snd_matrix, step_count):
        all_plots = {}
        for visualizer in self.visualizers:
            try:
                plots = visualizer.generate(snd_matrix, step_count)
                all_plots.update(plots)
            except Exception as e:
                print(f"Error generating {visualizer.__class__.__name__}: {e}")
        return all_plots


In [6]:
class SNDVisualizerCallback(Callback):
    """
    Computes the SND matrix and uses the Manager to log visualizations.
    """
    def __init__(self):
        super().__init__()
        self.control_group = None
        self.model = None
        # Initialize the manager that holds the 3 plot classes
        self.viz_manager = SNDVisualizationManager()

    def on_setup(self):
        """Auto-detects the agent group and initializes the model wrapper."""
        if not self.experiment.group_policies:
            print("\nWARNING: No group policies found. SND Visualizer disabled.\n")
            return

        self.control_group = list(self.experiment.group_policies.keys())[0]
        policy = self.experiment.group_policies[self.control_group]
        
        # Ensure 'get_het_model' is imported or available in this scope
        self.model = get_het_model(policy)

        if self.model is None:
             print(f"\nWARNING: Could not extract HetModel for group '{self.control_group}'. Visualizer disabled.\n")

    def _get_agent_actions_for_rollout(self, rollout):
        """Helper to run the forward pass and get actions for SND computation."""
        obs = rollout.get((self.control_group, "observation"))
        actions = []
        for i in range(self.model.n_agents):
            temp_td = TensorDict(
                {(self.control_group, "observation"): obs},
                batch_size=obs.shape[:-1]
            )
            action_td = self.model._forward(temp_td, agent_index=i, compute_estimate=False)
            actions.append(action_td.get(self.model.out_key))
        return actions

    def on_evaluation_end(self, rollouts: List[TensorDict]):
        """Runs at the end of evaluation to compute SND and log plots."""
        if self.model is None:
            return

        logs_to_push = {}
        first_rollout_snd_matrix = None

        with torch.no_grad():
            for i, r in enumerate(rollouts):
                # We only need the matrix from the first rollout for clean visualization
                if i > 0: 
                    break

                agent_actions = self._get_agent_actions_for_rollout(r)
                
                # Ensure 'compute_behavioral_distance' is imported/available
                pairwise_distances_tensor = compute_behavioral_distance(agent_actions, just_mean=False)
                
                if pairwise_distances_tensor.ndim > 2:
                    pairwise_distances_tensor = pairwise_distances_tensor.mean(dim=0)

                first_rollout_snd_matrix = pairwise_distances_tensor.cpu().numpy()

        # Generate and Log Visualizations via the Manager
        if first_rollout_snd_matrix is not None:
            visual_logs = self.viz_manager.generate_all(
                snd_matrix=first_rollout_snd_matrix, 
                step_count=self.experiment.n_iters_performed
            )
            logs_to_push.update(visual_logs)
            
            # Update the logger
            self.experiment.logger.log(logs_to_push, step=self.experiment.n_iters_performed)

## Env Setup

This code block compiles all the different section of MARl env together to Run the experiment

---

In [None]:
# 1. EXPERIMENT LOGIC

def setup(task_name):
    benchmarl.models.model_config_registry.update(
        {
            "hetcontrolmlpempirical": HetControlMlpEmpiricalConfig,
        }
    )
    if task_name == "vmas/navigation":
        # Set the render callback for the navigation case study
        VmasTask.render_callback = render_callback

def get_experiment(cfg: DictConfig) -> Experiment:
    hydra_choices = HydraConfig.get().runtime.choices
    task_name = hydra_choices.task
    algorithm_name = hydra_choices.algorithm

    setup(task_name)

    print(f"\nAlgorithm: {algorithm_name}, Task: {task_name}")
    # print("\nLoaded config:\n") # Optional: Commented out to reduce clutter
    # print(OmegaConf.to_yaml(cfg))

    algorithm_config = load_algorithm_config_from_hydra(cfg.algorithm)
    experiment_config = load_experiment_config_from_hydra(cfg.experiment)
    task_config = load_task_config_from_hydra(cfg.task, task_name)
    critic_model_config = load_model_config_from_hydra(cfg.critic_model)
    model_config = load_model_config_from_hydra(cfg.model)

    if isinstance(algorithm_config, (MappoConfig, IppoConfig, MasacConfig, IsacConfig)):
        model_config.probabilistic = True
        model_config.scale_mapping = algorithm_config.scale_mapping
        algorithm_config.scale_mapping = (
            "relu"  # The scaling of std_dev will be done in the model
        )
    else:
        model_config.probabilistic = False

    experiment = Experiment(
        task=task_config,
        algorithm_config=algorithm_config,
        model_config=model_config,
        critic_model_config=critic_model_config,
        seed=cfg.seed,
        config=experiment_config,
        callbacks=[
            SndCallback(),
            ExtremumSeekingController(
                        control_group="agents",
                        initial_snd=0.6,
                        dither_magnitude=0.2,
                        dither_frequency_rad_s=1.0,
                        integral_gain=-0.1,
                        high_pass_cutoff_rad_s=1.0,
                        low_pass_cutoff_rad_s=1.0,
                        sampling_period=1.0
            ),
            SNDVisualizerCallback(),
            # TrajectorySNDLoggerCallback(control_group="agents"),
            NormLoggerCallback(),
            ActionSpaceLoss(
                use_action_loss=cfg.use_action_loss, action_loss_lr=cfg.action_loss_lr
            ),
        ]
        + (
            [
                TagCurriculum(
                    cfg.simple_tag_freeze_policy_after_frames,
                    cfg.simple_tag_freeze_policy,
                )
            ]
            if task_name == "vmas/simple_tag"
            else []
        ),
    )
    return experiment

## Training Code

Trains the model for 100 episodes. 

In [None]:
ABS_CONFIG_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf"
CONFIG_NAME = "navigation_ippo"  # Make sure 'navigation_ippo.yaml' exists in the folder above!
SAVE_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/model_checkpoints/navigation_ippo_esc/"

save_interval = 600000
desired_snd = 1.0
max_frame = 6000000

if not os.path.exists(SAVE_PATH):
    print(f"Creating missing directory: {SAVE_PATH}")
    os.makedirs(SAVE_PATH, exist_ok=True)

GlobalHydra.instance().clear()

sys.argv = [
    "dummy.py",
    f"model.desired_snd={desired_snd}",
    f"experiment.max_n_frames={max_frame}",
    f"experiment.checkpoint_interval={save_interval}",
    f"experiment.save_folder={SAVE_PATH}", # FIXED: Removed space after '='
]

# 3. Define the Hydra wrapper
@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def hydra_experiment(cfg: DictConfig) -> None:
    print(f"Config loaded from: {ABS_CONFIG_PATH}")
    if wandb.run is not None:
        print("Finishing previous WandB run...")
        wandb.finish()
    
    print(f"Running with SND: {cfg.model.desired_snd}")
    
    experiment = get_experiment(cfg=cfg)
    experiment.run()
    wandb.finish()

# 4. Execute safely
if __name__ == "__main__":
    try:
        hydra_experiment()
    except SystemExit:
        print("Experiment finished successfully.")
    except Exception as e:
        print(f"An error occurred: {e}")

Config loaded from: /home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf
Running with SND: 1.0

Algorithm: ippo, Task: vmas/navigation






mean return = -0.2842506170272827:   1%|          | 1/100 [00:39<1:05:36, 39.76s/it]

mean return = -0.12061435729265213:   2%|▏         | 2/100 [01:19<1:05:06, 39.86s/it]

mean return = 0.06278912723064423:   3%|▎         | 3/100 [01:57<1:02:38, 38.75s/it] 

mean return = 0.3085945248603821:   4%|▍         | 4/100 [02:36<1:02:27, 39.03s/it] 

mean return = -1.3374804258346558:   5%|▌         | 5/100 [03:18<1:03:16, 39.96s/it]

mean return = -1.2115898132324219:   6%|▌         | 6/100 [03:57<1:02:17, 39.76s/it]

mean return = -0.7913774847984314:   7%|▋         | 7/100 [04:34<1:00:11, 38.83s/it]

mean return = -0.37020793557167053:   8%|▊         | 8/100 [05:12<59:00, 38.48s/it] 

mean return = -0.3419821858406067:   9%|▉         | 9/100 [05:52<59:06, 38.97s/it] 

mean return = -0.7258304357528687:  10%|█         | 10/100 [06:32<58:53, 39.26s/it]

mean return = -0.356292188167572:  11%|█         | 11/100 [07:12<58:55, 39.73s/it] 

mean return = -0.005251131020486355:  12%|█▏        | 12/100 [07:57<1:00:12, 41.05s/it]

mean return = -0.04384047910571098:  13%|█▎        | 13/100 [08:34<58:07, 40.08s/it]   

mean return = 0.36396729946136475:  14%|█▍        | 14/100 [09:13<56:46, 39.61s/it] 

mean return = -0.9160506129264832:  15%|█▌        | 15/100 [09:54<56:41, 40.02s/it]

mean return = 0.2671414911746979:  16%|█▌        | 16/100 [10:34<55:57, 39.97s/it] 

mean return = -0.20932984352111816:  17%|█▋        | 17/100 [11:13<54:49, 39.64s/it]

mean return = 0.27600908279418945:  18%|█▊        | 18/100 [11:55<55:14, 40.42s/it] 

mean return = 0.2526856064796448:  19%|█▉        | 19/100 [12:36<54:54, 40.67s/it] 

mean return = -1.683846116065979:  20%|██        | 20/100 [13:18<54:46, 41.09s/it]

mean return = -1.610228180885315:  21%|██        | 21/100 [13:59<53:58, 41.00s/it]

mean return = -1.3977614641189575:  22%|██▏       | 22/100 [14:38<52:43, 40.56s/it]

mean return = -1.212141990661621:  23%|██▎       | 23/100 [15:21<52:37, 41.01s/it] 

mean return = -0.9942213296890259:  24%|██▍       | 24/100 [16:03<52:21, 41.34s/it]

mean return = -0.8850975632667542:  25%|██▌       | 25/100 [16:45<52:03, 41.64s/it]

mean return = -0.7868425250053406:  26%|██▌       | 26/100 [17:27<51:33, 41.81s/it]

mean return = -0.7432085871696472:  27%|██▋       | 27/100 [18:07<50:06, 41.19s/it]

mean return = -0.7048777937889099:  28%|██▊       | 28/100 [18:45<48:25, 40.36s/it]

mean return = -0.7007088661193848:  29%|██▉       | 29/100 [19:25<47:22, 40.04s/it]

mean return = -0.647488534450531:  30%|███       | 30/100 [20:06<47:06, 40.37s/it] 

mean return = -0.6305897235870361:  31%|███       | 31/100 [20:48<47:01, 40.88s/it]

mean return = -0.6276352405548096:  32%|███▏      | 32/100 [21:26<45:18, 39.98s/it]

mean return = -0.6621479392051697:  33%|███▎      | 33/100 [22:05<44:30, 39.85s/it]

mean return = -0.6578906178474426:  34%|███▍      | 34/100 [22:48<44:37, 40.57s/it]

mean return = -0.6201854348182678:  35%|███▌      | 35/100 [23:28<43:47, 40.43s/it]

mean return = -0.6114463210105896:  36%|███▌      | 36/100 [24:04<41:57, 39.33s/it]

mean return = -0.6134855151176453:  37%|███▋      | 37/100 [24:45<41:38, 39.65s/it]

mean return = -0.5852857828140259:  38%|███▊      | 38/100 [25:25<41:15, 39.92s/it]

mean return = -0.5829774141311646:  39%|███▉      | 39/100 [26:06<40:48, 40.14s/it]

mean return = -0.5868636965751648:  40%|████      | 40/100 [26:48<40:50, 40.84s/it]

mean return = -0.5644099116325378:  41%|████      | 41/100 [27:29<39:55, 40.59s/it]

mean return = -0.6008509397506714:  42%|████▏     | 42/100 [28:08<39:03, 40.41s/it]

mean return = -0.5615042448043823:  43%|████▎     | 43/100 [28:51<38:52, 40.92s/it]

mean return = -0.5826094746589661:  44%|████▍     | 44/100 [29:30<37:38, 40.32s/it]

mean return = -0.5277758836746216:  45%|████▌     | 45/100 [30:07<36:06, 39.38s/it]

mean return = -0.5691078901290894:  46%|████▌     | 46/100 [30:45<35:01, 38.91s/it]

mean return = -0.585959792137146:  47%|████▋     | 47/100 [31:25<34:42, 39.28s/it] 

mean return = -0.4881935715675354:  48%|████▊     | 48/100 [32:08<35:09, 40.57s/it]

mean return = -0.545187771320343:  49%|████▉     | 49/100 [32:54<35:48, 42.12s/it] 

mean return = -0.5338388681411743:  50%|█████     | 50/100 [33:36<35:00, 42.00s/it]

mean return = -0.5915013551712036:  51%|█████     | 51/100 [34:22<35:16, 43.20s/it]

mean return = -0.6400428414344788:  52%|█████▏    | 52/100 [35:08<35:20, 44.18s/it]

mean return = -0.6746166944503784:  53%|█████▎    | 53/100 [35:49<33:48, 43.16s/it]

mean return = -0.6268823742866516:  54%|█████▍    | 54/100 [36:37<34:06, 44.48s/it]

mean return = -0.6256210803985596:  55%|█████▌    | 55/100 [37:35<36:37, 48.82s/it]

mean return = -0.5703659057617188:  56%|█████▌    | 56/100 [38:29<36:53, 50.30s/it]

mean return = -0.6798580884933472:  57%|█████▋    | 57/100 [39:12<34:31, 48.18s/it]

mean return = -0.655034065246582:  58%|█████▊    | 58/100 [39:52<31:56, 45.64s/it] 

mean return = -0.6068065762519836:  59%|█████▉    | 59/100 [40:29<29:27, 43.12s/it]

mean return = -0.666559100151062:  60%|██████    | 60/100 [41:11<28:29, 42.74s/it] 

mean return = -0.6611549258232117:  61%|██████    | 61/100 [41:58<28:39, 44.09s/it]

mean return = -0.6468213200569153:  62%|██████▏   | 62/100 [42:41<27:34, 43.55s/it]

mean return = -0.6857307553291321:  63%|██████▎   | 63/100 [43:21<26:15, 42.57s/it]

mean return = -0.7061131000518799:  64%|██████▍   | 64/100 [44:11<26:51, 44.77s/it]

mean return = -0.7006723880767822:  65%|██████▌   | 65/100 [45:00<26:46, 45.91s/it]

mean return = -0.6245374083518982:  66%|██████▌   | 66/100 [46:13<30:38, 54.06s/it]

mean return = -0.6481888890266418:  67%|██████▋   | 67/100 [48:00<38:26, 69.91s/it]

mean return = -0.6733251810073853:  68%|██████▊   | 68/100 [50:17<48:09, 90.31s/it]

mean return = -0.6577150821685791:  69%|██████▉   | 69/100 [52:33<53:42, 103.94s/it]

mean return = -0.618562638759613:  70%|███████   | 70/100 [54:26<53:17, 106.60s/it] 

mean return = -0.641441822052002:  71%|███████   | 71/100 [57:34<1:03:16, 130.92s/it]

mean return = -0.6700270771980286:  72%|███████▏  | 72/100 [1:01:57<1:19:37, 170.63s/it]

mean return = -0.6492499709129333:  73%|███████▎  | 73/100 [1:05:11<1:19:56, 177.64s/it]

mean return = -0.6325398683547974:  74%|███████▍  | 74/100 [1:06:27<1:03:48, 147.24s/it]

mean return = -0.6668291687965393:  75%|███████▌  | 75/100 [1:08:33<58:40, 140.81s/it]  

mean return = -0.6201265454292297:  76%|███████▌  | 76/100 [1:11:36<1:01:19, 153.32s/it]

mean return = -0.6522907018661499:  77%|███████▋  | 77/100 [1:15:03<1:04:57, 169.44s/it]

mean return = -0.662007212638855:  78%|███████▊  | 78/100 [1:16:14<51:21, 140.07s/it]   

mean return = -0.628177285194397:  79%|███████▉  | 79/100 [1:18:22<47:47, 136.54s/it]

mean return = -0.6931182146072388:  80%|████████  | 80/100 [1:21:41<51:40, 155.04s/it]

mean return = -0.5893171429634094:  81%|████████  | 81/100 [1:25:14<54:37, 172.53s/it]

mean return = -0.6325810551643372:  82%|████████▏ | 82/100 [1:26:09<41:09, 137.21s/it]

mean return = -0.6400985717773438:  83%|████████▎ | 83/100 [1:28:29<39:09, 138.18s/it]

mean return = -0.5972829461097717:  84%|████████▍ | 84/100 [1:31:21<39:33, 148.34s/it]

mean return = -0.6537256240844727:  85%|████████▌ | 85/100 [1:34:59<42:19, 169.28s/it]

mean return = -0.617300808429718:  86%|████████▌ | 86/100 [1:36:04<32:10, 137.91s/it] 

mean return = -0.6039918065071106:  87%|████████▋ | 87/100 [1:37:43<27:19, 126.10s/it]

## Eval Run
---
Single Step eval. from the check point

In [None]:

wandb.init = forced_wandb_init

# --- STEP 1: Define Unique ID ---
unique_id = f"AD2C_Eval_{int(time.time())}"

# --- STEP 2: Force WandB to use this exact run ---
os.environ["WANDB_MODE"] = "online"
os.environ["WANDB_RUN_ID"] = unique_id
os.environ["WANDB_NAME"] = unique_id
os.environ["WANDB_INIT_TIMEOUT"] = "300"
os.environ["WANDB_SILENT"] = "true"

# --- STEP 3: Hydra Setup ---
ABS_CONFIG_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf"
CONFIG_NAME = "navigation_ippo"
CHECKPOINT_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/snd06.pt"

GlobalHydra.instance().clear()

sys.argv = [
    "eval_script.py",
    f"experiment.restore_file={CHECKPOINT_PATH}",
    "experiment.evaluation_episodes=10",
    "experiment.render=True",
    "experiment.evaluation_deterministic_actions=True",
    "experiment.save_folder=null",
    "model.desired_snd=0.6",
    f"+experiment.name={unique_id}",
    f"+logger.id={unique_id}"
]

@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def eval_experiment(cfg: DictConfig) -> None:

    OmegaConf.set_struct(cfg, False)
    cfg.logger.id = unique_id

    print(f"Loading model from: {cfg.experiment.restore_file}")

    experiment = get_experiment(cfg=cfg)

    print("Model loaded. Starting Evaluation...")
    experiment._evaluation_loop()
    print("Evaluation Complete.")

    experiment.close()
    wandb.finish()

if __name__ == "__main__":
    try:
        eval_experiment()
    except SystemExit:
        pass
    except Exception as e:
        print(f"An error occurred: {e}")
        wandb.finish()

Saving original WandB init function...

--- INTERCEPTING WANDB INIT ---
Forced ID: AD2C_Eval_1764394529
-------------------------------

Initialized WandB with run id: AD2C_Eval_1764394529
Loading model from: /home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/snd06.pt

Algorithm: ippo, Task: vmas/navigation





--- INTERCEPTING WANDB INIT ---
Forced ID: AD2C_Eval_1764394529
-------------------------------

Model loaded. Starting Evaluation...




Evaluation Complete.


In [None]:
# !wandb login --relogin b3c2b62655aa322e8ab1d1ab07287749ce03ff8

## Testing Env with different task config
---

Changing the task config to check the adaptibility of learned policy. 

In [None]:
wandb.init = forced_wandb_init

# --- STEP 1: Define Unique ID ---
unique_id = f"AD2C_Eval_{int(time.time())}"

# --- STEP 2: Force WandB to use this exact run ---
os.environ["WANDB_MODE"] = "online"
os.environ["WANDB_RUN_ID"] = unique_id
os.environ["WANDB_NAME"] = unique_id
os.environ["WANDB_INIT_TIMEOUT"] = "300"
os.environ["WANDB_SILENT"] = "true"

# --- STEP 3: Hydra Setup ---
ABS_CONFIG_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf"
CONFIG_NAME = "navigation_ippo"
CHECKPOINT_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/snd06.pt"

GlobalHydra.instance().clear()

sys.argv = [
    "eval_script.py",
    f"experiment.restore_file={CHECKPOINT_PATH}",
    "experiment.evaluation_episodes=10",
    "experiment.render=True",
    "experiment.evaluation_deterministic_actions=True",
    "experiment.save_folder=null",
    "model.desired_snd=0.6",

    "task.agents_with_same_goal=1" ,

    f"+experiment.name={unique_id}",
    f"+logger.id={unique_id}"
]

@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def eval_experiment(cfg: DictConfig) -> None:

    OmegaConf.set_struct(cfg, False)
    cfg.logger.id = unique_id

    print(f"Loading model from: {cfg.experiment.restore_file}")

    experiment = get_experiment(cfg=cfg)

    print("Model loaded. Starting Evaluation...")
    experiment._evaluation_loop()
    print("Evaluation Complete.")

    experiment.close()
    wandb.finish()

if __name__ == "__main__":
    try:
        eval_experiment()
    except SystemExit:
        pass
    except Exception as e:
        print(f"An error occurred: {e}")
        wandb.finish()


Loading model from: /home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/snd06.pt

Algorithm: ippo, Task: vmas/navigation

--- INTERCEPTING WANDB INIT ---
Forced ID: AD2C_Eval_1764394624
-------------------------------





Model loaded. Starting Evaluation...




Evaluation Complete.


## Transfer Lerning
---

Finetune the policy to make it work with other task that it could not. 

In [None]:
wandb.init = forced_wandb_init

# --- STEP 1: Define Unique ID ---
unique_id = f"AD2C_Eval_{int(time.time())}"

# --- STEP 2: Force WandB to use this exact run ---
os.environ["WANDB_MODE"] = "online"
os.environ["WANDB_RUN_ID"] = unique_id
os.environ["WANDB_NAME"] = unique_id
os.environ["WANDB_INIT_TIMEOUT"] = "300"
os.environ["WANDB_SILENT"] = "true"


ABS_CONFIG_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf"
CONFIG_NAME = "navigation_ippo"
CHECKPOINT_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/snd06.pt"

# ==========================================
# 2. RUN LOGIC
# ==========================================
new_max_frames = 18000000 
desired_snd = 0.6

GlobalHydra.instance().clear()

sys.argv = [
    "run_script.py",
    f"model.desired_snd={desired_snd}",
    f"experiment.restore_file={CHECKPOINT_PATH}",
    f"experiment.max_n_frames={new_max_frames}",
    
    # --- TASK CONFIGURATION ---
    "task.agents_with_same_goal=1", 
    "experiment.save_folder=null"
]

@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def hydra_experiment(cfg: DictConfig) -> None:
    print(f"Resuming with SND: {cfg.model.desired_snd}")
    print(f"Agents sharing a goal: {cfg.task.agents_with_same_goal}")
    
    experiment = get_experiment(cfg=cfg)
    experiment.run()

if __name__ == "__main__":
    try:
        hydra_experiment()
    except SystemExit:
        print("Experiment finished successfully.")
    except Exception as e:
        print(f"An error occurred: {e}")

Resuming with SND: 0.6
Agents sharing a goal: 1

Algorithm: ippo, Task: vmas/navigation

--- INTERCEPTING WANDB INIT ---
Forced ID: AD2C_Eval_1764396731
-------------------------------


✅ SUCCESS: Extremum Seeking Controller initialized for group 'agents'.




[ESC] Updated SND: 0.6000000238418579 (Reward: 1.908, Update Step: -0.0000)




[ESC] Updated SND: 0.7078927159309387 (Reward: 1.942, Update Step: 0.1079)




[ESC] Updated SND: 0.722448468208313 (Reward: 1.693, Update Step: 0.0146)




[ESC] Updated SND: 0.6476646065711975 (Reward: 1.563, Update Step: -0.0748)




[ESC] Updated SND: 0.567611813545227 (Reward: 1.327, Update Step: -0.0801)




[ESC] Updated SND: 0.5559406280517578 (Reward: 1.298, Update Step: -0.0117)




[ESC] Updated SND: 0.6274011135101318 (Reward: 1.308, Update Step: 0.0715)




[ESC] Updated SND: 0.7182595729827881 (Reward: 1.171, Update Step: 0.0909)




[ESC] Updated SND: 0.7441697716712952 (Reward: 1.069, Update Step: 0.0259)




[ESC] Updated SND: 0.6808924674987793 (Reward: 0.955, Update Step: -0.0633)




[ESC] Updated SND: 0.5834177136421204 (Reward: 1.002, Update Step: -0.0975)




[ESC] Updated SND: 0.5351721048355103 (Reward: 1.051, Update Step: -0.0482)




[ESC] Updated SND: 0.5769445300102234 (Reward: 1.195, Update Step: 0.0418)




[ESC] Updated SND: 0.6754726767539978 (Reward: 1.398, Update Step: 0.0985)




[ESC] Updated SND: 0.7435112595558167 (Reward: 1.551, Update Step: 0.0680)




[ESC] Updated SND: 0.7074368596076965 (Reward: 1.249, Update Step: -0.0361)




[ESC] Updated SND: 0.6101542115211487 (Reward: 1.547, Update Step: -0.0973)




[ESC] Updated SND: 0.5414140224456787 (Reward: 1.475, Update Step: -0.0687)




[ESC] Updated SND: 0.566544234752655 (Reward: 1.329, Update Step: 0.0251)




[ESC] Updated SND: 0.6594111323356628 (Reward: 1.579, Update Step: 0.0929)




[ESC] Updated SND: 0.7453837990760803 (Reward: 1.737, Update Step: 0.0860)




[ESC] Updated SND: 0.7457050085067749 (Reward: 1.786, Update Step: 0.0003)




[ESC] Updated SND: 0.6640634536743164 (Reward: 1.787, Update Step: -0.0816)




[ESC] Updated SND: 0.5784136652946472 (Reward: 1.856, Update Step: -0.0856)




[ESC] Updated SND: 0.5772435069084167 (Reward: 1.685, Update Step: -0.0012)




[ESC] Updated SND: 0.6555934548377991 (Reward: 1.863, Update Step: 0.0783)




[ESC] Updated SND: 0.7514919638633728 (Reward: 2.007, Update Step: 0.0959)




[ESC] Updated SND: 0.783291757106781 (Reward: 2.194, Update Step: 0.0318)




[ESC] Updated SND: 0.7209947109222412 (Reward: 2.258, Update Step: -0.0623)




[ESC] Updated SND: 0.6292668581008911 (Reward: 2.230, Update Step: -0.0917)




[ESC] Updated SND: 0.6055794358253479 (Reward: 2.028, Update Step: -0.0237)




[ESC] Updated SND: 0.666426956653595 (Reward: 2.138, Update Step: 0.0608)




[ESC] Updated SND: 0.7684371471405029 (Reward: 2.363, Update Step: 0.1020)




[ESC] Updated SND: 0.8234175443649292 (Reward: 2.467, Update Step: 0.0550)




[ESC] Updated SND: 0.7822811603546143 (Reward: 2.505, Update Step: -0.0411)




[ESC] Updated SND: 0.6871809363365173 (Reward: 2.558, Update Step: -0.0951)




[ESC] Updated SND: 0.6321253180503845 (Reward: 2.504, Update Step: -0.0551)




[ESC] Updated SND: 0.6713494062423706 (Reward: 2.368, Update Step: 0.0392)




[ESC] Updated SND: 0.768584668636322 (Reward: 2.546, Update Step: 0.0972)




[ESC] Updated SND: 0.8420714139938354 (Reward: 2.640, Update Step: 0.0735)




[ESC] Updated SND: 0.8237938284873962 (Reward: 2.629, Update Step: -0.0183)




[ESC] Updated SND: 0.7339723110198975 (Reward: 2.733, Update Step: -0.0898)




[ESC] Updated SND: 0.6619102954864502 (Reward: 2.603, Update Step: -0.0721)




[ESC] Updated SND: 0.6746892333030701 (Reward: 2.550, Update Step: 0.0128)




[ESC] Updated SND: 0.7611922025680542 (Reward: 2.540, Update Step: 0.0865)




[ESC] Updated SND: 0.8508076071739197 (Reward: 2.712, Update Step: 0.0896)




[ESC] Updated SND: 0.8639097213745117 (Reward: 2.809, Update Step: 0.0131)




[ESC] Updated SND: 0.7896491289138794 (Reward: 2.876, Update Step: -0.0743)




[ESC] Updated SND: 0.7035004496574402 (Reward: 2.784, Update Step: -0.0861)




[ESC] Updated SND: 0.6878795027732849 (Reward: 2.757, Update Step: -0.0156)




[ESC] Updated SND: 0.7585951685905457 (Reward: 2.726, Update Step: 0.0707)




[ESC] Updated SND: 0.8529530763626099 (Reward: 2.760, Update Step: 0.0944)




[ESC] Updated SND: 0.8899185657501221 (Reward: 2.872, Update Step: 0.0370)




[ESC] Updated SND: 0.834480881690979 (Reward: 2.933, Update Step: -0.0554)




[ESC] Updated SND: 0.7387382984161377 (Reward: 2.968, Update Step: -0.0957)




[ESC] Updated SND: 0.6981923580169678 (Reward: 2.858, Update Step: -0.0405)




[ESC] Updated SND: 0.7484593391418457 (Reward: 2.835, Update Step: 0.0503)




[ESC] Updated SND: 0.8458560705184937 (Reward: 2.892, Update Step: 0.0974)


