In [18]:
import jax
import jax.numpy as jnp

import flax
import flax.linen as nn
import optax
from clu import metrics


from molnet import train_state

In [21]:
class M(metrics.Collection):
    l: metrics.Array

In [22]:
class Foo(nn.Module):
    train: bool
    filters: int

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(self.filters, (1, 1), use_bias=False, dtype=jnp.float32)(x)
        x = nn.BatchNorm(use_running_average=not self.train,
                        momentum=0.9,
                        epsilon=1e-5,
                        dtype=jnp.float32)(x)
        return x


In [23]:
key = jax.random.PRNGKey(0)
x = jnp.ones((5,4,4,3))

# We instantiate the layer then call its init function to get initial variable collections.
foo = Foo(filters=7, train=True)
foo_vars = foo.init(key, x)
params = foo_vars['params']
batch_stats = foo_vars['batch_stats']

In [24]:
state = train_state.TrainState.create(
    apply_fn=foo.apply,
    params=params,
    batch_stats=batch_stats,
    tx=optax.adam(1e-3),
    best_params=params,
    step_for_best_params=0,
    metrics_for_best_params={},
    train_metrics={},
)


In [25]:
state = flax.jax_utils.replicate(state)

In [32]:
state = state.replace(
    best_params=state.params,
    metrics_for_best_params=M([1.0]),
)