In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, BatchNormalization

# Function to binarize weights with a scaling factor (as per the paper's method)
def binarize_weights(weights):
    alpha = np.mean(np.abs(weights))  # Computing layer-wise scaling factor
    binary_weights = np.sign(weights)  # Applying sign function to binarize
    return alpha * binary_weights  # Scaling binarized weights

In [2]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train.reshape(-1, 28*28) / 255.0, x_test.reshape(-1, 28*28) / 255.0
y_train, y_test = tf.keras.utils.to_categorical(y_train, 10), tf.keras.utils.to_categorical(y_test, 10)

# A standard Neural Network model (No binarization in training, only during evaluation)
model = Sequential([
    tf.keras.Input(shape=(28*28,)),
    Dense(256, activation='tanh'),
    BatchNormalization(),
    Dense(128, activation='tanh'),
    BatchNormalization(),
    Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=3, batch_size=64, verbose=1)

# Evaluating the Original Model Before Any Changes
loss, original_acc = model.evaluate(x_test, y_test, verbose=0)
print(f"\n Original Model Accuracy (Floating Point): {original_acc * 100:.2f}%")

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Epoch 1/3
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 11ms/step - accuracy: 0.8926 - loss: 0.3554
Epoch 2/3
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 9ms/step - accuracy: 0.9654 - loss: 0.1145
Epoch 3/3
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 9ms/step - accuracy: 0.9789 - loss: 0.0697

 Original Model Accuracy (Floating Point): 97.45%


In [3]:
# Extracting original floating-point weights after evaluating original model
dense_layers = [layer for layer in model.layers if isinstance(layer, Dense)]
original_weights = [layer.get_weights()[0] for layer in dense_layers]

# Extracting Batch Normalization parameters
bn_layers = [layer for layer in model.layers if isinstance(layer, BatchNormalization)]
bn_params = [layer.get_weights() for layer in bn_layers]  # Extract full BN state

# Binarizing the Weights
binarized_weights = [binarize_weights(w) for w in original_weights]

In [4]:
# Function to set model weights (including BN layers)
def set_model_weights(model, weight_list, bn_param_list=None):
    dense_idx = 0
    bn_idx = 0

    for layer in model.layers:
        if isinstance(layer, Dense):
            layer.set_weights([weight_list[dense_idx], layer.get_weights()[1]])  # Updating weights, keeping bias same
            dense_idx += 1
        elif isinstance(layer, BatchNormalization) and bn_param_list is not None and bn_idx < len(bn_param_list):
            layer.set_weights(bn_param_list[bn_idx])  # Restoring full BN state (beta, gamma, moving mean, moving variance)
            bn_idx += 1

# Recalibrating BN Layers after Reset
def recalibrate_bn(model, x_train):
    _ = model.predict(x_train[:1000])  # Forcing BN layers to recalculate running mean/variance

In [5]:
# Evaluating the Binarized Model (No Inversion, No Swapping)
set_model_weights(model, binarized_weights, bn_params)
recalibrate_bn(model, x_train)
loss, binarized_acc = model.evaluate(x_test, y_test, verbose=0)

# Results
print(f"\n Binarized Model Accuracy (No Swapping/Inversion): {binarized_acc * 100:.2f}% ")


[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step

 Binarized Model Accuracy (No Swapping/Inversion): 84.63% 
