# Radio Modulation Classification with FINN - Notebook #3 of 5

### Overview 
In this Notebook we will transform the Brevitas model from last notebook to a tidy.onnx file! 


### FINN Pipeline Map
Throughout these notebooks, you will begin to understand the FINN pipeline! In order the pipeline is:
1. Dataset and Vanilla model
2. Brevitas Model
3. **Transforming the Brevitas Model to tidy.onnx** (you are here)
4. Transforming tidy.onnx to bitstream
5. Loading the bitstream on the FPGA!


In [6]:
from brevitas.export import export_qonnx
from pathlib import Path
import torch
import numpy as np

from torch import nn

# A qnn is a Brevitas version of pytorch's nn. nn stands for neural network.
import brevitas.nn as qnn
from brevitas.quant import Int8Bias
from brevitas.inject.enum import ScalingImplType
from brevitas.inject.defaults import Int8ActPerTensorFloatMinMaxInit

# Adjustable hyperparameters
input_bits = 8
a_bits = 8  # a_bits is the bit width for ReLu
w_bits = 8 # w_bits is the bit width for all the weights
filters_conv = 64
filters_dense = 128

# Setting seeds for reproducibility
torch.manual_seed(0)

np.random.seed(0)

class InputQuantizer(Int8ActPerTensorFloatMinMaxInit):
    #Using 8 bits
    bit_width = input_bits
    #Converting from int8 [-127,128] to finn-float32 [-127.0,128.0]
    #Since our dataset is already quantized before going through the model,
    #   this would only convert the datatype, the values would be the same
    #Refer to this notebook below, which is used for the 2018 dataset,
    #to see the difference between this version and the 2018 version
    #(https://github.com/Xilinx/brevitas-radioml-challenge-21/blob/main/sandbox/notebooks/training_and_evaluation.ipynb)
    min_val = -127.0
    max_val = 128.0

    # 
    scaling_impl_type = ScalingImplType.CONST # Fix the quantization range to [min_val, max_val]

model_class = nn.Sequential(
    # Input quantization layer
    qnn.QuantHardTanh(act_quant=InputQuantizer),

    qnn.QuantConv1d(2, filters_conv, 3, padding=1, weight_bit_width=w_bits, bias=False),
    nn.BatchNorm1d(filters_conv),
    qnn.QuantReLU(bit_width=a_bits),
    nn.MaxPool1d(2),

    qnn.QuantConv1d(filters_conv, filters_conv, 3, padding=1, weight_bit_width=w_bits, bias=False),
    nn.BatchNorm1d(filters_conv),
    qnn.QuantReLU(bit_width=a_bits),
    nn.MaxPool1d(2),

    qnn.QuantConv1d(filters_conv, filters_conv, 3, padding=1, weight_bit_width=w_bits, bias=False),
    nn.BatchNorm1d(filters_conv),
    qnn.QuantReLU(bit_width=a_bits),
    nn.MaxPool1d(2),

    qnn.QuantConv1d(filters_conv, filters_conv, 3, padding=1, weight_bit_width=w_bits,bias=False),
    nn.BatchNorm1d(filters_conv),
    qnn.QuantReLU(bit_width=a_bits),
    nn.MaxPool1d(2),

    qnn.QuantConv1d(filters_conv, filters_conv, 3, padding=1, weight_bit_width=w_bits, bias=False),
    nn.BatchNorm1d(filters_conv),
    qnn.QuantReLU(bit_width=a_bits),
    nn.MaxPool1d(2),

    qnn.QuantConv1d(filters_conv, filters_conv, 3, padding=1, weight_bit_width=w_bits, bias=False),
    nn.BatchNorm1d(filters_conv),
    qnn.QuantReLU(bit_width=a_bits),
    nn.MaxPool1d(2),

    qnn.QuantConv1d(filters_conv, filters_conv, 3, padding=1, weight_bit_width=w_bits, bias=False),
    nn.BatchNorm1d(filters_conv),
    qnn.QuantReLU(bit_width=a_bits),
    nn.MaxPool1d(2),
    
    nn.Flatten(),

    qnn.QuantLinear(filters_conv*8, filters_dense, weight_bit_width=w_bits, bias=False),
    nn.BatchNorm1d(filters_dense),
    qnn.QuantReLU(bit_width=a_bits),

    qnn.QuantLinear(filters_dense, filters_dense, weight_bit_width=w_bits, bias=False),
    nn.BatchNorm1d(filters_dense),
    qnn.QuantReLU(bit_width=a_bits, return_quant_tensor=True),

    qnn.QuantLinear(filters_dense, 27, weight_bit_width=w_bits, bias=True, bias_quant=Int8Bias),
)

model_name='27ml_rf'
model_file_name='27ml_rf_quantized.pth'

Path(model_name).mkdir(exist_ok=True)
chpt_path=model_name+'/'+model_file_name
print(f'Model parameters will be saved in {chpt_path}')

#Load the model back again, or you can change the file to load a different model
#Redefining the model class because we only save the parameters, not the structure of the model
load_path=chpt_path #Change this to a path of a different model
model=model_class #Reinitialize the class of the model, so we can fill the parameters in
model.load_state_dict(torch.load(load_path))
model.to('cuda')
model.eval()
build_dir="27ml_rf/models" #Directory to save model
#Ensuring path exist, otherwise create an empty directory
Path(build_dir).mkdir(exist_ok=True)
export_path=f"{build_dir}/radio_27ml_export.onnx" #Full name of the path of the model with the tail _export.onnx
export_qonnx(model.to('cuda'), torch.randn(1, 2, 1024).to('cuda'), export_path=export_path);

print(f'model is saved in {export_path}')

Model parameters will be saved in 27ml_rf/27ml_rf_quantized.pth
model is saved in 27ml_rf/models/radio_27ml_export.onnx


In [9]:
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.infer_data_layouts import InferDataLayouts
from qonnx.transformation.insert_topk import InsertTopK
from qonnx.transformation.general import (
    ConvertSubToAdd,
    ConvertDivToMul,
    GiveReadableTensorNames,
    GiveUniqueNodeNames,
    SortGraph,
    RemoveUnusedTensors,
    GiveUniqueParameterTensors,
    RemoveStaticGraphInputs,
    ApplyConfig,
)
import onnx
from onnx import TensorProto, helper
from qonnx.core.datatype import DataType
from qonnx.core.modelwrapper import ModelWrapper
import torch
import onnx
from brevitas.export import export_qonnx,export_brevitas_onnx
from qonnx.util.cleanup import cleanup as qonnx_cleanup
from qonnx.core.modelwrapper import ModelWrapper
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.fold_constants import FoldConstants
from qonnx.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames, RemoveStaticGraphInputs
from qonnx.util.cleanup import cleanup_model
import os 
import onnx



#Main goal is to skip the first multithreshold node.
#The first multithreshold node is the quantizing data node.
#Because on FPGA we will preprocess the data before forwarding through model, 
#we will remove the first multithreshold node.
model = ModelWrapper(export_path)
model = model.transform(ConvertQONNXtoFINN())
model.save(f"{build_dir}/radio_27ml_finn.onnx")

# tidy up
finn_model = ModelWrapper(f"{build_dir}/radio_27ml_finn.onnx")

finn_model = finn_model.transform(InferShapes())
finn_model = finn_model.transform(InferDataTypes())
finn_model = finn_model.transform(GiveUniqueNodeNames())
finn_model = finn_model.transform(GiveReadableTensorNames())
finn_model.cleanup()

# extract input quantization thresholds for sw-based quantization
# (in case they were not fixed before training)
input_mt_node = finn_model.get_nodes_by_op_type("MultiThreshold")[0]
input_mt_thresholds = finn_model.get_initializer(input_mt_node.input[1])
print("input quant thresholds")
print(input_mt_thresholds)

# preprocessing: remove input reshape/quantization from graph
new_input_node = finn_model.get_nodes_by_op_type("Conv")[0]   # <---- Change this to the name of the node that has its input node the new desired input graph node.
new_input_tensor = finn_model.get_tensor_valueinfo(new_input_node.input[0]) #<--- Get the input node that lead to Conv node or whatever node you just changed above
old_input_tensor = finn_model.graph.input[0]  # <--- Get the current starting node 
finn_model.graph.input.remove(old_input_tensor) #<--- Remove the current starting node
finn_model.graph.input.append(new_input_tensor) #<--- Make the [Node before the Conv node] the starting node
new_input_index = finn_model.get_node_index(new_input_node)
del finn_model.graph.node[0:new_input_index] #<--- Delete every node before the new starting node

#We dont have softmax node but still run it anyway
# postprocessing: remove final softmax node from training
softmax_node = finn_model.graph.node[-1]
softmax_in_tensor = finn_model.get_tensor_valueinfo(softmax_node.input[0])
softmax_out_tensor = finn_model.get_tensor_valueinfo(softmax_node.output[0])
finn_model.graph.output.remove(softmax_out_tensor)
finn_model.graph.output.append(softmax_in_tensor)
finn_model.graph.node.remove(softmax_node)

# remove redundant value_info for primary input/output
# othwerwise, newer FINN versions will not accept the model
if finn_model.graph.input[0] in finn_model.graph.value_info:
    finn_model.graph.value_info.remove(finn_model.graph.input[0])
if finn_model.graph.output[0] in finn_model.graph.value_info:
    finn_model.graph.value_info.remove(finn_model.graph.output[0])

# insert topK node in place of the final softmax node
# topK plays similar role to softmax
# k=1 means it pick only 1 class with highest predictions value
finn_model = finn_model.transform(InsertTopK(k=1))

# manually set input datatype (not done by brevitas yet)
finnonnx_in_tensor_name = finn_model.graph.input[0].name
finnonnx_model_in_shape = finn_model.get_tensor_shape(finnonnx_in_tensor_name)
finn_model.set_tensor_datatype(finnonnx_in_tensor_name, DataType["INT8"])
print("Input tensor name: %s" % finnonnx_in_tensor_name)
print("Input tensor shape: %s" % str(finnonnx_model_in_shape))
print("Input tensor datatype: %s" % str(finn_model.get_tensor_datatype(finnonnx_in_tensor_name)))

# save modified model that is now ready for the FINN compiler
finn_model.save(f"{build_dir}/radio_27ml_tidy.onnx")
print("Modified FINN-ready model saved to %s" % f"{build_dir}/radio_27ml_tidy.onnx")



input quant thresholds
[[-127.5 -126.5 -125.5 -124.5 -123.5 -122.5 -121.5 -120.5 -119.5 -118.5
  -117.5 -116.5 -115.5 -114.5 -113.5 -112.5 -111.5 -110.5 -109.5 -108.5
  -107.5 -106.5 -105.5 -104.5 -103.5 -102.5 -101.5 -100.5  -99.5  -98.5
   -97.5  -96.5  -95.5  -94.5  -93.5  -92.5  -91.5  -90.5  -89.5  -88.5
   -87.5  -86.5  -85.5  -84.5  -83.5  -82.5  -81.5  -80.5  -79.5  -78.5
   -77.5  -76.5  -75.5  -74.5  -73.5  -72.5  -71.5  -70.5  -69.5  -68.5
   -67.5  -66.5  -65.5  -64.5  -63.5  -62.5  -61.5  -60.5  -59.5  -58.5
   -57.5  -56.5  -55.5  -54.5  -53.5  -52.5  -51.5  -50.5  -49.5  -48.5
   -47.5  -46.5  -45.5  -44.5  -43.5  -42.5  -41.5  -40.5  -39.5  -38.5
   -37.5  -36.5  -35.5  -34.5  -33.5  -32.5  -31.5  -30.5  -29.5  -28.5
   -27.5  -26.5  -25.5  -24.5  -23.5  -22.5  -21.5  -20.5  -19.5  -18.5
   -17.5  -16.5  -15.5  -14.5  -13.5  -12.5  -11.5  -10.5   -9.5   -8.5
    -7.5   -6.5   -5.5   -4.5   -3.5   -2.5   -1.5   -0.5    0.5    1.5
     2.5    3.5    4.5    5.5    6.5    7