In [1]:
from qdax_es.utils.restart import FixedGens
from flax.struct import PyTreeNode

class EmitterState(PyTreeNode):
    restart_state: None

In [4]:
restarter = FixedGens(5)
restart_state = restarter.init()
state = EmitterState(restart_state=restart_state)

s = []
for i in range(12):
    state = restarter.update(state, None)
    s.append(state)
    bool = restarter.restart_criteria(state, None)
    print(i, bool)

0 False
1 False
2 False
3 False
4 True
5 False
6 False
7 False
8 False
9 False
10 True
11 False


In [11]:
import jax
import jax.numpy as jnp

def use_state(states, index):
    return states[index]

indices = jnp.arange(len(s))
jax.vmap(
    use_state,
    in_axes=(None, 0)
)(s, indices)

TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[].
This BatchTracer with object id 140304706562384 was created on line:
  /tmp/ipykernel_966253/2590179247.py:8 (<module>)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

In [12]:
a = jnp.arange(10)
jax.lax.dynamic_slice(a, (0,), (5,))

Array([0, 1, 2, 3, 4], dtype=int32)

In [13]:
jnp.split(a, 5)

[Array([0, 1], dtype=int32),
 Array([2, 3], dtype=int32),
 Array([4, 5], dtype=int32),
 Array([6, 7], dtype=int32),
 Array([8, 9], dtype=int32)]

In [20]:
# Make a tree with 5 copies of state
jax.tree_util.tree_map(
    lambda x: jnp.repeat(jnp.expand_dims(x, axis=0), 5),
    state
    )

EmitterState(restart_state=RestartState(generations=Array([0, 0, 0, 0, 0], dtype=int32, weak_type=True)))

In [19]:
from qdax_es.utils.restart import FixedGens
from flax.struct import PyTreeNode
import jax
import jax.numpy as jnp



class TestState(PyTreeNode):
    index: None
    value: int

# Create a tree with 5 states, each with a generatiosn number from a

def make_state(i):
    return TestState(index=i, value = jnp.eye(3)*i)

def make_batch(n):
    return jax.vmap(make_state)(jnp.arange(5) + n)

states = make_batch(0)
states

TestState(index=Array([0, 1, 2, 3, 4], dtype=int32), value=Array([[[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]],

       [[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]],

       [[2., 0., 0.],
        [0., 2., 0.],
        [0., 0., 2.]],

       [[3., 0., 0.],
        [0., 3., 0.],
        [0., 0., 3.]],

       [[4., 0., 0.],
        [0., 4., 0.],
        [0., 0., 4.]]], dtype=float32))

In [20]:
def update(s, i):
    jax.debug.print('s {} | i {}', s, i)
    return i
    # return state.index + i

data = jnp.split(jnp.arange(10)*10, 5)
data = jnp.array(data)
# data = jnp.arange(5)*10
print(data.shape)

jax.vmap(
    lambda i, state: update(state, i),
    in_axes=(0, 0),
    out_axes=0
    )(data, states)

(5, 2)
s TestState(index=Array(0, dtype=int32), value=Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)) | i [ 0 10]
s TestState(index=Array(1, dtype=int32), value=Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)) | i [20 30]
s TestState(index=Array(2, dtype=int32), value=Array([[2., 0., 0.],
       [0., 2., 0.],
       [0., 0., 2.]], dtype=float32)) | i [40 50]
s TestState(index=Array(3, dtype=int32), value=Array([[3., 0., 0.],
       [0., 3., 0.],
       [0., 0., 3.]], dtype=float32)) | i [60 70]
s TestState(index=Array(4, dtype=int32), value=Array([[4., 0., 0.],
       [0., 4., 0.],
       [0., 0., 4.]], dtype=float32)) | i [80 90]


Array([[ 0, 10],
       [20, 30],
       [40, 50],
       [60, 70],
       [80, 90]], dtype=int32)

In [12]:
def net_shape(net):
    return jax.tree_map(lambda x: x.shape, net)

state_groups = jax.vmap(
    make_batch,
    in_axes=0
)(jnp.arange(4)*10)

net_shape(state_groups)

TestState(index=(4, 5), value=(4, 5, 3, 3))

In [13]:
# Flatten the first dim
states = jax.tree_map(
        lambda x: jnp.concatenate(x, axis=0),
        state_groups
    )
net_shape(states)

TestState(index=(20,), value=(20, 3, 3))

In [102]:
jax.vmap(
    lambda state, i: update(state, i),
    in_axes=(None, 0),
    out_axes=0
    )(states, data)

s [0 1 2 3 4] | i [Array(0, dtype=int32), Array(20, dtype=int32), Array(40, dtype=int32), Array(60, dtype=int32), Array(80, dtype=int32)]
s [0 1 2 3 4] | i [Array(10, dtype=int32), Array(30, dtype=int32), Array(50, dtype=int32), Array(70, dtype=int32), Array(90, dtype=int32)]


[Array([ 0, 10], dtype=int32),
 Array([20, 30], dtype=int32),
 Array([40, 50], dtype=int32),
 Array([60, 70], dtype=int32),
 Array([80, 90], dtype=int32)]

In [16]:
def net_shape(net):
    return jax.tree_map(lambda x: x.shape, net)

net_shape(jnp.array(data))

(5, 2)

In [1]:
import os 
os.environ["JAX_ENABLE_X64"] = "True"

In [2]:
import jax
from jax import numpy as jnp
jax.devices()

[cuda(id=0)]

In [3]:
a = jnp.array([1, 2, 3])
a

Array([1, 2, 3], dtype=int64)

In [4]:
a.astype(jnp.float64)

Array([1., 2., 3.], dtype=float64)