In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
train_ds = tfds.load('mnist', split='train')
test_ds = tfds.load('mnist', split='test')

def data_normalize(ds):
    return ds.map(lambda sample: {
        'image': tf.cast(sample['image'], tf.float32) / 255.,
        'label': sample['label']
    })

train_ds = data_normalize(train_ds).shuffle(buffer_size=10, seed=42).batch(100).prefetch(1).take(1000)
test_ds = data_normalize(test_ds).shuffle(buffer_size=10, seed=42).batch(100).prefetch(1).take(1000)

total_batch = train_ds.cardinality().numpy()
total_tbatch = test_ds.cardinality().numpy()




In [3]:
import jax
import jax.numpy as jnp


class BatchNormLayer:

    def __init__(self, dims: int) -> None:
        self.gamma = jnp.ones((1, dims), dtype="float32")
        self.bias = jnp.zeros((1, dims), dtype="float32")

        self.running_mean_x = jnp.zeros(0)
        self.running_var_x = jnp.zeros(0)

        # forward params
        self.var_x = jnp.zeros(0)
        self.stddev_x = jnp.zeros(0)
        self.x_minus_mean = jnp.zeros(0)
        self.standard_x = jnp.zeros(0)
        self.num_examples = 0
        self.mean_x = jnp.zeros(0)
        self.running_avg_gamma = 0.9
        self.epsilon = 1e-6

        # backward params
        self.gamma_grad = jnp.zeros(0)
        self.bias_grad = jnp.zeros(0)

    def update_running_variables(self) -> None:
        is_mean_empty = jnp.array_equal(jnp.zeros(0), self.running_mean_x)
        is_var_empty = jnp.array_equal(jnp.zeros(0), self.running_var_x)
        if is_mean_empty != is_var_empty:
            raise ValueError("Mean and Var running averages should be "
                             "initilizaded at the same time")
        if is_mean_empty:
            self.running_mean_x = self.mean_x
            self.running_var_x = self.var_x
        else:
            gamma = self.running_avg_gamma
            self.running_mean_x = gamma * self.running_mean_x + \
                                  (1.0 - gamma) * self.mean_x
            self.running_var_x = gamma * self.running_var_x + \
                                 (1. - gamma) * self.var_x

    def forward(self, x: jnp.ndarray, train: bool = True) -> jnp.ndarray:
        self.num_examples = x.shape[0]
        if train:
            self.mean_x = jnp.mean(x, axis=0, keepdims=True)
            self.var_x = jnp.mean((x - self.mean_x) ** 2, axis=0, keepdims=True)
            self.update_running_variables()
        else:
            self.mean_x = self.running_mean_x.copy()
            self.var_x = self.running_var_x.copy()

        self.var_x += self.epsilon
        self.stddev_x = jnp.sqrt(self.var_x)
        self.x_minus_mean = x - self.mean_x
        self.standard_x = self.x_minus_mean / self.stddev_x
        return self.gamma * self.standard_x + self.bias

    def backward(self, grad_input: jnp.ndarray) -> jnp.ndarray:
        standard_grad = grad_input * self.gamma

        var_grad = jnp.sum(standard_grad * self.x_minus_mean * -0.5 * self.var_x ** (-3/2),
                          axis=0, keepdims=True)
        stddev_inv = 1 / self.stddev_x
        aux_x_minus_mean = 2 * self.x_minus_mean / self.num_examples

        mean_grad = (jnp.sum(standard_grad * -stddev_inv, axis=0,
                            keepdims=True) +
                            var_grad * jnp.sum(-aux_x_minus_mean, axis=0,
                            keepdims=True))

        self.gamma_grad = jnp.sum(grad_input * self.standard_x, axis=0,
                                 keepdims=True)
        self.bias_grad = jnp.sum(grad_input, axis=0, keepdims=True)

        return standard_grad * stddev_inv + var_grad * aux_x_minus_mean + \
               mean_grad / self.num_examples

    def apply_gradients(self, learning_rate: float) -> None:
        self.gamma -= learning_rate * self.gamma_grad
        self.bias -= learning_rate * self.bias_grad

In [10]:
bn = BatchNormLayer(dims=4)

for batch in train_ds.as_numpy_iterator():
    x = batch['image']
    y = batch['label']
    break

o = bn.forward(x)
g = bn.backward(jnp.ones_like(o))
o.shape

(100, 28, 28, 4)

In [30]:
import optax
import flax
import flax.linen as nn

class CNN(nn.Module):
    @nn.compact
    def __call__(self, x, train=True):
        x = nn.Conv(16, (3, 3), strides=2, padding='SAME')(x)
        return x

def create_state(lr):
    convnet = CNN()
    params = convnet.init(jax.random.key(42), x)['params']
    tx = optax.sgd(lr)
    return flax.training.train_state.TrainState.create(apply_fn=convnet.apply, params=params, tx=tx)

vmapped = nn.vmap(create_state, in_axes=(None, 0), out_axes=0, variable_axes={'params': None})
vmapped


<function __main__.create_state(lr)>