In [2]:
import shutil
import os
import re
from datetime import datetime
import tarfile
import zipfile

from torchvision import datasets
from torchvision.models import list_models
import torchvision.models as models
import torch

def log_message(message, level="info"):
    """
    Log a message, with support for info, warning, and error levels.
    If the level is "error", this function raises a ValueError with the message.

    Parameters:
        message (str): The message to log.
        level (str): The level of the message ("info", "warning", or "error").
    """
    if level == "info":
        print(f"[INFO] {message}")
    elif level == "warning":
        print(f"[WARNING] {message}")
    elif level == "error":
        print(f"[ERROR] {message}")
        raise ValueError(message)  # Raise an error if level is "error"

def sanitize_string(s):
    """
    Convert a string to lowercase, strip special characters, and replace spaces with underscores.

    Parameters:
        s (str): The input string to sanitize.

    Returns:
        str: The sanitized string.
    """
    s = s.lower()
    s = s.replace(' ', '_')
    s = re.sub(r'[^a-z0-9_]', '', s)
    return s

def ensure_directory_exists(dir_path):
    """
    Ensure the specified directory exists; create it if it does not, including required subdirectories.
    Set permissions to 777 for the main directory and its contents recursively.

    Parameters:
        dir_path (str): The path of the directory to ensure exists.

    Returns:
        str: The absolute path of the directory.
    """
    os.makedirs(dir_path, exist_ok=True)
    
    # List of required subdirectories
    subdirs = ['dataset', 'src', 'checkpoints', 'output']
    dirs_to_clear = ['checkpoints', 'output']    
    
    # Create each subdirectory if it doesn’t exist
    for subdir in subdirs:
        subdir_path = os.path.join(dir_path, subdir)
        os.makedirs(subdir_path, exist_ok=True)
    
    # Set permissions recursively for main directory and all subdirectories/files
    for root, dirs, files in os.walk(dir_path):
        os.chmod(root, 0o777)
        for d in dirs:
            os.chmod(os.path.join(root, d), 0o777)
        for f in files:
            os.chmod(os.path.join(root, f), 0o777)
    
    # Clear all contents in the specified directories
    for clear_dir in dirs_to_clear:
        clear_path = os.path.join(dir_path, clear_dir)
        for filename in os.listdir(clear_path):
            file_path = os.path.join(clear_path, filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)  # Remove file or link
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)  # Remove directory and all contents
            except Exception as e:
                print(f'Failed to delete {file_path}. Reason: {e}')

    return os.path.abspath(dir_path)

def check_file_exists(file_path):
    """
    Check if a file exists at the specified path.

    Parameters:
        file_path (str): The path of the file to check.

    Returns:
        bool: True if the file exists, False otherwise.
    """
    return os.path.isfile(file_path)

def is_archive_file(file_path):
    """
    Check if a file is a valid archive (zip or tar).

    Parameters:
        file_path (str): The path of the file to check.

    Returns:
        bool: True if the file is a valid archive, False otherwise.
    """
    return zipfile.is_zipfile(file_path) or tarfile.is_tarfile(file_path)

In [3]:
from finn.util.basic import pynq_part_map


def setup_project(prj_name, brd_name, model_type, project_folder=None, model_py_file=None, model_pth_file=None, torch_vision_model=None, finn_pretrained_model=None, dataset_type=None, custom_dataset=None, torch_vision_dataset=None):
    """
    Set up a project with a specified structure, including creating necessary directories, 
    validating model type, and checking for necessary files.

    Parameters:
        finn_pretrained_model: 
        prj_name (str): The name of the project.
        model_type (str): The type of model ('untrained', 'custom_pretrained', or 'torch_vision_pretrained').
        project_folder (str, optional): The main folder for the project. A new folder is generated if not provided.
        model_py_file (str, optional): The filename of the Python script defining the model architecture, required for 'untrained' and 'custom_pretrained' models.
        model_pth_file (str, optional): The filename of the .pth file with pre-trained weights, required for 'custom_pretrained' models.
        torch_vision_model (str, optional): The name of the TorchVision model to load if 'torch_vision_pretrained' is selected.
        dataset_type (str, optional): Type of dataset for 'untrained' models ('torch_vision_dataset' or 'custom_dataset').
        custom_dataset (str, optional): Path to the custom dataset file for 'untrained' models with 'custom_dataset' dataset type.
        torch_vision_dataset (str, optional): Name of the TorchVision dataset class for 'untrained' models with 'torch_vision_dataset' dataset type.

    Returns:
        dict: A dictionary with project setup information.
    """
    log_message("Setting up project")
    working_folder = "/home/fastqnn/finn/notebooks/Fast-QNN/outputs/txaviour/"
    prj_info = {}
    
    # Ensure Project Name is provided
    if not prj_name:
        log_message("Project Name is required", level="error")

    # Sanitize project name and set project info
    prj_name_stripped = sanitize_string(prj_name)
    display_name = prj_name
    prj_info["Display_Name"] = display_name
    prj_info["Stripped_Name"] = prj_name_stripped

    if not project_folder:
        # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # for production mode
        timestamp = "0"
        project_folder = f"{working_folder}{prj_name_stripped}_{timestamp}"
    
    folder_path = ensure_directory_exists(project_folder)
    prj_info["Folder"] = folder_path
    
    available_boards = pynq_part_map.keys()
    if brd_name in available_boards:
        prj_info["Board_name"] = brd_name
    else:
        log_message(f"'{brd_name}' is not a valid board name. Available board names are: {available_boards}", level="error")

    # Validate Model Type
    valid_model_types = ["untrained", "custom_pretrained", "torch_vision_pretrained", "finn_pretrained"]
    if model_type not in valid_model_types:
        log_message(f"Model Type must be one of {valid_model_types}", level="error")
    prj_info['Model_Type'] = model_type
    
    # Handle Model Type specific requirements
    if model_type in ["untrained", "custom_pretrained"]:
        if not model_py_file or not check_file_exists(os.path.join(project_folder, 'src', model_py_file)):
            log_message(f"Model Py File '{model_py_file}' does not exist in '{project_folder}'", level="error")
        prj_info["Model_Py_File"] = model_py_file

    if model_type == "custom_pretrained":
        if not model_pth_file or not check_file_exists(os.path.join(project_folder, 'src', model_pth_file)):
            log_message(f"Model Pth File '{model_pth_file}' does not exist in '{project_folder}'", level="error")
        prj_info["Model_Pth_File"] = model_pth_file

    if model_type == "torch_vision_pretrained":
        available_models = list_models(module=models)
        if not torch_vision_model or torch_vision_model not in available_models:
            log_message(f"Torch Vision Model must be one of: {available_models}", level="error")
        prj_info["Torch_Vision_Model"] = torch_vision_model

    if model_type == "finn_pretrained":
        available_models = ["cnv_1w1a", "cnv_1w2a", "cnv_2w2a", "lfc_1w1a", "lfc_1w2a", "sfc_1w1a", "sfc_1w2a", "sfc_2w2a", "tfc_1w1a", "tfc_1w2a", "tfc_2w2a", "quant_mobilenet_v1_4b"]
        if not finn_pretrained_model or finn_pretrained_model not in available_models:
            log_message(f"Finn Pretrained Model must be one of: {available_models}", level="error")
        prj_info["Finn_Pretrained_Model"] = finn_pretrained_model

    # Handle Dataset requirements for Untrained models
    if model_type == "untrained":
        if dataset_type not in ["torch_vision_dataset", "custom_dataset"]:
            log_message("Dataset Type must be either 'Torch Vision' or 'Custom Dataset' for Untrained models", level="error")
        prj_info["Dataset_Type"] = dataset_type

        if dataset_type == "custom_dataset":
            custom_dataset_path = os.path.join(project_folder, 'dataset', custom_dataset)
            if not custom_dataset or not check_file_exists(custom_dataset_path) or not is_archive_file(custom_dataset_path):
                log_message(f"Custom Dataset '{custom_dataset}' must exist in '{project_folder}' and be an archive file (zip or tar)", level="error")
            prj_info["Custom_Dataset"] = custom_dataset

        elif dataset_type == "torch_vision_dataset":
            available_datasets = [cls_name.lower() for cls_name in dir(datasets) if not cls_name.startswith('_')]
            if not torch_vision_dataset or torch_vision_dataset.lower() not in available_datasets:
                log_message(f"Torch Vision Dataset must be one of: {available_datasets}", level="error")
            prj_info["Torch_Vision_Dataset"] = torch_vision_dataset
            
    log_message(f"Project setup complete. {prj_name} has been initialized.")
    return prj_info

In [4]:
from finn.util.test import get_test_model_trained


def load_pretrained_model(model_name, model_type, src_folder, initial_channels = 3, max_size = 4096):
    """
    Loads a pre-trained model from TorchVision and ensures all downloads are in the src folder of the Project.

    Parameters:
        model_type: 
        initial_channels: 
        max_size: 
        model_name (str): The name of the pre-trained model to load (e.g., 'alexnet', 'resnet50').
        src_folder (str): The folder where model downloads will be stored. Default is 'src'.
                
    Returns:
        torch.nn.Module: The loaded pre-trained model.
    """
    log_message(f"Loading {model_type} Model: {model_name}")
    # Ensure the src folder exists
    os.makedirs(src_folder, exist_ok=True)
    # Set TORCH_HOME to the src folder to store the model downloads there
    os.environ['TORCH_HOME'] = src_folder
    pretrained_model = None
    if model_type == "torch_vision_pretrained":
        # Load the model from torch.hub with the specified model name
        pretrained_model = torch.hub.load('pytorch/vision:v0.10.0', model_name, pretrained=True)
    elif model_type == "finn_pretrained":
        if model_name == "cnv_1w1a":
            pretrained_model = get_test_model_trained("CNV", 1, 1)
        elif model_name == "cnv_1w2a":
            pretrained_model = get_test_model_trained("CNV", 1, 2)
        elif model_name == "cnv_2w2a":
            pretrained_model = get_test_model_trained("CNV", 2, 2)
        elif model_name == "lfc_1w1a":
            pretrained_model = get_test_model_trained("LFC", 1, 1)
        elif model_name == "lfc_1w2a":
            pretrained_model = get_test_model_trained("LFC", 1, 2)
        elif model_name == "sfc_1w1a":
            pretrained_model = get_test_model_trained("SFC", 1, 1)
        elif model_name == "sfc_1w2a":
            pretrained_model = get_test_model_trained("SFC", 1, 2)
        elif model_name == "sfc_2w2a":
            pretrained_model = get_test_model_trained("SFC", 2, 2)
        elif model_name == "tfc_1w1a":
            pretrained_model = get_test_model_trained("TFC", 1, 1)
        elif model_name == "tfc_1w2a":
            pretrained_model = get_test_model_trained("TFC", 1, 2)
        elif model_name == "tfc_2w2a":
            pretrained_model = get_test_model_trained("TFC", 2, 2)
        elif model_name == "quant_mobilenet_v1_4b":
            pretrained_model = get_test_model_trained("mobilenet", 4, 4)
    # List of common input shapes to test first (both square and non-square)
    common_shapes = [
        (1, initial_channels, 32, 32),  # Typical for many models
        (1, initial_channels, 28, 28),  # Typical for many models
        (1, initial_channels, 224, 224),  # Typical for many models
        (1, initial_channels, 299, 299),  # For models like Inception
        (1, initial_channels, 128, 128),  # Smaller size
        (1, initial_channels, 256, 256),  # Larger square size
        (1, initial_channels, 320, 240),  # Common non-square size
        (1, initial_channels, 240, 320),  # Non-square (swapped dimensions)
        (1, initial_channels, 256, 128),  # Non-square
        (1, initial_channels, 128, 256),  # Non-square (swapped)
    ]

    # Try common shapes first
    for shape in common_shapes:
        try:
            dummy_input = torch.rand(*shape)
            pretrained_model(dummy_input)
            log_message(f"Compatible common input shape found: {shape}")
            return pretrained_model, shape
        except RuntimeError:
            continue

    # If no common shape worked, test all possible square and non-square shapes up to max_size
    for width in range(1, max_size + 1, 1):  # Step by 16 for efficiency
        for height in range(1, max_size + 1, 1):
            try:
                dummy_input = torch.rand(1, initial_channels, width, height)
                pretrained_model(dummy_input)
                pretrained_model_input_shape = (1, initial_channels, width, height)
                log_message(f"Compatible input shape found: {pretrained_model_input_shape}")
                return pretrained_model, pretrained_model_input_shape
            except RuntimeError:
                continue

    log_message("Could not determine a compatible input shape within the specified range.", level="error")

In [5]:
prj_name_input = "AlexNet 1W1A Test"
board_name_input = "Pynq-Z2"
prj_folder_input = sanitize_string(prj_name_input)
model_type_input = "torch_vision_pretrained"
torch_vision_model_input = "alexnet"
Project_Info = setup_project(prj_name=prj_name_input, brd_name=board_name_input, model_type=model_type_input, torch_vision_model=torch_vision_model_input)

input_model= None
input_model_shape = None

if Project_Info['Model_Type'] == "untrained":
    log_message("Training for untrained models are not supported at the moment!", level="error")
elif Project_Info['Model_Type'] == "custom_pretrained":
    log_message("Custom Pretrained models are not supported at the moment!", level="error")
elif Project_Info['Model_Type'] == "torch_vision_pretrained" or Project_Info['Model_Type'] == "finn_pretrained":
    pretrained_folder = os.path.join(Project_Info['Folder'],"src")
    input_model, input_model_shape = load_pretrained_model(Project_Info['Torch_Vision_Model'], Project_Info['Model_Type'], pretrained_folder)
else:
    log_message("Unsupported Model Type", level="error")

[INFO] Setting up project
[INFO] Project setup complete. AlexNet 1W1A Test has been initialized.
[INFO] Loading torch_vision_pretrained Model: alexnet


Using cache found in /home/fastqnn/finn/notebooks/Fast-QNN/outputs/txaviour/alexnet_1w1a_test_0/src/hub/pytorch_vision_v0.10.0


[INFO] Compatible common input shape found: (1, 3, 224, 224)


In [177]:
def set_onnx_checkpoint(project_info, suffix):
    """
    Generates the export path for ONNX files based on a specified suffix.

    Parameters:
        project_info (dict): Dictionary containing project information (e.g., 'Folder' and 'Stripped_Name').
        suffix (str): The suffix to append to the exported ONNX filename (e.g., "model1" for "model1_export.onnx").

    Returns:
        str: The full path to the export file.
    """
    log_message(f"Saving Checkpoint: {suffix}")
    suffix = sanitize_string(suffix)
    filename = f"{project_info['Stripped_Name']}_{suffix}.onnx"
    return os.path.join(project_info['Folder'], "checkpoints", filename)


In [178]:
from brevitas.export import export_qonnx
from qonnx.util.cleanup import cleanup as qonnx_cleanup

export_onnx_path = set_onnx_checkpoint(Project_Info,"Brevitas Export")
export_qonnx(input_model, torch.randn(input_model_shape), export_onnx_path, opset_version=9)
qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)

[INFO] Saving Checkpoint: Brevitas Export


In [179]:
from qonnx.core.modelwrapper import ModelWrapper
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN

model = ModelWrapper(export_onnx_path)
model = model.transform(ConvertQONNXtoFINN())
model.save(set_onnx_checkpoint(Project_Info,"QONNX to FINN"))

[INFO] Saving Checkpoint: QONNX to FINN


In [180]:
from qonnx.transformation.double_to_single_float import DoubleToSingleFloat
from qonnx.transformation.remove import RemoveIdentityOps
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.fold_constants import FoldConstants
from qonnx.transformation.general import (GiveReadableTensorNames,
                                          GiveUniqueNodeNames,
                                          RemoveStaticGraphInputs,
                                          RemoveUnusedTensors, GiveUniqueParameterTensors, SortGraph)
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.infer_data_layouts import InferDataLayouts

def tidy_up_transforms(input_tidy_model, save_name):
    """
    Applies a series of transformations to a model and saves the resulting model.

    Parameters:
        input_tidy_model (ModelWrapper): The model to transform.
        save_name (str): The path to save the transformed model.
    
    Returns:
        ModelWrapper: The transformed model.
    """

    # Apply transformations
    input_tidy_model = input_tidy_model.transform(GiveUniqueParameterTensors())
    input_tidy_model = input_tidy_model.transform(InferShapes())
    input_tidy_model = input_tidy_model.transform(FoldConstants())
    input_tidy_model = input_tidy_model.transform(GiveUniqueNodeNames())
    input_tidy_model = input_tidy_model.transform(GiveReadableTensorNames())
    input_tidy_model = input_tidy_model.transform(InferDataTypes())
    input_tidy_model = input_tidy_model.transform(RemoveStaticGraphInputs())
    input_tidy_model = input_tidy_model.transform(InferDataLayouts())
    input_tidy_model = input_tidy_model.transform(RemoveUnusedTensors())
    input_tidy_model = input_tidy_model.transform(DoubleToSingleFloat())
    input_tidy_model = input_tidy_model.transform(SortGraph())
    input_tidy_model = input_tidy_model.transform(RemoveIdentityOps())

    # Save the transformed model
    input_tidy_model.save(save_name)

    return input_tidy_model

In [None]:
model = tidy_up_transforms(model, set_onnx_checkpoint(Project_Info,"Tidy ONNX Post Finn"))

[INFO] Saving Checkpoint: Tidy ONNX Post Finn


In [None]:
from qonnx.core.datatype import DataType
from qonnx.transformation.merge_onnx_models import MergeONNXModels
from finn.util.pytorch import ToTensor

log_message("Skipping Pre-Processing. Will be expecting the user to handle it in application!", level="warning")

In [None]:
"""
global_inp_name = model.graph.input[0].name
ishape = model.get_tensor_shape(global_inp_name)
# preprocessing: torchvision's ToTensor divides uint8 inputs by 255
totensor_pyt = ToTensor()
chkpt_preproc_name = set_onnx_checkpoint(Project_Info,"Pre Proc ONNX Finn")
export_qonnx(totensor_pyt, torch.randn(ishape), chkpt_preproc_name)
qonnx_cleanup(chkpt_preproc_name, out_file=chkpt_preproc_name)
pre_model = ModelWrapper(chkpt_preproc_name)
pre_model = pre_model.transform(ConvertQONNXtoFINN())
model = model.transform(MergeONNXModels(pre_model))
# add input quantization annotation: UINT8 for all BNN-PYNQ models
global_inp_name = model.graph.input[0].name
model.set_tensor_datatype(global_inp_name, DataType["UINT8"])
"""

In [None]:
from qonnx.transformation.insert_topk import InsertTopK

model = model.transform(InsertTopK(k=1))
model.save(set_onnx_checkpoint(Project_Info,"Post Processing"))
model = tidy_up_transforms(model, set_onnx_checkpoint(Project_Info,"Tidy Post PrePost Proc"))

In [None]:
from onnx import TensorProto
from qonnx.transformation.base import Transformation
from qonnx.transformation.infer_data_layouts import InferDataLayouts
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.infer_shapes import InferShapes


class CustomReluBreakdown(Transformation):
    """Replace ReLU nodes with an equivalent series of simpler ONNX operations."""

    def apply(self, transform_model):
        graph = transform_model.graph
        node_ind = 0
        graph_modified = False
        for n in graph.node:
            node_ind += 1
            if n.op_type == "Relu":
                # Get the input and output names
                input_name = n.input[0]
                output_name = n.output[0]

                # Create a constant tensor for zero
                zero_const_name = f"{n.name}_zero_const"
                zero_tensor = oh.make_tensor(
                    zero_const_name,  # Name of the constant tensor
                    TensorProto.FLOAT,  # Data type
                    dims=[1],  # Shape of the tensor
                    vals=[0.0]  # Value of the tensor
                )
                transform_model.graph.initializer.append(zero_tensor)

                # Create the Max node that mimics ReLU behavior
                max_node = oh.make_node(
                    "Max",
                    inputs=[input_name, zero_const_name],
                    outputs=[output_name],
                    name=f"{n.name}_max"
                )

                # Insert the Max node into the graph
                graph.node.insert(node_ind, max_node)

                # Remove the original ReLU node
                graph.node.remove(n)
                graph_modified = True

        # Infer shapes after modifying the graph
        transform_model = transform_model.transform(InferShapes())
        return (transform_model, graph_modified)

In [None]:
from onnx import helper as oh
from onnx import TensorProto

class ReplaceMaxWithArithmetic(Transformation):
    """
    Custom transformation to replace 'Max' nodes with a series of operations
    using addition, subtraction, absolute value, and division.
    """

    def apply(self, transform_model):
        graph = transform_model.graph
        graph_modified = False

        for node in graph.node:
            if node.op_type == "Max":
                # Check if the node has exactly 2 inputs for binary Max operation
                if len(node.input) == 2:
                    max_input_1 = node.input[0]
                    max_input_2 = node.input[1]
                    max_output = node.output[0]

                    # Create intermediate nodes to simulate the Max operation
                    # Create a node for computing (input_1 - input_2)
                    diff_output = max_output + "_diff"
                    diff_node = oh.make_node(
                        "Sub", [max_input_1, max_input_2], [diff_output], name=max_output + "_sub"
                    )

                    # Create a node for computing (abs(input_1 - input_2))
                    abs_output = max_output + "_abs"
                    abs_node = oh.make_node(
                        "Abs", [diff_output], [abs_output], name=max_output + "_abs"
                    )

                    # Create a node for computing (input_1 + input_2)
                    sum_inputs_output = max_output + "_sum_inputs"
                    sum_inputs_node = oh.make_node(
                        "Add", [max_input_1, max_input_2], [sum_inputs_output], name=max_output + "_sum_inputs"
                    )

                    # Create a node for computing ((input_1 + input_2) + abs(input_1 - input_2))
                    final_sum_output = max_output + "_final_sum"
                    final_sum_node = oh.make_node(
                        "Add", [sum_inputs_output, abs_output], [final_sum_output], name=max_output + "_final_sum"
                    )

                    # Create a constant tensor for the value 2 and add it to the model's initializers
                    const_two_name = max_output + "_const_two"
                    const_two = oh.make_tensor(
                        name=const_two_name,
                        data_type=TensorProto.FLOAT,
                        dims=[],
                        vals=[2.0]
                    )
                    transform_model.graph.initializer.append(const_two)

                    # Create a node for the division to get the final max output
                    div_node = oh.make_node(
                        "Div", [final_sum_output, const_two_name], [max_output], name=max_output + "_div"
                    )

                    # Insert new nodes into the graph
                    graph.node.extend([diff_node, abs_node, sum_inputs_node, final_sum_node, div_node])

                    # Remove the original Max node
                    graph.node.remove(node)
                    graph_modified = True

        transform_model = transform_model.transform(InferShapes())
        return (transform_model, graph_modified)

In [None]:
from onnx import helper as oh, TensorProto
from qonnx.transformation.base import Transformation
from qonnx.transformation.infer_shapes import InferShapes

class ReplaceAbsWithLinearArithmeticNoBranching(Transformation):
    """
    Custom transformation to replace 'Abs' nodes with a linear sequence of arithmetic operations
    without branching.
    """

    def apply(self, transform_model):
        graph = transform_model.graph
        graph_modified = False

        for node in graph.node:
            if node.op_type == "Abs":
                # Check if the node has exactly 1 input for Abs operation
                abs_input = node.input[0]
                abs_output = node.output[0]

                # Create unique names for the intermediate outputs
                neg_output = abs_output + "_neg"
                add_output = abs_output + "_add"
                mul_output = abs_output + "_mul"

                # Create a node for computing (input * -1)
                neg_node = oh.make_node(
                    "Mul", [abs_input, abs_output + "_const_neg1"], [neg_output], name=abs_output + "_mul_neg"
                )

                # Create a constant tensor for -1
                const_neg_one = oh.make_tensor(
                    name=abs_output + "_const_neg1",
                    data_type=TensorProto.FLOAT,
                    dims=[],
                    vals=[-1.0]
                )
                transform_model.graph.initializer.append(const_neg_one)

                # Create a node for computing (input + neg_input)
                add_node = oh.make_node(
                    "Add", [abs_input, neg_output], [add_output], name=abs_output + "_add"
                )

                # Create a node for computing ((input + neg_input) * 0.5)
                mul_node = oh.make_node(
                    "Mul", [add_output, abs_output + "_const_0.5"], [mul_output], name=abs_output + "_mul_half"
                )

                # Create a constant tensor for 0.5
                const_half = oh.make_tensor(
                    name=abs_output + "_const_0.5",
                    data_type=TensorProto.FLOAT,
                    dims=[],
                    vals=[0.5]
                )
                transform_model.graph.initializer.append(const_half)

                # Rename the final output to match the original Abs output
                final_node = oh.make_node(
                    "Identity", [mul_output], [abs_output], name=abs_output + "_identity"
                )

                # Insert the new nodes into the graph
                graph.node.extend([neg_node, add_node, mul_node, final_node])

                # Remove the original Abs node
                graph.node.remove(node)
                graph_modified = True

        transform_model = transform_model.transform(InferShapes())
        return (transform_model, graph_modified)


In [None]:
from qonnx.transformation.batchnorm_to_affine import BatchNormToAffine
from qonnx.transformation.general import ConvertSubToAdd, ConvertDivToMul
from finn.transformation.streamline import Streamline, RoundAndClipThresholds, CollapseRepeatedMul, ConvertSignToThres, \
    CollapseRepeatedAdd
from qonnx.transformation.lower_convs_to_matmul import LowerConvsToMatMul
from qonnx.transformation.bipolar_to_xnor import ConvertBipolarMatMulToXnorPopcount
from finn.transformation.streamline.absorb import (AbsorbTransposeIntoMultiThreshold,
                                                   AbsorbScalarMulAddIntoTopK,
                                                   AbsorbSignBiasIntoMultiThreshold,
                                                   AbsorbAddIntoMultiThreshold,
                                                   AbsorbMulIntoMultiThreshold, FactorOutMulSignMagnitude,
                                                   Absorb1BitMulIntoMatMul, Absorb1BitMulIntoConv)
from finn.transformation.streamline.reorder import MakeMaxPoolNHWC, MoveScalarLinearPastInvariants, MoveAddPastMul, \
    MoveScalarAddPastMatMul, MoveAddPastConv, MoveScalarMulPastMatMul, MoveScalarMulPastConv, \
    MoveMaxPoolPastMultiThreshold, MoveLinearPastEltwiseAdd, MoveLinearPastFork
from finn.transformation.streamline.absorb import AbsorbConsecutiveTransposes

def streamline_transforms(input_streamline_model, save_name):
    """
    Applies a series of streamlining transformations to a model and saves the resulting model.

    Parameters:
        input_streamline_model (ModelWrapper): The model to transform.
        save_name (str): The path to save the transformed model.
    
    Returns:
        ModelWrapper: The transformed model.
    """
    for iter_id in range(5):
        # Apply streamlining transformations    
        input_streamline_model = input_streamline_model.transform(CustomReluBreakdown())
        input_streamline_model = input_streamline_model.transform(ReplaceMaxWithArithmetic())
        input_streamline_model = input_streamline_model.transform(ReplaceAbsWithLinearArithmeticNoBranching())
        input_streamline_model = input_streamline_model.transform(AbsorbSignBiasIntoMultiThreshold())
        
        input_streamline_model = input_streamline_model.transform(ConvertSubToAdd())
        input_streamline_model = input_streamline_model.transform(ConvertDivToMul())
        input_streamline_model = input_streamline_model.transform(CollapseRepeatedMul())
        input_streamline_model = input_streamline_model.transform(BatchNormToAffine())
        input_streamline_model = input_streamline_model.transform(ConvertSignToThres())
        input_streamline_model = input_streamline_model.transform(MoveAddPastMul())
        input_streamline_model = input_streamline_model.transform(MoveScalarAddPastMatMul())
        input_streamline_model = input_streamline_model.transform(MoveAddPastConv())
        input_streamline_model = input_streamline_model.transform(MoveScalarMulPastMatMul())
        input_streamline_model = input_streamline_model.transform(MoveScalarMulPastConv())
        input_streamline_model = input_streamline_model.transform(MoveAddPastMul())
        input_streamline_model = input_streamline_model.transform(MoveScalarLinearPastInvariants())
        input_streamline_model = input_streamline_model.transform(CollapseRepeatedAdd())
        input_streamline_model = input_streamline_model.transform(AbsorbAddIntoMultiThreshold())
        input_streamline_model = input_streamline_model.transform(FactorOutMulSignMagnitude())
        input_streamline_model = input_streamline_model.transform(MoveMaxPoolPastMultiThreshold())
        input_streamline_model = input_streamline_model.transform(AbsorbMulIntoMultiThreshold())
        input_streamline_model = input_streamline_model.transform(Absorb1BitMulIntoMatMul())
        input_streamline_model = input_streamline_model.transform(Absorb1BitMulIntoConv())
        input_streamline_model = input_streamline_model.transform(AbsorbMulIntoMultiThreshold())
        
        input_streamline_model = input_streamline_model.transform(Streamline())
        input_streamline_model = input_streamline_model.transform(AbsorbConsecutiveTransposes())
        input_streamline_model = input_streamline_model.transform(LowerConvsToMatMul())
        input_streamline_model = input_streamline_model.transform(MakeMaxPoolNHWC())
        input_streamline_model = input_streamline_model.transform(AbsorbTransposeIntoMultiThreshold())
        input_streamline_model = input_streamline_model.transform(ConvertBipolarMatMulToXnorPopcount())
        input_streamline_model = input_streamline_model.transform(Streamline())
        input_streamline_model = input_streamline_model.transform(AbsorbScalarMulAddIntoTopK())
        input_streamline_model = input_streamline_model.transform(RoundAndClipThresholds())
        
        input_streamline_model = input_streamline_model.transform(MoveLinearPastEltwiseAdd())
        input_streamline_model = input_streamline_model.transform(MoveLinearPastFork())
        # Save the transformed model after tidying up
        input_streamline_model = tidy_up_transforms(input_streamline_model, save_name)
    

    #need_lowering = len(model.get_nodes_by_op_type("Conv")) > 0
    #if need_lowering:
    #    do transforms
    #    model = model.transform(MakeMaxPoolNHWC()) # how to handle repeats that are required. use the ifs?


    return input_streamline_model

In [None]:
model = streamline_transforms(model, set_onnx_checkpoint(Project_Info,"Streamlined ONNX"))

In [None]:
from finn.transformation.move_reshape import RemoveCNVtoFCFlatten
from finn.transformation.fpgadataflow.convert_to_hw_layers import (
    InferBinaryMatrixVectorActivation,
    InferQuantizedMatrixVectorActivation,
    InferLabelSelectLayer,
    InferThresholdingLayer,
    InferStreamingMaxPool,
    InferConvInpGen,
    InferAddStreamsLayer,
    InferChannelwiseLinearLayer,
    InferConcatLayer,
    InferDuplicateStreamsLayer,
    InferGlobalAccPoolLayer,
    InferLookupLayer,
    InferPool,
    InferStreamingEltwise,
    InferUpsample,
    InferVectorVectorActivation
)

def to_hw_transforms(input_hw_model, save_name):
    """
    Applies a comprehensive series of hardware-oriented transformations to a model and saves the resulting model.

    Parameters:
        input_hw_model (ModelWrapper): The model to transform.
        save_name (str): The path to save the transformed model.
    
    Returns:
        ModelWrapper: The transformed model.
    """

    # Apply a comprehensive list of hardware transformations
    input_hw_model = input_hw_model.transform(InferBinaryMatrixVectorActivation())
    input_hw_model = input_hw_model.transform(InferQuantizedMatrixVectorActivation())
    input_hw_model = input_hw_model.transform(InferLabelSelectLayer())
    input_hw_model = input_hw_model.transform(InferThresholdingLayer())
    input_hw_model = input_hw_model.transform(InferConvInpGen())
    input_hw_model = input_hw_model.transform(InferStreamingMaxPool())
    input_hw_model = input_hw_model.transform(InferAddStreamsLayer())
    input_hw_model = input_hw_model.transform(InferChannelwiseLinearLayer())
    input_hw_model = input_hw_model.transform(InferConcatLayer())
    input_hw_model = input_hw_model.transform(InferDuplicateStreamsLayer())
    input_hw_model = input_hw_model.transform(InferGlobalAccPoolLayer())
    input_hw_model = input_hw_model.transform(InferLookupLayer())
    input_hw_model = input_hw_model.transform(InferPool())
    input_hw_model = input_hw_model.transform(InferStreamingEltwise())
    input_hw_model = input_hw_model.transform(InferUpsample())
    input_hw_model = input_hw_model.transform(InferVectorVectorActivation())
    
    input_hw_model = input_hw_model.transform(RemoveCNVtoFCFlatten())
    input_hw_model = input_hw_model.transform(AbsorbConsecutiveTransposes())

    # Apply final tidy-up transformations
    input_hw_model = tidy_up_transforms(input_hw_model, save_name)

    # Save the final transformed model
    input_hw_model.save(save_name)

    return input_hw_model

In [None]:
model = to_hw_transforms(model, set_onnx_checkpoint(Project_Info,"To HW Layers"))

In [None]:
from finn.transformation.fpgadataflow.create_dataflow_partition import (
    CreateDataflowPartition,
)
from qonnx.custom_op.registry import getCustomOp

def dataflow_partitioning(input_data_model, save_name):
    """
    Applies dataflow partitioning transformation to the model and saves the resulting parent and dataflow models.

    Parameters:
        input_data_model (ModelWrapper): The model to transform.
        save_name (str): The directory to save the transformed models.

    Returns:
        ModelWrapper: The transformed parent model.
    """
    # Apply dataflow partitioning
    parent_model = input_data_model.transform(CreateDataflowPartition())
    parent_model.save(save_name)
    
    # Retrieve the dataflow partition model filename
    sdp_node = parent_model.get_nodes_by_op_type("StreamingDataflowPartition")[0]
    sdp_node = getCustomOp(sdp_node)
    dataflow_model_filename = sdp_node.get_nodeattr("model")
    # Load and return the dataflow model
    dataflow_model = ModelWrapper(dataflow_model_filename)
    
    return dataflow_model

In [None]:
model = dataflow_partitioning(model, set_onnx_checkpoint(Project_Info,"Dataflow Partition Parent Model"))
model.save(set_onnx_checkpoint(Project_Info,"Dataflow Partition Streaming Model"))

In [None]:
from finn.transformation.fpgadataflow.specialize_layers import SpecializeLayers

def specialize_layers_transform(input_specialize_model, board_name, save_name):
    """
    Applies layer specialization transformation to a dataflow model for the specified FPGA part and saves the resulting model.

    Parameters:
        input_specialize_model (ModelWrapper): The dataflow model to transform.
        board_name (str): The FPGA board for which to specialize the layers.
        save_name (str): The path to save the specialized model.
    
    Returns:
        ModelWrapper: The transformed and specialized dataflow model.
    """
    fpga_part = pynq_part_map[board_name]
    # Apply specialization for FPGA layers
    input_specialize_model = input_specialize_model.transform(SpecializeLayers(fpga_part))

    # Save the specialized model
    input_specialize_model.save(save_name)

    return input_specialize_model

In [None]:
model = specialize_layers_transform(model, Project_Info['Board_name'], set_onnx_checkpoint(Project_Info,f"Specialize Model Layers to {Project_Info['Board_name']}"))

In [None]:
def folding_transform(input_folding_model, save_name):
    """
    Applies folding configuration to fully connected (MVAU_hls) and sliding window (ConvolutionInputGenerator_rtl)
    layers in the model and saves the resulting model.

    Parameters:
        input_folding_model (ModelWrapper): The specialized model to transform.
        save_name (str): Directory to save the transformed model.
    
    Returns:
        ModelWrapper: The transformed and folded dataflow model.
    """
    folding_config = [
    (16, 3, [128]),
    (32, 32, [128]),
    (16, 32, [128]),
    (16, 32, [128]),
    (4, 32, [81]),
    (1, 32, [2]),
    (1, 4, [2]),
    (1, 8, [128]),
    (5, 1, [3]),
    ]
    
    # Apply folding configuration to fully connected layers
    fc_layers = input_folding_model.get_nodes_by_op_type("MVAU_hls")
    for fcl, (pe, simd, ififodepth) in zip(fc_layers, folding_config):
        fcl_inst = getCustomOp(fcl)
        fcl_inst.set_nodeattr("PE", pe)
        fcl_inst.set_nodeattr("SIMD", simd)
        fcl_inst.set_nodeattr("inFIFODepths", ififodepth)
    
    # Apply SIMD values from the folding configuration to sliding window layers
    swg_layers = input_folding_model.get_nodes_by_op_type("ConvolutionInputGenerator_rtl")
    for i in range(len(swg_layers)):
        swg_inst = getCustomOp(swg_layers[i])
        simd = folding_config[i][1]
        swg_inst.set_nodeattr("SIMD", simd)
    
    # Apply unique node names to all nodes and save the transformed model
    input_folding_model = input_folding_model.transform(GiveUniqueNodeNames())
    input_folding_model.save(save_name)

    return input_folding_model


In [None]:
model = folding_transform(model, set_onnx_checkpoint(Project_Info, "Folded Model"))

In [None]:
from finn.transformation.fpgadataflow.make_zynq_proj import ZynqBuild

def zynq_build_transform(input_zynq_model, save_name, brd_name):
    """
    Applies the ZynqBuild transformation to a model for the specified PYNQ board and clock period.

    Parameters:
        input_zynq_model (ModelWrapper): Folded model to transform.
        save_name (str): Directory to save the transformed model.
        brd_name (str): Name of the PYNQ board to target.
    
    Returns:
        ModelWrapper: The transformed model after applying ZynqBuild.
    """
    target_clk_ns = 10
    # Apply ZynqBuild transformation
    input_zynq_model = input_zynq_model.transform(ZynqBuild(platform=brd_name, period_ns=target_clk_ns))
    
    # Save the transformed model
    input_zynq_model.save(save_name)

    return input_zynq_model


In [None]:
from finn.transformation.fpgadataflow.make_pynq_driver import MakePYNQDriver

def pynq_driver_transform(input_driver_model, save_name):
    """
    Applies the MakePYNQDriver transformation to the model to generate a PYNQ-compatible driver.

    Parameters:
        input_driver_model (ModelWrapper): ZynqBuild model to transform.
        save_name (str): Directory to save the transformed model.
    
    Returns:
        ModelWrapper: The transformed model with PYNQ driver compatibility.
    """
    # Apply MakePYNQDriver transformation
    input_driver_model = input_driver_model.transform(MakePYNQDriver("zynq-iodma"))
    
    # Save the transformed model
    input_driver_model.save(save_name)

    return input_driver_model


In [None]:
model = zynq_build_transform(model, set_onnx_checkpoint(Project_Info, "Zynq Build"), Project_Info['Board_name'])
model = pynq_driver_transform(model, set_onnx_checkpoint(Project_Info, "Pynq Driver"))