# FINN - ModelWrapper and Analysis passes
--------------------------------------
<font size="3"> This notebook is about the ModelWrapper class and analysis passes within FINN. 

Following showSrc function is used to print the source code of function calls in the Jupyter notebook:</font>

In [2]:
import inspect

def showSrc(what):
    print("".join(inspect.getsourcelines(what)[0]))

## ModelWrapper
-------------------------
* <font size="3"> wrapper around ONNX ModelProto that exposes some utility
    functions for graph manipulation and exploration </font>
* <font size="3"> ModelWrapper instance takes onnx model proto and `make_deepcopy` flag as input </font>
* <font size="3"> onnx model proto can either be a string with the path to a stored .onnx file on disk, or serialized bytes </font>
* <font size="3"> `make_deepcopy` is by default False but can be set to True if a (deep) copy should be created </font>

### Create a ModelWrapper instance

In [3]:
from finn.core.modelwrapper import ModelWrapper
onnx_model = ModelWrapper("LFCW1A1.onnx")

### Access the attributes of the model
<font size="3"> Modelwrapper allows easy access to the various components of the model </font>

In [4]:
# i.e. the onnx model proto
model = onnx_model.model

# the graph
graph = onnx_model.graph

# the node list
nodes = onnx_model.graph.node

#### Tensors
<font size="3"> Every input and output of every node in the onnx model is represented as tensor with several properties (i.e. name, shape, data type). ModelWrapper provides some utility functions to work with the tensors </font>

##### Tensor names

In [5]:
# get all tensor names
tensor_list = onnx_model.get_all_tensor_names()
print(tensor_list)

['0', 'features.3.weight', 'features.3.bias', 'features.3.running_mean', 'features.3.running_var', 'features.7.weight', 'features.7.bias', 'features.7.running_mean', 'features.7.running_var', 'features.11.weight', 'features.11.bias', 'features.11.running_mean', 'features.11.running_var', '20', '23', '28', '30', '33', '34', '41', '42', '49', '50', '57', '58', '60']


##### Producer and consumer of a tensor

In [6]:
# get random tensor and find producer and consumer (returns node)

tensor_name = tensor_list[25]
print(onnx_model.find_producer(tensor_name))

tensor_name = tensor_list[0]
print(onnx_model.find_consumer(tensor_name))


input: "59"
input: "58"
output: "60"
op_type: "Mul"

input: "0"
output: "21"
op_type: "Shape"



##### Tensor shape

In [7]:
# get tensor_shape

print(onnx_model.get_tensor_shape(tensor_name))

[1, 1, 28, 28]


<font size="3"> It is also possible to set the tensor shape with a helper function. The syntax would be the following:
    
`onnx_model.set_tensor_shape(tensor_name, tensor_shape)`

Optionally, the dtype of the tensor can also be specified as third argument. By default it is set to TensorProto.FLOAT. 
    
**Important:** dtype should not be confused with FINN data type, which specifies the quantization annotation.
</font>

##### Tensor (FINN) data type

<font size="3">FINN introduces its own data types because ONNX does not natively support precisions less than 8 bits. FINN is about quantized neural networks, so precision of i.e. 4 bits, 3 bits, 2 bits or 1 bit are of interest. To represent the data within FINN, float tensors are used with additional annotation to specify the quantized data type of a tensor. The following helper functions are about this quantization annotation. </font>

In [8]:
# get tensor data type (FINN data type)
print(onnx_model.get_tensor_datatype(tensor_name))

DataType.FLOAT32


<font size="3">In addition to the get_tensor_datatatype() function, the (FINN) datatype of a tensor can be set using the `set_initializer(tensor_name, datatype)` function. </font>

##### Tensor initializers
<font size="3">Some tensors have initializer like tensors that represent constants or i.e. the determined weights from training. 

ModelWrapper contains two helper functions for this case, one to determine the current initializer and one to set the initializer of a tensor.</font>

In [11]:
# get tensor initializer
tensor_name = tensor_list[1]
print(onnx_model.get_initializer(tensor_name))

[0.10029524 0.0410215  0.09845579 ... 0.24390122 0.17647634 0.23984103]


<font size="3">Like for the other tensor helper functions there is also a `set_initializer(tensor_name, tensor_value)` function.</font>

<font size="3">ModelWrapper contains more useful functions, if you are interested please have a look at the python code directly. In the following analysis passes are discussed in more detail.

In the folder notebooks/ a Jupyter notebook about transformation passes can be found</font>

## Analysis passes
-------------------------
* <font size="3">traverses the graph structure and produces information about certain properties</font>
* <font size="3">input: ModelWrapper</font>
* <font size="3">returns dictionary of named properties that the analysis extracts</font>

In [12]:
import netron
netron.start('LFCW1A1.onnx', port=8081, host="0.0.0.0")


Stopping http://0.0.0.0:8081
Serving 'LFCW1A1.onnx' at http://0.0.0.0:8081


In [13]:
%%html
<iframe src="http://0.0.0.0:8081/" style="position: relative; width: 100%;" height="400"></iframe>

<font size="3">The onnx model has to be converted to a format that can be processed by FINN. This is done with ModelWrapper. As described in the short introduction, this is the format an analysis pass takes as input.</font>

In [14]:
from finn.core.modelwrapper import ModelWrapper
onnx_model = ModelWrapper('LFCW1A1.onnx')

<font size="3">The idea is to count all nodes that have the same operation type. The result should contain the operation types and the corresponding number of nodes that occur in the model. At the beginning an empty dictionary is created which is filled by the function and returned as result to the user at the end of the analysis.</font>

In [9]:
def count_equal_nodes(model):
    count_dict = {}
    for node in model.graph.node:
        if node.op_type in count_dict:
            count_dict[node.op_type] +=1
        else:
            count_dict[node.op_type] = 1
    return count_dict

<font size="3">The function takes the model as input and iterates over the nodes. Then it is checked whether there is already an entry for the operation type in the dictionary. If this is not the case, an entry is created and set to `1`. If there is already an entry, it is incremented. If all nodes in the model have been iterated, the filled dictionary is returned.</font>

In [10]:
print(count_equal_nodes(onnx_model))

{'Shape': 1, 'Gather': 1, 'Unsqueeze': 5, 'Concat': 1, 'Reshape': 1, 'Mul': 5, 'Sub': 1, 'Sign': 4, 'MatMul': 4, 'BatchNormalization': 3, 'Squeeze': 3}
