In [52]:
import os
from math import log2, floor

import tensorflow as tf
import numpy as np

In [2]:
HOME_DIR = os.getcwd()
MODEL_DIR = os.path.join(HOME_DIR, "models")

In [3]:
def quantize_nearest(x, scale, zero, qtype):
    if qtype not in {np.int8, np.uint8}:
        raise Exception("Only quantization to int8 or uint8 is supported")
    
    (min, max) = (-128, 127) if qtype == np.int8 else (0, 255)

    return np.clip(np.rint(x / scale) + zero, min, max).astype(qtype)

def fc_and_requantize(input_tensor, weights, bias, q_i, q_w, q_o):
    
    if input_tensor.dtype != np.int8:
        raise Exception("Input must be of type int8")
    
    if weights.dtype != np.int8:
        raise Exception("Weights must be of type int8")
    
    if bias.dtype != np.int32:
        raise Exception("Input and weights must be of type int32")
    
    (s_i, z_i), (s_w, z_w), (s_o, z_o) = q_i, q_w, q_o
    
    if z_w != 0:
        raise Exception("Expected zero point of weights to be 0")

    s = s_i * s_w / s_o

    # 1) shift input tensor
    input_tensor_32 = input_tensor.astype(np.int32) - z_i
    weights_32 = weights.astype(np.int32)

    # 2) compute the bmm
    bmm = np.matmul(input_tensor_32, weights_32.transpose()) + bias

    # 3) requantize
    rq = np.rint(s * bmm) + z_o

    # 4) saturating cast
    output = np.clip(rq, -128, 127).astype(np.int8)

    return output

In [4]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255.0, x_test.astype(np.float32) / 255.0

# Full model

(not ready; go to "Simple model" below)

In [5]:
interpreter = tf.lite.Interpreter(os.path.join(MODEL_DIR, "mnist_model_quant.tflite"))
interpreter.allocate_tensors()

INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


In [6]:
initial_tensors = [None] * 11

for i in range(11):
    try:
        initial_tensors[i] = interpreter.get_tensor(i).copy()
    except:
        pass

In [7]:
input_details = interpreter.get_input_details()[0]
output_details = interpreter.get_output_details()[0]

print(input_details)
print(input_details["dtype"])

print(output_details)
print(output_details["dtype"])


{'name': 'serving_default_flatten_3_input:0', 'index': 0, 'shape': array([ 1, 28, 28], dtype=int32), 'shape_signature': array([-1, 28, 28], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.003921568859368563, 0), 'quantization_parameters': {'scales': array([0.00392157], dtype=float32), 'zero_points': array([0], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}
<class 'numpy.uint8'>
{'name': 'StatefulPartitionedCall:0', 'index': 10, 'shape': array([ 1, 10], dtype=int32), 'shape_signature': array([-1, 10], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.2266918420791626, 143), 'quantization_parameters': {'scales': array([0.22669184], dtype=float32), 'zero_points': array([143], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}
<class 'numpy.uint8'>


In [78]:
chosen_image = 22

In [79]:
test_image = x_test[chosen_image]

# Need to quantize the inputs outside the model!
input_scale, input_zero_point = input_details["quantization"]

# quantisation transformation as float32 first
test_image = quantize_nearest(test_image, input_scale, input_zero_point, np.uint8)
test_image = np.expand_dims(test_image, axis=0)

interpreter.set_tensor(input_details["index"], test_image)
interpreter.invoke()
output = interpreter.get_tensor(output_details["index"])[0]
output_prediction = output.argmax()

print(output)
print("{} (correct: {})".format(y_test[chosen_image], output_prediction))

[131 128 120 123 152 117 191 125 130 109]
6 (correct: 6)


**Important**: to perform a saturating cast, one must use np.clip. Otherwise problematic things happen - for instance, from f32 to u8, it seems first the floor is applied followed by % 256 (which is not what we want).

In [80]:
arr_200 = np.array([200], dtype=np.float32)
print(arr_200.astype(np.uint8))

arr_m42 = np.array([-42], dtype=np.float32)
print(arr_m42.astype(np.uint8))
print(np.clip(arr_m42, 0, 255).astype(np.uint8))

arr_422 = np.array([422], dtype=np.float32)
print(arr_422.astype(np.uint8))
print(np.clip(arr_422, 0, 255).astype(np.uint8))

[200]
[214]
[0]
[166]
[255]


### Model

The tensors in the interpreter (cf. next cell) should be interpreted as follows:
- 0: `serving_default_flatten_3_input:0`: it simply holds the (already quantised) input tensor (u8, initialised to 0)
- 1: `sequential_6/flatten_3/Const`: it stores, as a constant, the shape that the input should be flattened to by the Reshape node (cf. 7), to which it is an input (i32, does not change)
- 2: `sequential_6/dense_7/BiasAdd/ReadVariableOp`: it holds the bias for the second FC layer (identified by `dense_7`), and it consists of 10 `i32`s
- 3: `sequential_6/dense_7/MatMul`: this is the vec-by-matrix multiplication for the second FC layer. It holds the matrix coefficients as with entries in `i8`. The vector's has entries in ???. The two are multiplied together in `i32` precision to avoid overflows.
- 4: `sequential_6/dense_6/BiasAdd/ReadVariableOp`: it holds the bias for the first FC layer  (cf. 2)
- 5: `sequential_6/dense_6/MatMul`: this is the vec-by-matrix multiplication for the first FC layer (cf. 3)
- 6: `tfl.quantize`: this has the exact same quantisation scale as the input node, but the zero point is -128 as opposed to 0. Also, I am unsure what it does, since input quantisation needs to be performed externally by the user... (it changes during inference!)
- 7: `sequential_6/flatten_3/Reshape`: it flattens the 28 x 28 image into a flat 784-element vector (no value)
- 8: `sequential_6/dense_6/MatMul;sequential_6/activation_3/Relu;sequential_6/dense_6/BiasAdd`: this performs BMM and ReLU (no value)
- 9: `StatefulPartitionedCall:01`: ??? (i8, initial value: 0)
- 10: `StatefulPartitionedCall:0`: holds the actual output (u8, initial value: 0)

In [81]:
final_tensors = [None] * 11

for i in range(11):
    try:
        final_tensors[i] = interpreter.get_tensor(i).copy()
    except:
        pass

In [133]:
params = {
    "bias2": interpreter.get_tensor(2).copy(),
    "mat2": interpreter.get_tensor(3).copy(),
    "bias1": interpreter.get_tensor(4).copy(),
    "mat1": interpreter.get_tensor(5).copy(),
}

ValueError: Tensor data is null. Run allocate_tensors() first

In [134]:
s_i1 = interpreter.get_tensor_details()[0]["quantization"][0]
s_w1 = interpreter.get_tensor_details()[5]["quantization"][0]
s_o1 = interpreter.get_tensor_details()[8]["quantization"][0]
s_i1, s_w1, s_o1

IndexError: list index out of range

In [84]:
s_i2 = interpreter.get_tensor_details()[8]["quantization"][0]
s_w2 = interpreter.get_tensor_details()[3]["quantization"][0]
s_o2 = interpreter.get_tensor_details()[9]["quantization"][0]
s_i2, s_w2, s_o2

(0.038711175322532654, 0.007592691574245691, 0.2266918420791626)

In [85]:
input_tensor = test_image
flattened_input = input_tensor.flatten()
# TODO I think the next two lines do the same as the third on its own
precision_input = flattened_input.astype(np.int32)
quantised_input = precision_input - 128
finalised_input = quantised_input.astype(np.int8)

fc1 = fc_and_requantize(finalised_input, params["mat1"], params["bias1"], s_i1, s_w1, s_o1)

# Applying ReLU to i8 input
relu = fc1.clip(0, 127)

fc2 = fc_and_requantize(relu, params["mat2"], params["bias2"], s_i2, s_w2, s_o2)

In [86]:
fc2

array([ -61,  -28, -102, -128,   59,  -37,  -70,  -31,  -63,  -56],
      dtype=int8)

In [73]:
v = (np.matmul(finalised_input.astype(np.int32), params["mat1"].astype(np.int32).transpose()) + params["bias1"]) * s_i1 * s_w1 / s_o1

In [76]:
v2 = v.clip(-128, 127).astype(np.int8).clip(0, 127)

In [77]:
(np.matmul(v2.astype(np.int32), params["mat2"].astype(np.int32).transpose()) + params["bias2"]) * s_i2 * s_w2 / s_o2

array([ -72.16972182,  -64.82853816, -116.54355962, -113.76111877,
         41.17908744,  -23.50034508, -141.97709171,  -48.27132951,
        -50.38862957,    1.92151787])

In [24]:
interpreter.get_tensor(9)
fc2

array([ -73,  -65, -116, -114,   41,  -24, -128,  -48,  -51,    2],
      dtype=int8)

In [26]:
fs, fz = interpreter.get_tensor_details()[9]["quantization"]

In [27]:
fs, fz

(0.2266918420791626, 15)

In [29]:
fc2/fs + fz

array([-307.02305707, -271.73285904, -496.70787151, -487.88532201,
        195.86226493,  -90.87059411, -549.64316857, -196.74118821,
       -209.97501248,   23.82254951])

In [49]:
interpreter.get_tensor(9) - interpreter.get_tensor(10)

array([[-128, -128, -128, -128, -128, -128, -128, -128, -128, -128]],
      dtype=int16)

# Simple model

In [5]:
interpreter = tf.lite.Interpreter(os.path.join(MODEL_DIR, "simple_model_quant.tflite"), experimental_preserve_all_tensors=True)
interpreter.allocate_tensors()

In [6]:
input_details = interpreter.get_input_details()[0]
output_details = interpreter.get_output_details()[0]

print(input_details)
print(input_details["dtype"])

print(output_details)
print(output_details["dtype"])


{'name': 'serving_default_flatten_2_input:0', 'index': 0, 'shape': array([ 1, 28, 28], dtype=int32), 'shape_signature': array([-1, 28, 28], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.003921568859368563, 0), 'quantization_parameters': {'scales': array([0.00392157], dtype=float32), 'zero_points': array([0], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}
<class 'numpy.uint8'>
{'name': 'StatefulPartitionedCall:0', 'index': 7, 'shape': array([ 1, 10], dtype=int32), 'shape_signature': array([-1, 10], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.1573459506034851, 175), 'quantization_parameters': {'scales': array([0.15734595], dtype=float32), 'zero_points': array([175], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}
<class 'numpy.uint8'>


In [7]:
chosen_image = 150

In [8]:
test_image = x_test[chosen_image]

# Need to quantize the inputs outside the model!
input_scale, input_zero_point = input_details["quantization"]
input_tensor = quantize_nearest(test_image, input_scale, input_zero_point, np.uint8)
input_tensor = np.expand_dims(input_tensor, axis=0)

# Run the model
interpreter.set_tensor(input_details["index"], input_tensor)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details["index"])[0]
output_prediction = tflite_output.argmax()

print(tflite_output)
print("{} (correct: {})".format(output_prediction, y_test[chosen_image]))

[135 109 152 161 187 157 159 151 173 202]
9 (correct: 9)


In [9]:
interpreter.get_tensor_details()

[{'name': 'serving_default_flatten_2_input:0',
  'index': 0,
  'shape': array([ 1, 28, 28], dtype=int32),
  'shape_signature': array([-1, 28, 28], dtype=int32),
  'dtype': numpy.uint8,
  'quantization': (0.003921568859368563, 0),
  'quantization_parameters': {'scales': array([0.00392157], dtype=float32),
   'zero_points': array([0], dtype=int32),
   'quantized_dimension': 0},
  'sparsity_parameters': {}},
 {'name': 'sequential_3/flatten_2/Const',
  'index': 1,
  'shape': array([2], dtype=int32),
  'shape_signature': array([2], dtype=int32),
  'dtype': numpy.int32,
  'quantization': (0.0, 0),
  'quantization_parameters': {'scales': array([], dtype=float32),
   'zero_points': array([], dtype=int32),
   'quantized_dimension': 0},
  'sparsity_parameters': {}},
 {'name': 'sequential_3/dense_2/BiasAdd/ReadVariableOp',
  'index': 2,
  'shape': array([10], dtype=int32),
  'shape_signature': array([10], dtype=int32),
  'dtype': numpy.int32,
  'quantization': (4.8770318244351074e-05, 0),
  'quan

In [10]:
for i, t in enumerate(interpreter.get_tensor_details()):
    print(i, ": ", t["name"], sep="")

0: serving_default_flatten_2_input:0
1: sequential_3/flatten_2/Const
2: sequential_3/dense_2/BiasAdd/ReadVariableOp
3: sequential_3/dense_2/MatMul
4: tfl.quantize
5: sequential_3/flatten_2/Reshape
6: StatefulPartitionedCall:01
7: StatefulPartitionedCall:0


In [11]:
interpreter._get_ops_details()

[{'index': 0,
  'op_name': 'QUANTIZE',
  'inputs': array([0], dtype=int32),
  'outputs': array([4], dtype=int32)},
 {'index': 1,
  'op_name': 'RESHAPE',
  'inputs': array([4, 1], dtype=int32),
  'outputs': array([5], dtype=int32)},
 {'index': 2,
  'op_name': 'FULLY_CONNECTED',
  'inputs': array([5, 3, 2], dtype=int32),
  'outputs': array([6], dtype=int32)},
 {'index': 3,
  'op_name': 'QUANTIZE',
  'inputs': array([6], dtype=int32),
  'outputs': array([7], dtype=int32)}]

In [12]:
input_idx = 5
bias_idx = 2
weight_idx = 3
output_idx = 6

w = interpreter.get_tensor(weight_idx)
b = interpreter.get_tensor(bias_idx)
q_i = interpreter.get_tensor_details()[input_idx]["quantization"]
q_w = interpreter.get_tensor_details()[weight_idx]["quantization"]
q_o = interpreter.get_tensor_details()[output_idx]["quantization"]

In [13]:
# 1) flatten input
flattened_input = input_tensor.reshape(interpreter.get_tensor(1)) # [-1, 784]

# 2) shift input tensor by -128 to sitch from input type (uint8) to TF Lite internal type (int8) 
finalised_input = flattened_input.astype(np.int32)
finalised_input = finalised_input - 128
finalised_input = finalised_input.astype(np.int8)

# 3) run fully-connected layer
fc1 = fc_and_requantize(finalised_input, w, b, q_i, q_w, q_o)

# 4) undo the shift to switch from TF Lite internal type (int8) to output type (uint8)
manual_output = fc1.astype(np.int32)
manual_output = manual_output + 128
manual_output = manual_output.astype(np.uint8)

In [14]:
(manual_output == tflite_output).all()

True

Slightly more optimised version of the simple model (TF Lite and manual) to meaningfully compare execution times

In [15]:
I_S, I_Z = input_details["quantization"]
RESHAPE = interpreter.get_tensor(1)

W_32 = interpreter.get_tensor(3).transpose().astype(np.int32)
B_32 = interpreter.get_tensor(2).astype(np.int32)
(S_I, Z_I) = interpreter.get_tensor_details()[5]["quantization"]
(S_W, Z_W) = interpreter.get_tensor_details()[3]["quantization"]
(S_O, Z_O) = interpreter.get_tensor_details()[6]["quantization"]
S = S_I * S_W / S_O

def quantise_input(x):
    x_q = quantize_nearest(x, I_S, I_Z, np.uint8)
    return np.expand_dims(x_q, axis=0)

def manual_model(x):
    x = (x.reshape(RESHAPE).astype(np.int32) - 128).astype(np.int8)
    x = x.astype(np.int32) - Z_I
    x = np.matmul(x, W_32) + B_32
    x = np.clip(np.rint(S * x) + Z_O, -128, 127)
    x = (x + 128).astype(np.uint8)

    return x


quantised_test_x = [quantise_input(x) for x in x_test]

In [16]:
%%timeit

for x in quantised_test_x:
    interpreter.set_tensor(input_details["index"], x)
    interpreter.invoke()
    interpreter.get_tensor(output_details["index"])[0]


21.2 ms ± 83.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [17]:
%%timeit

for x in quantised_test_x:
    manual_model(x)

114 ms ± 197 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Simple results

- The cumulative execution time of the TF Lite model on the 10000 test images is ~18 ms (average over several runs)
- The cumulative execution time of the manual model on the 10000 test images is ~116 ms (idem)

In [41]:
discrepancies = []

for (i, x) in enumerate(quantised_test_x):

    # TF Lite model
    interpreter.set_tensor(input_details["index"], x)
    interpreter.invoke()
    tflite_output = interpreter.get_tensor(output_details["index"])[0]

    # Manual model
    manual_output = manual_model(x)

    if not (tflite_output == manual_output).all():
        discrepancies.append(i)

In [43]:
len(discrepancies)/len(quantised_test_x)

0.0023

In [71]:
interpreter.set_tensor(input_details["index"], ip)
interpreter.invoke()
outl = interpreter.get_tensor(output_details["index"])[0]

In [72]:
out_m = manual_model(ip)

In [81]:
(outl == out_m).all()

True

In [82]:
np.array_equal(out_m, outl)

False

In [84]:
problematic = 77

In [86]:
x = quantised_test_x[problematic]

interpreter.set_tensor(input_details["index"], x)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details["index"])[0]

manual_output = manual_model(x)


In [87]:
print(tflite_output)
print(manual_output)

[144 156 185 150 149 159 153 173 163 173]
[[144 156 185 150 149 159 153 173 162 173]]


In [29]:
x1 = (x.reshape(RESHAPE).astype(np.int32) - 128).astype(np.int8)
x2 = x1.astype(np.int32) - Z_I
x3 = np.matmul(x2, W_32) + B_32
x4 = np.clip(np.rint(S * x3).astype(np.int32) + Z_O, -128, 127)
x5 = (x4 + 128).astype(np.uint8)

In [22]:
(x1 == interpreter.get_tensor(5)).all()

True

In [30]:
interpreter.get_tensor(6)

array([[16, 28, 57, 22, 21, 31, 25, 45, 35, 45]], dtype=int8)

In [31]:
x4

array([[16, 28, 57, 22, 21, 31, 25, 45, 34, 45]], dtype=int32)

In [28]:
np.rint(S * x3).dtype

dtype('float64')

In [32]:
tflite_round: np.vectorize(lambda x: np.floor(x + np.copysign(0.5, x)))

In [34]:
np.round(1.3)

1.0

In [35]:
ROUNDING_PRECISION = 32

# TODO perhaps this can be done slightly more elegantly assuming S > 0
# negative S, aside from theoretically never happening, would break our rounding assumption below
if S < 0:
    raise Exception("S must be positive")

scaled_m = 2**ROUNDING_PRECISION * S




In [40]:
round(-0.5)

0

In [45]:
type(S)

float