In [1]:
import jax
import jax.numpy as jnp
import chex
from functools import partial
from utils import tree_slice
from ipd_squared import IPDSquaredGenerator, IPDSquared

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def unpack(state) -> None:
    for k,t in list(zip(state.keys(), state.to_tuple())):
        print(k)
        print(t)
        print("\n")

In [3]:
key = jax.random.key(0)

In [16]:
generator = IPDSquaredGenerator()
env_config = {
    "epsilon_min": -0.1,
    "epsilon_max": 0.1,
    "scaling_factor": 100,
}
env = IPDSquared(generator, **env_config)
state, timestep = env.reset(key)
# unpack(state)
# unpack(timestep)

In [17]:
inner_actions = jnp.array([0,1,0,1])

inner_actions = inner_actions.reshape(2,2)

next_key, epsilon_key = jax.random.split(state.key, num=2)
epsilons = jax.random.uniform(
    epsilon_key, (2, 1), minval=env.epsilon_min, maxval=env.epsilon_max
)
epsilons = epsilons * jnp.array([-1, 1])

epsilons

Array([[-0.01531014,  0.01531014],
       [ 0.04016557, -0.04016557]], dtype=float32)

In [None]:
print(state.power)
power = jax.nn.softmax(state.power + epsilons, axis=-1)
print(power)

[[0.5 0.5]
 [0.5 0.5]]
[[0.4923455 0.5076545]
 [0.520072  0.479928 ]]


In [20]:
outer_actions = jnp.take_along_axis(
    inner_actions, jnp.argmax(power, axis=-1, keepdims=True), axis=-1
)
print(outer_actions)

[[1]
 [0]]


In [21]:
outer_payoffs = jnp.array([
    env.PAYOFF_MATRIX[outer_actions[0], outer_actions[1]],
    env.PAYOFF_MATRIX[outer_actions[1], outer_actions[0]],
])
print(outer_payoffs)

[[ 4]
 [-4]]


In [22]:
print(power)
power = jax.vmap(env._update_power)(power, inner_actions, outer_payoffs).squeeze()
print(power)

[[0.4923455 0.5076545]
 [0.520072  0.479928 ]]
[[0.45234552 0.5476545 ]
 [0.480072   0.519928  ]]


In [24]:
rewards = (power * outer_payoffs).flatten()

history = jnp.tile(inner_actions.flatten(), (env.num_agents, 1))

steps = state.step_count + 1
done = steps >= env.time_limit

rewards, history

(Array([ 1.8093821,  2.190618 , -1.920288 , -2.079712 ], dtype=float32),
 Array([[0, 1, 0, 1],
        [0, 1, 0, 1],
        [0, 1, 0, 1],
        [0, 1, 0, 1]], dtype=int32))