In [None]:
import os
from math import log2, ceil, floor

import tensorflow as tf
import numpy as np

np.seterr(all='raise')

In [None]:
HOME_DIR = os.getcwd()
MODEL_DIR = os.path.join(HOME_DIR, "models")

In [None]:
def quantize_nearest(x, scale, zero, qtype):
    if qtype not in {np.int8, np.uint8}:
        raise Exception("Only quantization to int8 or uint8 is supported")
    
    (min, max) = (-128, 127) if qtype == np.int8 else (0, 255)

    return np.clip(np.rint(x / scale) + zero, min, max).astype(qtype)

def fc_and_requantize(input_tensor, weights, bias, q_i, q_w, q_o):
    
    if input_tensor.dtype != np.int8:
        raise Exception("Input must be of type int8")
    
    if weights.dtype != np.int8:
        raise Exception("Weights must be of type int8")
    
    if bias.dtype != np.int32:
        raise Exception("Input and weights must be of type int32")
    
    (s_i, z_i), (s_w, z_w), (s_o, z_o) = q_i, q_w, q_o
    
    if z_w != 0:
        raise Exception("Expected zero point of weights to be 0")

    s = s_i * s_w / s_o

    # 1) shift input tensor
    input_tensor_32 = input_tensor.astype(np.int32) - z_i
    weights_32 = weights.astype(np.int32)

    # 2) compute the bmm
    bmm = np.matmul(input_tensor_32, weights_32.transpose()) + bias

    # 3) requantize
    raise Exception("Change to accurate rounding")
    rq = np.rint(s * bmm) + z_o

    # 4) saturating cast
    output = np.clip(rq, -128, 127).astype(np.int8)

    return output

In [None]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255.0, x_test.astype(np.float32) / 255.0

# Full model

(not ready; go to "Simple model" below)

In [None]:
interpreter = tf.lite.Interpreter(os.path.join(MODEL_DIR, "mnist_model_quant.tflite"))
interpreter.allocate_tensors()

In [None]:
initial_tensors = [None] * 11

for i in range(11):
    try:
        initial_tensors[i] = interpreter.get_tensor(i).copy()
    except:
        pass

In [None]:
input_details = interpreter.get_input_details()[0]
output_details = interpreter.get_output_details()[0]

print(input_details)
print(input_details["dtype"])

print(output_details)
print(output_details["dtype"])


In [None]:
chosen_image = 22

In [None]:
test_image = x_test[chosen_image]

# Need to quantize the inputs outside the model!
input_scale, input_zero_point = input_details["quantization"]

# quantisation transformation as float32 first
test_image = quantize_nearest(test_image, input_scale, input_zero_point, np.uint8)
test_image = np.expand_dims(test_image, axis=0)

interpreter.set_tensor(input_details["index"], test_image)
interpreter.invoke()
output = interpreter.get_tensor(output_details["index"])[0]
output_prediction = output.argmax()

print(output)
print("{} (correct: {})".format(y_test[chosen_image], output_prediction))

**Important**: to perform a saturating cast, one must use np.clip. Otherwise problematic things happen - for instance, from f32 to u8, it seems first the floor is applied followed by % 256 (which is not what we want).

In [None]:
arr_200 = np.array([200], dtype=np.float32)
print(arr_200.astype(np.uint8))

arr_m42 = np.array([-42], dtype=np.float32)
print(arr_m42.astype(np.uint8))
print(np.clip(arr_m42, 0, 255).astype(np.uint8))

arr_422 = np.array([422], dtype=np.float32)
print(arr_422.astype(np.uint8))
print(np.clip(arr_422, 0, 255).astype(np.uint8))

### Model

The tensors in the interpreter (cf. next cell) should be interpreted as follows:
- 0: `serving_default_flatten_3_input:0`: it simply holds the (already quantised) input tensor (u8, initialised to 0)
- 1: `sequential_6/flatten_3/Const`: it stores, as a constant, the shape that the input should be flattened to by the Reshape node (cf. 7), to which it is an input (i32, does not change)
- 2: `sequential_6/dense_7/BiasAdd/ReadVariableOp`: it holds the bias for the second FC layer (identified by `dense_7`), and it consists of 10 `i32`s
- 3: `sequential_6/dense_7/MatMul`: this is the vec-by-matrix multiplication for the second FC layer. It holds the matrix coefficients as with entries in `i8`. The vector's has entries in ???. The two are multiplied together in `i32` precision to avoid overflows.
- 4: `sequential_6/dense_6/BiasAdd/ReadVariableOp`: it holds the bias for the first FC layer  (cf. 2)
- 5: `sequential_6/dense_6/MatMul`: this is the vec-by-matrix multiplication for the first FC layer (cf. 3)
- 6: `tfl.quantize`: this has the exact same quantisation scale as the input node, but the zero point is -128 as opposed to 0. Also, I am unsure what it does, since input quantisation needs to be performed externally by the user... (it changes during inference!)
- 7: `sequential_6/flatten_3/Reshape`: it flattens the 28 x 28 image into a flat 784-element vector (no value)
- 8: `sequential_6/dense_6/MatMul;sequential_6/activation_3/Relu;sequential_6/dense_6/BiasAdd`: this performs BMM and ReLU (no value)
- 9: `StatefulPartitionedCall:01`: ??? (i8, initial value: 0)
- 10: `StatefulPartitionedCall:0`: holds the actual output (u8, initial value: 0)

In [None]:
final_tensors = [None] * 11

for i in range(11):
    try:
        final_tensors[i] = interpreter.get_tensor(i).copy()
    except:
        pass

In [None]:
params = {
    "bias2": interpreter.get_tensor(2).copy(),
    "mat2": interpreter.get_tensor(3).copy(),
    "bias1": interpreter.get_tensor(4).copy(),
    "mat1": interpreter.get_tensor(5).copy(),
}

In [None]:
s_i1 = interpreter.get_tensor_details()[0]["quantization"][0]
s_w1 = interpreter.get_tensor_details()[5]["quantization"][0]
s_o1 = interpreter.get_tensor_details()[8]["quantization"][0]
s_i1, s_w1, s_o1

In [None]:
s_i2 = interpreter.get_tensor_details()[8]["quantization"][0]
s_w2 = interpreter.get_tensor_details()[3]["quantization"][0]
s_o2 = interpreter.get_tensor_details()[9]["quantization"][0]
s_i2, s_w2, s_o2

In [None]:
input_tensor = test_image
flattened_input = input_tensor.flatten()
# TODO I think the next two lines do the same as the third on its own
precision_input = flattened_input.astype(np.int32)
quantised_input = precision_input - 128
finalised_input = quantised_input.astype(np.int8)

fc1 = fc_and_requantize(finalised_input, params["mat1"], params["bias1"], s_i1, s_w1, s_o1)

# Applying ReLU to i8 input
relu = fc1.clip(0, 127)

fc2 = fc_and_requantize(relu, params["mat2"], params["bias2"], s_i2, s_w2, s_o2)

In [None]:
fc2

In [None]:
v = (np.matmul(finalised_input.astype(np.int32), params["mat1"].astype(np.int32).transpose()) + params["bias1"]) * s_i1 * s_w1 / s_o1

In [None]:
v2 = v.clip(-128, 127).astype(np.int8).clip(0, 127)

In [None]:
(np.matmul(v2.astype(np.int32), params["mat2"].astype(np.int32).transpose()) + params["bias2"]) * s_i2 * s_w2 / s_o2

In [None]:
interpreter.get_tensor(9)
fc2

In [None]:
fs, fz = interpreter.get_tensor_details()[9]["quantization"]

In [None]:
fs, fz

In [None]:
fc2/fs + fz

In [None]:
interpreter.get_tensor(9) - interpreter.get_tensor(10)

# Simple model

In [None]:
interpreter = tf.lite.Interpreter(os.path.join(MODEL_DIR, "simple_model_quant.tflite"), experimental_preserve_all_tensors=True)
interpreter.allocate_tensors()

In [None]:
input_details = interpreter.get_input_details()[0]
output_details = interpreter.get_output_details()[0]

print(input_details)
print(input_details["dtype"])

print(output_details)
print(output_details["dtype"])


In [None]:
chosen_image = 150

In [None]:
test_image = x_test[chosen_image]

# Need to quantize the inputs outside the model!
input_scale, input_zero_point = input_details["quantization"]
input_tensor = quantize_nearest(test_image, input_scale, input_zero_point, np.uint8)
input_tensor = np.expand_dims(input_tensor, axis=0)

# Run the model
interpreter.set_tensor(input_details["index"], input_tensor)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details["index"])[0]
output_prediction = tflite_output.argmax()

print(tflite_output)
print("{} (correct: {})".format(output_prediction, y_test[chosen_image]))

In [None]:
interpreter.get_tensor_details()

In [None]:
for i, t in enumerate(interpreter.get_tensor_details()):
    print(i, ": ", t["name"], sep="")

In [None]:
interpreter._get_ops_details()

In [None]:
input_idx = 5
bias_idx = 2
weight_idx = 3
output_idx = 6

w = interpreter.get_tensor(weight_idx)
b = interpreter.get_tensor(bias_idx)
q_i = interpreter.get_tensor_details()[input_idx]["quantization"]
q_w = interpreter.get_tensor_details()[weight_idx]["quantization"]
q_o = interpreter.get_tensor_details()[output_idx]["quantization"]

In [None]:
# 1) flatten input
flattened_input = input_tensor.reshape(interpreter.get_tensor(1)) # [-1, 784]

# 2) shift input tensor by -128 to sitch from input type (uint8) to TF Lite internal type (int8) 
finalised_input = flattened_input.astype(np.int32)
finalised_input = finalised_input - 128
finalised_input = finalised_input.astype(np.int8)

# 3) run fully-connected layer
fc1 = fc_and_requantize(finalised_input, w, b, q_i, q_w, q_o)

# 4) undo the shift to switch from TF Lite internal type (int8) to output type (uint8)
manual_output = fc1.astype(np.int32)
manual_output = manual_output + 128
manual_output = manual_output.astype(np.uint8)

In [None]:
(manual_output == tflite_output).all()

Slightly more optimised version of the simple model (TF Lite and manual) to meaningfully compare execution times

In [None]:
# I_S, I_Z = input_details["quantization"]
# RESHAPE = interpreter.get_tensor(1)

# W_32 = interpreter.get_tensor(3).transpose().astype(np.int32)
# B_32 = interpreter.get_tensor(2).astype(np.int32)
# (S_I, Z_I) = interpreter.get_tensor_details()[5]["quantization"]
# (S_W, Z_W) = interpreter.get_tensor_details()[3]["quantization"]
# (S_O, Z_O) = interpreter.get_tensor_details()[6]["quantization"]
# S = S_I * S_W / S_O

# def quantise_input(x):
#     x_q = quantize_nearest(x, I_S, I_Z, np.uint8)
#     return np.expand_dims(x_q, axis=0)

# def manual_model(x):
#     x = (x.reshape(RESHAPE).astype(np.int32) - 128).astype(np.int8)
#     x = x.astype(np.int32) - Z_I
#     x = np.matmul(x, W_32) + B_32
#     x = np.clip(np.rint(S * x) + Z_O, -128, 127)
#     x = (x + 128).astype(np.uint8)

#     return x

# quantised_test_x = [quantise_input(x) for x in x_test]

In [None]:
I_S, I_Z = input_details["quantization"]
RESHAPE = interpreter.get_tensor(1)

W_32 = interpreter.get_tensor(3).transpose().astype(np.int32)
B_32 = interpreter.get_tensor(2).astype(np.int32)
(S_I, Z_I) = interpreter.get_tensor_details()[5]["quantization"]
(S_W, Z_W) = interpreter.get_tensor_details()[3]["quantization"]
(S_O, Z_O) = interpreter.get_tensor_details()[6]["quantization"]

def quantise_input(x):
    x_q = quantize_nearest(x, I_S, I_Z, np.uint8)
    return np.expand_dims(x_q, axis=0)

quantised_test_x = [quantise_input(x) for x in x_test]

In [None]:
# re-scaling computation

# Fun fact: changing to the following makes S exactly equal to S_UINT / (2**S_SHIFT)
# ROUNDING_PRECISION = 64
# APP_S_TYPE = np.uint64

ROUNDING_PRECISION = 32
APP_S_TYPE = np.int32 # as in gemmlowp's SaturatingRoundingDoublingHighMul

def approximate_rescaling_factor(s):
    # negative scale, aside from theoretically never happening, would break our rounding assumption below
    if s < 0:
        raise Exception("s must be positive")
    if s > 1:
        raise Exception("Make sure s > 1 is handled correctly")
    
    c = ceil(log2(s) + 1)

    scaled_s = floor(2**(ROUNDING_PRECISION - c) * s)
    rounding_bit = scaled_s & 1

    int_s = (scaled_s >> 1) + rounding_bit

    return (ROUNDING_PRECISION - 1 - c, APP_S_TYPE(int_s))

def round_float_half_away_from_zero(f):
    f_abs = np.abs(f)
    f_abs_floor = np.floor(f_abs)
    rounding_bit = 1 if (f_abs - f_abs_floor) >= 0.5 else 0

    return np.sign(f) * (f_abs_floor + rounding_bit)

def new_approximate_rescaling_factor(s1, s2, s3):

    # TODO we are omitting some of the checks

    if s1 == 0 or s2 == 0:
        print("Warning: Rescaling multiplier equal to 0 found")
        return 0, 0

    s1, s2, s3 = np.float64(s1), np.float64(s2), np.float64(s3)

    s = s1 * s2 / s3

    # negative scale, aside from theoretically never happening, would break our rounding assumption below
    if s < 0:
        raise Exception("s must be positive")
    if s > 1:
        raise Exception("Make sure s > 1 is handled correctly")
    
    # assuming TFLITE_EMULATE_FLOAT = false, since our system can actually run floating-point arithmetic
    exp = floor(log2(s)) + 1
    signif = s * (1 << -exp)

    q_signif = round_float_half_away_from_zero(signif * (1 << 31)).astype(np.int64)

    # TODO can this happen?
    if (q_signif == (1 << 31)):
        q_signif /= 2
        exp += 1

    if exp < -31:
        exp = 0
        q_signif = 0
    
    # I have no idea if our build has single rounding
    # #if TFLITE_SINGLE_ROUNDING
    #    // Single-rounding MultiplyByQuantizedMultiplier doesn't support a shift > 30,
    #    // saturate it.
    #    if (*shift > 30) {
    #    *shift = 30;
    #    q_fixed = (1LL << 31) - 1;
    #    }
    # #endif

    q_signif = q_signif.astype(np.int32)

    return exp, q_signif

S_REL_SHIFT, S_UINT = new_approximate_rescaling_factor(S_I, S_W, S_O)

In [None]:
def round_nearest_half_up(n, shift):
    return (n + (1 << (shift - 1))) >> shift

In [None]:
# TODO one could wrap this in type checks for good measure (one per tensor, not per element)
ROUNDING = round_nearest_half_up

# def requantise_half_away_from_zero(x):
#     # TODO control overflows here?
#     abs_a_s_int = np.abs(x) * S_UINT
#     rounding_bit = (abs_a_s_int >> (S_SHIFT - 1)) & 1
#     sh = (abs_a_s_int >> S_SHIFT)

#     return np.sign(x) * (sh + rounding_bit)

# TODO there's probably a more elegant way to do this
def requantise(x):
    # TODO control overflows here or in the ROUNDING function?
    return ROUNDING(x * S_UINT, S_SHIFT)

# requantise_tensor = np.vectorize(requantise)

In [None]:
# inline int32 MultiplyByQuantizedMultiplier( int32 x, 
#                                             int32 quantized_multiplier,
#                                             int shift) {
#   using gemmlowp::RoundingDivideByPOT;
#   using gemmlowp::SaturatingRoundingDoublingHighMul;


#   int left_shift = shift > 0 ? shift : 0;
#   int right_shift = shift > 0 ? 0 : -shift;
  
#   return RoundingDivideByPOT(
    
#             SaturatingRoundingDoublingHighMul(
#                                  x * (1 << left_shift), quantized_multiplier
#             ),

#         right_shift);
# }

# The shift arg above is to be understod as: to the left (by the two ternary assignments)

# In our case, the shift is always to the right, so:
#     - right_shift is set to the additive inverese of our shift
#     - left_shift is set to 0

#   return RoundingDivideByPOT(
    
#         SaturatingRoundingDoublingHighMul(
#                                 x * 1, quantized_multiplier
#         ),

#     right_shift);


# https://github.com/google/gemmlowp/blob/master/fixedpoint/fixedpoint.h#L302
# https://github.com/google/gemmlowp/blob/master/fixedpoint/fixedpoint.h#L340
# // This function implements the same computation as the ARMv7 NEON VQRDMULH
# // instruction.
# inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a,
#                                                       std::int32_t b) {
#   bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
#   std::int64_t a_64(a);
#   std::int64_t b_64(b);
#   std::int64_t ab_64 = a_64 * b_64;
#   std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
#   std::int32_t ab_x2_high32 =
#       static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
#   return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
# }

# inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) {
#   assert(exponent >= 0);
#   assert(exponent <= 31);
#   const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
#   const IntegerType zero = Dup<IntegerType>(0);
#   const IntegerType one = Dup<IntegerType>(1);
#   const IntegerType remainder = BitAnd(x, mask);
#   const IntegerType threshold =
#       Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one));
#   return Add(ShiftRight(x, exponent),
#              BitAnd(MaskIfGreaterThan(remainder, threshold), one));


# I think this is the line where they call the floating-point-multiplier computation:
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/fully_connected.cc#L418
# This is the function that is actually called
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/kernel_util.cc#L329
# It does the same thing we are doing, with the small caveat that the product is computed in double precision
#
# Right after that, they call the quantisation function for that multiplier in this line:
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/fully_connected.cc#L421
# I think this is the function that's called, although there are five defined functions with that same name
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/internal/quantization_util.cc#L53



In [None]:
S_INT_64 = np.int64(S_UINT)

def gemmlowp_requantize(x):
    # TODO overflows are disregarded

    # TODO is the cast necessary or induced from the type of S_UINT_64
    x_s_int_64 = np.int64(x) * S_INT_64
    nudge = (1 << 30) if x_s_int_64 >= 0 else (1 - (1 << 30))
    nudged = ((x_s_int_64 + nudge) >> 31).astype(np.int32)

    # funny (worrying?): nudge can be more than 1 away from the actual float-computed product

    mask = (2 ** S_REL_SHIFT) - 1
    remainder = nudged & mask
    threshold = (mask >> 1) + (1 if nudged < 0 else 0)

    return (nudged >> S_REL_SHIFT) + (1 if remainder > threshold else 0)

def arm_requantize(x):
    # TODO overflows are disregarded

    # TODO is the cast necessary or induced from the type of S_UINT_64
    x_s_int_64 = np.int64(x) * S_INT_64
    nudge = (1 << 30) if x_s_int_64 >= 0 else (1 - (1 << 30))
    nudged = ((x_s_int_64 + nudge) >> 31).astype(np.int32)

    # funny (worrying?): nudge can be more than 1 away from the actual float-computed product

    # TODO handle S_EXPONENT == 0
    return (nudged + (1 << (-S_REL_SHIFT - 1))) >> -S_REL_SHIFT

requantise_tensor = np.vectorize(arm_requantize)

In [None]:
def manual_model_accurate(x):
    x = (x.reshape(RESHAPE).astype(np.int32) - 128).astype(np.int8)
    x = x.astype(np.int32) - Z_I
    x = np.matmul(x, W_32) + B_32

    # this is the correct, specification-exact way to do it; in the 10000 sample images, it always coincides with np.rint(x * S)
    x = requantise_tensor(x)
    
    x = np.clip(x + Z_O, -128, 127)
    x = (x + 128).astype(np.uint8)

    return x

In [None]:
%%timeit

for x in quantised_test_x:
    interpreter.set_tensor(input_details["index"], x)
    interpreter.invoke()
    interpreter.get_tensor(output_details["index"])[0]


In [None]:
%%timeit

for x in quantised_test_x:
    manual_model(x)

In [None]:
%%timeit

for x in quantised_test_x:
    manual_model_accurate(x)

### Simple results

- The cumulative execution time of the TF Lite model on the 10000 test images is ~18 ms (average over several runs)
- The cumulative execution time of the manual model with naive re-quantisation on the 10000 test images is ~116 ms (idem)
- The cumulative execution time of the manual model with specification-exact re-quantisation on the 10000 test images is ~240 ms (idem)

In [None]:
discrepancies = []

for (i, x) in enumerate(quantised_test_x):

    # TF Lite model
    interpreter.set_tensor(input_details["index"], x)
    interpreter.invoke()
    tflite_output = interpreter.get_tensor(output_details["index"])[0]

    # Manual model
    manual_output = manual_model_accurate(x)

    if not (tflite_output == manual_output).all():
        discrepancies.append(i)

In [None]:
len(discrepancies)

In [None]:
ip = quantised_test_x[discrepancies[0]] 

interpreter.set_tensor(input_details["index"], ip)
interpreter.invoke()
out_l = interpreter.get_tensor(output_details["index"])[0]

out_m = manual_model_accurate(ip)

In [None]:
def compare(v1, v2):
    print(v1)
    print(v2)
    print((v1 == v2).all())

In [None]:
compare(out_l, out_m[0])

In [None]:
x1 = (ip.reshape(RESHAPE).astype(np.int32) - 128).astype(np.int8)
x2 = x1.astype(np.int32) - Z_I
x3 = np.matmul(x2, W_32) + B_32
x4 = requantise_tensor(x3).astype(np.int32)
x5 = np.clip(x4 + Z_O, -128, 127)
x6 = (x5 + 128).astype(np.uint8)

In [None]:
# FC input
(x1 == interpreter.get_tensor(5)).all()

In [None]:
# FC output
compare(x5, interpreter.get_tensor(6))

# Back of the envelope