In [1]:
import tensorflow as tf
import tensorflow_model_optimization as tfmopt

In [2]:
# ✅ Define model using tf.keras.layers
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(28, 28)),       # modern Input, no InputLayer
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(100, activation="relu"),
    tf.keras.layers.Dense(10, activation="softmax")  # softmax for 10-class classification
])

# ✅ Compile
model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",   # use categorical_crossentropy if one-hot labels
    metrics=["accuracy"]
)

# Example training data (MNIST)
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train, X_test = X_train / 255.0, X_test / 255.0

# Train
model.fit(X_train, y_train, epochs=5, batch_size=32, validation_split=0.1)

# ✅ Save in new format
model.save("./models/digit_classifier.keras")

# ✅ Load back safely
loaded_model = tf.keras.models.load_model("./models/digit_classifier.keras")

print("Reload successful 🎉")
print(loaded_model.summary())


Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Reload successful 🎉
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten (Flatten)           (None, 784)               0         
                                                                 
 dense (Dense)               (None, 100)               78500     
                                                                 
 dense_1 (Dense)             (None, 10)                1010      
                                                                 
Total params: 79510 (310.59 KB)
Trainable params: 79510 (310.59 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
None


### Method 1: Post Training Quantization

In [3]:
# With TF Lite conversion only and without quantization
converter = tf.lite.TFLiteConverter.from_saved_model("./models/saved_model")
tflite_model = converter.convert()

W0000 00:00:1757489004.244148 1965882 tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
W0000 00:00:1757489004.244162 1965882 tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2025-09-10 12:53:24.244336: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: ./models/saved_model
2025-09-10 12:53:24.244493: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2025-09-10 12:53:24.244496: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: ./models/saved_model
I0000 00:00:1757489004.245531 1965882 mlir_graph_optimization_pass.cc:437] MLIR V1 optimization pass is not enabled
2025-09-10 12:53:24.245688: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2025-09-10 12:53:24.251327: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: ./models/saved_model
2025-09-10 12:53:24.253330: I tensorflow/cc/saved_model/loader.cc:471]

In [4]:
len(tflite_model)

320676

In [5]:
# With TF Lite conversion and quantization
converter = tf.lite.TFLiteConverter.from_saved_model("./models/saved_model")
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()

W0000 00:00:1757489006.287241 1965882 tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
W0000 00:00:1757489006.287253 1965882 tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2025-09-10 12:53:26.287372: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: ./models/saved_model
2025-09-10 12:53:26.287515: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2025-09-10 12:53:26.287518: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: ./models/saved_model
2025-09-10 12:53:26.288512: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2025-09-10 12:53:26.293832: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: ./models/saved_model
2025-09-10 12:53:26.295787: I tensorflow/cc/saved_model/loader.cc:471] SavedModel load for tags { serve }; Status: success: OK. Took 8416 microseconds.


In [6]:
len(tflite_quant_model)

86744

In [7]:
with open("./models/tflite_model.tflite", "wb") as f:
    f.write(tflite_model)

In [8]:
with open("./models/tflite_quant_model.tflite", "wb") as f:
    f.write(tflite_quant_model)

### Method 2: Quantization Aware Training

In [9]:
# Load original model
model = tf.keras.models.load_model("./models/digit_classifier.keras")

In [10]:
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten (Flatten)           (None, 784)               0         
                                                                 
 dense (Dense)               (None, 100)               78500     
                                                                 
 dense_1 (Dense)             (None, 10)                1010      
                                                                 
Total params: 79510 (310.59 KB)
Trainable params: 79510 (310.59 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [11]:
quantize_model = tfmopt.quantization.keras.quantize_model
q_aware_model = quantize_model(model)

In [12]:
q_aware_model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

In [13]:
q_aware_model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 quantize_layer (QuantizeLa  (None, 28, 28)            3         
 yer)                                                            
                                                                 
 quant_flatten (QuantizeWra  (None, 784)               1         
 pperV2)                                                         
                                                                 
 quant_dense (QuantizeWrapp  (None, 100)               78505     
 erV2)                                                           
                                                                 
 quant_dense_1 (QuantizeWra  (None, 10)                1015      
 pperV2)                                                         
                                                                 
Total params: 79524 (310.64 KB)
Trainable params: 79510 

In [14]:
q_aware_model.fit(X_train, y_train, epochs=1)



<tf_keras.src.callbacks.History at 0x32675aef0>

In [15]:
q_aware_model.evaluate(X_test, y_test)



[0.07857106626033783, 0.9763000011444092]

In [16]:
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_qaware_model = converter.convert()

INFO:tensorflow:Assets written to: /var/folders/cn/w73m2p855wq0ys03tc29wqs40000gn/T/tmpwfpog3fb/assets


INFO:tensorflow:Assets written to: /var/folders/cn/w73m2p855wq0ys03tc29wqs40000gn/T/tmpwfpog3fb/assets
W0000 00:00:1757489014.794646 1965882 tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
W0000 00:00:1757489014.794654 1965882 tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2025-09-10 12:53:34.794744: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/cn/w73m2p855wq0ys03tc29wqs40000gn/T/tmpwfpog3fb
2025-09-10 12:53:34.795358: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2025-09-10 12:53:34.795362: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /var/folders/cn/w73m2p855wq0ys03tc29wqs40000gn/T/tmpwfpog3fb
2025-09-10 12:53:34.798526: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2025-09-10 12:53:34.813115: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /var/folders/cn/w73m2p85

In [17]:
with open("./models/tflite_qaware_model.tflite", "wb") as f:
    f.write(tflite_qaware_model)

In [18]:
len(tflite_qaware_model)

82688