# Basic simulation of EIC core-based simulation on MNIST 
- The code simulates a pseudo-feedforward network with constraints of a RRAM-based CIM chip.
- The goal is to achieve (near) SOTA on MNIST
- The most notable constraint is that the chip can only apply non-linearity to blocks of 256-bit wide vectors at a time. This poses a challenge: we cannot implement any arbitrary-sized layers.
- Most of the MNIST benchmark studies contain multiple hidden layers of often >1000 neurons.
- In order to tackle this issue, we define a set of cores with learnable parameters that can accumulate the individual 256-bit long vectors. 
- We call these modules `EICDense` and `Accumulator`

In [3]:
import os
from functools import partial
import jax
import jax.numpy as jnp
import optax
import flax
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
from flax import linen as nn
from flax.linen import summary
from flax.training import train_state


%load_ext autoreload
%autoreload 2

## Helper Functions

- These functions implement noisy binary activation with straight through estimator

In [5]:
import jax
import jax.numpy as jnp
from jax import vmap, jit

# print("Modified custom grad to STE")

## define binary thresholding function: states [-1, 1]
def binary_activation(x, threshold, noise_sd, key):
    """
    Binary activation function
    """
    # key, key2 = jax.random.split(key, 2)

    # generate noise
    noise = jax.random.normal(key, shape = x.shape) * noise_sd

    # inject noise
    x = x + noise

    s = jnp.where(
        x < threshold, 0.0, 1.0
    )

    return s

## helper function
@jax.jit
def gaussian_cdf(x, mu, sigma):
    return jax.scipy.stats.norm.cdf(x, loc = mu, scale = sigma)

@jax.jit
def gaussian_pdf(x, mu, sigma):
    return jax.scipy.stats.norm.pdf(x, loc = mu, scale = sigma)

@jax.jit
def bin_expected_state(x, threshold, noise_sd):
    e = gaussian_cdf(x = x - threshold, mu = 0, sigma = noise_sd)
    return e

@jax.custom_vjp
def custom_binary_gradient(x, threshold, noise_sd, key):
    return binary_activation(x = x, threshold = threshold, noise_sd = noise_sd, key = key)

def custom_binary_gradient_fwd(x, threshold, noise_sd, key):
    return custom_binary_gradient(x, threshold, noise_sd, key), (x, threshold, noise_sd)

def custom_binary_gradient_bwd(residuals, gradients):
    x, threshold, noise_sd = residuals
    key, subkey = jax.random.split(jax.random.key(0))
    grad = jnp.where(jnp.abs(x) <= 1.0, 1.0, 0.0) #gaussian_pdf(x = x - threshold, mu = 0, sigma = noise_sd*10)
    return (grad*gradients, None, None, None)

custom_binary_gradient.defvjp(custom_binary_gradient_fwd, custom_binary_gradient_bwd)

## Modules

In [6]:
class EICDense(nn.Module):
    """
    Pseudo-dense layer using EIC Cores.
    Args:
    in_size: int, number of input neurons
    out_size: int, number of output neurons
    threshold: float, threshold for binary activation
    noise_sd: flaat, standard deviation of noise for binary activation
    key: jax.random.PRNGKey, random key

    Returns:
    x: jnp.ndarray, output of the layer
    """

    in_size: int
    out_size: int

    def setup(self):
        """
        Set up dependent parameters
        """
        self.out_blocks = max(self.out_size//256, 1) # number of blocks required at the output 
        self.in_blocks = max(self.in_size//256, 1) # number of bloacks required at the input


        self.num_cores = self.out_blocks * self.in_blocks # number of cores required
        self.W = self.param(
            "weights",
            lambda key, shape: nn.initializers.xavier_normal()(key, shape),
            (self.out_blocks, self.in_blocks, 256, 256)
        )


    def __call__(self, x):
        """
        Forward pass of the layer
        Args:
        x: jnp.ndarray (batch_size, in_size), input to the layer
        
        Returns:
        x: jnp.ndarray, output of the layer
        """

        assert x.shape[-1] == self.in_size, f"Input shape is incorrect. Got {x.shape[-1]}, expected {self.in_size}"

        x_reshaped = x.reshape(x.shape[0], self.in_blocks, 256) # organize x into blocks of 256 for every batch

        # make sure that the weights are positive
        W_pos= self.W #jax.nn.softplus(self.W)

        # quantize weights
        # W_pos = quantize_params(W_pos, bits = 8)

        y = jnp.einsum("ijkl,bjl->bijk", W_pos, x_reshaped)


        return y
    

# define the accumulator module
class Accumulator(nn.Module):
    """
    Accumulating the EICDense outputs. 
    Since the EICDense generates pseudo-feedforward outputs, we use a learnable accumulation matrix that minimizes error
    between the true feedforward output and the EIC output.

    Args:
        in_block_size: int, number of 256-sized blocks. This should be the .shape[0] of the EICDense output
    """

    in_block_size: int

    def setup(self):
        """
        Set up the weights for the accumulator
        """

        self.W = self.param(
            "weights",
            nn.initializers.xavier_normal(),
            (self.in_block_size, 256, 256)
        )


    @nn.compact
    def __call__(self, x):
        """
        Forward pass of the accumulator
        Args:
        x: jnp.ndarray, input to the accumulator
        
        Returns:
        x: jnp.ndarray, output of the accumulator
        """

        assert x.shape[1] == self.in_block_size, "Input shape is incorrect"
        # assert x.shape[1] == self.out_block_size, "Input shape is incorrect"

        # ensure positive 
        W_pos = self.W #jax.nn.softplus(self.W)
        # W_pos = quantize_params(W_pos, bits = 8)
        
        x = jnp.einsum("bijk->bik", x)
        y = jnp.einsum("ijk,bik->bik", W_pos, x) 

        # flatten y before returning
        y = y.reshape((y.shape[0], -1)) # (batch_size, out_size)

        return y

class PermuteBlock(nn.Module):
    """
    Contains two fixed permutation matrices (pos and neg) to shuffle the input block-wise.
    """

    input_size: int
    permute_block_size: int = 16 # previously 64
    core_input_size: int = 256

    def setup(self):
        """
        Set up permutation matrices
        """


        self.num_slots = self.core_input_size // self.permute_block_size # should be 16 in the latest iteration
        self.num_subvectors = self.input_size // self.core_input_size # for input_size = 1024, should be 256

        self.tau = self.param(
            'tau',
            nn.initializers.constant(10),
            ()
        ) # temperature paramter

        # generate two independent permutation sequences
        key = jax.random.key(1245)
        key1, key2 = jax.random.split(key)
        p1 = jax.random.permutation(key1, self.num_slots)
        p2 = jax.random.permutation(key2, self.num_slots) # jnp.roll(p1, shift = 1) #

        # generate permutation matrices
        m1 = jnp.eye(self.num_slots)*self.tau
        m2 = jnp.eye(self.num_slots)*self.tau

        # generate the permutation matrices
        self.Ppos = m1[p1]
        self.Ppos = jax.nn.softmax(self.Ppos, axis = -1)
        self.Pneg = m2[p2]
        self.Pneg = jax.nn.softmax(self.Pneg, axis = -1)

    
    def __call__(self, x):
        """
        Apply permutations and return (xpos - xneg)
        Args:
        x: jnp.ndarray, input vector. Shape: (batch_size, input_size) e.g. (32, 2048)
        Returns:
        xpos - xneg: jnp.ndarray, difference of permuted inputs. Shape: (batch_size, input_size)
        """

        assert x.shape[-1] == self.input_size, f"Input shape is incorrect. Got {x.shape[-1]}, expected {self.input_size}"
        assert self.num_subvectors * self.num_slots * self.permute_block_size == self.input_size, f"Inconsistent metrics!"

        x = x.reshape(x.shape[0], self.num_subvectors, self.num_slots, self.permute_block_size) # first dimension must be the batch size

        xpos = jnp.einsum('ij,bsjp->bsip', self.Ppos, x)
        xneg = jnp.einsum('ij,bsjp->bsip', self.Pneg, x)

        xout = xpos - xneg

        xout = xout.reshape((x.shape[0], self.input_size))

        return xout
    
