In [24]:
import pgx
import jax
import jax.numpy as jnp
from flax import nnx

env = pgx.make("connect_four")

In [25]:
state = env.init(jax.random.key(0))

In [26]:
state.observation[None].shape

(1, 6, 7, 2)

In [27]:
from modeling.connect_four import ConnectFourNetwork

network = ConnectFourNetwork(rngs=nnx.Rngs(0))

In [28]:
nnx.visualization.display(network)

In [29]:
jnp.allclose(network(jax.random.normal(jax.random.key(0), (256, 6, 7, 2))).pi, network(jax.random.normal(jax.random.key(0), (256, 6, 7, 2))).pi)

Array(True, dtype=bool)

In [30]:
class Model(nnx.Module):
    def __init__(self, *, rngs):
        self.model = nnx.Sequential(
            nnx.Linear(1, 32, rngs=rngs),
            nnx.Dropout(0.5, rngs=rngs),
            nnx.Linear(32, 1, rngs=rngs),
        )

    def __call__(self, x):
        return self.model(x)

model = Model(rngs=nnx.Rngs(0))
model(jax.random.normal(jax.random.key(0), (1, 1)))

Array([[4.08513]], dtype=float32)

In [31]:
def f(carry, _):
    (graphdef, state), x = carry
    model = nnx.merge(graphdef, state)
    y = model(x)
    return (nnx.split(model), x), y
    # return ((graphdef, state), x), y

model = Model(rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(0), (1, 1))
_, out = jax.lax.scan(f, (nnx.split(model), x), length=42)
out.reshape(-1)

Array([ 4.08513   ,  1.2717068 ,  0.16650593,  1.3415939 ,  3.414431  ,
        3.5085955 ,  1.0200549 ,  4.9219737 ,  2.4660082 , -0.0425849 ,
        1.372101  , -0.57328045,  4.3436627 ,  2.5402184 ,  3.386818  ,
        4.504349  ,  0.5937252 ,  2.8924809 ,  5.0068913 ,  3.0534568 ,
        3.3834524 ,  0.44345856,  6.1559176 ,  3.325729  ,  2.6489692 ,
        8.944305  ,  0.7456118 ,  6.0325255 ,  3.2266455 ,  3.649506  ,
        3.2939844 ,  1.1744883 ,  3.741819  ,  0.5944227 ,  3.5099874 ,
        3.582304  ,  1.2485955 ,  4.3441916 , -0.08865237,  4.02142   ,
        6.486699  ,  0.7796018 ], dtype=float32)