# 🔄 Serket + Keras 

In this example, a simple `serket` model is converted and trained in the new `keras` with the `jax` backend.

In [1]:
!pip install git+https://github.com/ASEM000/serket --quiet
!pip install keras --quiet

## Imports

In [1]:
import os

os.environ["KERAS_BACKEND"] = "jax"
import jax.numpy as jnp
import jax.tree_util as jtu
import jax.random as jr
import keras
import serket as sk
import jax

## `serket` -> `keras` conversion

In [2]:
# define a keras layer that wraps the serket layer
def is_trainable(leaf):
    if isinstance(leaf, jax.Array) and jnp.issubdtype(leaf.dtype, jnp.inexact):
        return True
    return False


def serket_to_keras(layer: sk.TreeClass) -> keras.Layer:
    leaves, treedef = jax.tree_util.tree_flatten(sk.tree_mask(layer))

    class SerketToKeras(keras.Layer):
        def __init__(self, layer, name=None):
            """Converts a serket layer to a keras layer"""
            super().__init__(name=name)
            # extract the leaves from the serket layer
            # here leaves of a masked layer are the trainable variables
            # and treedef is the tree structure of the layer
            for leaf in leaves:
                variable = keras.Variable(
                    initializer=leaf, trainable=is_trainable(leaf)
                )
                self._track_variable(variable)
            # mark the layer as built
            self.built = True

        def call(self, x):
            """Applies the layer to the input"""
            # convert the keras variables to jax arrays to be used in serket
            leaves = jtu.tree_map(jnp.array, self.trainable_variables)
            # unflatten the layer with the updated leaves
            layer = jtu.tree_unflatten(treedef, leaves)
            # apply the layer after unmasking it
            return sk.tree_unmask(layer)(x)

        @property
        def model(self):
            leaves = jax.tree_map(jnp.array, self.trainable_variables)
            layer = jax.tree_util.tree_unflatten(treedef, leaves)
            return sk.tree_unmask(layer)

    return SerketToKeras(layer)

### Define a simple `serket` layer

In [3]:
# lets define a simple model in serket
class Linear(sk.TreeClass):
    def __init__(self, in_features, out_features, *, key):
        k1, k2 = jr.split(key)
        self.in_features = in_features
        self.out_features = out_features
        self.weight = jr.normal(k1, (in_features, out_features))
        self.bias = jr.normal(k2, (out_features,))

    def __call__(self, x):
        return x @ self.weight + self.bias

### Train in `keras`

In [4]:
sk_model = sk.Sequential(
    Linear(1, 20, key=jr.key(0)),
    jax.nn.tanh,
    Linear(20, 20, key=jr.key(1)),
    jax.nn.tanh,
)

# use serket with keras model
model = keras.Sequential([serket_to_keras(sk_model), keras.layers.Dense(1)])


model.compile(
    optimizer=keras.optimizers.Adam(1e-2),
    loss=keras.losses.MeanSquaredError(),
)

x = jnp.linspace(-1, 1, 100)[:, None]
y = x**2 + jr.normal(jr.key(0), (100, 1)) * 0.01
model.fit(x, y, epochs=100)

Epoch 1/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 72ms/step - loss: 1.3688
Epoch 2/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 482us/step - loss: 0.0929
Epoch 3/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 395us/step - loss: 0.2363
Epoch 4/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 389us/step - loss: 0.1136
Epoch 5/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 587us/step - loss: 0.0141
Epoch 6/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 378us/step - loss: 0.0579
Epoch 7/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 468us/step - loss: 0.0498
Epoch 8/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 380us/step - loss: 0.0191
Epoch 9/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 505us/step - loss: 0.0077
Epoch 10/100
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 356us/step - loss: 0.0138


<keras.src.callbacks.history.History at 0x15db37ad0>

### Extract trained layer

In [5]:
# take the serket layer from the keras layer at the end
model.layers[0].model

Sequential(
  layers=(
    Linear(
      in_features=1, 
      out_features=20, 
      weight=f32[1,20](μ=-0.19, σ=0.92, ∈[-1.65,1.91]), 
      bias=f32[20](μ=-0.03, σ=0.95, ∈[-2.12,1.93])
    ), 
    jit(tanh(x)), 
    Linear(
      in_features=20, 
      out_features=20, 
      weight=f32[20,20](μ=-0.04, σ=0.96, ∈[-2.73,2.65]), 
      bias=f32[20](μ=0.37, σ=0.82, ∈[-1.04,1.99])
    ), 
    jit(tanh(x))
  )
)