In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from base import *
from flax import linen as nn

import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'notebook_connected'

In [3]:
class BatchNorm(nn.Module):
    is_training: bool = False
    axis_name: str='batch'

    @nn.compact
    def __call__(self, x: jnp.ndarray):
        return nn.BatchNorm(use_running_average=not self.is_training, momentum=0.9, epsilon=1e-5, axis_name=self.axis_name)(x)

class SimpleModel(nn.Module):
    @nn.compact
    def __call__(self, x_batch: jnp.ndarray, is_training: bool=False, batch_name: str='batch'):
        return nn.Sequential([
        nn.Dense(5),
        nn.relu,
        BatchNorm(is_training=is_training, axis_name=batch_name),
        nn.Dense(7),
        nn.relu,
        BatchNorm(is_training=is_training, axis_name=batch_name),
        nn.Dense(6),
        nn.relu,
        BatchNorm(is_training=is_training, axis_name=batch_name),
        nn.Dense(1)
    ])(x_batch)

model_instance = ModelInstance(SimpleModel(), batch_name='batch')

In [4]:

key1, key2 = random.split(random.PRNGKey(0))

x_samples = random.uniform(key1, shape=(1000, 1), minval=-1, maxval=1)
y_samples = jnp.sum(x_samples ** 2, axis=1) + 0.1 * random.normal(key2, shape=(x_samples.shape[0],))

In [5]:
px.scatter(x=x_samples[:, 0], y=y_samples)

In [6]:
key3, key4 = random.split(random.PRNGKey(1))

x_samples_new = random.uniform(key3, shape=(500, 1), minval=-1, maxval=1)
y_samples_new = (x_samples_new ** 3).flatten() + 0.1 * random.normal(key4, shape=(x_samples_new.shape[0],))

In [7]:
px.scatter(x=x_samples_new[:, 0], y=y_samples_new)

In [8]:
class NewSimpleModel(nn.Module):
    @nn.compact
    def __call__(self, x_batch: jnp.ndarray, is_training: bool=False, batch_name: str='batch'):
        return nn.Sequential([
        nn.Dense(5),
        nn.relu,
        BatchNorm(is_training=is_training, axis_name=batch_name),
        nn.Dense(6),
        nn.relu,
        BatchNorm(is_training=is_training, axis_name=batch_name),
        nn.Dense(1)
    ])(x_batch)

model_instance_new = ModelInstance(NewSimpleModel(), batch_name='batch')

In [9]:
model_instance.initialize(x_samples)
model_instance_new.initialize(x_samples_new)

In [10]:
px.scatter(x=x_samples[:, 0], y=model_instance(x_samples)[:, 0])

In [11]:
px.scatter(x=x_samples_new[:, 0], y=model_instance_new(x_samples_new)[:, 0])

In [12]:
model_instance.compile(lambda x, y: (x - y) ** 2, need_vmap=True)
model_instance.attach_optimizer(optax.sgd(0.02))
model_instance.update_configs({'is_training': True})
model_instance_new.compile(lambda x, y: (x - y) ** 2, need_vmap=True)
model_instance_new.attach_optimizer(optax.sgd(0.02))
model_instance_new.update_configs({'is_training': True})

In [13]:
for i in range(1000):
    model_instance.step(x_samples, y_samples)
    model_instance_new.step(x_samples_new, y_samples_new)
    if i % 100 == 0:
        print(f'iteration {i}\nloss_1: {model_instance.compute_loss(x_samples, y_samples)}\nloss_2: {model_instance_new.compute_loss(x_samples_new, y_samples_new)}\n')

iteration 0
loss_1: 0.47364625334739685
loss_2: 0.4223848283290863

iteration 100
loss_1: 0.014027847908437252
loss_2: 0.017484677955508232

iteration 200
loss_1: 0.012430809438228607
loss_2: 0.014853582717478275

iteration 300
loss_1: 0.011907036416232586
loss_2: 0.013009863905608654

iteration 400
loss_1: 0.011600138619542122
loss_2: 0.011378384195268154

iteration 500
loss_1: 0.011410947889089584
loss_2: 0.010924562811851501

iteration 600
loss_1: 0.01125779002904892
loss_2: 0.010726905427873135

iteration 700
loss_1: 0.011145276017487049
loss_2: 0.01059980969876051

iteration 800
loss_1: 0.011047253385186195
loss_2: 0.010509643703699112

iteration 900
loss_1: 0.010966053232550621
loss_2: 0.010454313829541206



In [14]:
px.scatter(x=x_samples[:, 0], y=model_instance(x_samples)[:, 0])

In [15]:
px.scatter(x=x_samples_new[:, 0], y=model_instance_new(x_samples_new)[:, 0])

In [16]:
system = DifferentiableLearningSystem(({'a': model_instance, 'b': model_instance_new}))

In [17]:
submodules = system.submodules

In [18]:
submodules['a'].run_configs

{'is_training': True}

In [19]:
submodules['a'].update_configs({'test': 1})

In [20]:
model_instance.run_configs

{'is_training': True, 'test': 1}

In [21]:
system.submodules['a'].run_configs

{'is_training': True, 'test': 1}

In [22]:
submodules['b'] = 'cute'

In [23]:
submodules

{'a': <base.ModelInstance at 0x7fa39066dbd0>, 'b': 'cute'}

In [24]:
system.submodules

{'a': <base.ModelInstance at 0x7fa39066dbd0>,
 'b': <base.ModelInstance at 0x7fa36446fe20>}

In [25]:
system.submodules['b'].run_configs

{'is_training': True}

In [26]:
submodules['a'].reset_configs({'is_training': True})

In [27]:
system.submodules['a'].run_configs

{'is_training': True}

In [28]:
submodules

{'a': <base.ModelInstance at 0x7fa39066dbd0>, 'b': 'cute'}

In [30]:
system.save_states('test_system')

{'a': b'\x82\xa9variables\x82\xa6params\x87\xabBatchNorm_0\x81\xabBatchNorm_0\x82\xa4bias\xc7!\x01\x93\x91\x05\xa7float32\xc4\x14\xda:\xde;\xcfT\x14=\x10\xd0\xbe<\x84\x10\xdb\xba\xd2z\xed;\xa5scale\xc7!\x01\x93\x91\x05\xa7float32\xc4\x14\x16\xe6z?\x8dF|?:\xae}?\x86I\x80?j\xdf\x84?\xabBatchNorm_1\x81\xabBatchNorm_0\x82\xa4bias\xc7)\x01\x93\x91\x07\xa7float32\xc4\x1c7y\x80\xbc\xb76\xcd<\xe9s\x10\xbd@Q\x8c\xbc3r\x87;*\xe5E=\xa8\xaf\x08\xbd\xa5scale\xc7)\x01\x93\x91\x07\xa7float32\xc4\x1c\x16\xd7\x7f?\xcc\xef\x86?\xe0\xce\x80?\x01\xffk?\x86.\x7f?\x0b\xf8\x80?\xd4\x99\x80?\xabBatchNorm_2\x81\xabBatchNorm_0\x82\xa4bias\xc7%\x01\x93\x91\x06\xa7float32\xc4\x18;\xfc\x8b<>$\x7f=\x93\x04\r=\xcfY/=0\xe5\xef\xbd/m6\xbd\xa5scale\xc7%\x01\x93\x91\x06\xa7float32\xc4\x18\xe6\x90\x82?\x8f\xd1x?R\xb0\x80?\x8aHx?\x9d\xdc\x83?\xca\x02s?\xa7Dense_0\x82\xa4bias\xc7!\x01\x93\x91\x05\xa7float32\xc4\x14q00\xbdl\t\x85;}\xb7}=\x99Q\x84\xbd?\x07j=\xa6kernel\xc7"\x01\x93\x92\x01\x05\xa7float32\xc4\x14G@\r\xbf\xa4\r

In [31]:
system_new = DifferentiableLearningSystem({'a': ModelInstance(SimpleModel()), 'b': ModelInstance(NewSimpleModel())})

In [32]:
system_new.submodules['a'].initialize(x_samples)
system_new.submodules['b'].initialize(x_samples_new)

In [33]:
px.scatter(x=x_samples[:, 0], y=system_new.submodules['a'](x_samples)[:, 0])

In [34]:
px.scatter(x=x_samples_new[:, 0], y=system_new.submodules['b'](x_samples_new)[:, 0])

In [35]:
system_new.load_states('test_system')

In [36]:
px.scatter(x=x_samples[:, 0], y=system_new.submodules['a'](x_samples)[:, 0])

In [37]:
px.scatter(x=x_samples_new[:, 0], y=system_new.submodules['b'](x_samples_new)[:, 0])