In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Flatten, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt




In [2]:
# Prepare the data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') 
x_test = x_test.astype('float32')
x_train = np.expand_dims(x_train, -1) # CNNs typically expect input data to be 4D
x_test = np.expand_dims(x_test, -1)


In [3]:
class QuantizedDense(tf.keras.layers.Layer):
    def __init__(self, units, activation=None):
        super(QuantizedDense, self).__init__()
        self.units = units
        self.activation = tf.keras.activations.get(activation)

    def build(self, input_shape):
        self.w = self.add_weight(shape=(input_shape[-1], self.units), initializer="random_normal", trainable=True)
        self.b = self.add_weight(shape=(self.units,), initializer="random_normal", trainable=True)

    def call(self, inputs):
        # Simulate quantization with fake quantization
        quantized_w = tf.quantization.fake_quant_with_min_max_args(self.w, min=-1.0, max=1.0, num_bits=8)
        quantized_b = tf.quantization.fake_quant_with_min_max_args(self.b, min=-1.0, max=1.0, num_bits=8)
        output = tf.matmul(inputs, quantized_w) + quantized_b
        if self.activation is not None:
            output = self.activation(output)
        return output


In [4]:
input_layer = Input(shape=(28, 28, 1))
flatten_layer = Flatten()(input_layer)
quantized_dense_layer_1 = QuantizedDense(128, activation='relu')(flatten_layer)
output_layer = QuantizedDense(10, activation='softmax')(quantized_dense_layer_1)

quantized_model = Model(inputs=input_layer, outputs=output_layer)
quantized_model.compile(optimizer=Adam(learning_rate=0.001),
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy'])




In [5]:
quantized_model.fit(x_train, y_train, epochs=20, validation_data=(x_test, y_test))


Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.src.callbacks.History at 0x1314ca160>

In [None]:
loss, accuracy = quantized_model.evaluate(x_test, y_test)
print(f'Quantized Model Test Accuracy: {accuracy}')
