In [1]:
import os
from math import log2, ceil, floor
from collections import OrderedDict
import json

import tensorflow as tf
import numpy as np

np.seterr(all='raise')

{'divide': 'warn', 'over': 'warn', 'under': 'ignore', 'invalid': 'warn'}

In [None]:
HOME_DIR = os.getcwd()
MODEL_DIR = os.path.join(HOME_DIR, "models")

In [None]:
def tensor_to_dict(t):
    f = t.flatten().tolist()
    s = list(t.shape)
    c = [1]

    for e in s[-1:0:-1]:
        c.append(c[-1] * e)
    
    c.reverse()

    return OrderedDict([("f", f), ("s", s), ("c", c)])

In [None]:
def quantize_nearest(x, scale, zero, qtype):
    if qtype not in {np.int8, np.uint8}:
        raise Exception("Only quantization to int8 or uint8 is supported")
    
    (min, max) = (-128, 127) if qtype == np.int8 else (0, 255)

    return np.clip(np.rint(x / scale) + zero, min, max).astype(qtype)

def fc_and_requantize(input_tensor, weights, bias, q_i, q_w, q_o):
    
    if input_tensor.dtype != np.int8:
        raise Exception("Input must be of type int8")
    
    if weights.dtype != np.int8:
        raise Exception("Weights must be of type int8")
    
    if bias.dtype != np.int32:
        raise Exception("Input and weights must be of type int32")
    
    (s_i, z_i), (s_w, z_w), (s_o, z_o) = q_i, q_w, q_o
    
    if z_w != 0:
        raise Exception("Expected zero point of weights to be 0")

    s = s_i * s_w / s_o

    # 1) shift input tensor
    input_tensor_32 = input_tensor.astype(np.int32) - z_i
    weights_32 = weights.astype(np.int32)

    # 2) compute the bmm
    bmm = np.matmul(input_tensor_32, weights_32.transpose()) + bias

    # 3) requantize
    rq = np.rint(s * bmm) + z_o

    # 4) saturating cast
    output = np.clip(rq, -128, 127).astype(np.int8)

    return output

In [2]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255.0, x_test.astype(np.float32) / 255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 0us/step


In [None]:
interpreter = tf.lite.Interpreter(os.path.join(MODEL_DIR, "two_layer_perceptron_frozen.tflite"), experimental_preserve_all_tensors=True)
interpreter.allocate_tensors()

In [None]:
input_details = interpreter.get_input_details()[0]
output_details = interpreter.get_output_details()[0]

print(input_details)
print(input_details["dtype"])

print(output_details)
print(output_details["dtype"])

In [None]:
interpreter.get_tensor_details()

In [None]:
for i, t in enumerate(interpreter.get_tensor_details()):
    print(i, ": ", t["name"], sep="")

In [None]:
interpreter._get_ops_details()

In [None]:
input_1_idx = 7
weight_1_idx = 5
bias_1_idx = 4
output_1_idx = 8

w_1 = interpreter.get_tensor(weight_1_idx)
b_1 = interpreter.get_tensor(bias_1_idx)
q_1_i = interpreter.get_tensor_details()[input_1_idx]["quantization"]
q_1_w = interpreter.get_tensor_details()[weight_1_idx]["quantization"]
q_1_o = interpreter.get_tensor_details()[output_1_idx]["quantization"]

In [None]:
input_2_idx = 8
weight_2_idx = 3
bias_2_idx = 2
output_2_idx = 9

w_2 = interpreter.get_tensor(weight_2_idx)
b_2 = interpreter.get_tensor(bias_2_idx)
q_2_i = interpreter.get_tensor_details()[input_2_idx]["quantization"]
q_2_w = interpreter.get_tensor_details()[weight_2_idx]["quantization"]
q_2_o = interpreter.get_tensor_details()[output_2_idx]["quantization"]

In [None]:
relu_zero_point = q_1_o[1]

In [None]:
def manual_model(input_tensor):

    # 1) shift input tensor by -128 to sitch from input type (uint8) to TF Lite internal type (int8) 
    shifted_input = input_tensor.astype(np.int32)
    shifted_input = shifted_input - 128
    shifted_input = shifted_input.astype(np.int8)

    # 2) flatten input
    flattened_input = shifted_input.reshape(interpreter.get_tensor(1)) # [-1, 784]
    
    # 3) first fully-connected layer
    fc1 = fc_and_requantize(flattened_input, w_1, b_1, q_1_i, q_1_w, q_1_o)

    # 4) relu
    relu1 = np.maximum(fc1, relu_zero_point)

    # 5) second fully-connected layer
    fc2 = fc_and_requantize(relu1, w_2, b_2, q_2_i, q_2_w, q_2_o)

    # 4) undo the shift to switch from TF Lite internal type (int8) to output type (uint8)
    output = fc2.astype(np.int32)
    output = output + 128
    output = output.astype(np.uint8)

    return output

## Execution

In [None]:
def compare_models(input_tensor, verbose):

    # Need to quantize the inputs outside the model!
    input_scale, input_zero_point = input_details["quantization"]
    input_tensor = quantize_nearest(input_tensor, input_scale, input_zero_point, np.uint8)
    input_tensor = np.expand_dims(input_tensor, axis=0)

    # Run the TF Lite model
    interpreter.set_tensor(input_details["index"], input_tensor)
    interpreter.invoke()
    tflite_output = interpreter.get_tensor(output_details["index"])[0]

    # Run the manual model
    manual_output = manual_model(input_tensor)[0]

    if verbose:
        print("Manual model output:\t{}".format(manual_output))
    
    print( "Models match" if (manual_output == tflite_output).all() else "Mismatch!")

In [None]:
CHOSEN_IMAGE = 150

compare_models(x_test[CHOSEN_IMAGE], True)

In [None]:
def run_manual_model(input_tensor):

    # Need to quantize the inputs outside the model!
    input_scale, input_zero_point = input_details["quantization"]
    input_tensor = quantize_nearest(input_tensor, input_scale, input_zero_point, np.uint8)
    input_tensor = np.expand_dims(input_tensor, axis=0)
    
    return manual_model(input_tensor)[0]
    

In [None]:
# from random import randrange
# [randrange(0, len(x_test)) for _ in range(10)]

INDICES = [6393, 1894, 5978, 6120, 817, 3843, 7626, 9272, 498, 4622]

In [None]:
for i in INDICES:
    compare_models(x_test[i], False)

In [None]:
inputs, outputs = zip(*[(tensor_to_dict(x_test[i]), tensor_to_dict(run_manual_model(x_test[i]))) for i in INDICES])

In [None]:
json.dump(inputs, open("10_test_inputs.json", "w"))
json.dump(outputs, open("10_test_outputs.json", "w"))