### relu/silu/quantization post-training
* post-training quantization  
* next steps: implement further steps towards hardware/jax block compatibility **quantization-aware training** 

In [4]:
import jax
import jax.numpy as jnp
import optax
import flax
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from flax import linen as nn
from EICDense import *
from ShuffleBlocks import *
from Accumulator import *
#from PseudoFFNet import *
#from EICNet import *
#from HelperFunctions.activations import *
from mnist_dataloader import *
#from HelperFunctions.metric_functions import *

In [4]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

data_path = '/Users/rairo/trident data'

# Load MNIST data
print("Preparing the MNIST dataset...")
(train_images, train_labels), (val_images, val_labels), (test_images, test_labels) = load_and_process_mnist(data_path, binarize=True)

# Define the ANN model
class silu_net(nn.Module):
    activation_fn: callable

    @nn.compact
    def __call__(self, x):
        x = x.reshape(x.shape[0], -1)  # Flatten input
        x = self.noisy_dense(x, 2048)
        x = self.noisy_dense(x, 256)
        x = nn.Dense(10)(x)
        return x

    def noisy_dense(self, x, features):
        dense = nn.Dense(features)(x)
        # Add beta noise
        noise = jax.random.beta(jax.random.PRNGKey(0), 2, 5, shape=dense.shape)
        noisy_output = dense + noise
        return self.activation_fn(noisy_output)

# Initialize model parameters and optimizer
input_size = 28 * 28
activation_fn = jax.nn.silu

model = silu_net(activation_fn=activation_fn)

# Training state
class TrainState(train_state.TrainState):
    pass

def create_train_state(rng, model, learning_rate):
    params = model.init(rng, jnp.ones([1, input_size]))['params']
    tx = optax.adam(learning_rate)
    return TrainState.create(apply_fn=model.apply, params=params, tx=tx)

rng = jax.random.PRNGKey(0)
state = create_train_state(rng, model, learning_rate=0.001)

# Training step
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch["images"])
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch["labels"]).mean()
        return loss

    grads = jax.grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state

# Evaluation
def evaluate(state, images, labels):
    logits = state.apply_fn({'params': state.params}, images)
    predictions = jnp.argmax(logits, axis=-1)
    accuracy = jnp.mean(predictions == labels)
    return accuracy * 100

# Training loop
num_epochs = 10
batch_size = 64
for epoch in range(num_epochs):
    for images, labels in get_train_batches(train_images, train_labels, batch_size):
        batch = {
            "images": images,
            "labels": labels
        }
        state = train_step(state, batch)

    val_accuracy = evaluate(state, val_images, val_labels)
    print(f"Epoch {epoch + 1} completed. Validation Accuracy: {val_accuracy:.2f}%")

# Test evaluation
test_accuracy = evaluate(state, test_images, test_labels)
print(f"Test Accuracy: {test_accuracy:.2f}%")

# Post-training quantization
def quantize_params(params, num_bits=4):
    scale = 2 ** (num_bits - 1) - 1
    quantized_params = jax.tree_map(
        lambda p: jnp.round(p * scale) / scale, params
    )
    return quantized_params

# Quantize and evaluate
quantized_params = quantize_params(state.params)
test_accuracy_quantized = evaluate(state.replace(params=quantized_params), test_images, test_labels)
print(f"Test Accuracy after Quantization: {test_accuracy_quantized:.2f}%")


Preparing the MNIST dataset...
Epoch 1 completed. Validation Accuracy: 96.68%
Epoch 2 completed. Validation Accuracy: 97.72%
Epoch 3 completed. Validation Accuracy: 97.68%
Epoch 4 completed. Validation Accuracy: 98.10%
Epoch 5 completed. Validation Accuracy: 97.74%
Epoch 6 completed. Validation Accuracy: 97.82%
Epoch 7 completed. Validation Accuracy: 98.14%
Epoch 8 completed. Validation Accuracy: 97.82%
Epoch 9 completed. Validation Accuracy: 97.90%
Epoch 10 completed. Validation Accuracy: 98.26%
Test Accuracy: 97.83%


  quantized_params = jax.tree_map(


Test Accuracy after Quantization: 97.12%


In [25]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

data_path = '/Users/rairo/trident data'

# Load MNIST data

(train_images, train_labels), (val_images, val_labels), (test_images, test_labels) = load_and_process_mnist(data_path, binarize=True)

# Define the ANN model
class relu_net(nn.Module):
    activation_fn: callable

    @nn.compact
    def __call__(self, x):
        x = x.reshape(x.shape[0], -1)  # Flatten input (now 256 dimensions for 16x16 images)
        x = self.noisy_dense(x, 512)  # Adjusted dimensions to match smaller input size
        x = self.noisy_dense(x, 256)
        x = nn.Dense(10)(x)  # Output layer for 10 classes
        return x

    def noisy_dense(self, x, features):
        dense = nn.Dense(features)(x)
        # Add beta noise
        noise = jax.random.beta(jax.random.PRNGKey(0), 2, 5, shape=dense.shape)
        noisy_output = dense + noise
        return self.activation_fn(noisy_output)

# Initialize model parameters and optimizer
input_size = 16 * 16  # Updated input size for 16x16 images
activation_fn = jax.nn.relu

model = relu_net(activation_fn=activation_fn)

# Training state
class TrainState(train_state.TrainState):
    pass

def create_train_state(rng, model, learning_rate):
    params = model.init(rng, jnp.ones([1, input_size]))['params']
    tx = optax.adam(learning_rate)
    return TrainState.create(apply_fn=model.apply, params=params, tx=tx)

rng = jax.random.PRNGKey(0)
state = create_train_state(rng, model, learning_rate=0.001)

# Training step
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch["images"])
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch["labels"]).mean()
        return loss

    grads = jax.grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state

# Evaluation
def evaluate(state, images, labels):
    logits = state.apply_fn({'params': state.params}, images)
    predictions = jnp.argmax(logits, axis=-1)
    accuracy = jnp.mean(predictions == labels)
    return accuracy * 100

# Training loop
num_epochs = 10
batch_size = 64
for epoch in range(num_epochs):
    for images, labels in get_train_batches(train_images, train_labels, batch_size):
        batch = {
            "images": images,
            "labels": labels
        }
        state = train_step(state, batch)

    val_accuracy = evaluate(state, val_images, val_labels)
    print(f"Epoch {epoch + 1} completed. Validation Accuracy: {val_accuracy:.2f}%")

# Test evaluation
test_accuracy = evaluate(state, test_images, test_labels)
print(f"Test Accuracy: {test_accuracy:.2f}%")

# Post-training quantization
def quantize_params(params, num_bits=4):
    scale = 2 ** (num_bits - 1) - 1
    quantized_params = jax.tree_map(
        lambda p: jnp.round(p * scale) / scale, params
    )
    return quantized_params

# Quantize and evaluate
quantized_params = quantize_params(state.params)
test_accuracy_quantized = evaluate(state.replace(params=quantized_params), test_images, test_labels)
print(f"Test Accuracy after Quantization: {test_accuracy_quantized:.2f}%")


ScopeParamShapeError: Initializer expected to generate shape (256, 512) but got shape (784, 512) instead for parameter "kernel" in "/Dense_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

In [5]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

data_path = '/Users/rairo/trident data'

# Load MNIST data
print("Preparing the MNIST dataset...")
(train_images, train_labels), (val_images, val_labels), (test_images, test_labels) = load_and_process_mnist(data_path, binarize=True)

# Define the ANN model
class silu_net(nn.Module):
    activation_fn: callable

    @nn.compact
    def __call__(self, x):
        x = x.reshape(x.shape[0], -1)  # Flatten input
        x = self.noisy_dense(x, 2048)
        x = self.noisy_dense(x, 256)
        x = nn.Dense(10)(x)
        return x

    def noisy_dense(self, x, features):
        dense = nn.Dense(features)(x)
        # Add beta noise
        noise = jax.random.beta(jax.random.PRNGKey(0), 2, 5, shape=dense.shape)
        noisy_output = dense + noise
        return self.activation_fn(noisy_output)

# Initialize model parameters and optimizer
input_size = 28 * 28
activation_fn = jax.nn.silu

model = silu_net(activation_fn=activation_fn)

# Training state
class TrainState(train_state.TrainState):
    pass

def create_train_state(rng, model, learning_rate):
    params = model.init(rng, jnp.ones([1, input_size]))['params']
    tx = optax.adam(learning_rate)
    return TrainState.create(apply_fn=model.apply, params=params, tx=tx)

rng = jax.random.PRNGKey(0)
state = create_train_state(rng, model, learning_rate=0.001)

# Training step
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch["images"])
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch["labels"]).mean()
        return loss

    grads = jax.grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state

# Evaluation
def evaluate(state, images, labels):
    logits = state.apply_fn({'params': state.params}, images)
    predictions = jnp.argmax(logits, axis=-1)
    accuracy = jnp.mean(predictions == labels)
    return accuracy * 100

# Training loop
num_epochs = 10
batch_size = 64
for epoch in range(num_epochs):
    for images, labels in get_train_batches(train_images, train_labels, batch_size):
        batch = {
            "images": images,
            "labels": labels
        }
        state = train_step(state, batch)

    val_accuracy = evaluate(state, val_images, val_labels)
    print(f"Epoch {epoch + 1} completed. Validation Accuracy: {val_accuracy:.2f}%")

# Test evaluation
test_accuracy = evaluate(state, test_images, test_labels)
print(f"Test Accuracy: {test_accuracy:.2f}%")

# Post-training quantization
def quantize_params(params, num_bits=8):
    scale = 2 ** (num_bits - 1) - 1
    quantized_params = jax.tree_map(
        lambda p: jnp.round(p * scale) / scale, params
    )
    return quantized_params

# Quantize and evaluate
quantized_params = quantize_params(state.params)
test_accuracy_quantized = evaluate(state.replace(params=quantized_params), test_images, test_labels)
print(f"Test Accuracy after Quantization: {test_accuracy_quantized:.2f}%")


Preparing the MNIST dataset...
Epoch 1 completed. Validation Accuracy: 96.46%
Epoch 2 completed. Validation Accuracy: 97.56%
Epoch 3 completed. Validation Accuracy: 98.18%
Epoch 4 completed. Validation Accuracy: 97.68%
Epoch 5 completed. Validation Accuracy: 98.24%
Epoch 6 completed. Validation Accuracy: 98.12%
Epoch 7 completed. Validation Accuracy: 97.96%
Epoch 8 completed. Validation Accuracy: 97.92%
Epoch 9 completed. Validation Accuracy: 97.68%
Epoch 10 completed. Validation Accuracy: 97.94%
Test Accuracy: 97.81%
Test Accuracy after Quantization: 97.84%


# Jax blocks compatible version of silu-net? 
* reshaping & dense layers to approximate the tensor block level operations of jnp.einsum 
* silu 
* dense layer operation to aggregate outputs 
* permutations across features (static)


In [None]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

data_path = '/Users/rairo/trident data'

# Load MNIST data
print("Preparing the MNIST dataset...")
(train_images, train_labels), (val_images, val_labels), (test_images, test_labels) = load_and_process_mnist(data_path, binarize=True)

# Define the ANN model
class silu_net(nn.Module):
    activation_fn: callable

    @nn.compact
    def __call__(self, x):
        x = x.reshape(x.shape[0], -1)  # Flatten input
        x = self.noisy_dense(x, 2048)
        x = self.noisy_dense(x, 256)
        x = nn.Dense(10)(x)
        return x

    def noisy_dense(self, x, features):
        dense = nn.Dense(features)(x)
        # Add beta noise
        noise = jax.random.beta(jax.random.PRNGKey(0), 2, 5, shape=dense.shape)
        noisy_output = dense + noise
        return self.activation_fn(noisy_output)

# Initialize model parameters and optimizer
input_size = 28 * 28
activation_fn = jax.nn.silu

model = silu_net(activation_fn=activation_fn)

# Training state
class TrainState(train_state.TrainState):
    pass

def create_train_state(rng, model, learning_rate):
    params = model.init(rng, jnp.ones([1, input_size]))['params']
    tx = optax.adam(learning_rate)
    return TrainState.create(apply_fn=model.apply, params=params, tx=tx)

rng = jax.random.PRNGKey(0)
state = create_train_state(rng, model, learning_rate=0.001)

# Training step
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch["images"])
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch["labels"]).mean()
        return loss

    grads = jax.grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state

# Evaluation
def evaluate(state, images, labels):
    logits = state.apply_fn({'params': state.params}, images)
    predictions = jnp.argmax(logits, axis=-1)
    accuracy = jnp.mean(predictions == labels)
    return accuracy * 100

# Training loop
num_epochs = 10
batch_size = 64
for epoch in range(num_epochs):
    for images, labels in get_train_batches(train_images, train_labels, batch_size):
        batch = {
            "images": images,
            "labels": labels
        }
        state = train_step(state, batch)

    val_accuracy = evaluate(state, val_images, val_labels)
    print(f"Epoch {epoch + 1} completed. Validation Accuracy: {val_accuracy:.2f}%")

# Test evaluation
test_accuracy = evaluate(state, test_images, test_labels)
print(f"Test Accuracy: {test_accuracy:.2f}%")

# Post-training quantization
def quantize_params(params, num_bits=8):
    scale = 2 ** (num_bits - 1) - 1
    quantized_params = jax.tree_map(
        lambda p: jnp.round(p * scale) / scale, params
    )
    return quantized_params

# Quantize and evaluate
quantized_params = quantize_params(state.params)
test_accuracy_quantized = evaluate(state.replace(params=quantized_params), test_images, test_labels)
print(f"Test Accuracy after Quantization: {test_accuracy_quantized:.2f}%")


In [23]:
def reshape_to_blocks(x, block_size=256):
    # Calculate the required padding
    total_size = x.shape[1]
    pad_size = block_size - (total_size % block_size) if total_size % block_size != 0 else 0
    
    # Apply padding
    x_padded = jnp.pad(x, ((0, 0), (0, pad_size)))  # Pad along the feature dimension

    # Reshape to blocks
    num_blocks = x_padded.shape[1] // block_size
    return x_padded.reshape(x.shape[0], num_blocks, block_size)


def constrained_dense(x, features):
    dense = nn.Dense(features)(x)
    constrained_weights = jax.nn.relu(dense)
    return constrained_weights

class Shuffle(nn.Module):
    def __call__(self, x):
        # Random permutation of features
        perm = jax.random.permutation(jax.random.PRNGKey(0), x.shape[1])
        return x[:, perm]

def accumulate(x, features):
    # Simple accumulation block
    dense = nn.Dense(features)(x)
    return nn.relu(dense)  # Simulates accumulation


In [26]:
class silu_net_blocks(nn.Module):
    activation_fn: callable

    @nn.compact
    def __call__(self, x):
        x = x.reshape(x.shape[0], -1)  # Flatten input

        # Reshape to blocks with padding
        x = reshape_to_blocks(x)

        # First noisy dense layer
        x = self.noisy_dense(x, 2048)

        # Shuffle layer
        x = Shuffle()(x)

        # Accumulate outputs
        x = accumulate(x, 256)

        # Second dense layer
        x = self.noisy_dense(x, 256)

        # Final dense layer and reshape logits
        x = nn.Dense(10)(x)
        x = x.reshape(x.shape[0], -1)  # Remove extra dimensions

        return x

    def noisy_dense(self, x, features):
        dense = constrained_dense(x, features)
        noise = jax.random.beta(jax.random.PRNGKey(0), 2, 5, shape=dense.shape)
        return self.activation_fn(dense + noise)


In [42]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from HelperFunctions.mnist_dataloader import * 
data_path = '/Users/rairo/trident data'

# Load MNIST data
print("Preparing the MNIST dataset...")
(train_images, train_labels), (val_images, val_labels), (test_images, test_labels) = load_and_process_mnist(data_path, binarize=True)


Preparing the MNIST dataset...


In [27]:
# Initialize model parameters and optimizer
input_size = 28 * 28
activation_fn = jax.nn.silu

model = silu_net_blocks(activation_fn=activation_fn)

# Training state
class TrainState(train_state.TrainState):
    pass

def create_train_state(rng, model, learning_rate):
    params = model.init(rng, jnp.ones([1, input_size]))['params']
    tx = optax.adam(learning_rate)
    return TrainState.create(apply_fn=model.apply, params=params, tx=tx)

rng = jax.random.PRNGKey(0)
state = create_train_state(rng, model, learning_rate=0.001)

# Training step
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch["images"])
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch["labels"]).mean()
        return loss

    grads = jax.grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state

# Evaluation
def evaluate(state, images, labels):
    logits = state.apply_fn({'params': state.params}, images)
    predictions = jnp.argmax(logits, axis=-1)
    accuracy = jnp.mean(predictions == labels)
    return accuracy * 100

# Training loop
num_epochs = 10
batch_size = 64
for epoch in range(num_epochs):
    for images, labels in get_train_batches(train_images, train_labels, batch_size):
        batch = {
            "images": images,
            "labels": labels
        }
        state = train_step(state, batch)

    val_accuracy = evaluate(state, val_images, val_labels)
    print(f"Epoch {epoch + 1} completed. Validation Accuracy: {val_accuracy:.2f}%")

# Test evaluation
test_accuracy = evaluate(state, test_images, test_labels)
print(f"Test Accuracy: {test_accuracy:.2f}%")

# Post-training quantization
def quantize_params(params, num_bits=8):
    scale = 2 ** (num_bits - 1) - 1
    quantized_params = jax.tree_map(
        lambda p: jnp.round(p * scale) / scale, params
    )
    return quantized_params

# Quantize and evaluate
quantized_params = quantize_params(state.params)
test_accuracy_quantized = evaluate(state.replace(params=quantized_params), test_images, test_labels)
print(f"Test Accuracy after Quantization: {test_accuracy_quantized:.2f}%")


Epoch 1 completed. Validation Accuracy: 93.10%
Epoch 2 completed. Validation Accuracy: 95.06%
Epoch 3 completed. Validation Accuracy: 95.24%
Epoch 4 completed. Validation Accuracy: 95.36%
Epoch 5 completed. Validation Accuracy: 95.54%
Epoch 6 completed. Validation Accuracy: 96.27%
Epoch 7 completed. Validation Accuracy: 95.72%
Epoch 8 completed. Validation Accuracy: 96.45%
Epoch 9 completed. Validation Accuracy: 96.04%
Epoch 10 completed. Validation Accuracy: 96.33%
Test Accuracy: 96.47%
Test Accuracy after Quantization: 96.40%


In [None]:
from EICDense import *
from ShuffleBlocks import *
from Accumulator import *

In [20]:
class BlockSiluNet(nn.Module):
    activation_fn: callable

    def setup(self):
        # Layer 1: EICDense 
        self.fc1 = EICDense(
            in_size=784, 
            out_size=2048, 
            threshold=0.0, 
            noise_sd=0.05,  
            activation=self.activation_fn
        )
        self.ac1 = Accumulator(
            in_block_size=2048 // 256, 
            threshold=0.0, 
            noise_sd=0.05, 
            activation=self.activation_fn
        )
        # Shuffle block 
        self.shuffle1 = ShuffleBlocks(subvector_len=256, slot_len=64, key=jax.random.PRNGKey(0))
        
        # Layer 2: EICDense 
        self.fc2 = EICDense(
            in_size=2048, 
            out_size=512,  
            threshold=0.0, 
            noise_sd=0.05, 
            activation=self.activation_fn
        )
        self.ac2 = Accumulator(
            in_block_size=512 // 256, 
            threshold=0.0, 
            noise_sd=0.05, 
            activation=self.activation_fn
        )
        self.shuffle2 = ShuffleBlocks(subvector_len=256, slot_len=64, key=jax.random.PRNGKey(1))
        
        # Layer 3: EICDense with reduced output size
        self.fc3 = EICDense(
            in_size=512, 
            out_size=10, 
            threshold=0.0, 
            noise_sd=0.05, 
            activation=self.activation_fn
        )
        self.ac3 = Accumulator(
            in_block_size=10 // 256, 
            threshold=0.0, 
            noise_sd=0.05, 
            activation=self.activation_fn
        )

    def __call__(self, x):
        # Flatten input
        x = x.reshape(x.shape[0], -1)

        # First layer: EICDense -> Accumulator -> Shuffle
        x = self.fc1(x)
        x = self.ac1(x)
        x = self.shuffle1(x)

        # Second layer: EICDense -> Accumulator -> Shuffle
        x = self.fc2(x)
        x = self.ac2(x)
        x = self.shuffle2(x)

        # Third layer: EICDense -> Accumulator (final output)
        x = self.fc3(x)
        x = self.ac3(x)
        return x



In [24]:
# Initialize model parameters and optimizer
input_size = 28 * 28
activation_fn = jax.nn.silu
data_path = '/Users/rairo/trident data'

# Load MNIST data
print("Preparing the MNIST dataset...")
(train_images, train_labels), (val_images, val_labels), (test_images, test_labels) = load_and_process_mnist(data_path, binarize=True)

model = BlockSiluNet(activation_fn=activation_fn)

def prune_weights(params, threshold=0.01):
    return jax.tree_map(lambda p: jnp.where(jnp.abs(p) > threshold, p, 0), params)

# Training state
class TrainState(train_state.TrainState):
    pass

def create_train_state(rng, model, learning_rate):
    params = model.init(rng, jnp.ones([input_size]))['params']
    tx = optax.adam(learning_rate)
    return TrainState.create(apply_fn=model.apply, params=params, tx=tx)

rng = jax.random.PRNGKey(0)
state = create_train_state(rng, model, learning_rate=0.001)

# Post-training pruning
pruned_params = prune_weights(state.params, threshold=0.01)
state = state.replace(params=pruned_params)

# Training step
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch["images"])
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch["labels"]).mean()
        return loss

    grads = jax.grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state

# Evaluation
def evaluate(state, images, labels):
    logits = state.apply_fn({'params': state.params}, images)
    predictions = jnp.argmax(logits, axis=-1)
    accuracy = jnp.mean(predictions == labels)
    return accuracy * 100

# Training loop
num_epochs = 10
batch_size = 64
for epoch in range(num_epochs):
    for images, labels in get_train_batches(train_images, train_labels, batch_size):
        batch = {
            "images": images,
            "labels": labels
        }
        state = train_step(state, batch)

    val_accuracy = evaluate(state, val_images, val_labels)
    print(f"Epoch {epoch + 1} completed. Validation Accuracy: {val_accuracy:.2f}%")

# Test evaluation
test_accuracy = evaluate(state, test_images, test_labels)
print(f"Test Accuracy: {test_accuracy:.2f}%")

# Post-training quantization
def quantize_params(params, num_bits=8):
    scale = 2 ** (num_bits - 1) - 1
    quantized_params = jax.tree_map(
        lambda p: jnp.round(p * scale) / scale, params
    )
    return quantized_params

# Quantize and evaluate
quantized_params = quantize_params(state.params)
test_accuracy_quantized = evaluate(state.replace(params=quantized_params), test_images, test_labels)
print(f"Test Accuracy after Quantization: {test_accuracy_quantized:.2f}%")

Preparing the MNIST dataset...


SetAttributeFrozenModuleError: Can't set in_blocks=1 for Module of type EICDense: Module instance is frozen outside of setup method. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.SetAttributeFrozenModuleError)