In [3]:
import traceback
import os, sys
import pandas as pd
import torch
import random


dir2 = os.path.abspath('')
dir1 = os.path.dirname(dir2)
dir0 = os.path.dirname(dir1)

if dir1 not in sys.path: sys.path.append(dir0)

from src.config import PPOConfig, EmbeddingStrategy
from src.experiments import ExperimentSuite
from src.utils import ExperimentUtils


def regular_train_high(scenario, strategy, testing_agent, steps, all_iters, my_device, url_train):
    base_config_balance_5_agents = PPOConfig(
        scenario_name=scenario, max_agents=testing_agent, use_strategy_defaults=True, max_steps=steps, n_agents=testing_agent, n_iters=all_iters
    )

    param_grid = {
        "strategy": [strategy],
    }
    suite = ExperimentSuite(base_config=base_config_balance_5_agents, param_grid=param_grid, name="test_all", device=my_device)
    suite.run_all_confidence(k=10, profile_once=False)

    path_to_strategy = url_train + str(strategy) + '.csv'
    suite_utils = ExperimentUtils(experiment_suite=suite, path=path_to_strategy)
    suite_utils.save_df_to_file()

def run_mostly_low(strategies, scenario, file_name, training_agent, testing_agent, all_iters=80, steps=200, percentage=0.9, do_high=True):
    my_device = torch.device("cpu")

    assert training_agent <= testing_agent

    url = "saved_experiments" + "/" + file_name
    url_train_high = "saved_experiments" + "/" + file_name + '_just_high'

    # In the future it could be handled better with simple normalization in the MLP.
    for strategy in strategies:
        base_config_balance_5_agents = PPOConfig(
            scenario_name=scenario, max_agents=testing_agent, use_strategy_defaults=True, max_steps=steps, n_agents=training_agent, n_iters=int(percentage*all_iters)
        )

        param_grid = {
            "strategy": [strategy],
        }
        # Train for 90% of iterations on 'training agents'
        suite = ExperimentSuite(base_config=base_config_balance_5_agents, param_grid=param_grid, name="test_all", device=my_device)
        suite.run_all_confidence(k=10, profile_once=False, update=True)
        
        # Train for 10% of iterations on 'testing agents'
        change_to_config = {"n_agents": testing_agent, "n_iters": int((1-percentage)*all_iters)}
        suite.create_and_run_experiments_with_updated_config(change_to_config, create_new=False, k=10)

        path_to_strategy_trained = url + str(strategy) + '.csv'
        suite_utils = ExperimentUtils(experiment_suite=suite, path=path_to_strategy_trained)
        suite_utils.save_df_to_file()
        
        if do_high:
            # Train 100% on testing agents
            regular_train_high(scenario, strategy, testing_agent, steps, all_iters, my_device, url_train_high)

In [4]:
# file_name='2_balance_mostly_low'
# training_agent = 5
# strategies = [
#     EmbeddingStrategy.CONCAT,
#     EmbeddingStrategy.MLP,
#     EmbeddingStrategy.MLP_LOCAL,
#     EmbeddingStrategy.MLP_GLOBAL,
#     EmbeddingStrategy.GRAPH_SAGE,
#     EmbeddingStrategy.GRAPH_GAT,
#     EmbeddingStrategy.GRAPH_GAT_v2,
#     EmbeddingStrategy.SET_TRANSFORMER_INV,
#     EmbeddingStrategy.SAB_TRANSFORMER,
#     EmbeddingStrategy.ISAB_TRANSFORMER
# ]
# 
# run_mostly_low(strategies, scenario='balance', file_name=file_name, training_agent=training_agent, testing_agent=20, all_iters=80)

In [5]:
file_name='2_balance_75_low'
training_agent = 5
strategies = [
    # EmbeddingStrategy.CONCAT,
    # EmbeddingStrategy.MLP,
    # EmbeddingStrategy.MLP_LOCAL,
    # EmbeddingStrategy.MLP_GLOBAL,
    # EmbeddingStrategy.GRAPH_SAGE,
    # EmbeddingStrategy.GRAPH_GAT,
    EmbeddingStrategy.GRAPH_GAT_v2,
    EmbeddingStrategy.SET_TRANSFORMER_INV,
    EmbeddingStrategy.SAB_TRANSFORMER,
    EmbeddingStrategy.ISAB_TRANSFORMER
]

run_mostly_low(strategies, scenario='balance', file_name=file_name, training_agent=training_agent, testing_agent=20, all_iters=80, percentage=0.75, do_high=False)

2025-08-20 20:09:11,428 [torchrl][INFO] check_env_specs succeeded!
2025-08-20 20:09:11,482 [torchrl][INFO] check_env_specs succeeded!

episode_reward_mean = 0:   0%|          | 0/60 [00:00<?, ?it/s][A
episode_reward_mean = -9.69063663482666:   2%|▏         | 1/60 [00:01<01:51,  1.88s/it][A
episode_reward_mean = -9.85783863067627:   3%|▎         | 2/60 [00:03<01:40,  1.74s/it][A
episode_reward_mean = -8.959911346435547:   5%|▌         | 3/60 [00:05<01:39,  1.74s/it][A
episode_reward_mean = -10.561514854431152:   7%|▋         | 4/60 [00:07<01:39,  1.77s/it][A
episode_reward_mean = -8.402442932128906:   8%|▊         | 5/60 [00:08<01:34,  1.72s/it] [A
episode_reward_mean = -5.497735977172852:  10%|█         | 6/60 [00:10<01:30,  1.68s/it][A
episode_reward_mean = -4.109395503997803:  12%|█▏        | 7/60 [00:12<01:29,  1.68s/it][A
episode_reward_mean = -1.1521426439285278:  13%|█▎        | 8/60 [00:13<01:26,  1.66s/it][A
episode_reward_mean = -1.9453825950622559:  15%|█▌        | 9

In [6]:
# file_name='2_navigation_mostly_low'
# training_agent = 5
# strategies = [
#     EmbeddingStrategy.CONCAT,
#     EmbeddingStrategy.MLP,
#     EmbeddingStrategy.MLP_LOCAL,
#     EmbeddingStrategy.MLP_GLOBAL,
#     EmbeddingStrategy.GRAPH_SAGE,
#     EmbeddingStrategy.GRAPH_GAT,
#     EmbeddingStrategy.GRAPH_GAT_v2,
#     EmbeddingStrategy.SET_TRANSFORMER_INV,
#     EmbeddingStrategy.SAB_TRANSFORMER,
#     EmbeddingStrategy.ISAB_TRANSFORMER
# ]
#
# run_mostly_low(strategies, scenario='navigation', file_name=file_name, training_agent=training_agent, testing_agent=20, all_iters=80, steps=100)