# 🔄 Serket + Keras 

In this example, a simple `serket` model is converted and trained in the new `keras` with the `jax` backend.

In [1]:
!pip install git+https://github.com/ASEM000/serket --quiet
!pip install keras --quiet

### Imports

In [1]:
import os

os.environ["KERAS_BACKEND"] = "jax"
import jax.numpy as jnp
import jax.tree_util as jtu
import jax.random as jr
import keras
import serket as sk
import jax
import matplotlib.pyplot as plt

### Conversion layer

In [2]:
# define a keras layer that wraps the serket layer
def is_trainable(leaf):
    if isinstance(leaf, jax.Array) and jnp.issubdtype(leaf.dtype, jnp.inexact):
        return True
    return False


class Serket2Keras(keras.Layer):
    def __init__(self, layer, name=None):
        """Converts a serket layer to a keras layer"""
        super().__init__(name=name)
        # extract the leaves from the serket layer
        # here leaves of a masked layer are the trainable variables
        # and treedef is the tree structure of the layer
        leaves, treedef = jax.tree_util.tree_flatten(sk.tree_mask(layer))
        self.treedef = treedef
        for leaf in leaves:
            variable = keras.Variable(initializer=leaf, trainable=is_trainable(leaf))
            self._track_variable(variable)
        # mark the layer as built
        self.built = True

    def call(self, x):
        """Applies the layer to the input"""
        # convert the keras variables to jax arrays to be used in serket
        leaves = jtu.tree_map(jnp.array, self.trainable_variables)
        # unflatten the layer with the updated leaves
        layer = jtu.tree_unflatten(self.treedef, leaves)
        # apply the layer after unmasking it
        return sk.tree_unmask(layer)(x)

    @property
    def model(self):
        leaves = jax.tree_map(jnp.array, self.trainable_variables)
        layer = jax.tree_util.tree_unflatten(self.treedef, leaves)
        return sk.tree_unmask(layer)

### Define a simple `serket` layer

In [3]:
# lets define a simple model in serket
class Linear(sk.TreeClass):
    def __init__(self, in_features, out_features, *, key):
        k1, k2 = jr.split(key)
        self.in_features = in_features
        self.out_features = out_features
        self.weight = jr.normal(k1, (in_features, out_features))
        self.bias = jr.normal(k2, (out_features,))

    def __call__(self, x):
        return x @ self.weight + self.bias

### Train in `keras`

In [4]:
sk_model = sk.Sequential(
    Linear(1, 20, key=jr.PRNGKey(0)),
    jax.nn.tanh,
    Linear(20, 15, key=jr.PRNGKey(1)),
)

# use serket with keras model
model = keras.Sequential([Serket2Keras(sk_model, name="serket"), keras.layers.Dense(1)])


model.compile(
    optimizer=keras.optimizers.Adam(1e-2),
    loss=keras.losses.MeanSquaredError(),
)

x = jnp.linspace(-1, 1, 100)[:, None]
y = x**2 + jr.normal(jr.PRNGKey(0), (100, 1)) * 0.01
model.fit(x, y, epochs=100)

Epoch 1/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 51ms/step - loss: 11.4385
Epoch 2/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 451us/step - loss: 2.3039
Epoch 3/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 416us/step - loss: 2.5281
Epoch 4/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 367us/step - loss: 0.8995
Epoch 5/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 473us/step - loss: 1.5571
Epoch 6/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 486us/step - loss: 0.7577
Epoch 7/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 398us/step - loss: 0.2899
Epoch 8/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 344us/step - loss: 0.4642
Epoch 9/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 341us/step - loss: 0.2811
Epoch 10/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 391us/step - loss: 0.2261

<keras.src.callbacks.history.History at 0x13d5f74a0>

### Extract trained layer

In [5]:
# take the serket layer from the keras layer at the end
model.layers[0].model

Sequential(
  layers=(
    Linear(
      in_features=1, 
      out_features=20, 
      weight=f32[1,20](μ=-0.25, σ=0.94, ∈[-1.76,1.72]), 
      bias=f32[20](μ=-0.05, σ=0.91, ∈[-2.14,1.75])
    ), 
    jit(tanh(x)), 
    Linear(
      in_features=20, 
      out_features=15, 
      weight=f32[20,15](μ=0.01, σ=0.94, ∈[-2.31,2.37]), 
      bias=f32[15](μ=0.09, σ=1.16, ∈[-1.60,2.25])
    )
  )
)