In [1]:
import os
from math import log2, ceil, floor

import tensorflow as tf
import numpy as np

np.seterr(all='raise')

{'divide': 'warn', 'over': 'warn', 'under': 'ignore', 'invalid': 'warn'}

In [2]:
HOME_DIR = os.getcwd()
MODEL_DIR = os.path.join(HOME_DIR, "models")

In [3]:
def quantize_nearest(x, scale, zero, qtype):
    if qtype not in {np.int8, np.uint8}:
        raise Exception("Only quantization to int8 or uint8 is supported")
    
    (min, max) = (-128, 127) if qtype == np.int8 else (0, 255)

    return np.clip(np.rint(x / scale) + zero, min, max).astype(qtype)

def fc_and_requantize(input_tensor, weights, bias, q_i, q_w, q_o):
    
    if input_tensor.dtype != np.int8:
        raise Exception("Input must be of type int8")
    
    if weights.dtype != np.int8:
        raise Exception("Weights must be of type int8")
    
    if bias.dtype != np.int32:
        raise Exception("Input and weights must be of type int32")
    
    (s_i, z_i), (s_w, z_w), (s_o, z_o) = q_i, q_w, q_o
    
    if z_w != 0:
        raise Exception("Expected zero point of weights to be 0")

    s = s_i * s_w / s_o

    # 1) shift input tensor
    input_tensor_32 = input_tensor.astype(np.int32) - z_i
    weights_32 = weights.astype(np.int32)

    # 2) compute the bmm
    bmm = np.matmul(input_tensor_32, weights_32.transpose()) + bias

    # 3) requantize
    rq = np.rint(s * bmm) + z_o

    # 4) saturating cast
    output = np.clip(rq, -128, 127).astype(np.int8)

    return output

In [4]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255.0, x_test.astype(np.float32) / 255.0

In [5]:
interpreter = tf.lite.Interpreter(os.path.join(MODEL_DIR, "simple_model_quant.tflite"), experimental_preserve_all_tensors=True)
interpreter.allocate_tensors()

In [6]:
input_details = interpreter.get_input_details()[0]
output_details = interpreter.get_output_details()[0]

print(input_details)
print(input_details["dtype"])

print(output_details)
print(output_details["dtype"])

{'name': 'serving_default_flatten_2_input:0', 'index': 0, 'shape': array([ 1, 28, 28], dtype=int32), 'shape_signature': array([-1, 28, 28], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.003921568859368563, 0), 'quantization_parameters': {'scales': array([0.00392157], dtype=float32), 'zero_points': array([0], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}
<class 'numpy.uint8'>
{'name': 'StatefulPartitionedCall:0', 'index': 7, 'shape': array([ 1, 10], dtype=int32), 'shape_signature': array([-1, 10], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.1573459506034851, 175), 'quantization_parameters': {'scales': array([0.15734595], dtype=float32), 'zero_points': array([175], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}
<class 'numpy.uint8'>


In [7]:
interpreter.get_tensor_details()

[{'name': 'serving_default_flatten_2_input:0',
  'index': 0,
  'shape': array([ 1, 28, 28], dtype=int32),
  'shape_signature': array([-1, 28, 28], dtype=int32),
  'dtype': numpy.uint8,
  'quantization': (0.003921568859368563, 0),
  'quantization_parameters': {'scales': array([0.00392157], dtype=float32),
   'zero_points': array([0], dtype=int32),
   'quantized_dimension': 0},
  'sparsity_parameters': {}},
 {'name': 'sequential_3/flatten_2/Const',
  'index': 1,
  'shape': array([2], dtype=int32),
  'shape_signature': array([2], dtype=int32),
  'dtype': numpy.int32,
  'quantization': (0.0, 0),
  'quantization_parameters': {'scales': array([], dtype=float32),
   'zero_points': array([], dtype=int32),
   'quantized_dimension': 0},
  'sparsity_parameters': {}},
 {'name': 'sequential_3/dense_2/BiasAdd/ReadVariableOp',
  'index': 2,
  'shape': array([10], dtype=int32),
  'shape_signature': array([10], dtype=int32),
  'dtype': numpy.int32,
  'quantization': (4.8770318244351074e-05, 0),
  'quan

In [8]:
for i, t in enumerate(interpreter.get_tensor_details()):
    print(i, ": ", t["name"], sep="")

0: serving_default_flatten_2_input:0
1: sequential_3/flatten_2/Const
2: sequential_3/dense_2/BiasAdd/ReadVariableOp
3: sequential_3/dense_2/MatMul
4: tfl.quantize
5: sequential_3/flatten_2/Reshape
6: StatefulPartitionedCall:01
7: StatefulPartitionedCall:0


In [9]:
interpreter._get_ops_details()

[{'index': 0,
  'op_name': 'QUANTIZE',
  'inputs': array([0], dtype=int32),
  'outputs': array([4], dtype=int32)},
 {'index': 1,
  'op_name': 'RESHAPE',
  'inputs': array([4, 1], dtype=int32),
  'outputs': array([5], dtype=int32)},
 {'index': 2,
  'op_name': 'FULLY_CONNECTED',
  'inputs': array([5, 3, 2], dtype=int32),
  'outputs': array([6], dtype=int32)},
 {'index': 3,
  'op_name': 'QUANTIZE',
  'inputs': array([6], dtype=int32),
  'outputs': array([7], dtype=int32)}]

In [10]:
input_idx = 5
bias_idx = 2
weight_idx = 3
output_idx = 6

w = interpreter.get_tensor(weight_idx)
b = interpreter.get_tensor(bias_idx)
q_i = interpreter.get_tensor_details()[input_idx]["quantization"]
q_w = interpreter.get_tensor_details()[weight_idx]["quantization"]
q_o = interpreter.get_tensor_details()[output_idx]["quantization"]

In [11]:
def manual_model(input_tensor):

    # 1) shift input tensor by -128 to sitch from input type (uint8) to TF Lite internal type (int8) 
    shifted_input = input_tensor.astype(np.int32)
    shifted_input = shifted_input - 128
    shifted_input = shifted_input.astype(np.int8)

    # 2) flatten input
    flattened_input = shifted_input.reshape(interpreter.get_tensor(1)) # [-1, 784]

    # 3) run fully-connected layer
    fc1 = fc_and_requantize(flattened_input, w, b, q_i, q_w, q_o)

    # 4) undo the shift to switch from TF Lite internal type (int8) to output type (uint8)
    output = fc1.astype(np.int32)
    output = output + 128
    output = output.astype(np.uint8)

    return output

## Execution

In [12]:
chosen_image = 150

In [50]:
test_image = x_test[chosen_image]

# Need to quantize the inputs outside the model!
input_scale, input_zero_point = input_details["quantization"]
input_tensor = quantize_nearest(test_image, input_scale, input_zero_point, np.uint8)
input_tensor = np.expand_dims(input_tensor, axis=0)

# Run the TF Lite model
interpreter.set_tensor(input_details["index"], input_tensor)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details["index"])[0]

# Run the manual model
manual_output = manual_model(input_tensor)[0]

print("Test image {}\nTF Lite output:\t\t{}\nManual model output:\t{}\nCorrect label: {}".format(chosen_image, tflite_output, manual_output, y_test[chosen_image]))

Test image 150
TF Lite output:		[135 109 152 161 187 157 159 151 173 202]
Manual model output:	[135 109 152 161 187 157 159 151 173 202]
Correct label: 9


In [22]:
(manual_output == tflite_output).all()

True

In [33]:
# np.set_printoptions(precision=30, suppress=True)
# x_test[chosen_image]

In [60]:
def print_2d(t):
    for st in t[0]:
        print("[{}],".format(", ".join([str(e) for e in st])))

In [61]:
print_2d(input_tensor)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 18, 130, 188, 182, 79, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 46, 171, 254, 254, 254, 254, 254, 142, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 85, 239, 253, 201, 158, 175, 253, 254, 251, 20, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 113, 239, 230, 104, 0, 0, 0, 107, 234, 103, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 16, 218, 242, 32, 0, 0, 0, 0, 0, 22, 22, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 2, 174, 254, 153, 0, 0, 0, 0, 0, 0, 0, 15, 103, 7, 0, 0