In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
from use_case.baseline import * 
from tests.eval import *

payoff_i = np.random.uniform(-10, 10, (10, 10))
payoff_j = np.transpose(payoff_i).copy()

# Initialize environment
N_ACTIONS = payoff_i.shape[0]
N_AGENTS = 1000
env = BaselineEnvironment(N_AGENTS, payoff_i, payoff_j, total_games = 5)

# Actual Run

In [3]:
from models.model import *
from models.trainer import *


In [4]:
# Configure the network here
parameters = ParameterSettings(
    n_agents = N_AGENTS,
    d_action = N_ACTIONS, 
    d_obs = env.obs_size, 
    d_traits = 1,
    d_beliefs = 1
)
parameters.device = "cuda" if torch.cuda.is_available() else "cpu"

model = Model(parameters)

In [5]:
equilibriua = find_pure_equilibria(payoff_i, payoff_j)

for eq in equilibriua:
    x, y = eq 
    a = (y[0] + y[1]) / 2

    print(x, a)

(7, 7) 5.440956222987843
(8, 8) 5.627273900747461


In [6]:
evaluate_policy(model, env, 10)

Average Return: 0.2541386602529051
Total returns: 2.541386602529051
Action Dist, (array([  0,  61,   0,   1,   0, 781, 129,  14,  14,   0]), array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]))


In [None]:
# Setup the training loop
training_parameters = TrainingParameters(
    actor_training_loops = 1000, 
    outer_loops = 100,
    learning_rate= 0.01,
    experience_buffer_size = 3
)

train_model(model, env, training_parameters)
        

Epoch 0


Actor-Critic Loop: 100%|██████████| 1000/1000 [00:47<00:00, 21.00it/s]


Average Return: -0.08284865473789528
Total returns: -0.8284865473789528
Action Dist, (array([  4,  60,   7,   1,   1, 688, 192,  31,  10,   6]), array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]))
Epoch 1


Actor-Critic Loop: 100%|██████████| 1000/1000 [02:08<00:00,  7.77it/s]


Average Return: 0.15499366486948318
Total returns: 1.5499366486948318
Action Dist, (array([  0,  74,   2,   4,   3, 774, 130,   7,   6,   0]), array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]))
Epoch 2


Actor-Critic Loop: 100%|██████████| 1000/1000 [02:26<00:00,  6.84it/s]


Average Return: -0.206109175271615
Total returns: -2.06109175271615
Action Dist, (array([  0, 105,   9,   2,   2, 678, 187,   8,   6,   3]), array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]))
Epoch 3


Actor-Critic Loop: 100%|██████████| 1000/1000 [02:25<00:00,  6.87it/s]


Average Return: -0.2137664176346647
Total returns: -2.137664176346647
Action Dist, (array([  5, 105,   8,   2,   1, 678, 174,  16,   9,   2]), array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]))
Epoch 4


Actor-Critic Loop: 100%|██████████| 1000/1000 [02:23<00:00,  6.96it/s]


Average Return: -0.18290721184993255
Total returns: -1.8290721184993255
Action Dist, (array([  2,  93,  12,   2,   3, 686, 180,  16,   4,   2]), array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]))
Epoch 5


Actor-Critic Loop: 100%|██████████| 1000/1000 [02:27<00:00,  6.79it/s]


Average Return: -0.12500804732332446
Total returns: -1.2500804732332447
Action Dist, (array([  3,  89,  10,   0,   4, 700, 161,  22,  11,   0]), array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]))
Epoch 6


Actor-Critic Loop: 100%|██████████| 1000/1000 [02:23<00:00,  6.95it/s]


Average Return: -0.13932608993522883
Total returns: -1.3932608993522881
Action Dist, (array([  3,  91,   9,   2,   1, 684, 182,  14,   8,   6]), array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]))
Epoch 7


Actor-Critic Loop: 100%|██████████| 1000/1000 [02:27<00:00,  6.77it/s]


Average Return: -0.21914269850950427
Total returns: -2.1914269850950427
Action Dist, (array([  5, 104,  11,   2,   1, 673, 172,  21,   7,   4]), array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]))
Epoch 8


Actor-Critic Loop: 100%|██████████| 1000/1000 [02:19<00:00,  7.19it/s]


Average Return: -0.5692074130982628
Total returns: -5.692074130982628
Action Dist, (array([ 53,  92,  22,   2,   3, 546, 225,  38,   8,  11]), array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]))
Epoch 9


Actor-Critic Loop: 100%|██████████| 1000/1000 [02:24<00:00,  6.94it/s]


Average Return: -0.54152678936074
Total returns: -5.4152678936074
Action Dist, (array([ 28,  99,  17,   2,   6, 562, 233,  40,   7,   6]), array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]))
Epoch 10


Actor-Critic Loop: 100%|██████████| 1000/1000 [02:26<00:00,  6.81it/s]


Average Return: -0.16022747798293133
Total returns: -1.6022747798293133
Action Dist, (array([  2,  86,  10,   4,   0, 695, 182,  16,   4,   1]), array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]))
Epoch 11


Actor-Critic Loop: 100%|██████████| 1000/1000 [02:13<00:00,  7.48it/s]


Average Return: -0.1828313073797905
Total returns: -1.828313073797905
Action Dist, (array([  4,  96,  12,   0,   1, 719, 144,  17,   5,   2]), array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]))
Epoch 12


Actor-Critic Loop: 100%|██████████| 1000/1000 [02:16<00:00,  7.33it/s]


Average Return: -0.4905372176074003
Total returns: -4.905372176074003
Action Dist, (array([ 14, 117,  28,   1,   2, 598, 197,  26,   6,  11]), array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]))
Epoch 13


Actor-Critic Loop: 100%|██████████| 1000/1000 [02:14<00:00,  7.41it/s]


Average Return: -0.16585351569420576
Total returns: -1.6585351569420574
Action Dist, (array([  0,  91,   8,   2,   3, 700, 172,  20,   3,   1]), array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]))
Epoch 14


Actor-Critic Loop:  81%|████████  | 806/1000 [00:41<00:10, 18.81it/s]