# [Quantization aware training in Keras example](https://www.tensorflow.org/model_optimization/guide/quantization/training_example)

## Summary
In this tutorial, you will:

1. Train a keras model for MNIST from scratch.
2. Fine tune the model by applying the quantization aware training API, see the accuracy, and export a quantization aware model.
3. Use the model to create an actually quantized model for the TFLite backend.
4. See the persistence of accuracy in TFLite and a 4x smaller model. To see the latency benefits on mobile, try out the TFLite examples in the [TFLite app repository](https://www.tensorflow.org/lite/models)

In [None]:
!pip install -q tensorflow
!pip install -q tensorflow-model-optimization

In [None]:
import tempfile
import os

import tensorflow as tf
from tensorflow import keras

## Train a model for MNIST without quantization aware training

The **MNIST** database (Modified National Institute of Standards and Technology database) is a large database of handwritten digits that is commonly used for training various image processing systems. The database is also widely used for training and testing in the field of machine learning.

## Load MNIST Dataset

In [None]:
# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

## Normalize the dataset

In [None]:
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

## Define the model architecture

In [None]:
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])



## Train the digit classification model

In [None]:
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [None]:
model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_split=0.1,
)

[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 3ms/step - accuracy: 0.8490 - loss: 0.5670 - val_accuracy: 0.9685 - val_loss: 0.1103


<keras.src.callbacks.history.History at 0x79ac99083a90>

## Clone and fine-tune pre-trained model with quantization aware training
### Define the model

You will apply quantization aware training to the whole model and see this in the model summary. All layers are now prefixed by "quant".

Note that the resulting model is quantization aware but not quantized (e.g. the weights are float32 instead of int8). The sections after show how to create a quantized model from the quantization aware one.

In the [comprehensive guide](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide.md), you can see how to quantize some layers for model accuracy improvements.

In [None]:
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.keras.compat import keras # Import keras from tfmot

# Instead of directly applying quantize_model, recreate the model structure
def create_quantizable_model():
    model = keras.Sequential([ # Use keras imported from tfmot
        keras.layers.InputLayer(input_shape=(28, 28)),
        keras.layers.Reshape(target_shape=(28, 28, 1)),
        keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dense(10)
    ])
    return model

# Create a new model instance
quantizable_model = create_quantizable_model()

# Now apply quantize_model
q_aware_model = tfmot.quantization.keras.quantize_model(quantizable_model)

# Compile the quantized model
q_aware_model.compile(optimizer='adam',
                      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                      metrics=['accuracy'])

q_aware_model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 quantize_layer (QuantizeLa  (None, 28, 28)            3         
 yer)                                                            
                                                                 
 quant_reshape (QuantizeWra  (None, 28, 28, 1)         1         
 pperV2)                                                         
                                                                 
 quant_conv2d (QuantizeWrap  (None, 26, 26, 12)        147       
 perV2)                                                          
                                                                 
 quant_max_pooling2d (Quant  (None, 13, 13, 12)        1         
 izeWrapperV2)                                                   
                                                                 
 quant_flatten (QuantizeWra  (None, 2028)              1