# Radio Modulation with FINN - Notebook #3 of 5
This notebook walks you through performing a network surgery on the ONNX model. In this notebook, we introduce:
1. Usage of ModelWrapper() and showInNetron()
1. Getting a node in the model by their name 
2. Getting a node by getting the input of its succeeding node
3. Removing a node from the model
4. Changing the input and its datatype of the model
5. Add a top-K node to the model

## Converting QONNX format to FINN-QONNX format

QONNX format is from the Brevitas library. To run transformation through FINN, we need to convert the model to a FINN-QONNX format.

Whenever we want to instantiate an onnx model, we can use `ModelWrapper(filepath)`

### Load in the brevitas QONNX model

In [1]:
from qonnx.core.modelwrapper import ModelWrapper
from finn.util.visualization import showInNetron
#Path to the qonnx model exported by brevitas
brevitas_model_pth='27ml_rf/models/radio_27ml_export.onnx'

model=ModelWrapper(brevitas_model_pth)

!echo $VIVADO_PATH
!echo $HLS_PATH





/home/phu/Vivado/Vitis_HLS/2024.1


### Visualizing the Brevitas ONNX model using `showInNetron`

In [2]:
host_machine_ip='localhost'
assert host_machine_ip!='your host machine IP', print('host_machine_ip not set')
showInNetron(brevitas_model_pth,localhost_url=host_machine_ip, port=8081)

Serving '27ml_rf/models/radio_27ml_export.onnx' at http://0.0.0.0:8081


### Converting the model to FINN-ONNX model

In [3]:
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
model = model.transform(ConvertQONNXtoFINN())

finn_model_pth='27ml_rf/models/radio_27ml_finn.onnx'
model.save(finn_model_pth)



### Visualizing FINN-QONNX model. 
Notice how it is much easier to read compared to Brevitas ONNX model

In [4]:
showInNetron(finn_model_pth,localhost_url=host_machine_ip, port=8081)

Stopping http://0.0.0.0:8081
Serving '27ml_rf/models/radio_27ml_finn.onnx' at http://0.0.0.0:8081


## Network Surgery
From the graph above, you can see that the `input` goes through `MultiThreshold`, then `Add`, then `Conv` nodes. The problem we are facing is that FINN is unable to streamline the first `MultiThreshold` node. This is why we need to remove them manually.

Originally

`input`--> `MultiThreshold`--> `Add`--> `Conv` -->...

After network surgery:

`Add`--> `Conv` --> ...

However, during training, the input of `Add` node expects the output of `MultiThreshold`. Therefore, instead of passing through the raw data to our new model, we will manually perform what `MultiThreshold` does to our raw data (quantizing data) outside the model, then pass it through the new model as input.


Further information can be found in this finn's discussion: https://github.com/Xilinx/finn/discussions/420

**Notice**: This network surgery only works for this specific model architechture (VGG-10). If a different architechture is used, the network surgery might be different or even not required. Ultimately, the goal is allow the model to be fully streamlined before generating hardware layers

### Tidying up the model
This step does not change parameters in the model, it is purely for reformatting the labels and renaming layers name for readability.

In [5]:
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.insert_topk import InsertTopK
from qonnx.transformation.general import (
    GiveReadableTensorNames,
    GiveUniqueNodeNames,
)
import onnx
from qonnx.core.datatype import DataType

# tidy up
finn_model = ModelWrapper(finn_model_pth)

finn_model = finn_model.transform(InferShapes())
finn_model = finn_model.transform(InferDataTypes())
finn_model = finn_model.transform(GiveUniqueNodeNames())
finn_model = finn_model.transform(GiveReadableTensorNames())
finn_model.cleanup()

pre_net_surgery_pth='27ml_rf/models/radio_27ml_pre_nw_surgery.onnx'
finn_model.save(pre_net_surgery_pth)

showInNetron(pre_net_surgery_pth,localhost_url=host_machine_ip, port=8081)

Stopping http://0.0.0.0:8081
Serving '27ml_rf/models/radio_27ml_pre_nw_surgery.onnx' at http://0.0.0.0:8081


### Find the `Conv`, `Add`, and the `original input` nodes

In [None]:
finn_model=ModelWrapper(pre_net_surgery_pth)

#Find the first 'Conv' node and store it in 'new_input_node'
new_input_node = finn_model.get_nodes_by_op_type("Conv")[0]   
#Find the input of that 'Conv' node (in this case it is the 'Add' node)
new_input_tensor = finn_model.get_tensor_valueinfo(new_input_node.input[0]) 

#Find the original input node of the model.
old_input_tensor = finn_model.graph.input[0] 

### Remove the original input node, assigning the `Add` node to be the new input

In [None]:
#Remove the old input node, and replace it with the new input tensor ('Add' node)
finn_model.graph.input.remove(old_input_tensor) 
finn_model.graph.input.append(new_input_tensor)

#Find the index of the new input node, and remove everything from index 0 to that index
#In this case, we will be removing index 0 and index 1, which are the 'inp' and 'MultiThreshold' nodes
#So now, the 'Add' node become the model input with index 0, and the 'Conv' node has index 1, and so on...
new_input_index = finn_model.get_node_index(new_input_node)
del finn_model.graph.node[0:new_input_index]

### Remove any `softmax` node and insert a `topK` node

In [None]:
# postprocessing: remove final softmax node from training
# Removing any softmax node if there is any. This is because finn will use topK instead of softmax
# We have already removed the softmax node when building the model, so this can be ignored
softmax_node = finn_model.graph.node[-1]
softmax_in_tensor = finn_model.get_tensor_valueinfo(softmax_node.input[0])
softmax_out_tensor = finn_model.get_tensor_valueinfo(softmax_node.output[0])
finn_model.graph.output.remove(softmax_out_tensor)
finn_model.graph.output.append(softmax_in_tensor)
finn_model.graph.node.remove(softmax_node)

# insert topK node in place of the final softmax node
# topK plays similar role to softmax
# k=1 means it pick only 1 classification with highest predictions value
finn_model = finn_model.transform(InsertTopK(k=1))

### Handling compatibility and set the input datatype

In [None]:
# remove redundant value_info for primary input/output
# othwerwise, newer FINN versions will not accept the model
if finn_model.graph.input[0] in finn_model.graph.value_info:
    finn_model.graph.value_info.remove(finn_model.graph.input[0])
if finn_model.graph.output[0] in finn_model.graph.value_info:
    finn_model.graph.value_info.remove(finn_model.graph.output[0])

# manually set input datatype (not done by brevitas yet)
finnonnx_in_tensor_name = finn_model.graph.input[0].name
finnonnx_model_in_shape = finn_model.get_tensor_shape(finnonnx_in_tensor_name)
finn_model.set_tensor_datatype(finnonnx_in_tensor_name, DataType["INT8"])

### Export the model

In [None]:
print("Input tensor name: %s" % finnonnx_in_tensor_name)
print("Input tensor shape: %s" % str(finnonnx_model_in_shape))
print("Input tensor datatype: %s" % str(finn_model.get_tensor_datatype(finnonnx_in_tensor_name)))

# save modified model that is now ready for the FINN compiler
tidy_model_pth='27ml_rf/models/radio_27ml_tidy.onnx'
finn_model.save(tidy_model_pth)
print("Modified FINN-ready model saved to %s" % tidy_model_pth)

## Visualise the new model
Notice how the new model now has the `Add` node is the input. Everything from the old input node to the `Add` node is now removed

In [None]:
#Visualise the new model
showInNetron(tidy_model_pth,localhost_url=host_machine_ip, port=8081)