# Goals of the notebook
- Get familirized with flax.
- Be able to instantiate modules with activations given by a non-linearity
- Make a simple hierarchical parameter data structure whcih can be ported to MEC conectivity.

## Iteration 1:
1. Consider two cores and a lookup table.
2. The cores each have one fully connected layer.
3. The lookup table provides connectivity between cores.
4. Neurons in each core are binary stochastic.
5. Try to train MNIST on these cores.


In [1]:
import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
from flax import nnx
import optax
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import tensorflow_datasets as tfds
from tqdm.notebook import tqdm
from collections import defaultdict

plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'


In [1]:
## define binary thresholding function: states [0, 1]
def binary_activation(x, threshold, noise_sd, key):
    """
    Binary activation function
    """
    # key, key2 = jax.random.split(key, 2)

    # generate noise
    noise = jax.random.normal(key, shape = x.shape) * noise_sd

    # inject noise
    x = x + noise

    s = jnp.where(
        x < threshold, 0.0, 1.0
    )

    return s

## helper function
@jax.jit
def gaussian_cdf(x, mu, sigma):
    return jax.scipy.stats.norm.cdf(x, loc = mu, scale = sigma)

@jax.jit
def gaussian_pdf(x, mu, sigma):
    return jax.scipy.stats.norm.pdf(x, loc = mu, scale = sigma)

@jax.jit
def expected_state(x, thresholds, noise_sd):
    t1, t2 = thresholds
    e = ((1 - gaussian_cdf(x = t2 - x, mu = 0, sigma = noise_sd)) - gaussian_cdf(x = t1 - x, mu = 0, sigma = noise_sd))
    return e


## cuatom gradient for binary activation
@jax.custom_vjp
def custom_binary_gradient(x, threshold, noise_sd, key):
    return binary_activation(x = x, threshold = threshold, noise_sd = noise_sd, key = key)

def custom_binary_gradient_fwd(x, threshold, noise_sd, key):
    return custom_binary_gradient(x, threshold, noise_sd, key), (x, threshold, noise_sd)

def custom_binary_gradient_bwd(residuals, gradients):
    x, threshold, noise_sd = residuals
    grad = gaussian_pdf(x = x - threshold, mu = 0, sigma = noise_sd)
    return (grad*gradients, None, None, None)

custom_binary_gradient.defvjp(custom_binary_gradient_fwd, custom_binary_gradient_bwd)

NameError: name 'jax' is not defined

In [33]:
## define the dense layer module
class Dense(nnx.Module):
    """
    Define a single core.
    For a start this is one layer.
    """

    def __init__(self, num_inputs: int,
                num_outputs: int, 
                key: jax.random.key):
        
        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        self.key = key

        ## define the weights and biases
        self.weights = nnx.Param(jax.random.normal(key, (num_inputs, num_outputs)) * jnp.sqrt(2/(num_inputs + num_outputs)))

        self.biases = nnx.Param(jnp.zeros((num_outputs,)))

    def __call__(self, x: jnp.ndarray):
        # assert x.shape[-1] == self.num_inputs, "Input shape does not match the number of inputs"
        return x @ self.weights + self.biases
    

class Core(nnx.Module):
    """
    Define a single core. One layer with binary activation
    """

    def __init__(self, num_inputs: int,
                 num_neurons: int,
                 threshold: float,
                 noise_sd: float,
                key: jax.random.key
                 ):
        
        self.num_inputs = num_inputs
        self.num_neurons = num_neurons
        self.threshold = threshold
        self.noise_sd = noise_sd
        self.key = key
        self.dense = Dense(self.num_inputs, self.num_neurons, self.key)


    def __call__(self, 
                 x: jnp.array,
                 ):
        
        ## pass through the dense layer
        x = self.dense(x)
        self.key, self.subkey = jax.random.split(self.key, 2)
        x = custom_binary_gradient(x, self.threshold, self.noise_sd, self.subkey)
        return x
    



## TESTING..
    
# ## test if the dense layer works
# key = jax.random.PRNGKey(0)
# d1 = Dense(2, 2, key)
# x = jnp.ones((2,))#jax.random.normal(key, (2,))
# print(d1.weights.value)
# print(x)
# print(d1(x))

# plt.figure()
# plt.imshow(d1.weights.value, cmap = 'coolwarm')
# plt.colorbar()

# plt.figure()
# plt.plot(d1.biases.value)

## core
key = jax.random.PRNGKey(0)
c1 = Core(num_inputs=2, num_neurons=2, threshold=-1, noise_sd=0.1, key=key)
x = jax.random.normal(key, (2,))
print(x)
print(c1.dense.weights.value)
print(c1(x))


    
    






[-0.78476596  0.85644484]
[[ 1.2841669  -0.5337844 ]
 [ 0.24033786 -0.3781857 ]]
[1. 1.]
