### Set Up the Environment

In [8]:
import flax
import hydra # it is just there to read the config files for us, hydra has some logic, please read its tutorial
from omegaconf import DictConfig, OmegaConf # essentially hydra things
import jax.numpy as jp 
import jax.random as rax
from jax import config # in order to be able to disable jit so we can debug, or to debug nans (not always on for performance slowdown it causes)
# Update JAX configuration
config.update('jax_disable_jit', True)
config.update("jax_debug_nans", True)

### Load the hp with hydra

In [9]:
# Manually loading the configuration
config_path = "../conf/coin_conf/coin_config.yaml"  # Adjust path as necessary
cfg = OmegaConf.load(config_path)
# Work with the configuration
hp = OmegaConf.to_container(cfg.hp, resolve=True)  # Converts cfg to a Python dict
hp = flax.core.FrozenDict(hp)
print(f'Hyperparameters: {hp}')

Hyperparameters: FrozenDict({
    seed: 45,
    reward_discount: 0.96,
    batch_size: 128,
    agent_0: 'loqa',
    agent_1: 'loqa',
    eval_every: 100,
    op_softmax_temp: 1.0,
    game: {
        height: 3,
        width: 3,
        gnumactions: 4,
        game_length: 50,
    },
    save_dir: './experiments',
    save_every: 100,
})


# Initialization and setting up the environment

In [ ]:
from three_player_coin_game import ThreePlayerCoinGame

assert hp['just_self_play'] == False, 'This notebook does not support self play, because it is complex, and this needs to be a simple example'
just_self_play = False

dummy_env, _ = ThreePlayerCoinGame.init(
    rng=rax.PRNGKey(hp['seed']),
    **coin_game_params(hp),
)
dummy_episode = make_zero_episode(trace_length=hp['game']['game_length'], coin_game=dummy_env)


state = dict()
state['rng'] = rax.PRNGKey(hp['seed'])
agent_module = GRUCoinAgent(hidden_size_actor=hp['actor']['hidden_size'],
                            hidden_size_qvalue=hp['qvalue']['hidden_size'],
                            layers_before_gru_actor=hp['actor']['layers_before_gru'],
                            layers_before_gru_qvalue=hp['qvalue']['layers_before_gru'], )
dummy_rng = rax.PRNGKey(0)
rng, rng1, rng2 = rax.split(state['rng'], 3)
state['step'] = 0
state['rng'] = rng
dummy_obs_seq = dummy_episode['obs'][:, 0].reshape(dummy_episode['obs'].shape[0], -1)

def set_up_agent_nn(player_id):
    agent_params = agent_module.init(rng1, {'obs_seq': dummy_obs_seq, 'rng': dummy_rng, 't': 0})
    agent = CoinAgent(params=agent_params, model=agent_module, player=player_id)
    state[f'agent{player_id}'] = agent

set_up_agent_nn(player_id=0)
agent0 = state['agent0']
if not just_self_play:
    set_up_agent_nn(player_id=1)
    agent1 = state['agent1']
    set_up_agent_nn(player_id=2)
    agent2 = state['agent2']

# --- defining replay buffers ---
def create_rb_agent_params(player_id: int):
    rb_size = hp['agent_replay_buffer']['capacity']
    tmp_rb = [state[f'agent{player_id}'].params for _ in range(rb_size)]
    state[f'rb_agent{player_id}_params'] = jax.tree_map(lambda *xs: jp.stack(xs, axis=0), *tmp_rb)
    state['min_valid_index_rb'] = rb_size  # first, the buffer is not valid

if use_rb(hp):
    create_rb_agent_params(player_id=0)
    if not just_self_play:
        create_rb_agent_params(player_id=1)
        create_rb_agent_params(player_id=2)

# --- defining ema ---
state['agent0_ema'] = agent0
if not just_self_play:
    state['agent1_ema'] = agent1
    state['agent2_ema'] = agent2

# --- defining optimizers ---
if hp['actor']['train']['optimizer'] == 'adam':
    actor_opt_module = optax.adam
elif hp['actor']['train']['optimizer'] == 'sgd':
    actor_opt_module = optax.sgd
else:
    raise ValueError(f"Unknown optimizer: {hp['actor']['train']['optimizer']}")

actor_train_separate = hp['actor']['train']['separate']
if actor_train_separate == 'enabled':
    actor_agent_lr = hp['actor']['train']['lr_loss_actor_agent']
    actor_opponent_lr = hp['actor']['train']['lr_loss_actor_opponent']
    actor_opt_agent = actor_opt_module(learning_rate=actor_agent_lr)
    actor_opt_opponent = actor_opt_module(learning_rate=actor_opponent_lr)

    def setup_actor_optimizer(player_id):
        agent = state[f'agent{player_id}']
        state[f'agent{player_id}_opt_actor_loss_agent'] = Optimizer(actor_opt_agent, actor_opt_agent.init(agent))
        state[f'agent{player_id}_opt_actor_loss_opponent'] = Optimizer(actor_opt_opponent, actor_opt_opponent.init(agent))

elif actor_train_separate == 'disabled':
    lr = hp['actor']['train']['lr_loss_actor']
    actor_opt = actor_opt_module(learning_rate=lr)

    def setup_actor_optimizer(player_id):
        agent = state[f'agent{player_id}']
        state[f'agent{player_id}_opt_actor_loss'] = Optimizer(actor_opt, actor_opt.init(agent))
else:
    raise ValueError(f"Unknown separate: {hp['actor']['train']['separate']}")

setup_actor_optimizer(player_id=0)
if not just_self_play:
    setup_actor_optimizer(player_id=1)
    setup_actor_optimizer(player_id=2)

critic_lr = hp['qvalue']['train']['lr_loss_qvalue']
if hp['qvalue']['train']['optimizer'] == 'adam':
    qvalue_opt = optax.adam(learning_rate=critic_lr)
elif hp['qvalue']['train']['optimizer'] == 'sgd':
    qvalue_opt = optax.sgd(learning_rate=critic_lr)
else:
    raise ValueError(f"Unknown optimizer: {hp['qvalue']['train']['optimizer']}")

if hp['qvalue']['replay_buffer']['mode'] == 'disabled':
    pass
else:
    raise ValueError(f'Unknown replay buffer mode: {hp["qvalue"]["replay_buffer"]["mode"]}')

state['agent0_opt_qvalue'] = Optimizer(qvalue_opt, qvalue_opt.init(agent0))
if not just_self_play:
    state['agent1_opt_qvalue'] = Optimizer(qvalue_opt, qvalue_opt.init(agent1))
    state['agent2_opt_qvalue'] = Optimizer(qvalue_opt, qvalue_opt.init(agent2))

c_0 = agent0.get_initial_carries()
c_0_actor = c_0['carry_actor']
c_0_qvalue = c_0['carry_qvalue']

if not just_self_play:
    c_1 = agent1.get_initial_carries()
    c_1_actor = c_1['carry_actor']
    c_1_qvalue = c_1['carry_qvalue']
    c_2 = agent2.get_initial_carries()
    c_2_actor = c_2['carry_actor']
    c_2_qvalue = c_2['carry_qvalue']
else:
    c_1_actor = c_0_actor
    c_1_qvalue = c_0_qvalue
    c_2_actor = c_0_actor
    c_2_qvalue = c_0_qvalue

carries = {'c_0_actor': c_0_actor,
           'c_0_qvalue': c_0_qvalue,
           'c_1_actor': c_1_actor,
           'c_1_qvalue': c_1_qvalue,
           'c_2_actor': c_2_actor,
           'c_2_qvalue': c_2_qvalue,
           }

return state, carries