In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from base import ModelInstance

In [3]:
from flax import linen as nn
import jax
from jax import numpy as jnp, random
from functools import partial

In [4]:
class BatchNorm(nn.Module):
    is_training: bool = False

    @nn.compact
    def __call__(self, x: jnp.ndarray):
        return nn.BatchNorm(use_running_average=not self.is_training, momentum=0.9, epsilon=1e-5, axis_name='batch')(x)

class SimpleModel(nn.Module):
    @nn.compact
    def __call__(self, x_batch: jnp.ndarray, is_training: bool=False, batch_name: str='batch'):
        return nn.Sequential([
        nn.Dense(5),
        BatchNorm(is_training=is_training),
        nn.Dense(7),
        BatchNorm(is_training=is_training),
        nn.Dense(6),
        BatchNorm(is_training=is_training),
        nn.Dense(1)
    ])(x_batch)

model_instance = ModelInstance(SimpleModel())

In [5]:
key1, key2 = random.split(random.PRNGKey(0))

x_samples = random.normal(key1, shape=(1000, 1))
y_samples = jnp.sum(x_samples ** 2, axis=1) + 0.1 * random.normal(key2, shape=(x_samples.shape[0],))

In [6]:
print(x_samples.shape, y_samples.shape)

(1000, 1) (1000,)


In [7]:
import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'notebook_connected'

In [8]:
px.scatter(x=x_samples[:, 0], y=y_samples)

In [9]:
try:
    model_instance(x_samples)
except Exception as e:
    print(e)

This model is not initialized! Please call "initialize" first.


In [10]:
try:
    model_instance.variables
except Exception as e:
    print(e)

This model is not initialized! Please call "initialize" first.


In [11]:
model_instance.intitialize(x_samples)

In [12]:
temp_1 = model_instance.variables

In [13]:
model_instance(x_samples)

DeviceArray([[-9.84946012e-01],
             [ 5.15047848e-01],
             [-4.74046022e-01],
             [ 4.71352339e-02],
             [-5.83089650e-01],
             [ 1.55110002e+00],
             [ 9.89659250e-01],
             [ 6.68726489e-02],
             [ 1.20890772e+00],
             [ 9.06507075e-02],
             [-6.58767521e-01],
             [-3.92645970e-02],
             [ 6.74945116e-01],
             [-1.20533451e-01],
             [ 1.55406415e-01],
             [-8.39292482e-02],
             [-1.66883886e-01],
             [ 7.78452873e-01],
             [ 3.29093456e-01],
             [-3.00793331e-02],
             [ 2.91359067e-01],
             [ 3.15231889e-01],
             [-6.64605081e-01],
             [ 4.85885680e-01],
             [ 2.33225569e-01],
             [ 4.04779725e-02],
             [ 3.44446927e-01],
             [-7.45657742e-01],
             [ 1.71799731e+00],
             [ 2.51478583e-01],
             [-7.90008724e-01],
        

In [14]:
px.scatter(x=x_samples[:, 0], y=model_instance(x_samples)[:, 0])

In [15]:
temp_2 = model_instance.variables

In [16]:
from jax import tree_util

tmp = tree_util.tree_map(lambda x, y: jnp.allclose(x, y), temp_1, temp_2)

for val in tree_util.tree_leaves(tmp):
    assert val == jnp.array(True), "Error: variables were mutated!"

In [17]:
try:
    model_instance.eval_gradients(x_samples, y_samples)
except Exception as e:
    print(e)

The gradient function is not compiled! Please call "compile" first.


In [18]:
model_instance.compile(lambda y_pred, y_true: (y_pred - y_true) ** 2, True)

In [19]:
try:
    model_instance.step(x_samples, y_samples)
except Exception as e:
    print(e)

This model has no optimizer attached to it! Please call "attach_optimizer" first.


In [20]:
import optax

In [21]:

model_instance.attach_optimizer(optimizer=optax.sgd(0.02))

In [22]:
for i in range(100):
    model_instance.step(x_samples, y_samples)

In [23]:
model_instance.variables

{'batch_stats': FrozenDict({
     BatchNorm_0: {
         BatchNorm_0: {
             mean: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),
             var: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
         },
     },
     BatchNorm_1: {
         BatchNorm_0: {
             mean: DeviceArray([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
             var: DeviceArray([1., 1., 1., 1., 1., 1., 1.], dtype=float32),
         },
     },
     BatchNorm_2: {
         BatchNorm_0: {
             mean: DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32),
             var: DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32),
         },
     },
 }),
 'params': FrozenDict({
     BatchNorm_0: {
         BatchNorm_0: {
             bias: DeviceArray([-0.03346549,  0.04640758, -0.06939359,  0.04920085,
                          -0.04048646], dtype=float32),
             scale: DeviceArray([0.99269396, 0.9641437 , 0.99851197, 1.0021048 , 1.0015484 ],            dtype=float32),
         }

In [24]:
y_pred = model_instance(x_samples)

In [25]:
y_pred.shape

(1000, 1)

In [26]:
px.scatter(x=x_samples[:, 0], y=y_pred[:, 0])

In [27]:
model_instance.update_configs({'training': True})

In [28]:
model_instance.run_configs

{'training': True}

In [29]:
model_instance.update_configs({'training': False})

In [30]:
model_instance.run_configs

{'training': False}

In [31]:
model_instance.reset_configs({'is_training': True})

In [32]:
model_instance.run_configs

{'is_training': True}

In [36]:
try:
    for i in range(100):
        model_instance.step(x_samples, y_samples)
except Exception as e:
    print(e)

unbound axis name: batch. The following axis names (e.g. defined by pmap) are available to collective operations: []


In [34]:
model_instance.variables

{'batch_stats': FrozenDict({
     BatchNorm_0: {
         BatchNorm_0: {
             mean: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),
             var: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
         },
     },
     BatchNorm_1: {
         BatchNorm_0: {
             mean: DeviceArray([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
             var: DeviceArray([1., 1., 1., 1., 1., 1., 1.], dtype=float32),
         },
     },
     BatchNorm_2: {
         BatchNorm_0: {
             mean: DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32),
             var: DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32),
         },
     },
 }),
 'params': FrozenDict({
     BatchNorm_0: {
         BatchNorm_0: {
             bias: DeviceArray([-0.03346549,  0.04640758, -0.06939359,  0.04920085,
                          -0.04048646], dtype=float32),
             scale: DeviceArray([0.99269396, 0.9641437 , 0.99851197, 1.0021048 , 1.0015484 ],            dtype=float32),
         }