In [1]:
from aij_multiagent_rl.env import AijMultiagentEnv
import yaml
import flax.linen as nn
import numpy as np
import jax
from vg_lib.modules import nets
from vg_lib.utils.training import *

In [2]:
env = AijMultiagentEnv()
initial_agents_state, initial_area_state = env.reset()
initial_world_state = env.state()

In [3]:
with open("config.yaml") as stream:
    try:
        config = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)

In [4]:
config["NUM_ACTORS"] = NUM_AGENTS * config["NUM_ENVS"]
config["NUM_UPDATES"] = (
    config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
)
config["MINIBATCH_SIZE"] = (
    config["NUM_ACTORS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
)
config["CLIP_EPS"] = config["CLIP_EPS"] / env.num_agents if config["SCALE_CLIP_EPS"] else config["CLIP_EPS"]


In [5]:
rng = jax.random.PRNGKey(config["SEED"])

In [26]:
from omegaconf import OmegaConf
OmegaConf.to_container('config.yaml')

ValueError: Input cfg is not an OmegaConf config object (str)

In [6]:
rng, _rng_actor, _rng_critic = jax.random.split(rng, 3)

In [7]:
actor_init_image = jnp.zeros((1,config["NUM_ENVS"], *initial_agents_state[AGENT_KEYS[0]]['image'].shape))
actor_init_proprio = jnp.zeros((1,config["NUM_ENVS"], *initial_agents_state[AGENT_KEYS[0]]['proprio'].shape))
actor_init_obs = (actor_init_image, actor_init_proprio)

In [8]:
actor_init_x = (
            actor_init_obs,
            jnp.zeros((1,config["NUM_ENVS"])),
        )

In [9]:
actor = nets.ActorRNN(env.action_space(AGENT_KEYS[0]).n, config=config)
ac_init_hstate = nets.ScannedRNN.initialize_carry(config["NUM_ENVS"], config["GRU_HIDDEN_DIM"])
actor_network_params = actor.init(_rng_actor, ac_init_hstate, actor_init_x)

In [10]:
critic_image = jnp.zeros((1, config["NUM_ENVS"], *env.state()['image'].shape))
critic_additional_obs = jnp.concatenate([env.state()[_] for _ in CENTR_OBS_KEYS], dtype = jnp.float32)
area_obs = jnp.array([list(initial_area_state[_].values()) for _ in AGENT_KEYS],dtype = jnp.float32).reshape(-1)
critic_feats = jnp.concatenate([critic_additional_obs,area_obs]).reshape(1,config["NUM_ENVS"],-1)

In [11]:
cr_init_x = (
            (critic_image,critic_feats),
            jnp.zeros((1, config["NUM_ENVS"])),
        )

In [12]:
critic = nets.CriticRNN(config=config)
cr_init_hstate = nets.ScannedRNN.initialize_carry(config["NUM_ENVS"], config["GRU_HIDDEN_DIM"])
critic_network_params = critic.init(_rng_critic, cr_init_hstate, cr_init_x)

In [13]:
actor_train_state = create_train_state(actor, actor_network_params, config)
critic_train_state = create_train_state(critic, critic_network_params, config)

In [14]:
rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
obsv, env_state = env.reset(seed=reset_rng[0][0].item())
ac_init_hstate = nets.ScannedRNN.initialize_carry(config["NUM_ACTORS"], config["GRU_HIDDEN_DIM"])
cr_init_hstate = nets.ScannedRNN.initialize_carry(config["NUM_ACTORS"], config["GRU_HIDDEN_DIM"])

In [19]:
imgs = batchify_image(initial_agents_state, config)[np.newaxis,:]
proprios = batchify(initial_agents_state,'proprio',config)[np.newaxis,:]

In [None]:
def _env_step(runner_state):
    train_states, env_state, last_obs, last_done, hstates, rng = runner_state

    # SELECT ACTION
    rng, _rng = jax.random.split(rng)
    obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"])
    ac_in = (
        obs_batch[np.newaxis, :],
        last_done[np.newaxis, :],
    )
    ac_hstate, pi = actor.apply(train_states[0].params, hstates[0], ac_in)
    action = pi.sample(seed=_rng)
    log_prob = pi.log_prob(action)
    env_act = unbatchify(
        action, env.agents, config["NUM_ENVS"], env.num_agents
    )
    # VALUE
    # output of wrapper is (num_envs, num_agents, world_state_size)
    # swap axes to (num_agents, num_envs, world_state_size) before reshaping to (num_actors, world_state_size)
    world_state = last_obs["world_state"].swapaxes(0,1)  
    world_state = world_state.reshape((config["NUM_ACTORS"],-1))
    cr_in = (
        world_state[None, :],
        last_done[np.newaxis, :],
    )
    cr_hstate, value = critic_network.apply(train_states[1].params, hstates[1], cr_in)

    # STEP ENV
    rng, _rng = jax.random.split(rng)
    rng_step = jax.random.split(_rng, config["NUM_ENVS"])
    obsv, env_state, reward, done, info = jax.vmap(
        env.step, in_axes=(0, 0, 0)
    )(rng_step, env_state, env_act)
    info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
    done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()
    transition = Transition(
        jnp.tile(done["__all__"], env.num_agents),
        last_done,
        action.squeeze(),
        value.squeeze(),
        batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(),
        log_prob.squeeze(),
        obs_batch,
        world_state,
        info,
    )
    runner_state = (train_states, env_state, obsv, done_batch, (ac_hstate, cr_hstate), rng)
    return runner_state, transition