# Goals of the notebook
- Get familirized with flax.
- Be able to instantiate modules with activations given by a non-linearity
- trying to train MNIST using this setup
  

In [2]:
import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
from flax import nnx
import optax
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import tensorflow_datasets as tfds
from tqdm.notebook import tqdm
from collections import defaultdict

plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'


In [3]:
## define the activation and custom gradients

def ternary_activation(x, thresholds, noise_sd, key = jax.random.key(0)):
    """
    Ternary activation function
    """
    key, key2 = jax.random.split(key, 2)

    # generate noise
    noise = jax.random.normal(key = key2, shape = x.shape) * noise_sd

    # inject noise
    x = x + noise


    # threshold
    t1, t2 = thresholds
    s = jnp.where(
        x < t1, -1.0,
        jnp.where(
            x > t2, 1.0,
            0.0
        )
    )

    return s

## helper function
@jax.jit
def gaussian_cdf(x, mu, sigma):
    return jax.scipy.stats.norm.cdf(x, loc = mu, scale = sigma)

@jax.jit
def gaussian_pdf(x, mu, sigma):
    return jax.scipy.stats.norm.pdf(x, loc = mu, scale = sigma)

@jax.jit
def expected_state(x, thresholds, noise_sd):
    t1, t2 = thresholds
    e = ((1 - gaussian_cdf(x = t2 - x, mu = 0, sigma = noise_sd)) - gaussian_cdf(x = t1 - x, mu = 0, sigma = noise_sd))
    return e

# custom gradients
@jax.custom_vjp
def custom_ternary_gradient(x, thresholds, noise_sd):
    return ternary_activation(x = x, thresholds = thresholds, noise_sd = noise_sd) # call the ternary activation in the forward pass

def custom_ternary_gradient_fwd(x, thresholds, noise_sd):
    return custom_ternary_gradient(x, thresholds, noise_sd), (x, thresholds, noise_sd) # save the inputs, thresholds, noise_sd for the backward pass

def custom_ternary_gradients_bwd(residuals, grads):
    x, thresholds, noise_sd = residuals # unpack the residuals from forward pass
    t1, t2 = thresholds
    exp_state_grad = gaussian_pdf(x = t1 - x, mu = 0, sigma = noise_sd) + gaussian_pdf(x = t2 - x, mu = 0, sigma = noise_sd)
    return (exp_state_grad*grads, None, None)

custom_ternary_gradient.defvjp(custom_ternary_gradient_fwd, custom_ternary_gradients_bwd)