# Flax

Flax is a neural network library made with Jax with high speed computation in mind. This notebook explores some of the examples from flax as well as 
transformer implementations, specifiically in the field of NLP and Computer Vision

## Installationpip install --upgrade pip # To support manylinux2010 wheels.
pip install --upgrade jax jaxlib # CPU-only

In [2]:
%%bash
pip install --upgrade pip # To support manylinux2010 wheels.
pip install --upgrade jax jaxlib # CPU-only
pip install flax[all]

Collecting pip
  Using cached pip-23.1.2-py3-none-any.whl (2.1 MB)
Installing collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.0.1
    Uninstalling pip-23.0.1:
      Successfully uninstalled pip-23.0.1
Successfully installed pip-23.1.2
Collecting jax
  Using cached jax-0.4.10-py3-none-any.whl
Collecting jaxlib
  Using cached jaxlib-0.4.10-cp311-cp311-manylinux2014_x86_64.whl (69.9 MB)
Installing collected packages: jaxlib, jax
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.3.25
    Uninstalling jaxlib-0.3.25:
      Successfully uninstalled jaxlib-0.3.25
Successfully installed jax-0.4.10 jaxlib-0.4.10






In [5]:
## Sample MLP
from typing import Sequence

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

# model = MLP([12, 8, 4])
# batch = jnp.ones((32, 10))
# variables = model.init(jax.random.PRNGKey(0), batch)
# output = model.apply(variables, batch)

# Sample CNN
class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

# model = CNN()
# batch = jnp.ones((32, 64, 64, 10))  # (N, H, W, C) format
# variables = model.init(jax.random.PRNGKey(0), batch)
# output = model.apply(variables, batch)


# Sampel AutoEncoder
class AutoEncoder(nn.Module):
  encoder_widths: Sequence[int]
  decoder_widths: Sequence[int]
  input_shape: Sequence[int]

  def setup(self):
    input_dim = np.prod(self.input_shape)
    self.encoder = MLP(self.encoder_widths)
    self.decoder = MLP(self.decoder_widths + (input_dim,))

  def __call__(self, x):
    return self.decode(self.encode(x))

  def encode(self, x):
    assert x.shape[1:] == self.input_shape
    return self.encoder(jnp.reshape(x, (x.shape[0], -1)))

  def decode(self, z):
    z = self.decoder(z)
    x = nn.sigmoid(z)
    x = jnp.reshape(x, (x.shape[0],) + self.input_shape)
    return x

model = AutoEncoder(encoder_widths=[20, 10, 5],
                    decoder_widths=[5, 10, 20],
                    input_shape=(12,))
batch = jnp.ones((16, 12))
variables = model.init(jax.random.PRNGKey(0), batch)
encoded = model.apply(variables, batch, method=model.encode)
decoded = model.apply(variables, encoded, method=model.decode)
decoded

Array([[0.5158798 , 0.48319402, 0.5002328 , 0.51826   , 0.47074836,
        0.53025323, 0.5709436 , 0.4605159 , 0.4656811 , 0.46918797,
        0.5084141 , 0.5405306 ],
       [0.5158798 , 0.48319402, 0.5002328 , 0.51826   , 0.47074836,
        0.53025323, 0.5709436 , 0.4605159 , 0.4656811 , 0.46918797,
        0.5084141 , 0.5405306 ],
       [0.5158798 , 0.48319402, 0.5002328 , 0.51826   , 0.47074836,
        0.53025323, 0.5709436 , 0.4605159 , 0.4656811 , 0.46918797,
        0.5084141 , 0.5405306 ],
       [0.5158798 , 0.48319402, 0.5002328 , 0.51826   , 0.47074836,
        0.53025323, 0.5709436 , 0.4605159 , 0.4656811 , 0.46918797,
        0.5084141 , 0.5405306 ],
       [0.5158798 , 0.48319402, 0.5002328 , 0.51826   , 0.47074836,
        0.53025323, 0.5709436 , 0.4605159 , 0.4656811 , 0.46918797,
        0.5084141 , 0.5405306 ],
       [0.5158798 , 0.48319402, 0.5002328 , 0.51826   , 0.47074836,
        0.53025323, 0.5709436 , 0.4605159 , 0.4656811 , 0.46918797,
        0.5084141 ,