# End-to-End FINN Flow  for MobileNet-V1
-------------------------------------------------------------


In [2]:
from PIL import Image
import numpy as np
import brevitas.onnx as bo
import torch

# get single image as input
img = Image.open("/workspace/finn/tests/brevitas/king_charles.jpg")
img = img.resize((224, 224))
img = np.asarray(img).copy().astype(np.int32)
img = img.transpose(2, 0, 1)
# our network is trained with BGR instead of RGB images,
# so we need to invert the order of channels in the channel axis:
img = img[::-1, :, :].copy()
# finally, we need to subtract the mean per-channel pixel intensity
# since this is how this network has been trained
img[0] = img[0] - 104
img[1] = img[1] - 117
img[2] = img[2] - 123
img = img.reshape(1, 3, 224, 224)
input_tensor = torch.from_numpy(img).float()
assert input_tensor.shape == (1, 3, 224, 224)

In [3]:
from finn.util.test import get_test_model_trained
mobilenet = get_test_model_trained("mobilenet", 4, 4)

In [4]:
# golden output
# do forward pass in PyTorch/Brevitas
expected = mobilenet.forward(input_tensor).detach().numpy()
expected_topk = expected.flatten()
expected_top5 = np.argsort(expected_topk)[-5:]
expected_top5 = np.flip(expected_top5)
expected_top5_prob = []
for index in expected_top5:
    expected_top5_prob.append(expected_topk[index])

In [5]:
from finn.core.modelwrapper import ModelWrapper

bo.export_finn_onnx(mobilenet, (1, 3, 224, 224), "quant_mobilenet_v1_4b.onnx", input_t=input_tensor)
model = ModelWrapper("quant_mobilenet_v1_4b.onnx")

In [8]:
from finn.util.visualization import showInNetron
showInNetron("quant_mobilenet_v1_4b.onnx")

Serving 'quant_mobilenet_v1_4b.onnx' at http://0.0.0.0:8081


In [6]:
# tidy-up transformations
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
from finn.transformation.insert_topk import InsertTopK

model = model.transform(InferShapes())
model = model.transform(FoldConstants())
model = model.transform(InsertTopK())
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model.save("quant_mobilenet_v1_4b_before.onnx")

In [9]:
from finn.transformation.general import ConvertDivToMul
from finn.transformation.batchnorm_to_affine import BatchNormToAffine
from finn.transformation.streamline.absorb import AbsorbAddIntoMultiThreshold, AbsorbMulIntoMultiThreshold, FactorOutMulSignMagnitude
from finn.transformation.streamline.collapse_repeated import CollapseRepeatedMul
from finn.transformation.streamline.reorder import MoveScalarMulPastConv, MoveAddPastMul, MoveScalarMulPastMatMul
from finn.transformation.double_to_single_float import DoubleToSingleFloat

model = ModelWrapper("quant_mobilenet_v1_4b_before.onnx")

model = model.transform(ConvertDivToMul())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())

model = model.transform(BatchNormToAffine())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())

model = model.transform(AbsorbAddIntoMultiThreshold())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())

model = model.transform(CollapseRepeatedMul())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())

model = model.transform(MoveScalarMulPastConv())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())

model = model.transform(CollapseRepeatedMul())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())

model = model.transform(FactorOutMulSignMagnitude())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())

model = model.transform(AbsorbMulIntoMultiThreshold())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())

model = model.transform(MoveScalarMulPastMatMul())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())

model = model.transform(MoveAddPastMul())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())

model = model.transform(CollapseRepeatedMul())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())

model = model.transform(DoubleToSingleFloat())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())

model.save("quant_mobilenet_v1_4b_streamlined.onnx")
showInNetron("quant_mobilenet_v1_4b_streamlined.onnx")


Stopping http://0.0.0.0:8081
Serving 'quant_mobilenet_v1_4b_streamlined.onnx' at http://0.0.0.0:8081


In [10]:
from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul
from finn.transformation.streamline.absorb import AbsorbTransposeIntoMultiThreshold

model = ModelWrapper("quant_mobilenet_v1_4b_streamlined.onnx")

model = model.transform(LowerConvsToMatMul())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())

model = model.transform(AbsorbTransposeIntoMultiThreshold())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())

model.save("quant_mobilenet_v1_4b.onnx")
showInNetron("quant_mobilenet_v1_4b.onnx")


Stopping http://0.0.0.0:8081
Serving 'quant_mobilenet_v1_4b.onnx' at http://0.0.0.0:8081


In [15]:
#from finn.transformation.streamline.absorb import AbsorbSucceedMulIntoMultiThreshold
model = ModelWrapper("quant_mobilenet_v1_4b.onnx")
#model = model.transform(AbsorbSucceedMulIntoMultiThreshold())
#model = model.transform(GiveUniqueNodeNames())
#model = model.transform(GiveReadableTensorNames())
#model = model.transform(InferDataTypes())

#model = model.transform(AbsorbTransposeIntoMultiThreshold())
#model = model.transform(GiveUniqueNodeNames())
#model = model.transform(GiveReadableTensorNames())
#model = model.transform(InferDataTypes())
#model.save("experiment.onnx")
#showInNetron("experiment.onnx")

In [16]:
import finn.core.onnx_exec as oxe

idict = {model.graph.input[0].name: img.astype(np.float32)}
odict = oxe.execute_onnx(model, idict, True)
produced = odict[model.graph.output[0].name]
produced_prob = odict["TopK_0_out0"]

In [17]:
print(expected_top5)
print(produced)

[219 220 213 365 156]
[[219 220 213 365 156]]


In [8]:
showInNetron("quant_mobilenet_v1_4b.onnx")

NameError: name 'showInNetron' is not defined

In [10]:
from finn.transformation.streamline import Streamline
model = model.transform(Streamline())
model.save("quant_mobilenet_v1_4b.onnx")

In [11]:
showInNetron("quant_mobilenet_v1_4b.onnx")


Stopping http://0.0.0.0:8081
Serving 'quant_mobilenet_v1_4b.onnx' at http://0.0.0.0:8081


In [13]:
from finn.transformation.double_to_single_float import DoubleToSingleFloat
from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul


model = model.transform(DoubleToSingleFloat())
model = model.transform(LowerConvsToMatMul())
model.save("quant_mobilenet_v1_4b.onnx")
showInNetron("quant_mobilenet_v1_4b.onnx")


Stopping http://0.0.0.0:8081
Serving 'quant_mobilenet_v1_4b.onnx' at http://0.0.0.0:8081


In [15]:
import finn.transformation.streamline.absorb as absorb
model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold())
model.save("quant_mobilenet_v1_4b.onnx")
showInNetron("quant_mobilenet_v1_4b.onnx")


Stopping http://0.0.0.0:8081
Serving 'quant_mobilenet_v1_4b.onnx' at http://0.0.0.0:8081
