In [None]:
import os, sys
import torch

dir2 = os.path.abspath('')
dir1 = os.path.dirname(dir2)
dir0 = os.path.dirname(dir1)  # One level above dir1


if dir1 not in sys.path: sys.path.append(dir0)

from src.config import PPOConfig, EmbeddingStrategy
from src.experiments import ExperimentSuite
from src.utils import ExperimentUtils

cells = [16, 32, 48, 64]

def tune_width_wrt_depth(strategy, file_name, depth=1, list_of_cells=cells):
    url = "saved_experiments" + "/" + file_name
    base_config_balance_5_agents = PPOConfig(
        n_agents=5, scenario_name='balance', decentralized_execution=True,
        max_steps=200, pooling_method='mean', strategy=strategy, embedding_depth=depth, mlp_core_num_cells=256
    )

    param_grid = {
        "embedding_num_cells": list_of_cells
    }
    my_device = torch.device("cpu")
    suite = ExperimentSuite(base_config=base_config_balance_5_agents, param_grid=param_grid, name="test_all", device=my_device)
    suite.run_all_confidence(k=10)

    suite_utils = ExperimentUtils(path=url, experiment_suite=suite)
    suite_utils.save_df_to_file()
    suite_utils.plot_experiment_suite_df()
    print(suite_utils.create_table_with_confidence())


In [None]:
tune_width_wrt_depth(EmbeddingStrategy.MLP, file_name='mlp_depth_none.csv', depth=None, list_of_cells=[16]) # Note: if depth=None, embedding_num_cells is irrelevant

In [None]:
tune_width_wrt_depth(EmbeddingStrategy.MLP, file_name='mlp_depth_1.csv', depth=1)

In [None]:
tune_width_wrt_depth(EmbeddingStrategy.MLP, file_name='mlp_depth_2.csv', depth=2)

In [None]:
tune_width_wrt_depth(EmbeddingStrategy.MLP_LOCAL, file_name='mlp_local_depth_none.csv', depth=None, list_of_cells=[16])

In [None]:
tune_width_wrt_depth(EmbeddingStrategy.MLP_LOCAL, file_name='mlp_local_depth_1.csv', depth=1)

In [None]:
tune_width_wrt_depth(EmbeddingStrategy.MLP_LOCAL, file_name='mlp_local_depth_2.csv', depth=2)

In [None]:
tune_width_wrt_depth(EmbeddingStrategy.MLP_GLOBAL, file_name='mlp_global_depth_none.csv', depth=None, list_of_cells=[16])

In [None]:
tune_width_wrt_depth(EmbeddingStrategy.MLP_GLOBAL, file_name='mlp_global_depth_1.csv', depth=1)

In [None]:
tune_width_wrt_depth(EmbeddingStrategy.MLP_GLOBAL, file_name='mlp_global_depth_2.csv', depth=2)