In [23]:
import jax
import jax.numpy as jnp
import optax
import flax
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from flax import linen as nn
from EICDense import EICDense
from ShuffleBlocks import ShuffleBlocks
from Accumulator import Accumulator

from mnist_dataloader import *


In [26]:
class EICDense(nn.Module):
    in_size: int
    out_size: int
    threshold: float
    noise_sd: float
    activation: callable

    def setup(self):
        """
        Set up dependent parameters.
        """
        # Ensure valid in_size and out_size
        assert self.in_size > 0, f"Invalid in_size: {self.in_size}"
        assert self.out_size > 0, f"Invalid out_size: {self.out_size}"

        # Calculate block counts with a minimum of 1
        self.in_blocks = max(self.in_size // 256, 1)
        self.out_blocks = max(self.out_size // 256, 1)

        # Total cores required
        self.num_cores = self.out_blocks * self.in_blocks

        # Initialize weights
        self.W = self.param(
            "weights",
            lambda key, shape: nn.initializers.xavier_normal()(key, shape),
            (self.out_blocks, self.in_blocks, 256, 256),
        )

    def __call__(self, x):
        
        assert x.shape[-1] == self.in_size, f"Input shape is incorrect. Got {x.shape}, expected (batch_size, {self.in_size})."

        # Reshape input into blocks
        x_reshaped = x.reshape((x.shape[0], self.in_blocks, 256))

        # Apply weights
        W_pos = jax.nn.silu(self.W)
        #Block-wise matrix multiplication followed by summation over the input block index
        y = jnp.einsum("ijkl,bjl->bik", W_pos, x_reshaped)

        # Apply activation
        return self.activation(y)


In [27]:

class Accumulator(nn.Module):
    in_block_size: int
    threshold: float
    noise_sd: float
    activation: callable = None

    def setup(self):
        self.W = self.param(
            "weights",
            nn.initializers.xavier_normal(),
            (256, 256),  
        )

    def __call__(self, x):
        #print(f"x.shape in Accumulator: {x.shape}, in_block_size: {self.in_block_size}")

        # Ensure weights are positive
        W_pos = jax.nn.silu(self.W)

        
        #Matrix multiplication between 256 dimension vectors & weights
        y = jnp.einsum("bij,jk->bik", x, W_pos)
        #print(f"Shape after einsum in Accumulator: {y.shape}")

        # Flatten the output
        y = y.reshape((y.shape[0], -1))  # Combine blocks
        #print(f"Shape after flattening in Accumulator: {y.shape}")

        return y


In [28]:
class ShuffleBlocks(nn.Module):
    subvector_len: int
    slot_len: int
    key: jax.random.PRNGKey

    def __call__(self, x):
        #print(f"Input shape to ShuffleBlocks: {x.shape}")

        batch_size = x.shape[0]
        feature_size = x.shape[1]

        # Calculate number of subvectors
        num_subvectors = feature_size // self.subvector_len
        if num_subvectors == 0:
            pad_size = self.subvector_len - feature_size
            #print(f"Padding input by {pad_size} to match subvector_len.")
            x = jnp.pad(x, ((0, 0), (0, pad_size)))
            num_subvectors = 1

        slots_per_input = self.subvector_len // self.slot_len

        # Reshape input into blocks
        x_reshaped = x.reshape((batch_size, num_subvectors, slots_per_input, self.slot_len))
        #print(f"x_reshaped shape: {x_reshaped.shape}")

        # Shuffle blocks
        key, subkey = jax.random.split(self.key)
        keys = jax.random.split(key, num_subvectors)

        shuffled_blocks = [
            x_reshaped[:, i, jax.random.permutation(keys[i], slots_per_input, independent=True)]
            for i in range(num_subvectors)
        ]

        x_shuffled = jnp.concatenate([blocks.reshape(batch_size, -1) for blocks in shuffled_blocks], axis=1)

        # Ensure output matches the original feature size
        if x_shuffled.shape[1] != feature_size:
            x_shuffled = x_shuffled[:, :feature_size]  
        #print(f"x_shuffled shape: {x_shuffled.shape}")

        return x_shuffled

In [30]:
class PseudoFFNet(nn.Module):
    activation_fn: callable

    def setup(self):
        self.fc1 = EICDense(in_size=1024, out_size=2048, threshold=0.0, noise_sd=0.1, activation=self.activation_fn)
        self.ac1 = Accumulator(in_block_size=8, threshold=0.0, noise_sd=0.1, activation=self.activation_fn)
        self.shuffle1 = ShuffleBlocks(subvector_len=256, slot_len=64, key=jax.random.PRNGKey(0))

        self.fc2 = EICDense(in_size=2048, out_size=256, threshold=0.0, noise_sd=0.1, activation=self.activation_fn)
        self.ac2 = Accumulator(in_block_size=1, threshold=0.0, noise_sd=0.1, activation=self.activation_fn)
        self.shuffle2 = ShuffleBlocks(subvector_len=256, slot_len=64, key=jax.random.PRNGKey(1))

        self.fc3 = EICDense(in_size=256, out_size=10, threshold=0.0, noise_sd=0.1, activation=self.activation_fn)
        self.ac3 = Accumulator(in_block_size=1, threshold=0.0, noise_sd=0.1, activation=None)

    def __call__(self, x):
        #print(f"Input shape: {x.shape}")

        # Layer 1
        x = self.fc1(x)
        #print(f"Shape after fc1: {x.shape}")
        x = x.reshape((x.shape[0], -1, 256))
        x = self.ac1(x)
        #print(f"Shape after ac1: {x.shape}")
        x = self.shuffle1(x)
        #print(f"Shape after shuffle1: {x.shape}")

        # Layer 2
        x = x.reshape((x.shape[0], 2048))  
        x = self.fc2(x)
        #print(f"Shape after fc2: {x.shape}")
        x = x.reshape((x.shape[0], -1, 256))
        x = self.ac2(x)
        #print(f"Shape after ac2: {x.shape}")
        x = self.shuffle2(x)
        #print(f"Shape after shuffle2: {x.shape}")

        # Layer 3
        x = self.fc3(x)
        #print(f"Shape after fc3: {x.shape}")
        x = self.ac3(x)
        #print(f"Final output shape: {x.shape}")

        return x


In [21]:
# Initialize model parameters and optimizer
image_size =(32,32)
activation_fn = jax.nn.silu
data_path = '/Users/rairo/trident data'

# Load MNIST data
print("Preparing the MNIST dataset...")
(train_images, train_labels), (val_images, val_labels), (test_images, test_labels) = load_and_process_mnist(data_path, binarize=True, input_size=image_size)


Preparing the MNIST dataset...


In [34]:
from tqdm import tqdm
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import numpy as np


def prune_weights(params, threshold=0.01):
    return jax.tree_map(lambda p: jnp.where(jnp.abs(p) > threshold, p, 0), params)

# Training state
class TrainState(train_state.TrainState):
    pass

# Create training state
def create_train_state(rng, model, learning_rate):
    params = model.init(rng, jnp.ones([64, input_size]))['params']
    tx = optax.adam(learning_rate)
    return TrainState.create(apply_fn=model.apply, params=params, tx=tx)

# Training step
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch["images"])
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch["labels"]).mean()
        return loss

    # Compute gradients and loss
    loss, grads = jax.value_and_grad(loss_fn)(state.params)

    # Apply gradients to update state
    state = state.apply_gradients(grads=grads)
    return state, loss

# Evaluation function
def evaluate(state, images, labels):
    logits = state.apply_fn({'params': state.params}, images)
    predictions = jnp.argmax(logits, axis=-1)
    accuracy = jnp.mean(predictions == labels)
    return accuracy * 100

# Training loop 
num_epochs = 10
batch_size = 64
train_losses = []
val_accuracies = []

# Initialize model, RNG, and training state
rng = jax.random.PRNGKey(0)
input_size = 1024
model = PseudoFFNet(activation_fn=jax.nn.silu)
state = create_train_state(rng, model, learning_rate=0.001)

# Training loop
for epoch in range(num_epochs):
    epoch_loss = 0
    num_batches = 0

    # Progress bar for each epoch
    with tqdm(get_train_batches(train_images, train_labels, batch_size), desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as pbar:
        for images, labels in pbar:
            batch = {"images": images, "labels": labels}
            state, loss = train_step(state, batch)
            epoch_loss += loss
            num_batches += 1
            pbar.set_postfix({"Batch Loss": loss})

    # Average loss for the epoch
    epoch_loss /= num_batches
    train_losses.append(epoch_loss)

    # Validation accuracy
    val_accuracy = evaluate(state, val_images, val_labels)
    val_accuracies.append(val_accuracy)
    print(f"Epoch {epoch + 1} completed. Validation Accuracy: {val_accuracy:.2f}%")

# Test evaluation after training
test_accuracy = evaluate(state, test_images, test_labels)
print(f"Final Test Accuracy: {test_accuracy:.2f}%")

# Post-training quantization
def quantize_params(params, num_bits=8):
    scale = 2 ** (num_bits - 1) - 1
    quantized_params = jax.tree_map(lambda p: jnp.round(p * scale) / scale, params)
    return quantized_params

quantized_params = quantize_params(state.params)
test_accuracy_quantized = evaluate(state.replace(params=quantized_params), test_images, test_labels)
print(f"Test Accuracy after Quantization: {test_accuracy_quantized:.2f}%")


Epoch 1/10: 782batch [03:34,  3.65batch/s, Batch Loss=0.077030614]


Epoch 1 completed. Validation Accuracy: 94.23%


Epoch 2/10: 782batch [03:53,  3.36batch/s, Batch Loss=0.0044169524]


Epoch 2 completed. Validation Accuracy: 96.15%


Epoch 3/10: 782batch [03:50,  3.40batch/s, Batch Loss=0.22260378]  


Epoch 3 completed. Validation Accuracy: 96.55%


Epoch 4/10: 782batch [03:59,  3.26batch/s, Batch Loss=0.00039530403]


Epoch 4 completed. Validation Accuracy: 96.66%


Epoch 5/10: 782batch [02:54,  4.49batch/s, Batch Loss=0.0014377253] 


Epoch 5 completed. Validation Accuracy: 96.36%


Epoch 6/10: 782batch [02:28,  5.28batch/s, Batch Loss=6.0852253e-05]


Epoch 6 completed. Validation Accuracy: 96.53%


Epoch 7/10: 782batch [02:33,  5.10batch/s, Batch Loss=1.9750762e-05]


Epoch 7 completed. Validation Accuracy: 97.51%


Epoch 8/10: 782batch [02:58,  4.38batch/s, Batch Loss=0.0036040063] 


Epoch 8 completed. Validation Accuracy: 97.37%


Epoch 9/10: 782batch [02:56,  4.42batch/s, Batch Loss=0.00031952775]


Epoch 9 completed. Validation Accuracy: 96.98%


Epoch 10/10: 782batch [02:28,  5.25batch/s, Batch Loss=0.00052177964] 


Epoch 10 completed. Validation Accuracy: 97.15%
Final Test Accuracy: 97.44%
Test Accuracy after Quantization: 97.47%


* EICDense 1: For each (i,j) block pair, matrix multiplication between the 256×256 weight block W_pos[i,j,::] and the corresponding input blockx_reshaped[:,j,:] & sum the results across all input blocks (j) for each output block (i).
* for Accumulator 1: For each input block (i),matrix multiplication between x[:,i,:] (256-dimensional vector) and W_pos (256×256). ==> new 256-dimensional output for each block
* EICDense 2: Matrix multiplication between each block of W_pos and corresponding input blocks & sum over the input block index (j)
* Accumulator 2: Multiplyx[:,0,:] (256-dimensional) with W_pos
* EICDense 3: Matrix multiplication for the single block of weights W_pos[0,0,:,:] & the input x[:,0,:].
* Accumulator 3: Multiply x[:,0,:] with W_pos, accumulating the final output. 

