#  👀 Layers overview

In [1]:
!pip install git+https://github.com/ASEM000/serket --quiet

## `serket` general design features

### Handling weight initalization



Layers that contain `weight_init` or `bias_init` can accept:

- A string: 
    - `he_normal`
    - `he_uniform`
    - `glorot_normal`
    - `glorot_uniform`
    - `lecun_normal`
    - `lecun_uniform`
    - `normal`
    - `uniform`
    - `ones`
    - `zeros`
    - `xavier_normal`
    - `xavier_uniform`
    - `orthogonal`
- A function with the following signature `key:jax.Array, shape:tuple[int,...], dtype`.
- `None` to indicate no initialization (e.g no bias for layers that have `bias_init` argument).
- A registered string by `sk.def_init_entry("my_init", ....)` to map to custom init function.

In [2]:
import serket as sk
import jax
import math
import jax.random as jr

# 1) linear layer with no bias
linear = sk.nn.Linear(1, 10, weight_init="he_normal", bias_init=None, key=jr.PRNGKey(0))


# linear layer with custom initialization function
def init_func(key, shape, dtype=jax.numpy.float32):
    return jax.numpy.arange(math.prod(shape), dtype=dtype).reshape(shape)


linear = sk.nn.Linear(1, 10, weight_init=init_func, bias_init=None, key=jr.PRNGKey(0))
print(linear.weight)
# [[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]]

# linear layer with custom initialization function registered under a key
sk.def_init_entry("my_init", init_func)
linear = sk.nn.Linear(1, 10, weight_init="my_init", bias_init=None, key=jr.PRNGKey(0))
print(linear.weight)

[[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]]
[[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]]


### Handling activation functions

Layers that contain `act_func` accepts:

- A string: 
    - `adaptive_leaky_relu`
    - `adaptive_relu`
    - `adaptive_sigmoid`
    - `adaptive_tanh`
    - `celu`
    - `elu`
    - `gelu`
    - `glu`
    - `hard_shrink`
    - `hard_sigmoid`
    - `hard_swish`
    - `hard_tanh`
    - `leaky_relu`
    - `log_sigmoid`
    - `log_softmax`
    - `mish`
    - `prelu`
    - `relu`
    - `relu6`
    - `selu`
    - `sigmoid`
    - `snake`
    - `softplus`
    - `softshrink`
    - `softsign`
    - `squareplus`
    - `swish`
    - `tanh`
    - `tanh_shrink`
    - `thresholded_relu`
- A function of single input and output of `jax.Array`.
- A registered string by `sk.def_act_entry("my_act", ....)` to map to custom activation class with a `__call__` method.

In [3]:
import serket as sk
import jax
import jax.random as jr

# 1) activation function with a string
linear = sk.nn.FNN([1, 1], act="relu", key=jr.PRNGKey(0))

# 2) activation function with a function
linear = sk.nn.FNN([1, 1], act=jax.nn.relu, key=jr.PRNGKey(0))


@sk.autoinit
class MyTrainableActivation(sk.TreeClass):
    my_param: float = 10.0

    def __call__(self, x):
        return x * self.my_param


# 3) activation function with a class
linear = sk.nn.FNN([1, 1], act=MyTrainableActivation(), key=jr.PRNGKey(0))

# 4) activation function with a registered class
sk.def_act_entry("my_act", MyTrainableActivation())
linear = sk.nn.FNN([1, 1], act="my_act", key=jr.PRNGKey(0))

### Handling dtype

Layers that contain `dtype`, accept any valid `numpy.dtype` variant. this is useful if mixed precision policy is desired. For more, see the example on mixed precision training.


In [4]:
import serket as sk
import jax
import jax.random as jr

linear = sk.nn.Linear(10, 5, dtype=jax.numpy.float16, key=jr.PRNGKey(0))
linear
# note the dtype is f16(float16) in the repr output

Linear(
  in_features=(10), 
  out_features=5, 
  weight_init=glorot_uniform, 
  bias_init=zeros, 
  weight=f16[10,5](μ=0.07, σ=0.35, ∈[-0.63,0.60]), 
  bias=f16[5](μ=0.00, σ=0.00, ∈[0.00,0.00])
)

### Lazy shape inference

Lazy initialization is useful in scenarios where the dimensions of certain input features are not known in advance. For instance, when the number of neurons required for a flattened image input is uncertain, or the shape of the output from a flattened convolutional layer is not straightforward to calculate. In such cases, lazy initialization defers layers materialization until the first input.

In `serket`, simply replace `in_features` with `None` to indicate that this layer is lazy. then materialzie the layer by functionally calling the layer. Recall that functional call - via `.at[method_name](*args, **kwargs)` _always_ returns a tuple of method output and a _new_ instance.

**Marking the layer lazy**

In [1]:
import jax
import serket as sk
import jax.random as jr

# 5 images from MNIST
x = jax.numpy.ones([5, 1, 28, 28])

layer = sk.Sequential(
    jax.numpy.ravel,
    # lazy in_features inference pass `None`
    sk.nn.Linear(None, 10, key=jr.PRNGKey(0)),
    jax.nn.relu,
    sk.nn.Linear(10, 10, key=jr.PRNGKey(1)),
    jax.nn.softmax,
)

**Materialization by functional call**

In [2]:
# materialize the layer with single image
_, layer = sk.value_and_tree(lambda layer: layer(x[0]))(layer)
# apply on batch
y = jax.vmap(layer)(x)
y.shape

(5, 10)

## FFT variant

`serekt` offers `FFT` variant for most of the convolution layers. 
The ``fft`` convolution variant is useful in myriad of cases, specifically the ``fft`` variant could be faster for larger kernel sizes. the following figure compares the speed of both implementation.
    

<img src="../_static/fft_bench.svg" width="600" align="center">

The benchmark use ``FFTConv2D`` against ``Conv2D`` with ``in_features=3``, ``out_features=64``, and ``input_size=(10, 3, 128, 128)``