In [1]:
from flax import nnx
import jax
import jax.numpy as jnp
import copy

In [2]:
class MyModel(nnx.Module):
  def __init__(self, inSize: int, outSize: int, *, rngs: nnx.Rngs):
    intermediateSize = 64
    key = rngs.params()
    self.linear1 = nnx.Linear(inSize, intermediateSize, rngs=rngs)
    self.linear2 = nnx.Linear(intermediateSize, outSize, rngs=rngs)

  def __call__(self, x):
    x = self.linear1(x)
    x = jax.nn.relu(x)
    x = self.linear2(x)
    return x

In [3]:
rngs = nnx.Rngs(jax.random.key(0))
model = MyModel(10, 5, rngs=rngs)

In [None]:
input = jnp.ones((10,))
print(input)
res = model(input)
nnx.display(model)
print(f'Original model result: {res}')

model2 = copy.deepcopy(model)
graph, params = nnx.split(model2)
params['linear2']['bias'] = params['linear2']['bias'].value.at[0].set(123.0)
nnx.display(params)
model2 = nnx.merge(graph, params)

print(f'Model copy res {model2(input)}')

nnx.display(model2)

res = model(input)
print(f'Original model result: {res}')


[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


Original model result: [ 0.3173416   0.26317954 -0.58215386  0.45295987 -0.45487443]


Model copy res [123.317345     0.26317954  -0.58215386   0.45295987  -0.45487443]


Original model result: [ 0.3173416   0.26317954 -0.58215386  0.45295987 -0.45487443]
