# FLAX Convolution Neural Network Example - Interactive API

Run this jupyter notebook on a virtual environment.

In [2]:
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.8'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

In [None]:
# !pip install --upgrade -q "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# !pip install --upgrade -q git+https://github.com/google/flax.git

In [None]:
# !pip install flax==0.4.1 ml_collections optax jax==0.3.13 -q
# # jaxlib==0.3.10 -q

GPU version of JAX. Pick the jax version compatible with the CUDA and cuDNN pre-installed.

In [None]:
# !pip install --upgrade pip # Careful with the pip upgrade, it may cause a package dependency related problems during OpenFL workflow execution.

# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
# !pip install -U jaxlib==0.3.10+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
# !pip install -U jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.0.5
# !pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

In [None]:
# Without either of the below flags, JAX XLA raised CUDA_OUT_OF_MEMORY exception.
# JAX XLA pre-allocates 90% of the GPU at start

# Below flag to restrict max GPU allocation to 80%
# %env XLA_PYTHON_CLIENT_MEM_FRACTION=.8

# OR

# set XLA_PYTHON_CLIENT_PREALLOCATE to false to incrementally allocate GPU memory as and when required. But can take entire GPU by the end.
# %env XLA_PYTHON_CLIENT_PREALLOCATE=false


In [None]:
# %pip install tensorflow==2.8.1


In [None]:
# %pip install tensorflow_datasets ml_collections


In [3]:
%env

{'SHELL': '/bin/bash',
 'COLORTERM': 'truecolor',
 'no_proxy': 'localhost,127.0.0.0/8,10.0.0.0/8,.intel.com',
 'TERM_PROGRAM_VERSION': '1.69.2',
 'LANGUAGE': 'en_IN:en',
 'PWD': '/home/sunilach/openfl/forked-intel-openfl/openfl-tutorials/interactive_api/Flax_CNN_CIFAR',
 'LOGNAME': 'sunilach',
 'XDG_SESSION_TYPE': 'tty',
 'ftp_proxy': 'ftp://proxy.iind.intel.com:1080/',
 'MOTD_SHOWN': 'pam',
 'HOME': '/home/sunilach',
 'LANG': 'en_US.UTF-8',
 'LS_COLORS': 'rs=0:di=01;34:ln=01;36:mh=00:pi=40;33:so=01;35:do=01;35:bd=40;33;01:cd=40;33;01:or=40;31;01:mi=00:su=37;41:sg=30;43:ca=30;41:tw=30;42:ow=34;42:st=37;44:ex=01;32:*.tar=01;31:*.tgz=01;31:*.arc=01;31:*.arj=01;31:*.taz=01;31:*.lha=01;31:*.lz4=01;31:*.lzh=01;31:*.lzma=01;31:*.tlz=01;31:*.txz=01;31:*.tzo=01;31:*.t7z=01;31:*.zip=01;31:*.z=01;31:*.dz=01;31:*.gz=01;31:*.lrz=01;31:*.lz=01;31:*.lzo=01;31:*.xz=01;31:*.zst=01;31:*.tzst=01;31:*.bz2=01;31:*.bz=01;31:*.tbz=01;31:*.tbz2=01;31:*.tz=01;31:*.deb=01;31:*.rpm=01;31:*.jar=01;31:*.war=01;31

In [4]:
from flax import linen as nn
from flax.metrics import tensorboard
from flax.training import train_state
import jax
import jax.numpy as jnp
import ml_collections
import optax
import tensorflow_datasets as tfds
import logging

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# # Both the MSE function are optimal and accurate in terms of correctness.

# # Calculate MSE approach 1
# def mse_loss_function1(W, X, y):
#     y_pred = jnp.dot(X, W)
#     mse_error = y_pred - y
#     return jnp.mean(jnp.square(mse_error))

# # Calculate MSE approach 2
# def mse_loss_function2(W, X, Y):
#     def squared_error(x, y):
#         y_pred = jnp.dot(x, W)
#         return jnp.inner(y-y_pred, y-y_pred)|
#     vectorized_square_error = jax.vmap(squared_error)
#     return jnp.mean(vectorized_square_error(X, Y), axis=0)

# # Weight update, JAX compiled function. Consequent executions are way faster!!!.
# def update(W, x, y, lr):
#     W = W - lr * jax.grad(mse_loss_function1)(W, x, y)
#     return W

In [5]:
editor_relpaths = ('configs/default.py', 'train.py')

In [6]:
@jax.jit
def apply_model(state, images, labels):
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(params):
        logits = state.apply_fn({'params': params}, images)
        one_hot = jax.nn.one_hot(labels, 10)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return grads, loss, accuracy


@jax.jit
def update_model(state, grads):
    return state.apply_gradients(grads=grads)

In [307]:
def train_epoch(state, train_ds, batch_size, rng):
    """Train for a single epoch."""
    train_ds_size = len(train_ds['image'])
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, len(train_ds['image']))
    perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))

    epoch_loss = []
    epoch_accuracy = []

    for perm in perms:
        batch_images = train_ds['image'][perm, ...]
        batch_labels = train_ds['label'][perm, ...]
        grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
        print(grads)
        break
        state = update_model(state, grads)
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)
        
    train_loss = np.mean(epoch_loss)
    train_accuracy = np.mean(epoch_accuracy)
    return state, train_loss, train_accuracy

In [308]:
def train_and_evaluate(config: ml_collections.ConfigDict,
                       workdir: str) -> train_state.TrainState:
    """Execute model training and evaluation loop.
    Args:
        config: Hyperparameter configuration for training and evaluation.
        workdir: Directory where the tensorboard summaries are written to.
    Returns:
        The train state (which includes the `.params`).
    """
    
    train_ds, test_ds = get_datasets()
    rng = jax.random.PRNGKey(0)

    summary_writer = tensorboard.SummaryWriter(workdir)
    summary_writer.hparams(dict(config))

    rng, init_rng = jax.random.split(rng)
    state = create_train_state(init_rng, config)

    for epoch in range(1, config.num_epochs + 1):
        rng, input_rng = jax.random.split(rng)
        state, train_loss, train_accuracy = train_epoch(state, train_ds,
                                                        config.batch_size,
                                                        input_rng)
        _, test_loss, test_accuracy = apply_model(state, test_ds['image'],
                                                  test_ds['label'])

        logging.info(
            'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f'
            % (epoch, train_loss, train_accuracy * 100, test_loss,
               test_accuracy * 100))

        summary_writer.scalar('train_loss', train_loss, epoch)
        summary_writer.scalar('train_accuracy', train_accuracy, epoch)
        summary_writer.scalar('test_loss', test_loss, epoch)
        summary_writer.scalar('test_accuracy', test_accuracy, epoch)

    summary_writer.flush()
    return state

In [9]:
# #Using jit decorator for GPU acceleration for entire function
# @jax.jit
# def train_step(optimizer, batch):
#     def loss_fn(model):
#         preds = model(batch['image'])
#         loss = cross_entropy_loss(preds, batch['label'])
#         return loss, preds
#     grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
#     (_, preds), grad = grad_fn(optimizer.target)
#     optimizer = optimizer.apply_gradient(grad)
#     return optimizer

# @jax.jit
# def eval_step(model, batch):
#     preds = model(batch['image'])
#     return compute_metrics(preds, batch['label'])

# def eval_model(model, test_ds):
#     metrics = eval_step(model, test_ds)
#     metrics = jax.device_get(metrics)
#     summary = jax.tree_map(lambda x: x.item(), metrics)
#     return summary['loss'], summary['accuracy']

In [309]:
def get_datasets():
    """Load MNIST train and test datasets into memory."""
    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.
    return train_ds, test_ds

In [338]:
# rng = jax.random.PRNGKey(0)
# rng, init_rng = jax.random.split(rng)
# cnn = CNN()
# params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'].unfreeze()

In [341]:
def create_train_state(rng, config):
    """Creates initial `TrainState`."""
    cnn = CNN()
    params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'].unfreeze()
    tx = optax.sgd(config.learning_rate, config.momentum)
    return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)

In [342]:
class CNN(nn.Module):
    """A simple CNN model."""
    
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x

In [13]:
train_ds, test_ds = get_datasets()

Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.


2022-08-09 15:09:37.394065: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.


In [14]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# # Helper functions for images.

# def show_img(img, ax=None, title=None):
#     """Shows a single image."""
#     if ax is None:
#         ax = plt.gca()
#         ax.imshow(img[..., 0], cmap='gray')
#         ax.set_xticks([])
#         ax.set_yticks([])
#     if title:
#         ax.set_title(title)

# def show_img_grid(imgs, titles):
#     """Shows a grid of images."""
#     n = int(np.ceil(len(imgs)**.5))
#     _, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n))
#     for i, (img, title) in enumerate(zip(imgs, titles)):
#         show_img(img, axs[i // n][i % n], title)

In [None]:
# show_img_grid(
#     [train_ds['image'][idx] for idx in range(25)],
#     [f'label={train_ds["label"][idx]}' for idx in range(25)],
# )

In [343]:
from configs import default as config_lib
config = config_lib.get_config()

In [None]:
# cnn = CNN()
# params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']

In [344]:
config.num_epochs = 3
models = {}
for momentum in (0.8, 0.9, 0.95):
    name = f'momentum={momentum}'
    config.momentum = momentum
    state = train_and_evaluate(config, workdir=f'./models/{name}')
    break
    models[name] = state.params

{'Conv_0': {'bias': DeviceArray([-2.5294956e-03,  1.0617384e-03,  1.0871092e-03,
             -7.5664517e-04, -1.8678233e-04, -4.9029361e-04,
             -4.4167665e-04, -7.9249972e-03, -9.1656536e-04,
             -1.8964002e-03, -2.1935222e-03,  7.8311907e-03,
              2.1288046e-03, -1.6735045e-03, -3.3306577e-03,
             -1.0311848e-03,  7.7862656e-03,  2.1576872e-03,
             -7.2822906e-03, -1.5483726e-03, -6.3303817e-04,
             -4.6299753e-04,  4.3941289e-03, -5.3220720e-04,
             -8.3184766e-04,  7.8265844e-03, -3.4356546e-03,
             -3.6323036e-05, -2.5688871e-03,  1.1922442e-03,
             -5.2049058e-04, -4.6177106e-03], dtype=float32), 'kernel': DeviceArray([[[[-6.59753452e-04,  9.93206748e-04,  4.25191654e-04,
                -1.60608336e-03, -1.48017437e-03,  1.05801504e-03,
                -2.38798209e-04, -2.88684689e-03, -9.36191296e-04,
                -1.06081145e-03, -3.68145120e-04,  4.18583397e-03,
                -1.70943735e-0

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


{'Conv_0': {'bias': DeviceArray([-0.00429211,  0.00014817,  0.00244336,  0.00265356,
             -0.00806753,  0.00110937, -0.00016437, -0.00045536,
             -0.00082569,  0.00109708,  0.00296827,  0.00832468,
              0.00184548, -0.00093308, -0.00029506, -0.00064374,
              0.0170198 ,  0.00206642, -0.00162658, -0.00632473,
             -0.00038288, -0.00032191,  0.00537447,  0.00046672,
             -0.00196092,  0.00701693, -0.00064962, -0.00046556,
             -0.00377867,  0.00299433, -0.00460592, -0.00390522],            dtype=float32), 'kernel': DeviceArray([[[[-1.74214155e-03,  4.18937649e-04,  2.15712716e-04,
                -9.75224306e-04, -2.56598252e-03,  1.26618624e-03,
                -2.05858014e-04,  5.63990034e-04, -2.27882992e-08,
                 8.98707483e-04,  3.82827688e-03,  4.17585997e-03,
                -9.91377165e-05,  2.81037384e-04,  4.01121855e-04,
                -4.87552752e-04,  7.82713573e-03,  2.50424747e-03,
                -1.2

In [326]:
new_model.apply_gradients(grads=uf_params)

ValueError: Expected dict, got FrozenDict({
    Conv_0: {
        bias: DeviceArray([ 1.0808865e-31, -4.5855012e-02,  2.8046457e-02,
                      5.4478245e-32,  1.4003498e-31,  9.9151448e-06,
                      7.7850878e-02,  2.9201466e-31,  1.7564533e-31,
                      3.2798093e-02,  1.6265796e-31,  1.7444642e-31,
                      1.7386822e-02,  4.9014372e-31,  9.8282407e-32,
                      2.8425191e-02,  2.1398419e-31,  7.5427423e-32,
                      7.2153812e-32,  3.7461914e-02, -1.4826582e-02,
                      7.1077035e-03, -3.4116551e-02,  1.0147342e-31,
                      4.1000829e-31,  9.8198112e-32, -1.5755799e-03,
                     -3.2646570e-02, -3.2302715e-02, -5.4963474e-04,
                      5.0304699e-32, -3.0913241e-03], dtype=float32),
        kernel: DeviceArray([[[[ 1.88760765e-32, -1.06703984e-02,  1.37033593e-02,
                         9.39276337e-33,  4.76551170e-32, -3.07930395e-06,
                        -6.38741907e-03,  6.00778471e-32,  9.45506677e-33,
                         1.03639625e-02,  3.98432194e-32,  3.23849575e-32,
                         2.30746623e-03,  1.97585060e-31,  1.36012749e-32,
                         7.36412778e-03,  4.38277720e-32,  1.40749550e-32,
                         2.59687890e-32, -1.68234576e-03,  2.04338343e-03,
                        -3.77814780e-04, -9.07193962e-03,  1.56974076e-32,
                         1.10642183e-31,  2.22435064e-32, -2.62265708e-02,
                        -1.25320628e-04, -6.56265626e-03, -3.29226162e-03,
                         1.90065608e-32, -7.96138123e-03]],
        
                      [[ 1.75537541e-32, -5.11189317e-03,  1.77772492e-02,
                         6.25111512e-33,  5.03651017e-32, -1.19744084e-06,
                         4.21682326e-03,  5.83606203e-32,  1.09854892e-32,
                        -6.99534267e-03,  3.85376243e-32,  3.14009218e-32,
                         9.59832687e-03,  1.73702013e-31,  1.33813413e-32,
                         2.72393115e-02,  4.24502543e-32,  1.28540484e-32,
                         2.30628245e-32,  2.49417312e-03,  1.28552341e-03,
                         7.63133401e-04, -8.47547408e-03,  1.54377996e-32,
                         1.23537638e-31,  2.17163060e-32, -1.81277879e-02,
                         5.10839163e-04, -5.90922125e-03,  1.82963093e-03,
                         1.49170395e-32, -2.97121517e-03]],
        
                      [[ 1.62452585e-32, -1.26466947e-02,  2.20823772e-02,
                        -4.01341102e-33,  5.08877676e-32, -1.64414438e-08,
                         6.86563458e-03,  5.99846950e-32,  7.52986710e-33,
                        -1.54449092e-02,  3.68295693e-32,  3.25629861e-32,
                         1.34102628e-02,  1.48982530e-31,  1.59024975e-32,
                         3.39684635e-02,  4.38228349e-32,  1.19342931e-32,
                         2.22313988e-32,  2.09489986e-02,  3.51920142e-04,
                        -2.25777295e-03, -7.02692196e-03,  1.60227212e-32,
                         1.15111871e-31,  2.20974189e-32,  5.43883396e-03,
                         2.03639967e-04, -2.82242429e-03, -4.41902434e-04,
                         1.22419831e-32,  3.91736254e-03]]],
        
        
                     [[[ 2.02526612e-32, -2.98805274e-02,  4.53310926e-03,
                         1.18504319e-32,  5.49777239e-32,  1.00575189e-05,
                        -6.85883779e-03,  5.77905467e-32,  2.52192880e-32,
                         2.15671156e-02,  3.53051412e-32,  3.64840298e-32,
                         1.23566180e-03,  2.04512719e-31,  1.52200569e-32,
                        -9.44643840e-03,  4.34643650e-32,  1.45866638e-32,
                         2.69898528e-32,  8.88581271e-05,  1.10417383e-03,
                        -1.82714430e-05, -4.07742895e-03,  1.90136270e-32,
                         1.29183726e-31,  2.22417637e-32, -1.61921717e-02,
                        -2.80353753e-03, -9.56237037e-03,  6.21392741e-04,
                         1.69886999e-32, -5.51198050e-03]],
        
                      [[ 2.14794130e-32, -3.12967598e-02,  8.19312967e-03,
                         5.05629087e-33,  5.49085519e-32,  1.86354664e-05,
                         7.80786911e-04,  5.58499407e-32,  2.80271561e-32,
                         1.73810031e-03,  3.58531478e-32,  3.49926507e-32,
                         1.04123242e-02,  1.85995309e-31,  1.56224536e-32,
                         1.31919179e-02,  4.39095394e-32,  1.51995754e-32,
                         2.62341922e-32,  1.49283782e-02,  2.40249094e-03,
                         2.83300993e-04, -6.28831331e-03,  1.78228307e-32,
                         1.42217114e-31,  2.25022724e-32, -2.37629749e-03,
                        -1.07910694e-03, -1.59482490e-02,  3.43064289e-03,
                         1.28626133e-32, -2.32797489e-03]],
        
                      [[ 2.07728439e-32, -2.91103423e-02,  1.45034026e-02,
                        -3.24532060e-33,  5.58059125e-32,  1.12391754e-05,
                         1.00401649e-03,  5.72589705e-32,  2.70371518e-32,
                        -9.17699654e-03,  3.61638104e-32,  3.59247854e-32,
                         1.61960460e-02,  1.60538580e-31,  1.75453714e-32,
                         2.43985988e-02,  4.57183460e-32,  1.45188496e-32,
                         2.50762450e-32,  3.33447605e-02,  1.97158195e-04,
                        -6.06852351e-04, -8.44884012e-03,  1.81301152e-32,
                         1.15253365e-31,  2.36912820e-32,  2.08975747e-02,
                        -5.32232085e-03, -1.69059392e-02,  8.74940306e-04,
                         1.12411893e-32,  1.30783324e-03]]],
        
        
                     [[[ 2.25709359e-32, -3.60307768e-02, -3.18310980e-04,
                         9.70342007e-33,  5.63030937e-32, -1.84576649e-07,
                        -1.13286935e-02,  5.37584247e-32,  3.81144904e-32,
                         2.56528184e-02,  2.87330198e-32,  3.49251597e-32,
                        -7.08714360e-04,  2.12542050e-31,  1.65294563e-32,
                        -1.51337506e-02,  4.43293760e-32,  1.54038528e-32,
                         2.82186059e-32,  4.13856609e-03, -5.83265349e-03,
                        -4.28673578e-03, -1.63070876e-02,  1.84614444e-32,
                         1.46710723e-31,  2.27888300e-32,  1.17624905e-02,
                        -1.25518832e-02, -8.25445727e-03,  7.18970317e-03,
                         1.42115665e-32, -4.29227017e-03]],
        
                      [[ 2.46675550e-32, -3.57003249e-02,  1.42756884e-03,
                         7.30573045e-34,  5.52290328e-32,  1.56635906e-06,
                        -2.99157086e-03,  5.40685495e-32,  4.04269111e-32,
                        -1.01008974e-02,  2.96996200e-32,  3.44802233e-32,
                         2.98276218e-03,  1.94704935e-31,  1.65187299e-32,
                         5.14872419e-03,  4.45338092e-32,  1.59029442e-32,
                         2.62273449e-32,  2.28536502e-02, -4.95431339e-03,
                        -2.05106172e-03, -2.54069827e-02,  1.76643241e-32,
                         1.54683713e-31,  2.09739225e-32,  2.21123230e-02,
                        -1.40414303e-02, -2.22542509e-02,  9.24549624e-03,
                         1.00495628e-32, -4.27999394e-03]],
        
                      [[ 2.40194947e-32, -2.70901360e-02,  1.07428730e-02,
                        -2.35700576e-33,  5.61231667e-32,  2.43128616e-05,
                         3.80234444e-03,  5.58449977e-32,  4.49171056e-32,
                        -2.82089189e-02,  3.11002304e-32,  3.66501682e-32,
                         1.31476652e-02,  1.63966427e-31,  1.64783046e-32,
                         1.21509824e-02,  4.38118999e-32,  1.62723169e-32,
                         2.34609585e-32,  3.03830057e-02, -7.87956826e-03,
                        -8.16675834e-03, -3.03711779e-02,  1.87898481e-32,
                         1.21256262e-31,  2.07846547e-32,  3.18143852e-02,
                        -2.25854963e-02, -2.88177952e-02,  7.86592253e-03,
                         9.73208671e-33,  7.64041324e-04]]]], dtype=float32),
    },
    Conv_1: {
        bias: DeviceArray([ 7.2362374e-32,  4.3423921e-02, -3.2792657e-11,
                      1.1282176e-31,  6.0194356e-32,  4.3301480e-32,
                      6.0269047e-32,  1.6132905e-31,  2.1355933e-07,
                     -2.6829765e-08, -8.1259329e-03, -7.0907641e-05,
                      1.1180537e-31, -4.1502616e-03,  4.2984208e-32,
                     -2.5609383e-04,  6.3052618e-32,  2.0556299e-03,
                      2.9264822e-32,  6.3099637e-32, -6.3015718e-04,
                      2.3913330e-03, -8.2194638e-03, -2.4244960e-02,
                      6.2192921e-02,  3.1647896e-03,  7.6166262e-32,
                      1.6307342e-03, -3.6509786e-02,  5.4633804e-32,
                      2.0081217e-31, -5.1490813e-03, -4.5157220e-02,
                     -1.3541744e-22,  3.8570526e-32, -6.9861137e-03,
                      4.9329232e-02,  7.4679856e-32,  1.4692178e-02,
                     -1.5012160e-02,  1.3628905e-31,  2.3997447e-32,
                      3.7825742e-32, -1.5029975e-03,  3.3741928e-32,
                      1.0799532e-31,  8.4676847e-32, -2.6299991e-17,
                      8.2792630e-32, -6.8202093e-03,  2.6623955e-31,
                      1.2483622e-31, -1.2146547e-02,  4.7126203e-32,
                      6.9751960e-32,  5.5513839e-04,  2.8393987e-32,
                      3.7111582e-32,  3.3369472e-32,  1.0118496e-03,
                      3.1109426e-32,  5.2568501e-32,  9.3120513e-03,
                      1.0045990e-31], dtype=float32),
        kernel: DeviceArray([[[[ 5.16660514e-33,  8.72831647e-34,  7.22479537e-34, ...,
                         4.89801129e-33,  3.12169514e-34,  3.10468262e-34],
                       [ 1.39643552e-34,  3.59744672e-03, -9.61801888e-16, ...,
                         5.07812320e-34, -2.37903488e-03,  7.68304164e-34],
                       [ 9.49177709e-34,  6.28543831e-03,  6.11239568e-34, ...,
                         7.21326956e-34,  4.01173439e-03,  2.83446953e-33],
                       ...,
                       [ 1.66276431e-33,  9.74386465e-04, -5.28283743e-13, ...,
                         1.45966129e-33, -6.03575201e-04,  5.89998982e-33],
                       [ 9.98494362e-33,  2.25412962e-33,  1.58179589e-33, ...,
                         6.59922675e-33,  3.36227430e-34,  7.06095120e-34],
                       [ 4.49088837e-33,  9.06452886e-04,  2.54892652e-33, ...,
                         2.83779967e-33,  7.73492793e-04,  1.89520046e-32]],
        
                      [[ 5.06363000e-33,  3.67381910e-34,  1.10508426e-33, ...,
                         4.99033388e-33,  9.78429611e-35,  3.64036458e-34],
                       [ 2.28274345e-34,  8.75685085e-03, -1.32461576e-11, ...,
                         4.11173873e-34, -2.02565361e-03,  9.78549823e-34],
                       [ 1.13398067e-33,  2.62255245e-03, -8.75944384e-15, ...,
                         8.52597348e-34,  1.79070607e-03,  2.07517989e-33],
                       ...,
                       [ 1.30135582e-33,  1.33752718e-03, -1.18095055e-13, ...,
                         6.75296112e-34,  9.48218512e-04,  7.22490493e-33],
                       [ 1.07123917e-32,  3.70204153e-34,  2.59596374e-33, ...,
                         8.04174861e-33, -3.96826000e-35,  4.73984118e-34],
                       [ 4.73119174e-33,  5.65278810e-04,  1.91032452e-33, ...,
                         1.38174916e-33,  1.86420185e-03,  2.20431140e-32]],
        
                      [[ 6.78733010e-33,  6.09619727e-34,  1.57974925e-33, ...,
                         5.69251102e-33, -1.58041255e-34, -1.62897551e-34],
                       [ 3.09481352e-34,  6.34378754e-03, -1.60753702e-11, ...,
                         2.98022164e-34,  1.40427845e-03,  1.22019675e-33],
                       [ 1.30113964e-33,  1.61401310e-03, -7.35946355e-15, ...,
                         8.31812496e-34,  8.92492535e-04,  5.03098863e-34],
                       ...,
                       [ 1.36893168e-33,  6.20536646e-03, -8.92528344e-17, ...,
                         8.97137197e-34,  1.89593725e-03,  8.86259190e-33],
                       [ 1.34203971e-32,  7.12544222e-34,  3.27687968e-33, ...,
                         9.57936354e-33, -8.07845360e-34, -9.94568505e-34],
                       [ 5.05433698e-33,  1.18793109e-02,  1.85132352e-33, ...,
                         2.01688915e-33,  3.86982528e-03,  2.48990098e-32]]],
        
        
                     [[[ 7.34348806e-33,  2.87048374e-34,  3.03825598e-33, ...,
                         5.56492287e-33,  1.82396396e-34, -4.23220182e-35],
                       [ 2.90028229e-34,  1.84943870e-04, -1.37785764e-11, ...,
                         5.58362801e-34, -7.30571453e-04,  5.23686223e-34],
                       [ 7.33205050e-34,  9.70160030e-03, -8.29524708e-14, ...,
                         6.24957082e-34,  3.51056596e-03,  2.63996011e-33],
                       ...,
                       [ 2.32200873e-33, -4.47895116e-04,  3.31392111e-17, ...,
                         1.45270962e-33,  1.98447990e-04,  5.20445642e-33],
                       [ 1.13915350e-32,  7.46473673e-34,  2.96875029e-33, ...,
                         7.52809137e-33,  6.03624201e-34, -4.72055894e-34],
                       [ 4.32291537e-33, -1.58410243e-04,  5.69450652e-34, ...,
                         2.30635188e-33, -6.49258785e-04,  1.79712706e-32]],
        
                      [[ 5.76632840e-33, -3.70972288e-34,  3.04478181e-33, ...,
                         5.12343437e-33,  2.61829382e-34, -1.49048299e-34],
                       [ 2.66159678e-34, -1.94534205e-03, -6.39037998e-12, ...,
                         4.90082283e-34,  1.48366124e-03,  5.69582573e-34],
                       [ 9.47557455e-34,  4.49621258e-03,  5.64135772e-18, ...,
                         7.63514208e-34,  4.61676111e-03,  1.94015912e-33],
                       ...,
                       [ 1.79629790e-33, -2.04860524e-04, -9.52121975e-12, ...,
                         6.19866273e-34,  2.47474469e-04,  6.79570476e-33],
                       [ 1.00462449e-32, -8.50443990e-34,  3.84315315e-33, ...,
                         8.33716650e-33,  2.91574762e-34, -1.00513895e-33],
                       [ 4.83189635e-33,  4.72944404e-04, -3.97244114e-13, ...,
                         1.34665983e-33,  9.27922083e-05,  2.09745279e-32]],
        
                      [[ 5.71120543e-33,  7.03903971e-34,  2.39766409e-33, ...,
                         5.42370926e-33, -4.27295066e-35, -2.16374485e-34],
                       [ 2.47079729e-34, -3.30616091e-03,  1.62490119e-17, ...,
                         3.19398598e-34,  1.03440939e-03,  7.47274570e-34],
                       [ 1.20869894e-33,  2.08018185e-03,  4.63920968e-34, ...,
                         7.05805701e-34, -1.17018772e-03,  1.29156459e-33],
                       ...,
                       [ 1.42359538e-33,  9.67374071e-03, -7.16823779e-12, ...,
                         1.15566376e-33,  2.28069816e-03,  8.62177572e-33],
                       [ 1.08142130e-32,  1.73098210e-33,  3.28859238e-33, ...,
                         9.00591185e-33, -3.66020518e-34, -1.87891678e-33],
                       [ 5.07020799e-33,  2.20814347e-02, -3.98493218e-12, ...,
                         2.99776020e-33,  2.51931231e-03,  2.07802084e-32]]],
        
        
                     [[[ 1.08398241e-32,  3.59838060e-34,  5.04925921e-33, ...,
                         7.10047518e-33,  2.04986286e-34, -2.23284601e-34],
                       [ 5.76402140e-34,  3.96381249e-04, -3.77349959e-15, ...,
                         4.84187592e-34, -1.00021818e-04,  1.75028446e-34],
                       [ 7.54697863e-34,  5.08471113e-03, -5.45096216e-12, ...,
                         5.51226954e-34,  3.06925038e-03,  2.18107761e-33],
                       ...,
                       [ 3.18599891e-33, -1.92081032e-04,  1.39429472e-33, ...,
                         1.35244041e-33, -7.69458362e-04,  4.63508075e-33],
                       [ 1.56375029e-32,  8.65698417e-34,  6.64169002e-33, ...,
                         1.02317056e-32,  6.79582488e-34, -1.03247039e-33],
                       [ 4.54585375e-33,  8.55025777e-04,  6.81127402e-35, ...,
                         1.86550990e-33, -1.69464410e-03,  1.35792211e-32]],
        
                      [[ 9.02815808e-33,  3.11518200e-35,  4.66014376e-33, ...,
                         6.74649563e-33,  2.54814092e-34, -7.07404051e-34],
                       [ 6.06859749e-34,  1.55769975e-03, -1.98761862e-12, ...,
                         3.92863069e-34,  2.10208382e-04,  8.07363614e-35],
                       [ 6.70488340e-34,  3.02247703e-03, -1.01299732e-11, ...,
                         5.33448887e-34,  4.42312984e-03,  1.90072441e-33],
                       ...,
                       [ 2.93607155e-33, -2.65759998e-04, -1.31269214e-12, ...,
                         1.29199512e-33, -2.11198436e-04,  5.87791146e-33],
                       [ 1.35078642e-32,  9.24237852e-34,  6.32755238e-33, ...,
                         1.03447331e-32,  4.54565043e-34, -1.97623780e-33],
                       [ 4.87366937e-33,  6.66087773e-03,  6.36460029e-34, ...,
                         2.49053311e-33, -5.51473000e-04,  1.47460271e-32]],
        
                      [[ 7.18495722e-33,  4.60129907e-34,  3.42456255e-33, ...,
                         5.95905731e-33,  1.50355967e-34, -3.15595989e-34],
                       [ 4.98718126e-34,  1.07275392e-03, -1.60635377e-11, ...,
                         2.56292712e-34,  7.16920593e-04,  1.19419493e-34],
                       [ 8.68249423e-34,  8.35605664e-04, -2.09388154e-14, ...,
                         5.47690920e-34, -6.71895337e-04,  2.02213901e-33],
                       ...,
                       [ 2.29824592e-33,  1.02452794e-02, -2.75087193e-11, ...,
                         1.78385437e-33,  5.97893144e-04,  6.53469285e-33],
                       [ 1.17613691e-32,  3.24205346e-33,  4.50850169e-33, ...,
                         9.41841265e-33,  1.83255345e-34, -1.78775426e-33],
                       [ 4.71684851e-33,  2.44701840e-02, -1.20149220e-11, ...,
                         3.83458637e-33,  9.75530944e-04,  1.28060720e-32]]]],            dtype=float32),
    },
    Dense_0: {
        bias: DeviceArray([ 1.01703573e-02,  5.99808882e-05,  9.55130626e-03,
                      3.04929035e-05,  4.30238603e-32,  3.45537152e-33,
                      4.65148331e-33,  4.24136619e-33,  4.28053507e-13,
                     -1.02158193e-09,  1.66617376e-32,  3.65041080e-04,
                      1.92211444e-32,  8.36720876e-03,  1.82623025e-02,
                      5.14209853e-04, -5.23955578e-07, -6.05069450e-04,
                     -1.51621194e-07, -1.07350005e-02,  2.85517757e-13,
                      8.47246663e-33,  3.20250976e-32, -5.32590598e-03,
                     -1.27136163e-05,  6.66677952e-05,  8.19738097e-11,
                      4.38755544e-33,  4.53771970e-33,  5.67810936e-03,
                      2.25419080e-05,  2.52435599e-02, -6.20762631e-03,
                      8.71772063e-05,  4.13120936e-32,  7.38564056e-33,
                      3.39134736e-03,  3.38260218e-33, -4.01938194e-03,
                      1.32401430e-33,  2.10565559e-04, -1.89103861e-03,
                      1.17704665e-04,  2.21359295e-32, -4.83605685e-03,
                     -8.91372096e-03,  2.61325784e-33,  1.04736560e-03,
                      2.95243855e-03,  5.76212608e-18,  5.38529007e-33,
                      1.03429705e-03,  5.72646875e-03, -1.15213413e-02,
                      3.88636506e-32,  6.48714509e-03,  6.85124980e-33,
                      9.89897922e-03,  2.21768250e-33,  1.35594066e-02,
                     -3.12413205e-03, -1.39312877e-03,  3.15260054e-33,
                     -7.89280330e-06,  5.30344443e-33,  1.65717109e-04,
                      4.19745791e-28,  3.91075803e-32, -1.66994585e-06,
                     -1.14290295e-02,  2.81406901e-33,  3.48943705e-32,
                      5.36308903e-03, -4.63112304e-03,  1.84506757e-06,
                      9.62018083e-19, -2.44257133e-03, -1.15036684e-07,
                     -6.38324228e-17,  3.72263230e-03,  4.03445665e-06,
                     -2.34187595e-04, -1.35071296e-02,  2.11333489e-32,
                      6.77898875e-04,  1.54639599e-06,  3.55283525e-10,
                      1.37949901e-33,  5.01221314e-33, -8.83733388e-04,
                      2.03620214e-02,  8.80872892e-12, -6.14980189e-03,
                      7.34670387e-18, -2.85355048e-03, -3.69201833e-03,
                      4.41327831e-03, -3.14070189e-06, -4.89034133e-15,
                      6.89034117e-03, -4.21927171e-03,  9.51401781e-33,
                      8.96642258e-33, -7.11248394e-09,  1.82399871e-32,
                     -1.90336607e-07,  5.54403176e-33,  4.04103595e-08,
                      1.62419602e-02, -9.78302793e-04,  3.04815530e-32,
                     -3.44636175e-03,  2.75004956e-33, -4.15745674e-34,
                     -2.10962514e-03,  8.06986864e-33,  3.58684763e-32,
                      2.43059406e-32,  3.35664861e-03,  2.93033991e-33,
                      1.75595493e-32, -6.21314044e-04,  9.43966654e-07,
                      3.50578172e-32, -1.02336562e-04,  2.22073452e-32,
                      2.38420413e-33,  1.90696795e-03,  2.14496936e-32,
                      9.12412544e-33, -1.13130827e-03,  3.34614165e-32,
                      4.74392902e-04, -1.40524411e-03,  3.06402771e-32,
                     -2.79411888e-05,  4.21874380e-33,  9.82114021e-03,
                     -8.60426668e-03,  3.24725789e-32,  3.84135126e-32,
                     -3.33187683e-03,  3.66587192e-05,  2.82166402e-05,
                     -4.56465903e-04, -2.05339916e-06,  5.40548936e-05,
                      2.67100558e-32,  3.16858763e-33, -6.20762818e-03,
                      4.85221835e-06,  1.64141344e-32, -4.85267909e-03,
                      6.79855142e-03,  9.77414427e-04, -7.07316212e-03,
                      3.19159916e-03,  3.10355561e-33,  5.77431735e-33,
                      4.38001721e-33,  2.20132782e-03,  5.20531343e-33,
                      4.12327126e-02,  2.65931382e-32,  2.24305114e-32,
                      1.58009247e-03,  3.45818903e-33,  3.86114835e-32,
                      1.02751923e-32,  3.02830432e-03,  3.27801146e-03,
                      5.86235930e-33,  1.62855210e-03,  1.08323880e-02,
                     -5.37389347e-14, -5.78694325e-03, -8.31814532e-07,
                     -1.04779648e-02,  1.55105768e-02,  5.20596048e-04,
                      3.05511894e-10, -2.49014391e-07,  5.58388310e-05,
                     -1.03518541e-03, -8.67732987e-03, -3.42808873e-03,
                      1.13862995e-20,  2.07039600e-33,  4.46842657e-03,
                      3.77023127e-04,  1.68904825e-03,  2.23367033e-03,
                      5.34148806e-32, -2.65359848e-12,  6.37734044e-33,
                     -7.31587829e-03, -1.02802981e-02,  2.46214345e-32,
                     -1.97132118e-03,  3.76066416e-32,  7.63285847e-04,
                     -5.66401333e-03, -1.09764130e-07,  1.47738481e-32,
                      1.60278417e-02,  5.28343860e-03,  3.16928081e-33,
                      1.94687687e-03, -6.56406485e-09, -1.18566602e-02,
                      5.69459444e-05,  4.66139750e-33,  6.43157922e-33,
                      5.06975079e-32,  1.02619117e-03, -7.57842354e-05,
                      1.94587426e-32,  9.79289828e-28,  6.32853906e-33,
                      4.20722208e-05,  5.41840301e-33,  7.61336414e-03,
                      1.53141876e-32,  3.62121733e-03, -3.66104953e-03,
                     -2.70130066e-03,  3.54149742e-33,  6.15753696e-33,
                      4.80689380e-32,  3.58464336e-03,  1.29225703e-33,
                      2.11454288e-04,  5.62761503e-04,  2.62150913e-03,
                      2.90290387e-07, -5.36223035e-03, -1.00542326e-03,
                     -4.23786882e-03,  2.73819272e-32,  2.16575479e-03,
                     -4.99137957e-03,  1.24744615e-02,  5.64042293e-03,
                     -1.88381250e-06, -8.50463845e-03, -4.62390699e-06,
                     -1.03131495e-02, -3.85973230e-03,  4.87079933e-33,
                     -1.16005521e-02, -1.96431556e-05, -1.30271428e-06,
                      2.77548969e-32,  4.08529543e-33,  2.23131049e-20,
                      3.95821175e-03], dtype=float32),
        kernel: DeviceArray([[ 1.95309814e-36,  2.37601226e-36,  3.58959269e-35, ...,
                      -5.65117905e-36,  2.17434040e-35, -3.45978788e-36],
                     [-6.11466695e-30,  6.27598902e-32, -4.07856479e-31, ...,
                       1.86428109e-35,  3.43079190e-35,  1.45839418e-30],
                     [-4.28965755e-36,  8.97212395e-38,  2.03908364e-35, ...,
                       2.17002671e-35,  2.77544835e-35, -4.58905676e-35],
                     ...,
                     [-1.95994771e-35,  2.87401391e-36,  9.79940534e-35, ...,
                       1.23964167e-34,  9.46841586e-35, -1.52746847e-34],
                     [-1.13898986e-34,  5.73729807e-35,  8.55015836e-34, ...,
                       1.02545111e-34,  5.77545916e-35, -3.89402478e-34],
                     [-8.78509836e-35,  3.60062099e-35,  2.04327238e-33, ...,
                       9.08154621e-35,  4.91905585e-35,  1.01528825e-33]],            dtype=float32),
    },
    Dense_1: {
        bias: DeviceArray([ 0.0238293 ,  0.00255279,  0.00165561, -0.00595134,
                      0.00641471, -0.00226816, -0.00470199, -0.01119075,
                      0.00604223, -0.01638225], dtype=float32),
        kernel: DeviceArray([[ 5.1968065e-03, -3.5570948e-03, -4.8827650e-03, ...,
                       1.3550627e-03,  1.1933317e-02, -5.2302587e-03],
                     [ 2.6878531e-04,  9.1034671e-08, -1.9705308e-06, ...,
                       9.9719164e-06,  1.4086970e-06, -1.3356811e-05],
                     [ 1.1172998e-02, -2.7297743e-04, -6.4533832e-03, ...,
                      -1.3111686e-04,  1.6451434e-03, -7.6837902e-04],
                     ...,
                     [-2.1089845e-32, -3.5258281e-33, -2.4732194e-33, ...,
                       5.6005101e-34, -2.5262963e-32, -3.9286369e-33],
                     [ 7.7069369e-26,  5.9369065e-25, -9.4798079e-22, ...,
                       9.2960429e-22,  5.8415279e-26,  7.3898588e-24],
                     [ 1.8079231e-04,  7.8052897e-03,  1.1503263e-03, ...,
                      -6.9765705e-03, -2.8172573e-03, -7.5913011e-04]],            dtype=float32),
    },
}).

In [346]:
param =jax.tree_util.tree_map(np.array, state.params)

In [44]:
!pip install flatdict



In [45]:
import flatdict

In [66]:
d = flatdict.FlatDict(param, delimiter='.')

In [96]:
'.'.join(filter(None, ['', 'Conv_0', 'kernel', '', 'Lol']))

'Conv_0.kernel.Lol'

In [350]:
weight_dict = dict()

In [351]:
alias = 'param'
delim = '.'

In [352]:
param = state.params

In [353]:
for layer_name, param_obj in param.items():
    for param_name, value in param_obj.items():
        key = delim.join([alias, layer_name, param_name])
        weight_dict[key] = value

In [354]:
for k, v in weight_dict.items():
    print(type(k), type(v))
    break

<class 'str'> <class 'jaxlib.xla_extension.DeviceArray'>


In [355]:
model = state

In [356]:
def _get_weights_dict(obj, prefix='', suffix=''):
    """
    Get the dictionary of weights.

    Parameters
    ----------
    obj : Model or Optimizer
        The target object that we want to get the weights.

    Returns
    -------
    dict
        The weight dictionary.
    """
    weights_dict = dict()
    delim = '.'
    for layer_name, param_obj in obj.items():
        for param_name, value in param_obj.items():
            key = delim.join(filter(None, [prefix, layer_name, param_name, suffix]))
            weights_dict[key] = value

    return weights_dict

In [362]:
model_params = jax.tree_util.tree_map(np.array, model.params)
model_opt_state = jax.tree_util.tree_map(np.array, model.opt_state)[0][0]

In [364]:
opt_dict = _get_weights_dict(model_opt_state, 'opt', suffix)

In [366]:
suffix = ''

In [367]:
model_params = jax.tree_util.tree_map(np.array, model.params)
params_dict = _get_weights_dict(model_params, 'param', suffix)

if model.opt_state is not None:
    model_opt_state = jax.tree_util.tree_map(np.array, model.opt_state)[0][0]
    opt_dict = _get_weights_dict(model_opt_state, 'opt', suffix)
    params_dict.update(opt_dict)

In [368]:
for k, v in params_dict.items():
    print(k)

param.Conv_0.bias
param.Conv_0.kernel
param.Conv_1.bias
param.Conv_1.kernel
param.Dense_0.bias
param.Dense_0.kernel
param.Dense_1.bias
param.Dense_1.kernel
opt.Conv_0.bias
opt.Conv_0.kernel
opt.Conv_1.bias
opt.Conv_1.kernel
opt.Dense_0.bias
opt.Dense_0.kernel
opt.Dense_1.bias
opt.Dense_1.kernel


In [369]:
new_model = state

In [370]:
from flax.core.frozen_dict import freeze, unfreeze

In [250]:
{'bias': new_model.params['Conv_0']['bias'].at[0].set(1.10101)}

{'bias': DeviceArray([ 1.10101   , -0.09015119, -0.13249867, -1.272739  ,
              -2.9447565 , -0.42616746,  0.52555424, -6.1838517 ,
              -1.3973562 , -0.52935064, -3.4368992 , -3.6981347 ,
              -0.25468838, -4.1671534 , -2.1030016 , -0.14081043,
              -4.5040197 , -1.5975306 , -1.5151801 , -2.4758625 ,
              -0.09489761,  0.02707327, -0.52436036, -2.1598485 ,
              -3.6401956 , -2.0718012 , -1.2368054 , -0.13675065,
              -0.36861733, -0.40148488, -1.0966164 , -0.83448386],            dtype=float32)}

In [271]:
grads

NameError: name 'grads' is not defined

In [259]:
params_ = jax.tree_util.tree_map(np.array, model.params)

In [266]:
params_ = unfreeze(params_)

In [267]:
params_['Conv_0']['bias']= 1.10101

In [270]:
model.params = freeze(params_)

FrozenInstanceError: cannot assign to field 'params'

In [235]:
new_model.params['Conv_0']['bias'] = unfreeze(new_model.params['Conv_0']['bias'].at[0].set(1.10101)

ValueError: FrozenDict is immutable.

In [375]:
new_model

TrainState(step=0, apply_fn=<bound method Module.apply of CNN()>, params={'Conv_0': {'bias': DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0.], dtype=float32), 'kernel': DeviceArray([[[[ 3.32822055e-01, -2.82645106e-01, -1.80686682e-01,
                -6.79024830e-02, -7.08380640e-01,  2.77156025e-01,
                 2.83575952e-01,  5.28992653e-01, -3.09524149e-01,
                 5.76644480e-01,  3.28146875e-01, -7.18251020e-02,
                -4.51019973e-01, -5.43700039e-01, -3.07050906e-02,
                -4.92292970e-01,  7.26195991e-01,  1.11178890e-01,
                 3.07978690e-01, -7.01449215e-01, -7.72421658e-02,
                -4.72714245e-01,  1.41713738e-01, -1.60780087e-01,
                -1.76688954e-01,  1.37713805e-01, -5.79014122e-02,
                 1.17106706e-01,  3.02598029e-01,  2.02198133e-01,
                 3.30374956e-01,  7.1303147

In [380]:
params_dict = jax.tree_util.tree_map(jnp.array, params_dict)

In [381]:
params_dict

{'opt.Conv_0.bias': DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0.], dtype=float32),
 'opt.Conv_0.kernel': DeviceArray([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                 0., 0., 0., 0.]],
 
               [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                 0., 0., 0., 0.]],
 
               [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                 0., 0., 0., 0.]]],
 
 
              [[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                 0., 0., 0., 0.]],
 
               [[0., 0., 0., 0., 0., 0

In [389]:
ja = jnp.array([0.80, 0.80, 0.88, 0.887, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0.])

In [390]:
params_dict['opt.Conv_0.bias'] = ja

In [397]:
prefix = 'param'
delim = '.'
for layer_name, param_obj in model_params.items():
    for param_name, value in param_obj.items():
        key = delim.join(filter(None, [prefix, layer_name, param_name]))
        if key in params_dict:
            model_params[layer_name][param_name] = params_dict[key]

if model.opt_state is not None:
    prefix = 'opt'
    for layer_name, param_obj in model_opt_state.items():
        for param_name, value in param_obj.items():
            key = delim.join(filter(None, [prefix, layer_name, param_name]))
            if key in params_dict:
                model_opt_state[layer_name][param_name] = params_dict[key]

# new_model.params = jax.tree_util.tree_map(jnp.array, model_params)
new_model.opt_state[0][0] = freeze(jax.tree_util.tree_map(jnp.array, model_opt_state))

TypeError: 'TraceState' object does not support item assignment

In [None]:
optstate1 = jax.tree_util.tree_map(np.array, state.opt_state)

In [None]:
jax.tree_util.tree_map(jnp.array, optstate1)

In [None]:
for i in state.params:
    print(i)

In [None]:
model_weight_names = [weight for weight in state.params]

In [None]:
import copy

In [None]:
tensor_dict = copy.deepcopy(state.params)

In [None]:
model_weight_names

In [None]:
model_weights_dict = { name: tensor_dict[name] for name in model_weight_names }

In [None]:
for name in model_weight_names:
    print(name)

In [None]:
state1 = jax.tree_map(np.array, state)

In [None]:
# Find all mistakes in testset.
logits = CNN().apply({'params': state.params}, test_ds['image'])


In [None]:
error_idxs, = jnp.where(test_ds['label'] != logits.argmax(axis=1))
len(error_idxs) / len(logits)

In [None]:
show_img_grid(
    [test_ds['image'][idx] for idx in error_idxs[:25]],
    [f'pred={logits[idx].argmax()}' for idx in error_idxs[:25]],
)

In [None]:
# class LinearRegression:
#     def __init__(self, n_feat: int) -> None:
#         self.weights = jnp.ones(n_feat)
    
#     def mse(self, X, y) -> float:
#         return mse_loss_function1(self.weights, X, y)
 
#     def predict(self, X):
#         return jnp.dot(X, self.weights)
    
#     def fit(self, X, Y, n_epochs : int, learning_rate : int, silent : bool) -> None:
        
#         # Speed up weight updates with consecutive calls to jitted `update` function.
#         update_weights = jax.jit(update)
        
#         start_time = time.time()
#         print('Training Loss at start: ', self.mse(X, Y))
#         for i in range(n_epochs):
#             self.weights = update_weights(self.weights, X, Y, learning_rate)
#             if i % int(n_epochs/10) == 0 and not silent:
#                 print(str(i), 'Training Loss: ', self.mse(X, Y))

#         print("--- %s seconds ---" % (time.time() - start_time))

    

In [None]:
_, init_params = CNN().init_by_shape(random.PRNGKey(0), [((1, 32, 32, 3), jnp.float32)])
model = nn.Model(CNN, init_params)
optimizer = optim.Adam(learning_rate=learning_rate, beta1=beta, beta2 = beta_2).create(model)

#Jit Pre-compilation

In [None]:
train_step(optimizer, train_ds)

In [None]:
eval_model(optimizer.target, test_ds)

In [None]:
def train(train_ds, test_ds, model, optimizer):

    batch_size = 128
    num_epochs = 10
    learning_rate = 0.001
    beta = 0.9
    beta_2 = 0.999
    loss = 0
    accuracy = 0
    
    start_time = time.monotonic()
    
    for epoch in range(1, num_epochs + 1):
        train_time = 0
        start_time_3 = time.monotonic()
        batch_gen = tfds.as_numpy(train_ds)
        for batch in batch_gen:
            start_time_step = time.monotonic()
            optimizer = train_step(optimizer, batch)
            train_time += time.monotonic() - start_time_step
            
        flax_step = time.monotonic() - start_time_3
        
        start_time_2 = time.monotonic()
        loss, accuracy = eval_model(optimizer.target, test_ds)
        flax_inf = time.monotonic() - start_time_2
        
        print('eval epoch: %d, epoch: %.2fs, actual_training: %.2fs, validation: %.2fs, loss: %.4f, accuracy: %.2f' % 
              (epoch, flax_step, train_time, flax_inf, loss, accuracy * 100))
        
    flax_time = time.monotonic() - start_time
    return optimizer, flax_time, accuracy, flax_inf

In [None]:
_, flax_time, flax_acc, flax_inf = train(train_ds, test_ds, model, optimizer)

# JAX Linear Regression with federation

## Connect to a Federation

In [None]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
client_id = 'frontend'
director_node_fqdn = 'localhost'
director_port = 50050

federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port,
    tls=False
)

In [None]:
shard_registry = federation.get_shard_registry()
shard_registry

### Initialize Data Interface

In [None]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

class LinearRegressionDataSet(DataInterface):
    def __init__(self, **kwargs):
        """Initialize DataLoader."""
        self.kwargs = kwargs

    @property
    def shard_descriptor(self):
        """Return shard descriptor."""
        return self._shard_descriptor
    
    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.
        
        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        self.train_set = shard_descriptor.get_dataset('train')
        self.val_set = shard_descriptor.get_dataset('val')
        
    def get_train_loader(self, **kwargs):
        """Output of this method will be provided to tasks with optimizer in contract."""
        return self.train_set

    def get_valid_loader(self, **kwargs):
        """Output of this method will be provided to tasks without optimizer in contract."""
        return self.val_set

    def get_train_data_size(self):
        """Information for aggregation."""
        return len(self.train_set)

    def get_valid_data_size(self):
        """Information for aggregation."""
        return len(self.val_set)
    
lin_reg_dataset = LinearRegressionDataSet()

### Define Model Interface

In [None]:
framework_adapter = 'custom_adapter.CustomFrameworkAdapter'

# LinearRegression class accepts a parameter n_features. Should be same as `sample_shape` from `director_config.yaml`
fed_model = LinearRegression(1)
MI = ModelInterface(model=fed_model, optimizer=None, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = LinearRegression(1)

### Register Tasks
We need to employ a trick reporting metrics. OpenFL decides which model is the best based on an *increasing* metric.

In [None]:
TI = TaskInterface()

@TI.add_kwargs(**{'lr': 0.01,
                   'epochs': 101})
@TI.register_fl_task(model='my_model', data_loader='train_data', \
                     device='device', optimizer='optimizer')     
def train(my_model, train_data, optimizer, device, lr, epochs):
    X, Y = train_data[:,:-1], train_data[:,-1]
    my_model.fit(X, Y, epochs, lr, silent=False)
    return {'train_MSE': my_model.mse(X, Y),}

@TI.register_fl_task(model='my_model', data_loader='val_data', device='device')
def validate(my_model, val_data, device):
    X, Y = val_data[:,:-1], val_data[:,-1] 
    return {'validation_MSE': my_model.mse(X, Y),}

### Run the federation

In [None]:
experiment_name = 'jax_linear_regression_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
fl_experiment.start(model_provider=MI,
                    task_keeper=TI,
                    data_loader=lin_reg_dataset,
                    rounds_to_train=2)

In [None]:
fl_experiment.stream_metrics()

# JAX Linear Regression without federation (Optional Simulation)

In [None]:
!pip install matplotlib scikit-learn -q

In [None]:
# Imports for running JAX Linear Regression example without OpenFL.

import matplotlib.pyplot as plt

%matplotlib inline
from matplotlib.pylab import rcParams
rcParams['figure.figsize'] = 7, 5

from jax import make_jaxpr
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

#### Simple Linear Regression
<img src="https://www.analyticsvidhya.com/wp-content/uploads/2016/01/eq5-1.png" width="500">



In [None]:
# create a dataset with n_features
X, y = make_regression(n_samples=100, n_features=1, noise=10, random_state=42)

# Train test split - Default 0.75/0.25
X, X_test, y, y_test = train_test_split(X, y, random_state=42)

Visualize data distribution

In [None]:
_ = plt.scatter(X, y)

In [None]:
_ = plt.scatter(X_test, y_test)

In [None]:
# JAX logical execution plan
print(jax.make_jaxpr(update)(jnp.ones(X.shape[1]), X, y, 0.01))

In [None]:
# X.shape -> (n_samples, n_features)

lr_model = LinearRegression(X.shape[1])
lr = 0.01
epochs = 101

print(f"Initial Test MSE: {lr_model.mse(X_test,y_test)}")

# silent: logging verbosity
lr_model.fit(X,y, epochs, lr, silent=False)

print(f"Final Test MSE: {lr_model.mse(X_test,y_test)}")

print(f"Final parameters: {lr_model.weights}")