In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from base import ModelInstance

In [3]:
from flax import linen as nn
import jax
import optax
from jax import numpy as jnp, random
from functools import partial

import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='.1'

In [4]:
import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'notebook_connected'

In [5]:
class BatchNorm(nn.Module):
    is_training: bool = False
    axis_name: str='batch'

    @nn.compact
    def __call__(self, x: jnp.ndarray):
        return nn.BatchNorm(use_running_average=not self.is_training, momentum=0.9, epsilon=1e-5, axis_name=self.axis_name)(x)

class SimpleModel(nn.Module):
    @nn.compact
    def __call__(self, x_batch: jnp.ndarray, is_training: bool=False, batch_name: str='batch'):
        return nn.Sequential([
        nn.Dense(5),
        nn.relu,
        BatchNorm(is_training=is_training, axis_name=batch_name),
        nn.Dense(7),
        nn.relu,
        BatchNorm(is_training=is_training, axis_name=batch_name),
        nn.Dense(6),
        nn.relu,
        BatchNorm(is_training=is_training, axis_name=batch_name),
        nn.Dense(1)
    ])(x_batch)

model_instance = ModelInstance(SimpleModel(), batch_name='batch')

In [6]:
key1, key2 = random.split(random.PRNGKey(0))

x_samples = random.normal(key1, shape=(1000, 1))
y_samples = jnp.sum(x_samples ** 2, axis=1) + 0.1 * random.normal(key2, shape=(x_samples.shape[0],))

In [7]:
print(x_samples.shape, y_samples.shape)

(1000, 1) (1000,)


In [8]:
px.scatter(x=x_samples[:, 0], y=y_samples)

In [9]:
model_instance.intitialize(x_samples)
model_instance.update_configs({'is_training': True})
model_instance.attach_optimizer(optax.sgd(0.02))

In [10]:
px.scatter(x=x_samples[:, 0], y=model_instance(x_samples)[:, 0])

In [11]:
forward_pass = model_instance.forward_fn

@jax.jit
def loss(params, state, x_batch, y_batch):
    y_pred, new_state = forward_pass(params, state, x_batch)
    return jnp.mean((y_pred.flatten() - y_batch.flatten()) ** 2), new_state

grad_fn = jax.jit(jax.grad(loss, has_aux=True))

In [12]:
for i in range(1000):
    params, state = model_instance.parameters_, model_instance.state_
    grads, new_state = grad_fn(params, state, x_samples, y_samples)
    model_instance.manual_step_with_optimizer(grads, new_state)

    if i % 100 == 0:
        print(f'loss: {loss(params, state, x_samples, y_samples)[0]}')

loss: 3.7212629318237305
loss: 0.0324089452624321
loss: 0.02732883393764496
loss: 0.024536438286304474
loss: 0.021715892478823662
loss: 0.018379464745521545
loss: 0.016348853707313538
loss: 0.015178583562374115
loss: 0.013827169314026833
loss: 0.01278134249150753


In [13]:
px.scatter(x=x_samples[:, 0], y=model_instance(x_samples)[:, 0])

In [14]:
model_instance.state

FrozenDict({
    batch_stats: {
        BatchNorm_0: {
            BatchNorm_0: {
                mean: DeviceArray([1.8353546e-01, 6.5854925e-01, 6.7443680e-03, 5.6051939e-45,
                             1.2852257e-02], dtype=float32),
                var: DeviceArray([8.3229408e-02, 9.5335805e-01, 6.3567434e-04, 5.6051939e-45,
                             1.7828426e-03], dtype=float32),
            },
        },
        BatchNorm_1: {
            BatchNorm_0: {
                mean: DeviceArray([0.208325  , 0.48652145, 0.11190458, 0.267204  , 0.4529868 ,
                             0.6916883 , 0.40137354], dtype=float32),
                var: DeviceArray([0.5545493 , 1.2655706 , 0.08869918, 0.58227545, 0.1235225 ,
                             0.2031821 , 1.1190659 ], dtype=float32),
            },
        },
        BatchNorm_2: {
            BatchNorm_0: {
                mean: DeviceArray([0.18760596, 0.49629655, 0.24809857, 0.20999658, 0.5986507 ,
                             0.

In [15]:
model_instance.parameters

FrozenDict({
    BatchNorm_0: {
        BatchNorm_0: {
            bias: DeviceArray([-0.03187943, -0.10190659, -0.07571849, -0.07529528,
                         -0.0776908 ], dtype=float32),
            scale: DeviceArray([0.982472  , 1.028007  , 0.97847664, 1.0122045 , 0.9673883 ],            dtype=float32),
        },
    },
    BatchNorm_1: {
        BatchNorm_0: {
            bias: DeviceArray([ 0.050379  ,  0.02354307, -0.10578769, -0.00530372,
                         -0.02527744,  0.10789678, -0.16058823], dtype=float32),
            scale: DeviceArray([0.9960681, 0.9395276, 1.0251638, 0.9395023, 1.0857949,
                         0.9906519, 0.9613953], dtype=float32),
        },
    },
    BatchNorm_2: {
        BatchNorm_0: {
            bias: DeviceArray([ 0.1680399 ,  0.1223714 ,  0.01548062,  0.18134804,
                         -0.27217585, -0.15057394], dtype=float32),
            scale: DeviceArray([1.1892052 , 0.9680278 , 0.99010545, 1.1084329 , 1.0332512 ,
         