In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from base import ModelInstance

In [3]:
from flax import linen as nn
import jax
from jax import numpy as jnp, random
from functools import partial

import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='.1'

In [4]:
class BatchNorm(nn.Module):
    is_training: bool = False
    axis_name: str='batch'

    @nn.compact
    def __call__(self, x: jnp.ndarray):
        return nn.BatchNorm(use_running_average=not self.is_training, momentum=0.9, epsilon=1e-5, axis_name=self.axis_name)(x)

class SimpleModel(nn.Module):
    @nn.compact
    def __call__(self, x_batch: jnp.ndarray, is_training: bool=False, batch_name: str='batch'):
        return nn.Sequential([
        nn.Dense(5),
        nn.relu,
        BatchNorm(is_training=is_training, axis_name=batch_name),
        nn.Dense(7),
        nn.relu,
        BatchNorm(is_training=is_training, axis_name=batch_name),
        nn.Dense(6),
        nn.relu,
        BatchNorm(is_training=is_training, axis_name=batch_name),
        nn.Dense(1)
    ])(x_batch)

model_instance = ModelInstance(SimpleModel(), batch_name='cute')

In [5]:
key1, key2 = random.split(random.PRNGKey(0))

x_samples = random.normal(key1, shape=(1000, 1))
y_samples = jnp.sum(x_samples ** 2, axis=1) + 0.1 * random.normal(key2, shape=(x_samples.shape[0],))

In [6]:
print(x_samples.shape, y_samples.shape)

(1000, 1) (1000,)


In [7]:
import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'notebook_connected'

In [8]:
px.scatter(x=x_samples[:, 0], y=y_samples)

In [9]:
try:
    model_instance(x_samples)
except Exception as e:
    print(e)

This model is not initialized! Please call "initialize" first.


In [10]:
try:
    model_instance.variables
except Exception as e:
    print(e)

This model is not initialized! Please call "initialize" first.


In [11]:
model_instance.intitialize(x_samples)

In [12]:
temp_1 = model_instance.variables

In [13]:
model_instance(x_samples)

DeviceArray([[-1.41045362e-01],
             [ 3.73392999e-02],
             [-6.78812265e-02],
             [ 3.41803581e-03],
             [-8.35109651e-02],
             [ 1.12453990e-01],
             [ 7.18076825e-02],
             [ 4.84998245e-03],
             [ 8.76955315e-02],
             [ 6.57335809e-03],
             [-9.43522081e-02],
             [-5.62527217e-03],
             [ 4.89484519e-02],
             [-1.72702372e-02],
             [ 1.12695405e-02],
             [-1.20161884e-02],
             [-2.38939226e-02],
             [ 5.64822517e-02],
             [ 2.38720085e-02],
             [-4.30890592e-03],
             [ 2.11295784e-02],
             [ 2.28722058e-02],
             [-9.52074528e-02],
             [ 3.52375954e-02],
             [ 1.69138294e-02],
             [ 2.93592550e-03],
             [ 2.49864608e-02],
             [-1.06786653e-01],
             [ 1.24650657e-01],
             [ 1.82463266e-02],
             [-1.13107435e-01],
        

In [14]:
px.scatter(x=x_samples[:, 0], y=model_instance(x_samples)[:, 0])

In [15]:
temp_2 = model_instance.variables

In [16]:
from jax import tree_util

tmp = tree_util.tree_map(lambda x, y: jnp.allclose(x, y), temp_1, temp_2)

for val in tree_util.tree_leaves(tmp):
    assert val == jnp.array(True), "Error: variables were mutated!"

In [17]:
print(temp_2)

{'batch_stats': FrozenDict({
    BatchNorm_0: {
        BatchNorm_0: {
            mean: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),
            var: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
        },
    },
    BatchNorm_1: {
        BatchNorm_0: {
            mean: DeviceArray([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
            var: DeviceArray([1., 1., 1., 1., 1., 1., 1.], dtype=float32),
        },
    },
    BatchNorm_2: {
        BatchNorm_0: {
            mean: DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32),
            var: DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32),
        },
    },
}), 'params': FrozenDict({
    BatchNorm_0: {
        BatchNorm_0: {
            bias: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),
            scale: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
        },
    },
    BatchNorm_1: {
        BatchNorm_0: {
            bias: DeviceArray([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
            scale: D

In [18]:
try:
    model_instance.eval_gradients(x_samples, y_samples)
except Exception as e:
    print(e)

The gradient function is not compiled! Please call "compile" first.


In [19]:
model_instance.compile(lambda y_pred, y_true: (y_pred - y_true) ** 2, True)

In [20]:
try:
    model_instance.step(x_samples, y_samples)
except Exception as e:
    print(e)

This model has no optimizer attached to it! Please call "attach_optimizer" first.


In [21]:
import optax

In [22]:

model_instance.attach_optimizer(optimizer=optax.sgd(0.02))

In [23]:
for i in range(100):
    model_instance.step(x_samples, y_samples)
    if i % 10 == 0:
        print(f'current loss: {model_instance.compute_loss(x_samples, y_samples)}')

current loss: 2.912132501602173
current loss: 1.9188743829727173
current loss: 1.5457590818405151
current loss: 1.3780467510223389
current loss: 1.1942375898361206
current loss: 0.9315579533576965
current loss: 0.5991763472557068
current loss: 0.3469927906990051
current loss: 0.2166314423084259
current loss: 0.15786048769950867


In [24]:
model_instance.variables

{'batch_stats': FrozenDict({
     BatchNorm_0: {
         BatchNorm_0: {
             mean: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),
             var: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
         },
     },
     BatchNorm_1: {
         BatchNorm_0: {
             mean: DeviceArray([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
             var: DeviceArray([1., 1., 1., 1., 1., 1., 1.], dtype=float32),
         },
     },
     BatchNorm_2: {
         BatchNorm_0: {
             mean: DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32),
             var: DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32),
         },
     },
 }),
 'params': FrozenDict({
     BatchNorm_0: {
         BatchNorm_0: {
             bias: DeviceArray([-0.02758599, -0.15754293, -0.04459511,  0.05464069,
                          -0.11339177], dtype=float32),
             scale: DeviceArray([1.2220318, 1.1611606, 1.0342898, 1.0043111, 1.0356225], dtype=float32),
         },
     },
     B

In [25]:
y_pred = model_instance(x_samples)

In [26]:
y_pred.shape

(1000, 1)

In [27]:
px.scatter(x=x_samples[:, 0], y=y_pred[:, 0])

In [28]:
model_instance.update_configs({'training': True})

In [29]:
model_instance.run_configs

{'training': True}

In [30]:
model_instance.update_configs({'training': False})

In [31]:
model_instance.run_configs

{'training': False}

In [32]:
model_instance.reset_configs({'is_training': True})

In [33]:
model_instance.run_configs

{'is_training': True}

In [34]:
model_instance.update_configs({'axis_name': 'cute'})

In [35]:
model_instance.run_configs

{'is_training': True, 'axis_name': 'cute'}

In [36]:
model_instance.reset_configs({'is_training': True, 'batch_name': 'cute'})

In [37]:
model_instance.run_configs

{'is_training': True, 'batch_name': 'cute'}

In [38]:
try:
    for i in range(100):
        model_instance.step(x_samples, y_samples)
        if i % 10 == 0:
            print(f'current loss: {model_instance.compute_loss(x_samples, y_samples)}')
except Exception as e:
    print(e)

current loss: 1.2210227251052856
current loss: 0.13761088252067566
current loss: 0.05219889432191849
current loss: 0.042282212525606155
current loss: 0.04059835895895958
current loss: 0.039887335151433945
current loss: 0.03918471559882164
current loss: 0.03855922818183899
current loss: 0.038035087287425995
current loss: 0.037440620362758636


In [39]:
model_instance.variables

{'batch_stats': FrozenDict({
     BatchNorm_0: {
         BatchNorm_0: {
             mean: DeviceArray([0.33294827, 0.5751627 , 0.11376952, 0.08421738, 0.04051756],            dtype=float32),
             var: DeviceArray([0.24644703, 0.89029956, 0.03019902, 0.00411578, 0.00918032],            dtype=float32),
         },
     },
     BatchNorm_1: {
         BatchNorm_0: {
             mean: DeviceArray([0.33368668, 0.75055766, 0.20627376, 0.14809375, 0.39673132,
                          0.79097337, 0.5890024 ], dtype=float32),
             var: DeviceArray([0.58337003, 2.1068916 , 0.2516741 , 0.1266787 , 0.14710124,
                          0.55886173, 1.9449422 ], dtype=float32),
         },
     },
     BatchNorm_2: {
         BatchNorm_0: {
             mean: DeviceArray([0.5855001 , 0.31849164, 0.1949912 , 0.24666765, 0.39563817,
                          0.3523214 ], dtype=float32),
             var: DeviceArray([2.84505   , 0.1197658 , 0.2108169 , 0.07267924, 0.21498045,
     

In [40]:
px.scatter(x=x_samples[:, 0], y=model_instance(x_samples)[:, 0])

In [41]:
model = SimpleModel()

In [42]:
variables = model.init(random.PRNGKey(0), x_samples)

In [43]:
model.apply(variables, x_samples).shape

(1000, 1)

In [44]:
print(f'final loss: {model_instance.compute_loss(x_samples, y_samples)}')

final loss: 0.03699241578578949


In [45]:
from flax import serialization

In [46]:
serialization.to_bytes(model_instance._ModelInstance__optimizer_state)

b'\x82\xa10\x80\xa11\x80'

In [47]:
model_instance._ModelInstance__optimizer_state

(EmptyState(), EmptyState())

In [48]:
import os
print(os.getcwd())

/home/trent/college-files-fa22/cs182/CS182-282A-Final-Project/library/models


In [49]:
state_dict = model_instance.save_states()

In [50]:
model_instance.save_states('test_model')

{'variables': {'params': FrozenDict({
      BatchNorm_0: {
          BatchNorm_0: {
              bias: DeviceArray([-0.0344986 , -0.17745182, -0.05497848,  0.04635565,
                           -0.10639708], dtype=float32),
              scale: DeviceArray([1.2294991, 1.1678644, 1.0343363, 1.0153118, 1.0024687], dtype=float32),
          },
      },
      BatchNorm_1: {
          BatchNorm_0: {
              bias: DeviceArray([ 0.00630755, -0.15142591, -0.02285976, -0.10151136,
                           -0.0536942 , -0.07301318, -0.07914702], dtype=float32),
              scale: DeviceArray([1.0266856 , 1.2405134 , 1.066992  , 0.96714586, 0.9898587 ,
                           0.9937279 , 1.2213867 ], dtype=float32),
          },
      },
      BatchNorm_2: {
          BatchNorm_0: {
              bias: DeviceArray([ 0.19292898,  0.18033852,  0.1896456 ,  0.11285002,
                           -0.21252334, -0.09991998], dtype=float32),
              scale: DeviceArray([1.230273  , 1

In [51]:
new_instance = ModelInstance(SimpleModel())

In [52]:
try:
    new_instance.load_states('test_model')
except Exception as e:
    print(e)

Model is not initialized!


In [53]:
new_instance.intitialize(x_samples)

In [54]:
px.scatter(x=x_samples[:, 0], y=new_instance(x_samples)[:, 0])

In [55]:
new_instance.load_states('test_model')

In [56]:
px.scatter(x=x_samples[:, 0], y=new_instance(x_samples)[:, 0])

In [57]:
new_instance = ModelInstance(SimpleModel())

In [58]:
try:
    new_instance.load_states('test_model')
except Exception as e:
    print(e)

Model is not initialized!


In [59]:
new_instance.intitialize(x_samples)

In [60]:
px.scatter(x=x_samples[:, 0], y=new_instance(x_samples)[:, 0])

In [61]:
new_instance.load_states(state_dict)

In [62]:
px.scatter(x=x_samples[:, 0], y=new_instance(x_samples)[:, 0])


In [63]:
different_instance = ModelInstance(SimpleModel())

In [64]:
different_instance.intitialize(jnp.ones((100, 2)))

In [65]:
different_state_dict = different_instance.save_states()

In [66]:
different_state_dict

{'variables': {'params': FrozenDict({
      Dense_0: {
          kernel: DeviceArray([[ 0.2782076 ,  0.3343372 ,  0.18675362,  0.21081653,
                         0.79690367],
                       [-0.4492056 , -0.22952792,  0.1016606 ,  0.08951215,
                        -0.30509454]], dtype=float32),
          bias: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),
      },
      BatchNorm_0: {
          BatchNorm_0: {
              scale: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
              bias: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),
          },
      },
      Dense_1: {
          kernel: DeviceArray([[-0.2587283 ,  0.8625674 ,  0.05435297,  0.3570576 ,
                        -0.12405909, -0.55076146, -0.37448213],
                       [-0.17958921, -0.5078695 ,  0.01714774, -0.1306908 ,
                        -0.6551801 , -0.36190635,  0.5624816 ],
                       [ 0.05415251,  0.52317613,  0.07341513,  0.55867696,
                        -0.

In [67]:
try:
    new_instance.load_states(different_state_dict)
except Exception as e:
    print(e)

'NoneType' object is not subscriptable


In [68]:
class AnotherSimpleModel(nn.Module):
    @nn.compact
    def __call__(self, x_batch: jnp.ndarray, is_training: bool=False, batch_name: str='batch'):
        return nn.Sequential([
        nn.Dense(5),
        nn.relu,
        BatchNorm(is_training=is_training, axis_name=batch_name),
        nn.Dense(7),
        nn.relu,
        BatchNorm(is_training=is_training, axis_name=batch_name),
        nn.Dense(6),
    ])(x_batch)

In [69]:
another_instance = ModelInstance(AnotherSimpleModel())

In [70]:
another_instance.intitialize(x_samples)

In [71]:
another_state_dict = another_instance.save_states('test_model_new')

In [72]:
try:
    new_instance.load_states(another_state_dict)
except Exception as e:
    print(e)

'BatchNorm_2'


In [73]:
try:
    new_instance.load_states('test_model_new')
except Exception as e:
    print(e)

'BatchNorm_2'
