In [18]:
import jax
import jax.numpy as jnp
from flax import struct
import flax.linen as nn
from flax.training import train_state
from clu import metrics
from optax import softmax_cross_entropy, adamw
from functools import partial

In [2]:
@struct.dataclass
class MyMetrics(metrics.Collection):
    accuracy: metrics.Accuracy
    loss: metrics.Average.from_output("loss")  # type: ignore


class CustomState(train_state.TrainState):
    metrics: MyMetrics

In [3]:
class Model(nn.Module):

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(2)(x)
        x = nn.relu(x)
        return x

class OverModel(nn.Module):
    def setup(self):
        self.model1 = Model()
        self.model2 = Model()
    
    def __call__(self, x):
        y1 = self.model1(x)
        y2 = self.model2(x)
        return y1, y2

In [4]:
key = jax.random.PRNGKey(0)
model = Model()
test_x = jnp.ones((1, 2))
test_labels = jnp.zeros((1, 2)).at[0, 1].set(1)
params = model.init(key, test_x)
params["params"]["Dense_0"]["kernel"] = jnp.array([[1.0, 1.0], [1.0, 1.0]])
y = model.apply(params, test_x)
loss = softmax_cross_entropy(y, test_labels).mean()
loss

Array(0.6931472, dtype=float32)

In [5]:
base_state = train_state.TrainState.create(
    apply_fn=model.apply, params=params, tx=adamw(1e-3)
)

state = CustomState.create(
    apply_fn=model.apply,
    params=params,
    tx=adamw(1e-3),
    metrics=MyMetrics.empty(),
)

y = state.apply_fn(state.params, test_x)
loss = softmax_cross_entropy(y, test_labels).mean()
loss, y

(Array(0.6931472, dtype=float32), Array([[2., 2.]], dtype=float32))

In [6]:
params

{'params': {'Dense_0': {'kernel': Array([[1., 1.],
          [1., 1.]], dtype=float32),
   'bias': Array([0., 0.], dtype=float32)}}}

In [7]:
def loss_fn1(params, state, batch):
    x, labels = batch
    logits = state.apply_fn(params, x)
    loss = softmax_cross_entropy(logits, labels).mean()
    return loss, y

def loss_fn2(state, batch):
    x, labels = batch
    logits = state.apply_fn(state.params, x)
    loss = softmax_cross_entropy(logits, labels).mean()
    return loss, y

In [8]:
l1 = loss_fn1(state.params, state, (test_x, test_labels))
l2 = loss_fn2(state, (test_x, test_labels))
l1, l2

((Array(0.6931472, dtype=float32), Array([[2., 2.]], dtype=float32)),
 (Array(0.6931472, dtype=float32), Array([[2., 2.]], dtype=float32)))

In [24]:
grad_fn1 = jax.grad(partial(loss_fn1, state=state, batch=(test_x, test_labels)), has_aux=True)
g1, _ = grad_fn1(state.params)
g2, _ = grad_fn1(state.params)
# g1, g2

In [27]:
jax.tree.all(jax.tree.map(lambda x, y: jnp.all(x == y), g1, g2))

True