# AD2C : Test Bed

This Jupter Notebook aims to Experiment with different Component of AD2C framework Like task, ESC, Callback and Loggers. This will also include advance experimental setup to test the trainied model. 


----


## Imports

In [3]:
import sys
import os
import hydra
import wandb
import sys
import hydra
from hydra.core.hydra_config import HydraConfig
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig, OmegaConf

# Benchmarl & Project Imports
import benchmarl.models
from benchmarl.algorithms import *
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment
from benchmarl.hydra_config import (
    load_algorithm_config_from_hydra,
    load_experiment_config_from_hydra,
    load_task_config_from_hydra,
    load_model_config_from_hydra,
)

# Custom Callbacks
from het_control.callback import *
from het_control.environments.vmas import render_callback
from het_control.models.het_control_mlp_empirical import HetControlMlpEmpiricalConfig
from het_control.callbacks.esc_callback import ExtremumSeekingController
from het_control.callbacks.sndESLogger import TrajectorySNDLoggerCallback

import numpy as np
import torch
import matplotlib.pyplot as plt
import networkx as nx
import wandb
from tensordict import TensorDict
from typing import List
from benchmarl.experiment.callback import Callback

## SND Ploting Callback
This allow us to display plots for the snd. from the eval run. 

---

In [4]:
def generate_snd_visualizations(snd_matrix, n_agents, step_count):
    """
    Generates 3 matplotlib figures (Heatmap, Bar, Graph) for a given SND matrix.
    Returns a dictionary of wandb.Image objects.
    Assumes snd_matrix is an N x N Symmetric Distance Matrix.
    """
    plots = {}
    
    # Define pairs (Upper triangle only, since matrix is symmetric: 1-2 is same as 2-1)
    pairs = [(i, j) for i in range(n_agents) for j in range(i + 1, n_agents)]
    
    # ==========================================
    # 1. HEATMAP with Cell Values
    # ==========================================
    fig_heat, ax_heat = plt.subplots(figsize=(6, 5))
    im = ax_heat.imshow(snd_matrix, cmap='viridis', interpolation='nearest', vmin=0, vmax=3)
    
    ax_heat.set_title(f'SND Matrix (Heatmap) - Step {step_count}')
    ax_heat.set_xlabel('Agent Index')
    ax_heat.set_ylabel('Agent Index')
    
    # Set ticks to be integers (Agent 0, Agent 1...)
    ax_heat.set_xticks(np.arange(n_agents))
    ax_heat.set_yticks(np.arange(n_agents))
    
    fig_heat.colorbar(im, ax=ax_heat, label='Distance')
    
    # Loop over data dimensions and create text annotations.
    for i in range(n_agents):
        for j in range(n_agents):
            val = snd_matrix[i, j]
            
            # Text color logic: White for dark background (low values), Black for light (high values)
            # Viridis: Low values (purple) -> White text. High values (yellow) -> Black text.
            # Scale is 0-3. Midpoint roughly 1.5.
            text_color = "white" if val < 1.0 else "black"
            
            # Print value centered in the cell
            ax_heat.text(j, i, f"{val:.2f}",
                         ha="center", va="center", color=text_color, 
                         fontsize=8, fontweight='bold')

    plt.tight_layout()
    plots["Visuals/SND_Heatmap"] = wandb.Image(fig_heat)
    plt.close(fig_heat)

    # ==========================================
    # 2. BAR CHART (Pairwise Values)
    # ==========================================
    if len(pairs) > 0:
        # Extract values for the unique pairs (upper triangle)
        pair_values = [snd_matrix[p[0], p[1]] for p in pairs]
        pair_labels = [f"A{p[0]}-A{p[1]}" for p in pairs]
        
        fig_bar, ax_bar = plt.subplots(figsize=(8, 5))
        bars = ax_bar.bar(pair_labels, pair_values, color='teal')
        
        ax_bar.set_title(f'Pairwise Distances - Step {step_count}')
        ax_bar.set_ylabel('Distance')
        ax_bar.set_ylim(0, 3) 
        ax_bar.tick_params(axis='x', rotation=45)
        
        # Add values on top of bars
        ax_bar.bar_label(bars, fmt='%.2f', padding=3)

        plt.tight_layout()
        plots["Visuals/SND_BarChart"] = wandb.Image(fig_bar)
        plt.close(fig_bar)

    # ==========================================
    # 3. NETWORK GRAPH (Topology)
    # ==========================================
    if len(pairs) > 0:
        fig_graph = plt.figure(figsize=(7, 7))
        G = nx.Graph()
        
        # Add edges for unique pairs
        for u, v in pairs:
            dist = snd_matrix[u, v]
            G.add_edge(u, v, weight=dist)

        pos = nx.spring_layout(G, seed=42)
        weights = [G[u][v]['weight'] for u, v in G.edges()]
        
        # Draw Nodes
        nx.draw_networkx_nodes(G, pos, node_size=600, node_color='lightblue')
        nx.draw_networkx_labels(G, pos, font_weight='bold')
        
        # Draw Edges
        edges = nx.draw_networkx_edges(G, pos, 
                               edge_color=weights, 
                               edge_cmap=plt.cm.viridis, 
                               width=2,
                               edge_vmin=0,
                               edge_vmax=3)
        
        # Draw Edge Labels (The distance values on the lines)
        edge_labels = {
            (u, v): f"{d['weight']:.2f}" 
            for u, v, d in G.edges(data=True)
        }
        nx.draw_networkx_edge_labels(
            G, pos, 
            edge_labels=edge_labels, 
            font_color='black', 
            font_size=8,
            font_weight='bold'
        )
        
        plt.colorbar(edges, label='Distance')
        plt.title(f'Interaction Graph - Step {step_count}')
        plt.axis('off')
        plots["Visuals/SND_NetworkGraph"] = wandb.Image(fig_graph)
        plt.close(fig_graph)

    return plots

In [5]:
class SNDVisualizerCallback(Callback):
    """
    A visualization-only callback that computes the SND (Behavioral Distance) matrix
    at evaluation time and logs Heatmap, Bar Chart, and Graph visualizations to WandB.
    """
    def __init__(self):
        super().__init__()
        self.control_group = None
        self.model = None

    def on_setup(self):
        """Auto-detects the agent group and initializes the model wrapper."""
        if not self.experiment.group_policies:
            print("\nWARNING: No group policies found. SND Visualizer disabled.\n")
            return

        # Auto-detect: Simply grab the first available control group
        self.control_group = list(self.experiment.group_policies.keys())[0]
        
        policy = self.experiment.group_policies[self.control_group]
        
        # We assume 'get_het_model' is available in your scope
        self.model = get_het_model(policy)

        if self.model is None:
             print(f"\nWARNING: Could not extract HetModel for group '{self.control_group}'. Visualizer disabled.\n")

    def _get_agent_actions_for_rollout(self, rollout):
        """Helper to run the forward pass and get actions for SND computation."""
        obs = rollout.get((self.control_group, "observation"))
        actions = []
        for i in range(self.model.n_agents):
            temp_td = TensorDict(
                {(self.control_group, "observation"): obs},
                batch_size=obs.shape[:-1]
            )
            # Ensure _forward exists and returns a TensorDict with the output key
            action_td = self.model._forward(temp_td, agent_index=i, compute_estimate=False)
            actions.append(action_td.get(self.model.out_key))
        return actions

    def on_evaluation_end(self, rollouts: List[TensorDict]):
        """Runs at the end of evaluation to compute SND and log plots."""
        if self.model is None:
            return

        logs_to_push = {}
        first_rollout_snd_matrix = None

        with torch.no_grad():
            for i, r in enumerate(rollouts):
                # We only need the matrix from the first rollout for clean visualization
                if i > 0: 
                    break

                agent_actions = self._get_agent_actions_for_rollout(r)
                
                # Compute behavioral distance (Assumed to be available in scope)
                # Must return N x N matrix
                pairwise_distances_tensor = compute_behavioral_distance(agent_actions, just_mean=False)
                
                # If the function returns (Time x N x N), average over Time to get (N x N)
                if pairwise_distances_tensor.ndim > 2:
                    pairwise_distances_tensor = pairwise_distances_tensor.mean(dim=0)

                first_rollout_snd_matrix = pairwise_distances_tensor.cpu().numpy()

        # Generate and Log Visualizations
        if first_rollout_snd_matrix is not None:
            visual_logs = generate_snd_visualizations(
                snd_matrix=first_rollout_snd_matrix, 
                n_agents=self.model.n_agents,
                step_count=self.experiment.n_iters_performed
            )
            logs_to_push.update(visual_logs)
            
            # Update the logger
            self.experiment.logger.log(logs_to_push, step=self.experiment.n_iters_performed)

## Env Setup

This code block compiles all the different section of MARl env together to Run the experiment

---

In [6]:
def setup(task_name):
    benchmarl.models.model_config_registry.update(
        {
            "hetcontrolmlpempirical": HetControlMlpEmpiricalConfig,
        }
    )
    if task_name == "vmas/navigation":
        # Set the render callback for the navigation case study
        VmasTask.render_callback = render_callback

def get_experiment(cfg: DictConfig) -> Experiment:
    hydra_choices = HydraConfig.get().runtime.choices
    task_name = hydra_choices.task
    algorithm_name = hydra_choices.algorithm

    setup(task_name)

    print(f"\nAlgorithm: {algorithm_name}, Task: {task_name}")
    # print("\nLoaded config:\n") # Optional: Commented out to reduce clutter
    # print(OmegaConf.to_yaml(cfg))

    algorithm_config = load_algorithm_config_from_hydra(cfg.algorithm)
    experiment_config = load_experiment_config_from_hydra(cfg.experiment)
    task_config = load_task_config_from_hydra(cfg.task, task_name)
    critic_model_config = load_model_config_from_hydra(cfg.critic_model)
    model_config = load_model_config_from_hydra(cfg.model)

    if isinstance(algorithm_config, (MappoConfig, IppoConfig, MasacConfig, IsacConfig)):
        model_config.probabilistic = True
        model_config.scale_mapping = algorithm_config.scale_mapping
        algorithm_config.scale_mapping = (
            "relu"  # The scaling of std_dev will be done in the model
        )
    else:
        model_config.probabilistic = False

    experiment = Experiment(
        task=task_config,
        algorithm_config=algorithm_config,
        model_config=model_config,
        critic_model_config=critic_model_config,
        seed=cfg.seed,
        config=experiment_config,
        callbacks=[
            SndCallback(),
            SNDVisualizerCallback(),
            # ExtremumSeekingController(
            #             control_group="agents",
            #             # initial_snd=0.0,
            #             dither_magnitude=0.1,
            #             dither_frequency_rad_s=1.0,
            #             integral_gain=-0.01,
            #             high_pass_cutoff_rad_s=1.0,
            #             low_pass_cutoff_rad_s=1.0,
            #             sampling_period=1.0
            # ),
            # TrajectorySNDLoggerCallback(control_group="agents"),
            NormLoggerCallback(),
            ActionSpaceLoss(
                use_action_loss=cfg.use_action_loss, action_loss_lr=cfg.action_loss_lr
            ),
        ]
        + (
            [
                TagCurriculum(
                    cfg.simple_tag_freeze_policy_after_frames,
                    cfg.simple_tag_freeze_policy,
                )
            ]
            if task_name == "vmas/simple_tag"
            else []
        ),
    )
    return experiment

## Training Code

Trains the model for 200 episodes. 

In [None]:

ABS_CONFIG_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf"
CONFIG_NAME = "navigation_ippo"  # Make sure 'navigation_ippo.yaml' exists in the folder above!
SAVE_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/model_checkpoints/navigation_ippo_esc/"

save_interval = 600000
desired_snd = 1.0
max_frame = 6000000

if not os.path.exists(SAVE_PATH):
    print(f"Creating missing directory: {SAVE_PATH}")
    os.makedirs(SAVE_PATH, exist_ok=True)

GlobalHydra.instance().clear()

sys.argv = [
    "dummy.py",
    f"model.desired_snd={desired_snd}",
    f"experiment.max_n_frames={max_frame}",
    f"experiment.checkpoint_interval={save_interval}",
    f"experiment.save_folder={SAVE_PATH}", # FIXED: Removed space after '='
]

# 3. Define the Hydra wrapper
@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def hydra_experiment(cfg: DictConfig) -> None:
    print(f"Config loaded from: {ABS_CONFIG_PATH}")
    if wandb.run is not None:
        print("Finishing previous WandB run...")
        wandb.finish()

    
    print(f"Running with SND: {cfg.model.desired_snd}")
    
    experiment = get_experiment(cfg=cfg)
    experiment.run()
    wandb.finish()

# 4. Execute safely
if __name__ == "__main__":
    try:
        hydra_experiment()
    except SystemExit:
        print("Experiment finished successfully.")
    except Exception as e:
        print(f"An error occurred: {e}")

Config loaded from: /home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf
Running with SND: 1.0

Algorithm: ippo, Task: vmas/navigation






mean return = -0.2842506170272827:   1%|          | 1/100 [00:39<1:05:36, 39.76s/it]

mean return = -0.12061435729265213:   2%|▏         | 2/100 [01:19<1:05:06, 39.86s/it]

mean return = 0.06278912723064423:   3%|▎         | 3/100 [01:57<1:02:38, 38.75s/it] 

mean return = 0.3085945248603821:   4%|▍         | 4/100 [02:36<1:02:27, 39.03s/it] 

mean return = -1.3374804258346558:   5%|▌         | 5/100 [03:18<1:03:16, 39.96s/it]

mean return = -1.2115898132324219:   6%|▌         | 6/100 [03:57<1:02:17, 39.76s/it]

mean return = -0.7913774847984314:   7%|▋         | 7/100 [04:34<1:00:11, 38.83s/it]

mean return = -0.37020793557167053:   8%|▊         | 8/100 [05:12<59:00, 38.48s/it] 

mean return = -0.3419821858406067:   9%|▉         | 9/100 [05:52<59:06, 38.97s/it] 

mean return = -0.7258304357528687:  10%|█         | 10/100 [06:32<58:53, 39.26s/it]

mean return = -0.356292188167572:  11%|█         | 11/100 [07:12<58:55, 39.73s/it] 

mean return = -0.005251131020486355:  12%|█▏        | 12/100 [07:57<1:00:12, 41.05s/it]

mean return = -0.04384047910571098:  13%|█▎        | 13/100 [08:34<58:07, 40.08s/it]   

mean return = 0.36396729946136475:  14%|█▍        | 14/100 [09:13<56:46, 39.61s/it] 

mean return = -0.9160506129264832:  15%|█▌        | 15/100 [09:54<56:41, 40.02s/it]

mean return = 0.2671414911746979:  16%|█▌        | 16/100 [10:34<55:57, 39.97s/it] 

mean return = -0.20932984352111816:  17%|█▋        | 17/100 [11:13<54:49, 39.64s/it]

mean return = 0.27600908279418945:  18%|█▊        | 18/100 [11:55<55:14, 40.42s/it] 

mean return = 0.2526856064796448:  19%|█▉        | 19/100 [12:36<54:54, 40.67s/it] 

mean return = -1.683846116065979:  20%|██        | 20/100 [13:18<54:46, 41.09s/it]

mean return = -1.610228180885315:  21%|██        | 21/100 [13:59<53:58, 41.00s/it]

mean return = -1.3977614641189575:  22%|██▏       | 22/100 [14:38<52:43, 40.56s/it]

mean return = -1.212141990661621:  23%|██▎       | 23/100 [15:21<52:37, 41.01s/it] 

mean return = -0.9942213296890259:  24%|██▍       | 24/100 [16:03<52:21, 41.34s/it]

mean return = -0.8850975632667542:  25%|██▌       | 25/100 [16:45<52:03, 41.64s/it]

mean return = -0.7868425250053406:  26%|██▌       | 26/100 [17:27<51:33, 41.81s/it]

mean return = -0.7432085871696472:  27%|██▋       | 27/100 [18:07<50:06, 41.19s/it]

mean return = -0.7048777937889099:  28%|██▊       | 28/100 [18:45<48:25, 40.36s/it]

mean return = -0.7007088661193848:  29%|██▉       | 29/100 [19:25<47:22, 40.04s/it]

mean return = -0.647488534450531:  30%|███       | 30/100 [20:06<47:06, 40.37s/it] 

mean return = -0.6305897235870361:  31%|███       | 31/100 [20:48<47:01, 40.88s/it]

mean return = -0.6276352405548096:  32%|███▏      | 32/100 [21:26<45:18, 39.98s/it]

## Eval Run
---
Single Step eval. from the check point

In [8]:
# CONFIGURATION
import shutil


ABS_CONFIG_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/AD2C/ControllingBehavioralDiversity/het_control/conf"
CONFIG_NAME = "navigation_ippo"



In [9]:
import time

# Your checkpoint path
CHECKPOINT_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/snd00.pt"

unique_id = f"AD2C_Eval_{int(time.time())}"

# FORCE WandB to use this ID, ignoring whatever is in your YAML/Hydra config
os.environ["WANDB_RUN_ID"] = unique_id
os.environ["WANDB_NAME"] = unique_id


# ==========================================
# EVALUATION LOGIC
# ==========================================
GlobalHydra.instance().clear()

sys.argv = [
    "eval_script.py",
    f"experiment.restore_file={CHECKPOINT_PATH}",
    "experiment.evaluation_episodes=10",
    "experiment.render=True",
    "experiment.evaluation_deterministic_actions=True",
    "experiment.save_folder=null",
    "model.desired_snd=0.3" ,
    f"+experiment.name={unique_id}"
]

@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def eval_experiment(cfg: DictConfig) -> None:
    print(f"Loading model from: {cfg.experiment.restore_file}")
    print(f"Initializing model with dummy SND: {cfg.model.desired_snd}")
    
    experiment = get_experiment(cfg=cfg)
    
    print("Model loaded. Starting Evaluation...")
    
    experiment._evaluation_loop()
    
    print("Evaluation Complete.")
    experiment.close()
    
if __name__ == "__main__":
    try:
        eval_experiment()
    except SystemExit:
        pass
    except Exception as e:
        print(f"An error occurred: {e}")

Loading model from: /home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/snd00.pt
Initializing model with dummy SND: 0.3

Algorithm: ippo, Task: vmas/navigation


Error executing job with overrides: ['experiment.restore_file=/home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/snd00.pt', 'experiment.evaluation_episodes=10', 'experiment.render=True', 'experiment.evaluation_deterministic_actions=True', 'experiment.save_folder=null', 'model.desired_snd=0.3', '+experiment.name=AD2C_Eval_1764198683']
Traceback (most recent call last):
  File "/home/grad/doc/2027/spatel2/miniconda3/envs/ad2c/lib/python3.9/asyncio/locks.py", line 226, in wait
    await fut
asyncio.exceptions.CancelledError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/grad/doc/2027/spatel2/miniconda3/envs/ad2c/lib/python3.9/asyncio/tasks.py", line 490, in wait_for
    return fut.result()
asyncio.exceptions.CancelledError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/grad/doc/2027/spatel2/miniconda3/envs/ad2c/lib/python3.9/site-packages/wand

In [10]:
# !wandb login --relogin b3c2b62655aa322e8ab1d1ab07287749ce03ff8

## Testing Env with different task config
---

Changing the task config to check the adaptibility of learned policy. 

In [11]:
CHECKPOINT_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/model2.pt"

# ==========================================
# EVALUATION LOGIC
# ==========================================
GlobalHydra.instance().clear()

sys.argv = [
    "eval_script.py",
    f"experiment.restore_file={CHECKPOINT_PATH}",
    "experiment.evaluation_episodes=10",
    "experiment.render=True",
    "experiment.evaluation_deterministic_actions=True",
    "experiment.save_folder=null",
    "model.desired_snd=0.3",
    
    "task.agents_with_same_goal=1" 
]

@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def eval_experiment(cfg: DictConfig) -> None:
    print(f"Loading model from: {cfg.experiment.restore_file}")
    
    print(f"Evaluation Setup: {cfg.task.n_agents} Agents, {cfg.task.agents_with_same_goal} per goal.")
    
    experiment = get_experiment(cfg=cfg)
    print("Starting Evaluation...")
    experiment._evaluation_loop()
    print("Evaluation Complete.")
    experiment.close()

if __name__ == "__main__":
    try:
        eval_experiment()
    except SystemExit:
        pass
    except Exception as e:
        print(f"An error occurred: {e}")

Loading model from: /home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/model2.pt
Evaluation Setup: 3 Agents, 1 per goal.

Algorithm: ippo, Task: vmas/navigation




Error executing job with overrides: ['experiment.restore_file=/home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/model2.pt', 'experiment.evaluation_episodes=10', 'experiment.render=True', 'experiment.evaluation_deterministic_actions=True', 'experiment.save_folder=null', 'model.desired_snd=0.3', 'task.agents_with_same_goal=1']
Traceback (most recent call last):
  File "/home/grad/doc/2027/spatel2/miniconda3/envs/ad2c/lib/python3.9/asyncio/locks.py", line 226, in wait
    await fut
asyncio.exceptions.CancelledError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/grad/doc/2027/spatel2/miniconda3/envs/ad2c/lib/python3.9/asyncio/tasks.py", line 490, in wait_for
    return fut.result()
asyncio.exceptions.CancelledError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/grad/doc/2027/spatel2/miniconda3/envs/ad2c/lib/python3.9/site-packages/wandb/sdk/ma

## Transfer Lerning
---

Finetune the policy to make it work with other task that it could not. 

In [12]:
CHECKPOINT_PATH = "/home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/model2.pt"

# ==========================================
# 2. RUN LOGIC
# ==========================================
new_max_frames = 10000000 
desired_snd = 1.0

GlobalHydra.instance().clear()

sys.argv = [
    "run_script.py",
    f"model.desired_snd={desired_snd}",
    f"experiment.restore_file={CHECKPOINT_PATH}",
    f"experiment.max_n_frames={new_max_frames}",
    
    # --- TASK CONFIGURATION ---
    "task.agents_with_same_goal=1", 
    "experiment.save_folder=null"
]

@hydra.main(version_base=None, config_path=ABS_CONFIG_PATH, config_name=CONFIG_NAME)
def hydra_experiment(cfg: DictConfig) -> None:
    print(f"Resuming with SND: {cfg.model.desired_snd}")
    print(f"Agents sharing a goal: {cfg.task.agents_with_same_goal}")
    
    experiment = get_experiment(cfg=cfg)
    experiment.run()

if __name__ == "__main__":
    try:
        hydra_experiment()
    except SystemExit:
        print("Experiment finished successfully.")
    except Exception as e:
        print(f"An error occurred: {e}")

Resuming with SND: 1.0
Agents sharing a goal: 1

Algorithm: ippo, Task: vmas/navigation




Experiment finished successfully.


Error executing job with overrides: ['model.desired_snd=1.0', 'experiment.restore_file=/home/grad/doc/2027/spatel2/AD2C_testBed/saved_models/model2.pt', 'experiment.max_n_frames=10000000', 'task.agents_with_same_goal=1', 'experiment.save_folder=null']
Traceback (most recent call last):
  File "/home/grad/doc/2027/spatel2/miniconda3/envs/ad2c/lib/python3.9/asyncio/locks.py", line 226, in wait
    await fut
asyncio.exceptions.CancelledError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/grad/doc/2027/spatel2/miniconda3/envs/ad2c/lib/python3.9/asyncio/tasks.py", line 490, in wait_for
    return fut.result()
asyncio.exceptions.CancelledError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/grad/doc/2027/spatel2/miniconda3/envs/ad2c/lib/python3.9/site-packages/wandb/sdk/mailbox/response_handle.py", line 82, in wait_async
    await asyncio.wait_for(se