In [1]:
import os, sys
import pandas as pd
import torch


dir2 = os.path.abspath('')
dir1 = os.path.dirname(dir2)
dir0 = os.path.dirname(dir1)

if dir1 not in sys.path: sys.path.append(dir0)

from src.config import PPOConfig, EmbeddingStrategy
from src.experiments import ExperimentSuite
from src.utils import ExperimentUtils


def rewards_of_suite(suite, agents):
    all_df = pd.DataFrame()
    for i in agents:
        change_to_config = {"n_agents": i}
        df = suite.rollout_all_get_rewards(change_to_config)

        # Rename the reward column to a common name and add the number of agents
        reward_col = df.columns[1]  # Assumes second column is the reward
        df = df.rename(columns={reward_col: "rewards"})
        df["agents"] = i

        all_df = pd.concat([all_df, df[["experiment", "agents", "rewards"]]], ignore_index=True)

    all_df = all_df.rename(columns={"experiment": "strategy"})
    return all_df


def run_generalizability(strategies, file_name, training_agents, testing_agents, iters):
    my_device = torch.device("cpu")
    df_all_strategies_rollout = pd.DataFrame()
    df_all_strategies_train = pd.DataFrame()

    url = "saved_experiments" + "/" + file_name
    url_rollout = "saved_experiments" + "/" + file_name + '_rollout'


    for strategy in strategies:
        base_config_balance_5_agents = PPOConfig(
            scenario_name='balance', max_agents=max(testing_agents), use_strategy_defaults=True, n_iters=iters
        )

        param_grid = {
            "strategy": [strategy],
        }
        suite = ExperimentSuite(base_config=base_config_balance_5_agents, param_grid=param_grid, name="test_all", device=my_device)

        for i in training_agents:
            change_to_config = {"n_agents": i}
            suite.create_and_run_experiments_with_updated_config(change_to_config, create_new=False, k=10)

        # test on testing_agents
        df = rewards_of_suite(suite, testing_agents)
        df_all_strategies_rollout = df_all_strategies_rollout.append(df)

        suite_utils = ExperimentUtils(experiment_suite=suite)
        df_all_strategies_train = df_all_strategies_train.append(suite_utils.df)

        path_to_strategy_trained = url + str(strategy) + '.csv'
        path_to_strategy_rollout = url_rollout + str(strategy) + '.csv'
        df_all_strategies_rollout.to_csv(path_to_strategy_rollout, index=False)
        df_all_strategies_train.to_csv(path_to_strategy_trained, index=False)

In [2]:
file_name='1_high_variance'
high_variance = [4, 6, 8, 10, 12, 14, 16, 18]
n_iters = int(80/len(high_variance))
strategies = [
    EmbeddingStrategy.CONCAT,
    EmbeddingStrategy.MLP,
    EmbeddingStrategy.MLP_LOCAL,
    EmbeddingStrategy.MLP_GLOBAL,
    EmbeddingStrategy.GRAPH_SAGE,
    EmbeddingStrategy.GRAPH_GAT,
    EmbeddingStrategy.GRAPH_GAT_v2,
    EmbeddingStrategy.SET_TRANSFORMER_INV,
    EmbeddingStrategy.SAB_TRANSFORMER,
    EmbeddingStrategy.ISAB_TRANSFORMER,
]

run_generalizability(strategies, file_name=file_name, training_agents=high_variance, testing_agents=[5, 10, 20], iters=n_iters)

In [3]:
file_name='1_low_variance_remaining'
low_variance = [6, 10, 14, 18]
n_iters = int(80/len(low_variance))
strategies = [
    EmbeddingStrategy.CONCAT,
    EmbeddingStrategy.MLP,
    EmbeddingStrategy.MLP_LOCAL,
    EmbeddingStrategy.MLP_GLOBAL,
    EmbeddingStrategy.GRAPH_SAGE,
    EmbeddingStrategy.GRAPH_GAT,
    EmbeddingStrategy.GRAPH_GAT_v2,
    EmbeddingStrategy.SET_TRANSFORMER_INV,
    EmbeddingStrategy.SAB_TRANSFORMER,
    EmbeddingStrategy.ISAB_TRANSFORMER,
]

run_generalizability(strategies, file_name=file_name, training_agents=low_variance, testing_agents=[5, 10, 20], iters=n_iters)

2025-08-13 15:30:00,394 [torchrl][INFO] check_env_specs succeeded!
2025-08-13 15:30:00,594 [torchrl][INFO] check_env_specs succeeded!
episode_reward_mean = 36.83730697631836: 100%|██████████| 20/20 [00:40<00:00,  2.01s/it]    
2025-08-13 15:30:40,803 [torchrl][INFO] Training time: 25.59 seconds
2025-08-13 15:30:40,811 [torchrl][INFO] macs: 57.87 MMac  Params: 35.47 k
2025-08-13 15:30:41,604 [torchrl][INFO] check_env_specs succeeded!
2025-08-13 15:30:41,643 [torchrl][INFO] check_env_specs succeeded!
episode_reward_mean = 51.45933151245117: 100%|██████████| 20/20 [00:40<00:00,  2.01s/it]  
2025-08-13 15:31:21,938 [torchrl][INFO] Training time: 25.54 seconds
2025-08-13 15:31:21,945 [torchrl][INFO] macs: 57.87 MMac  Params: 35.47 k
2025-08-13 15:31:22,741 [torchrl][INFO] check_env_specs succeeded!
2025-08-13 15:31:22,777 [torchrl][INFO] check_env_specs succeeded!
episode_reward_mean = 37.926387786865234: 100%|██████████| 20/20 [00:41<00:00,  2.07s/it]
2025-08-13 15:32:04,215 [torchrl][INFO

In [4]:
file_name='1_no_variance'
no_variance = [10]
n_iters = int(80/len(no_variance))
strategies = [
    EmbeddingStrategy.CONCAT,
    EmbeddingStrategy.MLP,
    EmbeddingStrategy.MLP_LOCAL,
    EmbeddingStrategy.MLP_GLOBAL,
    EmbeddingStrategy.GRAPH_SAGE,
    EmbeddingStrategy.GRAPH_GAT,
    EmbeddingStrategy.GRAPH_GAT_v2,
    EmbeddingStrategy.SET_TRANSFORMER_INV,
    EmbeddingStrategy.SAB_TRANSFORMER,
    EmbeddingStrategy.ISAB_TRANSFORMER,
]

run_generalizability(strategies, file_name=file_name, training_agents=no_variance, testing_agents=[5, 10, 20], iters=n_iters)

2025-08-13 19:10:23,982 [torchrl][INFO] check_env_specs succeeded!
2025-08-13 19:10:24,043 [torchrl][INFO] check_env_specs succeeded!
episode_reward_mean = 94.7380142211914: 100%|██████████| 80/80 [03:07<00:00,  2.34s/it]  
2025-08-13 19:13:31,103 [torchrl][INFO] Training time: 80.71 seconds
2025-08-13 19:13:31,106 [torchrl][INFO] macs: 58.12 MMac  Params: 14.53 k
2025-08-13 19:13:32,556 [torchrl][INFO] check_env_specs succeeded!
2025-08-13 19:13:32,617 [torchrl][INFO] check_env_specs succeeded!
episode_reward_mean = 95.87168884277344: 100%|██████████| 80/80 [03:04<00:00,  2.30s/it] 
2025-08-13 19:16:36,626 [torchrl][INFO] Training time: 78.19 seconds
2025-08-13 19:16:36,630 [torchrl][INFO] macs: 58.12 MMac  Params: 14.53 k
2025-08-13 19:16:38,071 [torchrl][INFO] check_env_specs succeeded!
2025-08-13 19:16:38,133 [torchrl][INFO] check_env_specs succeeded!
episode_reward_mean = 105.7174072265625: 100%|██████████| 80/80 [03:14<00:00,  2.43s/it]  
2025-08-13 19:19:52,647 [torchrl][INFO] T