Skip to content

ASEM000/serket

Repository files navigation

The ✨Magical✨ JAX ML Library.

*Serket is the goddess of magic in Egyptian mythology

Installation |Description |Documentation |Quick Example

Tests pyver codestyle codecov Documentation Status DOI CodeFactor

🛠️ Installation

Install development version

pip install git+https://github.com/ASEM000/serket

📖 Description and motivation

  • serket aims to be the most intuitive and easy-to-use machine learning library in jax.
  • serket is fully transparent to jax transformation (e.g. vmap,grad,jit,...).

📙 Documentation

🏃 Quick example

import jax, jax.numpy as jnp
import serket as sk

x_train, y_train = ..., ...
k1, k2 = jax.random.split(jax.random.PRNGKey(0))

net = sk.tree_mask(sk.Sequential(
    jnp.ravel,
    sk.nn.Linear(28 * 28, 64, key=k1),
    jax.nn.relu,
    sk.nn.Linear(64, 10, key=k2),
))

@ft.partial(jax.grad, has_aux=True)
def loss_func(net, x, y):
    logits = jax.vmap(sk.tree_unmask(net))(x)
    onehot = jax.nn.one_hot(y, 10)
    loss = jnp.mean(softmax_cross_entropy(logits, onehot))
    return loss, (loss, logits)

@jax.jit
def train_step(net, x, y):
    grads, (loss, logits) = loss_func(net, x, y)
    net = jax.tree_map(lambda p, g: p - g * 1e-3, net, grads)
    return net, (loss, logits)

for j, (xb, yb) in enumerate(zip(x_train, y_train)):
    net, (loss, logits) = train_step(net, xb, yb)
    accuracy = accuracy_func(logits, y_train)

net = sk.tree_unmask(net)
📚 Layers catalog

🔗 Common API

Group Layers
Containers - Sequential, Random{Choice}

🧠 Neural network package: serket.nn

Group Layers
Attention - MultiHeadAttention
Convolution - {FFT,_}Conv{1D,2D,3D}
- {FFT,_}Conv{1D,2D,3D}Transpose
- Depthwise{FFT,_}Conv{1D,2D,3D}
- Separable{FFT,_}Conv{1D,2D,3D}
- Conv{1D,2D,3D}Local
- SpectralConv{1D,2D,3D}
Dropout - Dropout
- Dropout{1D,2D,3D}
- RandomCutout{1D,2D,3D}
Linear - Linear, MLP, Identity
Normalization - {Layer,Instance,Group,Batch}Norm
Pooling - {Avg,Max,LP}Pool{1D,2D,3D}
- Global{Avg,Max}Pool{1D,2D,3D}
- Adaptive{Avg,Max}Pool{1D,2D,3D}
Reshaping - Upsample{1D,2D,3D}
- {Random,Center}Crop{1D,2D,3D} `
Recurrent cells - {SimpleRNN,LSTM,GRU,Dense}Cell
- {Conv,FFTConv}{LSTM,GRU}{1D,2D,3D}Cell
Activations - Adaptive{LeakyReLU,ReLU,Sigmoid,Tanh},
- CeLU,ELU,GELU,GLU
- Hard{SILU,Shrink,Sigmoid,Swish,Tanh},
- Soft{Plus,Sign,Shrink}
- LeakyReLU,LogSigmoid,LogSoftmax,Mish,PReLU,
- ReLU,ReLU6,SeLU,Sigmoid
- Swish,Tanh,TanhShrink, ThresholdedReLU, Snake

🖼️ Image package: serket.image

Group Layers
Filter - {FFT,_}{Avg,Box,Gaussian,Motion}Blur2D
- {JointBilateral,Bilateral,Median}Blur2D
- {FFT,_}{UnsharpMask}2D
- {FFT,_}{Sobel,Laplacian}2D
- {FFT,_}BlurPool2D
Augment - Adjust{Sigmoid,Log}2D
- {Adjust,Random}{Brightness,Contrast,Hue,Saturation}2D,
- RandomJigSaw2D,PixelShuffle2D,
- Pixelate2D,Posterize2D,Solarize2D
- FourierDomainAdapt2D
Geometric - {Random,_}{Horizontal,Vertical}{Translate,Flip,Shear}2D
- {Random,_}{Rotate}2D
- RandomPerspective2D
- {FFT,_}ElasticTransform2D
Color - RGBToGrayscale2D , GrayscaleToRGB2D
- RGBToHSV2D, HSVToRGB2D