# 📙 `serket.nn` layers overview

## `serket` general design features

**Handling weight initalization**

Layers that contain `weight_init` or `bias_init` can accept:
- A string: 
    - `he_normal`
    - `he_uniform`
    - `glorot_normal`
    - `glorot_uniform`
    - `lecun_normal`
    - `lecun_uniform`
    - `normal`
    - `uniform`
    - `ones`
    - `zeros`
    - `xavier_normal`
    - `xavier_uniform`
    - `orthogonal`


- A function with the following signature `key:jax.random.KeyArray, shape:tuple[int,...], dtype`.
- `None` to indicate no initialization (e.g no bias for layers that have `bias_init` argument).
- A registered string by `sk.def_init_entry("my_init", ....)` to map to custom init function.

In [1]:
import serket as sk
import jax
import math

# 1) linear layer with no bias
linear = sk.nn.Linear(1, 10, weight_init="he_normal", bias_init=None)


# linear layer with custom initialization function
def init_func(key, shape, dtype=jax.numpy.float32):
    return jax.numpy.arange(math.prod(shape), dtype=dtype).reshape(shape)


linear = sk.nn.Linear(1, 10, weight_init=init_func, bias_init=None)
print(linear.weight)
# [[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]]

# linear layer with custom initialization function registered under a key
sk.def_init_entry("my_init", init_func)
linear = sk.nn.Linear(1, 10, weight_init="my_init", bias_init=None)
print(linear.weight)

[[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]]
[[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]]


**Handling activation functions**

Layers that contain `act_func` accepts:
- A string: 
    - `adaptive_leaky_relu`
    - `adaptive_relu`
    - `adaptive_sigmoid`
    - `adaptive_tanh`
    - `celu`
    - `elu`
    - `gelu`
    - `glu`
    - `hard_shrink`
    - `hard_sigmoid`
    - `hard_swish`
    - `hard_tanh`
    - `leaky_relu`
    - `log_sigmoid`
    - `log_softmax`
    - `mish`
    - `prelu`
    - `relu`
    - `relu6`
    - `selu`
    - `sigmoid`
    - `snake`
    - `softplus`
    - `softshrink`
    - `softsign`
    - `squareplus`
    - `swish`
    - `tanh`
    - `tanh_shrink`
    - `thresholded_relu`


- A function of single input and output of `jax.Array`.
- A registered string by `sk.def_act_entry("my_act", ....)` to map to custom activation class with a `__call__` method.

In [2]:
import serket as sk
import jax

# 1) activation function with a string
linear = sk.nn.FNN([1, 1], act_func="relu")

# 2) activation function with a function
linear = sk.nn.FNN([1, 1], act_func=jax.nn.relu)


@sk.autoinit
class MyTrainableActivation(sk.TreeClass):
    my_param: float = 10.0

    def __call__(self, x):
        return x * self.my_param


# 3) activation function with a class
linear = sk.nn.FNN([1, 1], act_func=MyTrainableActivation())

# 4) activation function with a registered class
sk.def_act_entry("my_act", MyTrainableActivation)
linear = sk.nn.FNN([1, 1], act_func="my_act")

**Handling dtype**

Layers that contain `dtype`, accept any valid `numpy.dtype` variant. this is useful if mixed precision policy is desired. see the example on mixed precision.


In [3]:
import serket as sk
import jax

linear = sk.nn.Linear(10, 5, dtype=jax.numpy.float16)
linear
# note the dtype is f16(float16) in the repr output

Linear(
  in_features=(10), 
  out_features=5, 
  weight_init=he_normal, 
  bias_init=ones, 
  weight=f16[10,5](μ=0.08, σ=0.43, ∈[-1.01,0.87]), 
  bias=f16[5](μ=1.00, σ=0.00, ∈[1.00,1.00])
)