In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow.keras.models import load_model
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.utils import to_categorical

# 1. Load Fashion MNIST dataset
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()

# Normalize and reshape
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0
X_train = np.expand_dims(X_train, -1)  # shape: (60000, 28, 28, 1)
X_test = np.expand_dims(X_test, -1)

# One-hot encode labels (because pruning loss expects categorical)
y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10)

# 2. Load baseline model (make sure it's compiled if needed)
baseline_model = load_model('baseline_model1.h5')
print("✅ Baseline model loaded")

✅ Baseline model loaded


In [2]:
# Step 4: Apply Quantization-Aware Training
quantized_model = tfmot.quantization.keras.quantize_model
q_aware_model = quantized_model(baseline_model)

print(q_aware_model.summary())

# Step 5: Compile and train the quantization-aware model
q_aware_model.compile(optimizer='adam',
                      loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
                      metrics=['accuracy'])

q_aware_model.fit(X_train, y_train, batch_size=500, epochs=2, validation_split=0.1)


Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 quantize_layer (QuantizeLa  (None, 28, 28, 1)         3         
 yer)                                                            
                                                                 
 quant_conv2d (QuantizeWrap  (None, 26, 26, 32)        387       
 perV2)                                                          
                                                                 
 quant_max_pooling2d (Quant  (None, 13, 13, 32)        1         
 izeWrapperV2)                                                   
                                                                 
 quant_conv2d_1 (QuantizeWr  (None, 11, 11, 64)        18627     
 apperV2)                                                        
                                                                 
 quant_max_pooling2d_1 (Qua  (None, 5, 5, 64)          1

  output, from_logits = _get_logits(


Epoch 2/2


<keras.src.callbacks.History at 0x1504e53f0>

In [6]:
# Step 6: Convert the QAT model to a quantized TFLite model
import pathlib


converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model_qat = converter.convert()

# Step 7: Save the quantized model to your desired directory
tflite_models_dir = pathlib.Path('/Users/oscarpatrikminj/Documents/IITR/FMNIST/tflite_models/')
tflite_models_dir.mkdir(exist_ok=True, parents=True)
tflite_model_file = tflite_models_dir / 'model_qat.tflite'
tflite_model_file.write_bytes(tflite_model_qat)

print(f"✅ QAT TFLite model saved at: {tflite_model_file}")

INFO:tensorflow:Assets written to: /var/folders/bs/x0lj933d1hv0py0d4w2ypdp40000gn/T/tmpa8hi8xrd/assets


INFO:tensorflow:Assets written to: /var/folders/bs/x0lj933d1hv0py0d4w2ypdp40000gn/T/tmpa8hi8xrd/assets


✅ QAT TFLite model saved at: /Users/oscarpatrikminj/Documents/IITR/FMNIST/tflite_models/model_qat.tflite


2025-06-05 21:18:42.464203: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
2025-06-05 21:18:42.464218: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2025-06-05 21:18:42.464434: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: /var/folders/bs/x0lj933d1hv0py0d4w2ypdp40000gn/T/tmpa8hi8xrd
2025-06-05 21:18:42.467155: I tensorflow/cc/saved_model/reader.cc:91] Reading meta graph with tags { serve }
2025-06-05 21:18:42.467166: I tensorflow/cc/saved_model/reader.cc:132] Reading SavedModel debug info (if present) from: /var/folders/bs/x0lj933d1hv0py0d4w2ypdp40000gn/T/tmpa8hi8xrd
2025-06-05 21:18:42.476615: I tensorflow/cc/saved_model/loader.cc:231] Restoring SavedModel bundle.
2025-06-05 21:18:42.562753: I tensorflow/cc/saved_model/loader.cc:215] Running initialization op on SavedModel bundle at path: /var/folders/bs/x0lj933d1hv0py0d4w2ypdp40000gn/T/tmpa8hi8xrd
2025-06-

In [8]:
# Step 8: Evaluate the TFLite QAT model on test data
interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
interpreter.allocate_tensors()
input_index = interpreter.get_input_details()[0]['index']
output_index = interpreter.get_output_details()[0]['index']

pred_list = []
for image in X_test:
    input_data = np.array(image, dtype=np.float32)
    input_data = input_data.reshape(1, input_data.shape[0], input_data.shape[1], 1)
    interpreter.set_tensor(input_index, input_data)
    interpreter.invoke()
    prediction = interpreter.get_tensor(output_index)
    pred_list.append(np.argmax(prediction))

# Step 9: Compute Accuracy
accurate_count = sum([1 for i in range(len(pred_list)) if pred_list[i] == np.argmax(y_test[i])])
accuracy = accurate_count / len(pred_list)
print(f" Accuracy of QAT TFLite model: {accuracy:.4f}")

# Step 10: Calculate Model Size in KB
size_kb = os.path.getsize(tflite_model_file) / 1024
print(f" Model size (uncompressed): {size_kb:.2f} KB")

 Accuracy of QAT TFLite model: 0.9202
 Model size (uncompressed): 183.39 KB
