# 🥱 Functional lazy initialization

Lazy initialization is particularly useful in scenarios where the dimensions of certain input features are not known in advance. For instance, consider a situation where the number of neurons required for a flattened image input is uncertain (**Example 1**), or the shape of the output from a flattened convolutional layer is not straightforward to calculate (**Example 2**).

In such cases, lazy initialization allows the model to defer the allocation of memory for these uncertain dimensions until they are explicitly computed during the training process. This flexibility ensures that the model can handle varying input sizes and adapt its architecture accordingly, making it more versatile and efficient when dealing with different data samples or changing conditions.

In [1]:
import jax
import serket as sk

# 5 images from MNIST
x = jax.numpy.ones([5, 1, 28, 28])

layer = sk.nn.Sequential(
    jax.numpy.ravel,
    # lazy in_features inference pass `None`
    sk.nn.Linear(None, 10),
    jax.nn.relu,
    sk.nn.Linear(10, 10),
    jax.nn.softmax,
)
# materialize the layer with single image
_, layer = layer.at["__call__"](x[0])
# apply on batch
y = jax.vmap(layer)(x)
y.shape

(5, 10)

In [2]:
import jax
import serket as sk

# 5 images from MNIST
x = jax.numpy.ones([5, 1, 28, 28])

layer = sk.nn.Sequential(
    sk.nn.Conv2D(1, 10, 3),
    jax.nn.relu,
    sk.nn.MaxPool2D(2),
    jax.numpy.ravel,
    # linear input size is inferred from
    # previous layer output
    sk.nn.Linear(None, 10),
    jax.nn.softmax,
)

# materialize the layer with single image
_, layer = layer.at["__call__"](x[0])

# apply on batch
y = jax.vmap(layer)(x)

y.shape

(5, 10)