In [1]:
import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.datasets import mnist



In [2]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize to [0, 1]
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Reshape
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

In [3]:
model = Sequential([
    Flatten(input_shape=(28, 28, 1)),
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

loss, accuracy = model.evaluate(x_test, y_test)
print(f'Baseline Test Accuracy: {accuracy}')

model.save("./saved_unquantized_model")

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Baseline Test Accuracy: 0.9767000079154968
INFO:tensorflow:Assets written to: ./saved_unquantized_model/assets


INFO:tensorflow:Assets written to: ./saved_unquantized_model/assets


In [4]:
def get_min_max(weights):
    min_val = np.min(weights)
    max_val = np.max(weights)
    return min_val, max_val

weight_min_max = {}

for i, layer in enumerate(model.layers):
    if len(layer.get_weights()) > 0: #the flatten layer doesn't have weigts
        weights = layer.get_weights()[0]
        print("weigths: ", len(weights)) #28x28 input flattened into a vector
        bias = layer.get_weights()[1] # number of neurons - output neurons in second layer
        print("biases: ", len(bias))
        weight_min_max[layer.name] = {
            'weights_min_max': get_min_max(weights),
            'bias_min_max': get_min_max(bias)
        }

print(weight_min_max)

weigths:  784
biases:  128
weigths:  128
biases:  10
{'dense': {'weights_min_max': (-0.92499673, 0.5262218), 'bias_min_max': (-0.16946252, 0.2295666)}, 'dense_1': {'weights_min_max': (-0.9400578, 0.7109961), 'bias_min_max': (-0.095205545, 0.1373023)}}


In [5]:
def quantize(weights, min_val, max_val, num_bits=8): #[0,255]
    scale = (max_val - min_val) / (2 ** num_bits - 1) #255 max
    zero_point = np.round(-min_val / scale) #0 = round(min_val / scale + zero_point)
    quantized_weights = np.round(weights / scale + zero_point)
    return quantized_weights, scale, zero_point

def symmetrically_quantize(weights, min_val, max_val, num_bits=8): #[-128, 127]
    scale = (max_val - min_val) / (2 ** num_bits - 1)    
    zero_point = 0
    quantized_weights = np.round(weights / scale)    
    quantized_weights = np.clip(quantized_weights, -(2 ** (num_bits - 1)), (2 ** (num_bits - 1)) - 1) #due to scaling and rounding, some quantized value falls outside this range    
    return quantized_weights, scale, zero_point

In [6]:
# Quantize the weights
quantized_weights = {}
scales = {}
zero_points = {}

for layer in model.layers:
    if len(layer.get_weights()) > 0: #the flatten layer doesn't have weigts
        weights = layer.get_weights()[0]
        bias = layer.get_weights()[1]
        
        w_min, w_max = weight_min_max[layer.name]['weights_min_max']
        b_min, b_max = weight_min_max[layer.name]['bias_min_max']
        
        #q_weights, w_scale, w_zero_point = quantize(weights, w_min, w_max)
        #q_bias, b_scale, b_zero_point = quantize(bias, b_min, b_max)
        
        q_weights, w_scale, w_zero_point = symmetrically_quantize(weights, w_min, w_max)
        q_bias, b_scale, b_zero_point = symmetrically_quantize(bias, b_min, b_max)

        quantized_weights[layer.name] = {
            'quantized_weights': q_weights,
            'quantized_bias': q_bias
        }
        scales[layer.name] = {
            'weights_scale': w_scale,
            'bias_scale': b_scale
        }
        zero_points[layer.name] = {
            'weights_zero_point': w_zero_point,
            'bias_zero_point': b_zero_point
        }

print("Quantized weights keys:", quantized_weights.keys())
print("Scales keys:", scales.keys())

Quantized weights keys: dict_keys(['dense', 'dense_1'])
Scales keys: dict_keys(['dense', 'dense_1'])


In [7]:
def dequantize(q_weights, scale, zero_point):
    return (q_weights - zero_point) * scale


In [10]:
class QuantizedDense(tf.keras.layers.Layer):
    def __init__(self, units, quantized_weights, scales, zero_points, activation=None):
        super(QuantizedDense, self).__init__()
        self.units = units
        self.quantized_weights = quantized_weights
        self.scales = scales
        self.zero_points = zero_points
        self.activation = activation
    
    def build(self, input_shape):
        self.quantized_bias = self.add_weight(name='quantized_bias', shape=(self.units,), initializer='zeros', trainable=False)

    def call(self, inputs):
        q_weights = self.quantized_weights
        w_scale = self.scales['weights_scale']
        w_zero_point = self.zero_points['weights_zero_point']
        
        w = dequantize(q_weights, w_scale, w_zero_point)
        b = dequantize(self.quantized_bias, self.scales['bias_scale'], self.zero_points['bias_zero_point'])
        
        output = tf.matmul(inputs, w) + b
        
        if self.activation is not None:
            output = self.activation(output)
        
        return output

In [11]:
# Create a new model with quantized layers
quantized_model = tf.keras.Sequential([
    Flatten(input_shape=(28, 28, 1)),
    QuantizedDense(128, quantized_weights['dense']['quantized_weights'], scales['dense'], zero_points['dense'], activation=tf.nn.relu),
    QuantizedDense(10, quantized_weights['dense_1']['quantized_weights'], scales['dense_1'], zero_points['dense_1'], activation=tf.nn.softmax)
])

quantized_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

quantized_model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

loss, accuracy = quantized_model.evaluate(x_test, y_test)
print(f'Quantized Model Test Accuracy: {accuracy}')

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Quantized Model Test Accuracy: 0.9772999882698059
