#  👀 Layers overview

In [1]:
!pip install git+https://github.com/ASEM000/serket --quiet

## `serket` general design features

### Lazy shape inference

**Marking the layer lazy**

In [1]:
import jax
import serket as sk
import jax.random as jr

# 5 images from MNIST
x = jax.numpy.ones([5, 1, 28, 28])

layer = sk.Sequential(
    jax.numpy.ravel,
    # lazy in_features inference pass `None`
    sk.nn.Linear(None, 10, key=jr.PRNGKey(0)),
    jax.nn.relu,
    sk.nn.Linear(10, 10, key=jr.PRNGKey(1)),
    jax.nn.softmax,
)

**Materialization by functional call**

In [2]:
# materialize the layer with single image
_, layer = sk.value_and_tree(lambda layer: layer(x[0]))(layer)
# apply on batch
y = jax.vmap(layer)(x)
y.shape

(5, 10)

## FFT variant

`serekt` offers `FFT` variant for most of the convolution layers. 
The ``fft`` convolution variant is useful in myriad of cases, specifically the ``fft`` variant could be faster for larger kernel sizes. the following figure compares the speed of both implementation.
    

<img src="../_static/fft_bench.svg" width="600" align="center">

The benchmark use ``FFTConv2D`` against ``Conv2D`` with ``in_features=3``, ``out_features=64``, and ``input_size=(10, 3, 128, 128)``