# Running convolution

In [None]:
import jax
import jax.numpy as jnp
from jax.lax import conv, conv_general_dilated
import numpy as np


In [None]:
H, W = 6, 6
I, O = 1, 1
window_strides = (1, 1)

x = jnp.ones((10, H, W, I))       # NHWC
kernel = jnp.ones((3, 3, I, O)) * 0.1    # HWIO

x = jnp.transpose(x, [0, 3, 1, 2])  # NCHW
kernel = jnp.transpose(kernel, [3, 2, 0, 1])    # OIHW

output = conv(x, kernel, window_strides=window_strides, padding='VALID')  # NCHW
# print(output.shape)

output = jnp.transpose(output, [0, 2, 3, 1])  # NHWC
# print(output.shape)

# print(x[0].reshape((H, W)))
# print(kernel[0])
# print(output[0].reshape((H, W)))

# Gradient on conv

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
train_ds = tfds.load('mnist', split='train')
test_ds = tfds.load('mnist', split='test')

def data_normalize(ds):
    return ds.map(lambda sample: {
        'image': tf.cast(sample['image'], tf.float32) / 255.,
        'label': sample['label']
    })

train_ds = data_normalize(train_ds).shuffle(buffer_size=10, seed=42).batch(100).prefetch(1).take(1000)
test_ds = data_normalize(test_ds).shuffle(buffer_size=10, seed=42).batch(100).prefetch(1).take(1000)

total_batch = train_ds.cardinality().numpy()
total_tbatch = test_ds.cardinality().numpy()

In [81]:
import optax
import flax.linen as nn



def net(kernels, x):

    kernel1 = kernels[0]
    kernel2 = kernels[1]
    fc = kernels[2]

    strides = (1, 1)
    padding = 'SAME'

    x = jnp.transpose(x, [0, 3, 1, 2])
    kernel1 = jnp.transpose(kernel1, [3, 2, 0, 1])
    kernel2 = jnp.transpose(kernel2, [3, 2, 0, 1])

    x = conv(x, kernel1, window_strides=strides, padding=padding)
    x = jax.nn.relu(x)
    # (100, 16, 28, 28) = NCHW


    x = conv(x, kernel2, window_strides=strides, padding=padding)
    x = jax.nn.relu(x)
    x = jnp.transpose(x, [0, 2, 3, 1])

    x = nn.avg_pool(x, window_shape=(2, 2), padding='SAME', strides=(2, 2))

    x = x.reshape((x.shape[0], -1))
    x = jnp.dot(x, fc)
    return jax.nn.softmax(x)

def loss(kernels, x, y):
    x = net(kernels, x)
    l = optax.softmax_cross_entropy_with_integer_labels(jnp.clip(x, 1e-10, 1.), y)
    return l.mean()


lr = 0.001
num_epochs = 100

kernel1 = jax.nn.initializers.xavier_normal()(jax.random.key(42), (3, 3, 1, 16))
kernel2 = jax.nn.initializers.xavier_normal()(jax.random.key(43), (3, 3, 16, 32))
fc = jax.nn.initializers.xavier_normal()(jax.random.key(44), (32 * 14 * 14, 10))

kernels = (kernel1, kernel2, fc)


def update(kernels, bn, x, y):
    l = loss(kernels, x, y)



for i in range(num_epochs):
    print(f"\nEpoch {i} =======================================")
    for batch in train_ds.as_numpy_iterator():
        x = batch['image']
        y = batch['label']
        value, grad = jax.value_and_grad(loss)(kernels, x, y)

        kernel1 = kernel1 - lr * grad[0]
        kernel2 = kernel2 - lr * grad[1]
        fc = fc - lr * grad[2]
        kernels = (kernel1, kernel2, fc)

    print("\tkernel1: ", kernel1.min(), kernel1.max(), kernel1.mean())
    print("\tkernel2: ", kernel2.min(), kernel2.max(), kernel2.mean())
    print("\tfc: ", fc.min(), fc.max(), fc.mean())
    print("\tgrad1: ", grad[0].min(), grad[0].max(), grad[0].mean())
    print("\tgrad2: ", grad[1].min(), grad[1].max(), grad[1].mean())
    print("\tgrad3: ", grad[2].min(), grad[2].max(), grad[2].mean())
    print("\tloss: ", value)




	kernel1:  -0.2579745 0.24605475 -0.00069521513
	kernel2:  -0.1540713 0.15285209 0.0004105228
	fc:  -0.040572066 0.040740248 -4.7104633e-05
	grad1:  -0.0021712063 0.0014822792 -0.00013415546
	grad2:  -0.0022889 0.0011049211 -2.571689e-06
	grad3:  -0.0043050298 0.0030392187 -2.1796608e-13
	loss:  2.3015804



KeyboardInterrupt: 

In [None]:
import optax
import flax.linen as nn

def bn(x, gamma, beta, running_mean, running_var, train=True):
    if train:
        mean_x = jnp.mean(x, axis=0, keepdims=True)
        var_x = jnp.mean((x - mean_x) ** 2, axis=0, keepdims=True)
        


def net(kernels, x):

    kernel1 = kernels[0]
    gamma1 = kernels[1]
    beta1 = kernels[2]
    kernel2 = kernels[3]
    gamma2 = kernels[4]
    beta2 = kernels[5]
    fc = kernels[6]

    strides = (1, 1)
    padding = 'SAME'

    x = jnp.transpose(x, [0, 3, 1, 2])
    kernel1 = jnp.transpose(kernel1, [3, 2, 0, 1])
    kernel2 = jnp.transpose(kernel2, [3, 2, 0, 1])

    x = conv(x, kernel1, window_strides=strides, padding=padding)
    x = jax.nn.relu(x)
    x = conv(x, kernel2, window_strides=strides, padding=padding)
    x = jax.nn.relu(x)
    x = jnp.transpose(x, [0, 2, 3, 1])

    x = nn.avg_pool(x, window_shape=(2, 2), padding='SAME', strides=(2, 2))

    x = x.reshape((x.shape[0], -1))
    x = jnp.dot(x, fc)
    return jax.nn.softmax(x)

def loss(kernels, x, y):
    x = net(kernels, x)
    l = optax.softmax_cross_entropy_with_integer_labels(jnp.clip(x, 1e-10, 1.), y)
    return l.mean()


lr = 0.001
num_epochs = 100

kernel1 = jax.nn.initializers.xavier_normal()(jax.random.key(42), (3, 3, 1, 16))
gamma1 = jnp.ones((1, 16, 1, 1))
beta1 = jnp.zeros((1, 16, 1, 1))
kernel2 = jax.nn.initializers.xavier_normal()(jax.random.key(43), (3, 3, 16, 32))
gamma2 = jnp.ones((1, 32, 1, 1))
beta2 = jnp.zeros((1, 32, 1, 1))
fc = jax.nn.initializers.xavier_normal()(jax.random.key(44), (32 * 14 * 14, 10))

kernels = (kernel1, gamma1, beta1, kernel2, gamma2, beta2, fc)


def update(kernels, x, y):
    l = loss(kernels, x, y)



for i in range(num_epochs):
    print(f"\nEpoch {i} =======================================")
    for batch in train_ds.as_numpy_iterator():
        x = batch['image']
        y = batch['label']
        value, grad = jax.value_and_grad(loss)(kernels, x, y)

        kernel1 = kernel1 - lr * grad[0]
        kernel2 = kernel2 - lr * grad[1]
        fc = fc - lr * grad[2]
        kernels = (kernel1, kernel2, fc)

    print("\tkernel1: ", kernel1.min(), kernel1.max(), kernel1.mean())
    print("\tkernel2: ", kernel2.min(), kernel2.max(), kernel2.mean())
    print("\tfc: ", fc.min(), fc.max(), fc.mean())
    print("\tgrad1: ", grad[0].min(), grad[0].max(), grad[0].mean())
    print("\tgrad2: ", grad[1].min(), grad[1].max(), grad[1].mean())
    print("\tgrad3: ", grad[2].min(), grad[2].max(), grad[2].mean())
    print("\tloss: ", value)



In [77]:
def batch_norm(X, deterministic, gamma, beta, moving_mean, moving_var, eps,
               momentum):
    # Use `deterministic` to determine whether the current mode is training
    # mode or prediction mode
    if deterministic:
        # In prediction mode, use mean and variance obtained by moving average
        # `linen.Module.variables` have a `value` attribute containing the array
        X_hat = (X - moving_mean.value) / jnp.sqrt(moving_var.value + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # When using a fully connected layer, calculate the mean and
            # variance on the feature dimension
            mean = X.mean(axis=0)
            var = ((X - mean) ** 2).mean(axis=0)
        else:
            # When using a two-dimensional convolutional layer, calculate the
            # mean and variance on the channel dimension (axis=1). Here we
            # need to maintain the shape of `X`, so that the broadcasting
            # operation can be carried out later
            mean = X.mean(axis=(0, 2, 3), keepdims=True)
            var = ((X - mean) ** 2).mean(axis=(0, 2, 3), keepdims=True)
        # In training mode, the current mean and variance are used
        X_hat = (X - mean) / jnp.sqrt(var + eps)
        # Update the mean and variance using moving average
        moving_mean.value = momentum * moving_mean.value + (1.0 - momentum) * mean
        moving_var.value = momentum * moving_var.value + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # Scale and shift
    return Y

class BatchNorm(nn.Module):
    # `num_features`: the number of outputs for a fully connected layer
    # or the number of output channels for a convolutional layer.
    # `num_dims`: 2 for a fully connected layer and 4 for a convolutional layer
    # Use `deterministic` to determine whether the current mode is training
    # mode or prediction mode
    num_features: int
    num_dims: int
    deterministic: bool = False

    @nn.compact
    def __call__(self, X):
        if self.num_dims == 2:
            shape = (1, self.num_features)
        else:
            shape = (1, 1, 1, self.num_features)

        # The scale parameter and the shift parameter (model parameters) are
        # initialized to 1 and 0, respectively
        gamma = self.param('gamma', jax.nn.initializers.ones, shape)
        beta = self.param('beta', jax.nn.initializers.zeros, shape)

        # The variables that are not model parameters are initialized to 0 and
        # 1. Save them to the 'batch_stats' collection
        moving_mean = self.variable('batch_stats', 'moving_mean', jnp.zeros, shape)
        moving_var = self.variable('batch_stats', 'moving_var', jnp.ones, shape)
        Y = batch_norm(X, self.deterministic, gamma, beta,
                       moving_mean, moving_var, eps=1e-5, momentum=0.9)

        return Y

a = jnp.array([jnp.ones((28, 28, 1))*1, jnp.ones((28, 28, 1))*2, jnp.ones((28, 28, 1))*3, jnp.ones((28, 28, 1))*4, jnp.ones((28, 28, 1))*5])

b = BatchNorm(1, 4, True)
v = b.init(jax.random.key(42), a)
v

{'params': {'gamma': Array([[[[1.]]]], dtype=float32),
  'beta': Array([[[[0.]]]], dtype=float32)},
 'batch_stats': {'moving_mean': Array([[[[0.]]]], dtype=float32),
  'moving_var': Array([[[[1.]]]], dtype=float32)}}

In [None]:
class BatchNormLayer:

    def __init__(self, dims: int) -> None:
        self.gamma = np.ones((1, dims), dtype="float32")
        self.bias = np.zeros((1, dims), dtype="float32")

        self.running_mean_x = np.zeros(0)
        self.running_var_x = np.zeros(0)

        # forward params
        self.var_x = np.zeros(0)
        self.stddev_x = np.zeros(0)
        self.x_minus_mean = np.zeros(0)
        self.standard_x = np.zeros(0)
        self.num_examples = 0
        self.mean_x = np.zeros(0)
        self.running_avg_gamma = 0.9
        self.epsilon = 1e-6

        # backward params
        self.gamma_grad = np.zeros(0)
        self.bias_grad = np.zeros(0)

    def update_running_variables(self) -> None:
        is_mean_empty = np.array_equal(np.zeros(0), self.running_mean_x)
        is_var_empty = np.array_equal(np.zeros(0), self.running_var_x)
        if is_mean_empty != is_var_empty:
            raise ValueError("Mean and Var running averages should be "
                             "initilizaded at the same time")
        if is_mean_empty:
            self.running_mean_x = self.mean_x
            self.running_var_x = self.var_x
        else:
            gamma = self.running_avg_gamma
            self.running_mean_x = gamma * self.running_mean_x + \
                                  (1.0 - gamma) * self.mean_x
            self.running_var_x = gamma * self.running_var_x + \
                                 (1. - gamma) * self.var_x

    def forward(self, x: np.ndarray, train: bool = True) -> np.ndarray:
        self.num_examples = x.shape[0]
        if train:
            self.mean_x = np.mean(x, axis=0, keepdims=True)
            self.var_x = np.mean((x - self.mean_x) ** 2, axis=0, keepdims=True)
            self.update_running_variables()
        else:
            self.mean_x = self.running_mean_x.copy()
            self.var_x = self.running_var_x.copy()

        self.var_x += self.epsilon
        self.stddev_x = np.sqrt(self.var_x)
        self.x_minus_mean = x - self.mean_x
        self.standard_x = self.x_minus_mean / self.stddev_x
        return self.gamma * self.standard_x + self.bias

    def backward(self, grad_input: np.ndarray) -> np.ndarray:
        standard_grad = grad_input * self.gamma

        var_grad = np.sum(standard_grad * self.x_minus_mean * -0.5 * self.var_x ** (-3/2),
                          axis=0, keepdims=True)
        stddev_inv = 1 / self.stddev_x
        aux_x_minus_mean = 2 * self.x_minus_mean / self.num_examples

        mean_grad = (np.sum(standard_grad * -stddev_inv, axis=0,
                            keepdims=True) +
                            var_grad * np.sum(-aux_x_minus_mean, axis=0,
                            keepdims=True))

        self.gamma_grad = np.sum(grad_input * self.standard_x, axis=0,
                                 keepdims=True)
        self.bias_grad = np.sum(grad_input, axis=0, keepdims=True)

        return standard_grad * stddev_inv + var_grad * aux_x_minus_mean + \
               mean_grad / self.num_examples

    def apply_gradients(self, learning_rate: float) -> None:
        self.gamma -= learning_rate * self.gamma_grad
        self.bias -= learning_rate * self.bias_grad

bn = BatchNormLayer(dims=4)

o = bn.forward(x)
bn.backward(np.ones_like(o))

In [None]:
class BatchNormLayer:

    def __init__(self, dims: int) -> None:
        self.gamma = jnp.ones((1, dims), dtype="float32")
        self.bias = jnp.zeros((1, dims), dtype="float32")

        self.running_mean_x = jnp.zeros(0)
        self.running_var_x = jnp.zeros(0)

        # forward params
        self.var_x = jnp.zeros(0)
        self.stddev_x = jnp.zeros(0)
        self.x_minus_mean = jnp.zeros(0)
        self.standard_x = jnp.zeros(0)
        self.num_examples = 0
        self.mean_x = jnp.zeros(0)
        self.running_avg_gamma = 0.9
        self.epsilon = 1e-6

        # backward params
        self.gamma_grad = jnp.zeros(0)
        self.bias_grad = jnp.zeros(0)

    def update_running_variables(self) -> None:
        is_mean_empty = jnp.array_equal(jnp.zeros(0), self.running_mean_x)
        is_var_empty = jnp.array_equal(jnp.zeros(0), self.running_var_x)
        if is_mean_empty != is_var_empty:
            raise ValueError("Mean and Var running averages should be "
                             "initilizaded at the same time")
        if is_mean_empty:
            self.running_mean_x = self.mean_x
            self.running_var_x = self.var_x
        else:
            gamma = self.running_avg_gamma
            self.running_mean_x = gamma * self.running_mean_x + \
                                  (1.0 - gamma) * self.mean_x
            self.running_var_x = gamma * self.running_var_x + \
                                 (1. - gamma) * self.var_x

    def forward(self, x: jnp.ndarray, train: bool = True) -> jnp.ndarray:
        self.num_examples = x.shape[0]
        if train:
            self.mean_x = jnp.mean(x, axis=0, keepdims=True)
            self.var_x = jnp.mean((x - self.mean_x) ** 2, axis=0, keepdims=True)
            self.update_running_variables()
        else:
            self.mean_x = self.running_mean_x.copy()
            self.var_x = self.running_var_x.copy()

        self.var_x += self.epsilon
        self.stddev_x = jnp.sqrt(self.var_x)
        self.x_minus_mean = x - self.mean_x
        self.standard_x = self.x_minus_mean / self.stddev_x
        return self.gamma * self.standard_x + self.bias

    def backward(self, grad_input: jnp.ndarray) -> jnp.ndarray:
        standard_grad = grad_input * self.gamma

        var_grad = jnp.sum(standard_grad * self.x_minus_mean * -0.5 * self.var_x ** (-3/2),
                          axis=0, keepdims=True)
        stddev_inv = 1 / self.stddev_x
        aux_x_minus_mean = 2 * self.x_minus_mean / self.num_examples

        mean_grad = (jnp.sum(standard_grad * -stddev_inv, axis=0,
                            keepdims=True) +
                            var_grad * jnp.sum(-aux_x_minus_mean, axis=0,
                            keepdims=True))

        self.gamma_grad = jnp.sum(grad_input * self.standard_x, axis=0,
                                 keepdims=True)
        self.bias_grad = jnp.sum(grad_input, axis=0, keepdims=True)

        return standard_grad * stddev_inv + var_grad * aux_x_minus_mean + \
               mean_grad / self.num_examples

    def apply_gradients(self, learning_rate: float) -> None:
        self.gamma -= learning_rate * self.gamma_grad
        self.bias -= learning_rate * self.bias_grad

bn = BatchNormLayer(dims=4)

o = bn.forward(x)
bn.backward(jnp.ones_like(o))