# Baseline Model with Partial Sums
- Construct a baseline model with partial sums to compare against ScRRAMBLe.
- Refer to older scrips.

_Created on: 08/06/2025_

In [1]:
import jax
import math
import jax.numpy as jnp
import optax
import flax
from flax import nnx
from flax.nnx.nn import initializers
from typing import Callable
import json
import os
import pickle
import numpy as np
from collections import defaultdict
from functools import partial
from tqdm import tqdm
from datetime import date

import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
# mpl.use('Agg')  # Use a non-interactive backend for matplotlib.

from models import ScRRAMBLeCapsLayer

from utils.activation_functions import quantized_relu_ste, squash, qrelu
# from utils.loss_functions import margin_loss
from utils import ScRRAMBLe_routing, intercore_connectivity, load_and_augment_mnist


import tensorflow_datasets as tfds  # TFDS to download MNIST.
import tensorflow as tf  # TensorFlow / `tf.data` operations.

%load_ext autoreload
%autoreload 2

2025-08-06 19:02:29.423407: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754532149.435797 1922137 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754532149.439827 1922137 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1754532149.450166 1922137 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754532149.450178 1922137 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754532149.450180 1922137 computation_placer.cc:177] computation placer alr

In [2]:
class PartialSumsLayer(nnx.Module):
    """
    Module for single partial sums layer.
    Takes in size of a feedforwardf layer.
    Initializes appropriate number of cores as trainable parameters.
    Accumulates partial sums across the cores
    """

    def __init__(self,
                 in_size: int,
                 out_size: int,
                 rngs: nnx.Rngs,
                 activation_function: Callable,
                 columns_per_core: int = 256
                 ):
        
        self.in_size = in_size
        self.out_size = out_size
        self.activation_function = activation_function
        self.columns_per_core = columns_per_core
        self.rngs = rngs

        # number of cores requires
        self.in_blocks = math.ceil(in_size / columns_per_core)
        self.out_blocks = math.ceil(out_size / columns_per_core)

        # initialize parameters
        initializer = initializers.glorot_normal()
        self.W = nnx.Param(
            initializer(self.rngs.params(), (self.out_blocks, self.in_blocks, self.columns_per_core, self.columns_per_core))
        )

    def __call__(self, x: jax.Array) -> jax.Array:
        """
        Forward pass. No batch dimension. use vmap
        Assume the x is flat
        """

        # # pad the input if in_size is not a multiple of 256
        # x_padded = jnp.pad(x, pad_width=((0, 0), (0, self.columns_per_core * self.in_blocks - self.in_size)))

        x_reshape = x.reshape(self.in_blocks, self.columns_per_core)

        # compute the partial sums
        y = jnp.einsum('ijkl,jl->ik', self.W, x_reshape)

        # apply activation function
        y = jax.vmap(self.activation_function, in_axes=(0,))(y)

        return y






In [3]:
# testing forward pass
rngs = nnx.Rngs(params=0, activations=1, permute=5, default=345)
x_test = jax.random.normal(rngs.default(), (10, 1024))
activation_function = nnx.relu
# activation_function = partial(qrelu, bits=4)
test_layer = PartialSumsLayer(
    in_size=x_test.shape[-1],
    out_size=512,
    rngs=rngs,
    activation_function=activation_function,
)

# nnx.display(test_layer)
y_test = jax.vmap(test_layer, in_axes=(0,))(x_test)
print(f"Output shape: {y_test.shape}")
print(f"Output: {y_test[0, 0, :10]}")

del test_layer

Output shape: (10, 2, 256)
Output: [0.         0.7686344  0.         0.         0.33890092 0.12463952
 0.         0.03248509 0.26785067 0.18344072]


In [4]:
# define a network with partial sums
class PartialSumsNetwork(nnx.Module):
    """
    Network with partial sums layers.
    """

    def __init__(self,
                 layer_sizes: list,
                 rngs: nnx.Rngs,
                 activation_function: Callable,
                 columns_per_core: int = 256
                ):
        
        self.layer_sizes = layer_sizes
        self.activation_function = activation_function
        self.columns_per_core = columns_per_core
        self.rngs = rngs

        # initialize the layers
        self.layers = [
            PartialSumsLayer(
                in_size=i,
                out_size=o,
                rngs=rngs,
                activation_function=activation_function,
                columns_per_core=columns_per_core
            )
            for i, o in zip(layer_sizes[:-1], layer_sizes[1:])
        ]

    def __call__(self, x: jax.Array) -> jax.Array:
        """
        Forward pass through the network. Assume that x has a batch dimension!
        """

        # resize the image to be (32, 32) for MNIST
        x = jax.image.resize(x, (x.shape[0], 32, 32, 1), method='nearest')

        # flatten the first two dimensions
        x = jnp.reshape(x, (x.shape[0], -1))

        for layer in self.layers:
            x = jax.vmap(layer, in_axes=(0,))(x)

        # at the final layer apply population code
        x = x.reshape(x.shape[0], -1)
        x = x[:, :250]
        x = x.reshape(x.shape[0], 10, -1)
        x = jnp.mean(x, axis=-1)

        return x

        

In [5]:
# testing the network
layer_sizes = [1024, 2048, 512, 256]
activation_function = nnx.relu
test_network = PartialSumsNetwork(
    layer_sizes=layer_sizes,
    rngs=rngs,
    activation_function=activation_function,
    columns_per_core=256
)

x_test = jax.random.normal(rngs.default(), (10, 32, 32, 1))
y_test = test_network(x_test)
print(f"Output shape: {y_test.shape}")
print(f"Output: {y_test[0, :]}")

# nnx.display(test_network)
del test_network  # Clean up the test network to free memory.

Output shape: (10, 10)
Output: [0.02932509 0.06052509 0.05852911 0.05839609 0.04864245 0.05797828
 0.05379442 0.01932956 0.04287183 0.08704461]


## Setting up a training pipeline

In [6]:
# dataset loading
data_dir = "/local_disk/vikrant/datasets"
dataset_dict = {
    'batch_size': 64, # 64 is a good batch size for MNIST
    'train_steps': int(2e4), # run for longer, 20000 is good!
    'binarize': True, 
    'greyscale': True,
    'data_dir': data_dir,
    'seed': 101,
    'shuffle_buffer': 1024,
    'threshold' : 0.5, # binarization threshold, not to be confused with the threshold in the model
    'eval_every': 1000,
}

# loading the dataset
train_ds, valid_ds, test_ds = load_and_augment_mnist(
    batch_size=dataset_dict['batch_size'],
    train_steps=dataset_dict['train_steps'],
    data_dir=dataset_dict['data_dir'],
    seed=dataset_dict['seed'],
    shuffle_buffer=dataset_dict['shuffle_buffer'],
)


W0000 00:00:1754532170.340388 1922137 gpu_device.cc:2341] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [12]:
# Training step functions
def loss_fn(model: PartialSumsNetwork, batch):
    logits = model(batch['image'])
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch['label']
    ).mean()

    return loss, logits

@nnx.jit
@nnx.jit
def train_step(model: PartialSumsNetwork, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch, loss_fn: Callable = loss_fn):
  """Train for a single step."""
  grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['label'])  # In-place updates.
  optimizer.update(grads)  # In-place updates.

@nnx.jit
def eval_step(model: PartialSumsNetwork, metrics: nnx.MultiMetric, batch, loss_fn: Callable = loss_fn):
  loss, logits = loss_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['label'])  # In-place updates.

@nnx.jit
def pred_step(model: PartialSumsNetwork, batch):
  logits = model(batch['image'])
  return logits.argmax(axis=1)


In [8]:
key = jax.random.key(10)
key1, key2, key3, key4 = jax.random.split(key, 4)
rngs = nnx.Rngs(params=key1, activations=key2, permute=key3, default=key4)

# Initialize the model
model = PartialSumsNetwork(
    layer_sizes=[1024, 2048, 512, 256],
    rngs=rngs,
    activation_function=nnx.relu,
    columns_per_core=256
)

# optimizers
hyperparameters = {
    'learning_rate': 0.8e-4, # 1e-3 seems to work well
    'momentum': 0.9, 
    'weight_decay': 1e-4
}

optimizer = nnx.Optimizer(
    model,
    optax.adamw(learning_rate=hyperparameters['learning_rate'], weight_decay=hyperparameters['weight_decay'])
)

metrics = nnx.MultiMetric(
    accuracy=nnx.metrics.Accuracy(),
    loss=nnx.metrics.Average('loss')
)


In [9]:
metrics_history = {
'train_loss': [],
'train_accuracy': [],
'test_loss': [],
'valid_loss': [],
'valid_accuracy': [],
'test_accuracy': [],
'step': []
}

def train_scrramble_capsnet_mnist(
        model: PartialSumsNetwork = model,
        optimizer: nnx.Optimizer = optimizer,
        train_ds: tf.data.Dataset = train_ds,
        valid_ds: tf.data.Dataset = valid_ds,
        dataset_dict: dict = dataset_dict,
        save_model_flag: bool = False,
        save_metrics_flag: bool = False,
):
    
    eval_every = dataset_dict['eval_every']
    train_steps = dataset_dict['train_steps']

    for step, batch in enumerate(train_ds.as_numpy_iterator()):
        # Run the optimization for one step and make a stateful update to the following:
        # - The train state's model parameters
        # - The optimizer state
        # - The training loss and accuracy batch metrics

        train_step(model, optimizer, metrics, batch)

        if step > 0 and (step % eval_every == 0 or step == train_steps - 1):  # One training epoch has passed.
            metrics_history['step'].append(step)  # Record the step.
            # Log the training metrics.
            for metric, value in metrics.compute().items():  # Compute the metrics.
                metrics_history[f'train_{metric}'].append(float(value))  # Record the metrics.
            metrics.reset()  # Reset the metrics for the test set.

            # Compute the metrics on the validation set after each training epoch.
            for valid_batch in valid_ds.as_numpy_iterator():
                eval_step(model, metrics, valid_batch)

            # Log the validation metrics.
            for metric, value in metrics.compute().items():
                metrics_history[f'valid_{metric}'].append(float(value))
            metrics.reset()  # Reset the metrics for the next training epoch.

            print(f"Step {step}: Valid loss: {metrics_history['valid_loss'][-1]}, Accuracy: {metrics_history['valid_accuracy'][-1]}")

    best_accuracy = max(metrics_history['valid_accuracy'])
    print(f"Best accuracy: {best_accuracy}")

    # find the test set accuracy
    for test_batch in test_ds.as_numpy_iterator():
        eval_step(model, metrics, test_batch)
        # print the metrics
    for metric, value in metrics.compute().items():
        metrics_history[f'test_{metric}'].append(float(value))
    metrics.reset()  # Reset the metrics for the next training epoch.

    print("="*50)
    print(f"Test loss: {metrics_history['test_loss'][-1]}, Test accuracy: {metrics_history['test_accuracy'][-1]}")
    print("="*50)

    # if save_model_flag:
    #     today = date.today().isoformat()
    #     filename = f"sscamble_mnist_capsnet_recon_capsules{(sum(model.layer_sizes)-model.input_eff_capsules):d}_acc_{metrics_history['test_accuracy'][-1]*100:.0f}_{today}.pkl"
    #     graphdef, state = nnx.split(model)
    #     save_model(state, filename)

    # if save_metrics_flag:
    #     today = date.today().isoformat()
    #     filename = f"sscamble_mnist_capsnet_recon_capsules{(sum(model.layer_sizes)-model.input_eff_capsules):d}_acc_{metrics_history['test_accuracy'][-1]*100:.0f}_{today}.pkl"
    #     save_metrics(metrics_history, filename)

    return model


In [10]:
trained_model = train_scrramble_capsnet_mnist(
    model=model,
    optimizer=optimizer,
    train_ds=train_ds,
    valid_ds=valid_ds,
    dataset_dict=dataset_dict,
    save_model_flag=False,  # Set to True if you want to save the model.
    save_metrics_flag=False,  # Set to True if you want to save the metrics.
)

2025-08-06 19:02:51.381742: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-08-06 19:02:51.381770: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-08-06 19:02:51.381835: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-08-06 19:02:51.381846: W external/xla/xla/service/gpu/au

Step 1000: Valid loss: 0.5676942467689514, Accuracy: 0.8343870043754578
Step 2000: Valid loss: 0.33052974939346313, Accuracy: 0.9018086194992065
Step 3000: Valid loss: 0.2367028146982193, Accuracy: 0.9331986308097839
Step 4000: Valid loss: 0.19924293458461761, Accuracy: 0.9418213963508606
Step 5000: Valid loss: 0.17907705903053284, Accuracy: 0.9481034278869629
Step 6000: Valid loss: 0.15969376266002655, Accuracy: 0.9534050822257996
Step 7000: Valid loss: 0.15318016707897186, Accuracy: 0.9547855257987976
Step 8000: Valid loss: 0.13666708767414093, Accuracy: 0.9592269659042358
Step 9000: Valid loss: 0.12477154284715652, Accuracy: 0.9638484716415405
Step 10000: Valid loss: 0.12560532987117767, Accuracy: 0.9622079133987427
Step 11000: Valid loss: 0.1087472140789032, Accuracy: 0.9675696492195129
Step 12000: Valid loss: 0.1167786493897438, Accuracy: 0.9648887515068054
Step 13000: Valid loss: 0.11031541973352432, Accuracy: 0.967409610748291
Step 14000: Valid loss: 0.0993165448307991, Accuracy

In [11]:
del model

In [17]:
graphdef, trained_state = nnx.split(trained_model)

qmodel = PartialSumsNetwork(
    layer_sizes=[1024, 2048, 512, 256],
    rngs=nnx.Rngs(params=0, activations=1, permute=5, default=345),
    activation_function=partial(qrelu, bits=32, max_value=2.0),
    columns_per_core=256
)

qgraphdef, _ = nnx.split(qmodel)
qmodel = nnx.merge(qgraphdef, trained_state)

qmodel.eval() # Switch to evaluation mode.

accuracies = []

# evaluate the quantized model
for test_batch in test_ds.as_numpy_iterator():
    preds = pred_step(qmodel, test_batch)
    true_labels = test_batch['label']
    accuracies.append(jnp.mean(preds == true_labels))

accuracy = jnp.mean(jnp.array(accuracies))
print(f"Quantized model test accuracy: {accuracy:.4f}")

Quantized model test accuracy: 0.6991
