Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CNN Synthesis fail : ERROR: [XFORM 203-504] Stop unrolling loop 'Product1' (firmware/nnet_utils/nnet_dense_latency.h:37) in function 'nnet::dense_latency<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, config6>' because it may cause large runtime and excessive memory usage due to increase in code size. Please avoid unrolling the loop or form sub-functions for code in the loop body. ERROR: [HLS 200-70] Pre-synthesis failed. #1013

Open
ZWhimsi opened this issue May 14, 2024 · 13 comments
Labels

Comments

@ZWhimsi
Copy link

ZWhimsi commented May 14, 2024

Environment

  • OS: Ubuntu 16.04.07
  • Python: 3.9.12
  • hls4ml: 0.8.1

For all the pip dependencies, refer to the env.txt file.
env.txt

Quick Summary

I attempted to synthesize some CNN models and all attempts failed. After removing layers one by one, I identified Conv2D as the problematic layer due to excessive memory usage and large runtime. MLPs, however, work perfectly. I attempted synthesis with all NN configurations using the following scripts.

Test_sript.txt
example.py.txt

I hope the conversion from .py or .yaml to .txt will not create to many artefacts and bugs

Details

Steps to Reproduce

Warning: All provided files are in .txt format; you will need to change the extension to .py or .yml for the process to work. Also, remember to modify the path to your own file locations.

To replicate the bug, follow these steps:

  1. Clone the hls4ml repository from GitHub.
  2. Checkout the master branch with the specified commit hash.
  3. Install all dependencies from env.txt.
  4. Rename Test_script.txt to Test_script.py and execute it to create the models to be synthesized.
  5. Modify the yaml.txt configuration file to include the CNN model you wish to synthesize.
  6. Run the example.py script to attempt synthesis.

Step 1: Install dependencies from env.txt

pip install -r env.txt

Step 2: Execute Test_script to create the models

python Test_script.py

Step 3: Update yml configuration to include the CNN model

Update this yaml to try the desired NN ( H5 + JSON) :
yaml.txt

Step 4: Run the example script

python example.py

Expected behavior

It should synthetize the CNN.

Actual behavior

(hls4ml_env) ~/Bureau/Stage_2024/Scripts_python$ python example.py
2024-05-14 09:41:33.489814: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-05-14 09:41:34.586838: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
{'Backend': 'Vivado', 'ClockPeriod': 5, 'HLSConfig': {'Model': {'Precision': 'ap_fixed<2,1>', 'ReuseFactor': 1000000, 'Strategy': 'Resources'}}, 'IOType': 'io_parallel', 'KerasJSON': '../models_test/model_simple_cnn_cifar10.json', 'KerasH5': '../models_test/model_simple_cnn_cifar10.h5', 'OutputDir': '../my-hls-test', 'Part': 'xczu9eg-ffvb1156-2-e', 'ProjectName': 'myproject', 'Stamp': 'eAAc27c6', 'Version': '1.0.0'}
{'Backend': 'Vivado', 'ClockPeriod': 5, 'HLSConfig': {'Model': {'Precision': 'ap_fixed<2,1>', 'ReuseFactor': 1000000, 'Strategy': 'Resources'}}, 'IOType': 'io_parallel', 'KerasJSON': '../models_test/model_simple_cnn_cifar10.json', 'KerasH5': '../models_test/model_simple_cnn_cifar10.h5', 'OutputDir': '../my-hls-test', 'Part': 'xczu9eg-ffvb1156-2-e', 'ProjectName': 'myproject', 'Stamp': 'eAAc27c6', 'Version': '1.0.0'}
Interpreting Sequential
Topology:
Layer name: input_layer_6, layer type: InputLayer, input shapes: [None], output shape: [None, 32, 32, 3]
Layer name: conv2d_8, layer type: Conv2D, input shapes: [[None, 32, 32, 3]], output shape: [None, 30, 30, 32]
Layer name: max_pooling2d_8, layer type: MaxPooling2D, input shapes: [[None, 30, 30, 32]], output shape: [None, 15, 15, 32]
Layer name: flatten_6, layer type: Reshape, input shapes: [[None, 15, 15, 32]], output shape: [None, 7200]
Layer name: dense_11, layer type: Dense, input shapes: [[None, 7200]], output shape: [None, 10]
Creating HLS model
WARNING: Layer conv2d_8 requires "dataflow" pipeline style. Switching to "dataflow" pipeline style.
/home/pierric/anaconda3/envs/hls4ml_env/lib/python3.9/site-packages/hls4ml/backends/fpga/passes/fix_softmax_table_size.py:34: UserWarning: Softmax layer dense_11_softmax table size is too large for inputbitwidth 2. Setting table size to 4.To avoid this warning, please increase input bitwidth ordecrease table size.
  warnings.warn(
Writing HLS project
Done

****** Vivado(TM) HLS - High-Level Synthesis from C, C++ and SystemC v2018.3 (64-bit)
  **** SW Build 2405991 on Thu Dec  6 23:36:41 MST 2018
  **** IP Build 2404404 on Fri Dec  7 01:43:56 MST 2018
	** Copyright 1986-2018 Xilinx, Inc. All Rights Reserved.

source /media/pierric/DATA/Viv18.3/Vivado/2018.3/scripts/vivado_hls/hls.tcl -notrace
INFO: [HLS 200-10] Running '/media/pierric/DATA/Viv18.3/Vivado/2018.3/bin/unwrapped/lnx64.o/vivado_hls'
INFO: [HLS 200-10] On os Ubuntu 16.04.7 LTS
INFO: [HLS 200-10] In directory '/home/pierric/Bureau/Stage_2024/my-hls-test'
INFO: [HLS 200-10] Creating and opening project '/home/pierric/Bureau/Stage_2024/my-hls-test/myproject_prj'.
INFO: [HLS 200-10] Adding design file 'firmware/myproject.cpp' to the project
INFO: [HLS 200-10] Adding test bench file 'myproject_test.cpp' to the project
INFO: [HLS 200-10] Adding test bench file 'firmware/weights' to the project
INFO: [HLS 200-10] Adding test bench file 'tb_data' to the project
INFO: [HLS 200-10] Creating and opening solution '/home/pierric/Bureau/Stage_2024/my-hls-test/myproject_prj/solution1'.
INFO: [XFORM 203-101] Allowed max sub elements number after partition is 4096.
INFO: [XFORM 203-1161] The maximum of name length is set into 80.
INFO: [HLS 200-10] Setting target device to 'xczu9eg-ffvb1156-2-e'
INFO: [SYN 201-201] Setting up clock 'default' with a period of 5ns.
INFO: [SYN 201-201] Setting up clock 'default' with an uncertainty of 0.625ns.
***** C SIMULATION *****
INFO: [SIM 211-2] *************** CSIM start ***************
INFO: [SIM 211-4] CSIM will launch GCC as the compiler.
   Compiling ../../../../myproject_test.cpp in debug mode
   Compiling ../../../../firmware/myproject.cpp in debug mode
   Generating csim.exe
INFO: Unable to open input/predictions file, using default input.
-0.5 -0.5 -0.5 -0.5 -0.5 -0.5 -0.5 -0.5 -0.5 -0.5
INFO: Saved inference results to file: tb_data/csim_results.log
INFO: [SIM 211-1] CSim done with 0 errors.
INFO: [SIM 211-3] *************** CSIM finish ***************
***** C SIMULATION COMPLETED IN 0h0m14s *****
***** C/RTL SYNTHESIS *****
INFO: [SCHED 204-61] Option 'relax_ii_for_timing' is enabled, will increase II to preserve clock frequency constraints.
INFO: [HLS 200-10] Analyzing design file 'firmware/myproject.cpp' ...
WARNING: [HLS 200-40] In file included from firmware/myproject.cpp:1:
In file included from firmware/myproject.cpp:4:
In file included from firmware/parameters.h:12:
In file included from firmware/nnet_utils/nnet_conv2d.h:6:
firmware/nnet_utils/nnet_conv2d_resource.h:82:31: warning: comparison of unsigned expression >= 0 is always true [-Wtautological-compare]
            	if (i_acc + 1 >= multscale) {
                	~~~~~~~~~ ^  ~~~~~~~~~
firmware/nnet_utils/nnet_conv2d.h:52:9: note: in instantiation of function template specialization 'nnet::conv_2d_resource_cl<ap_fixed<2, 1, 5, 3, 0>, ap_fixed<2, 1, 5, 3, 0>, config2>' requested here
    	conv_2d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
    	^
firmware/myproject.cpp:37:2: note: in instantiation of function template specialization 'nnet::conv_2d_cl<ap_fixed<2, 1, 5, 3, 0>, ap_fixed<2, 1, 5, 3, 0>, config2>' requested here
 nnet::conv_2d_cl<input_t, layer2_t, config2>(input_layer_6, layer2_out, w2, b2);
 ^
1 warning generated.
WARNING: [HLS 214-113] Either use an argument of the function or declare the variable inside the dataflow loop body: firmware/myproject.cpp:37:74
WARNING: [HLS 200-471] Dataflow form checks found 1 issue(s) in file firmware/myproject.cpp
INFO: [HLS 200-111] Finished Linking Time (s): cpu = 00:01:11 ; elapsed = 00:01:20 . Memory (MB): peak = 435.098 ; gain = 0.125 ; free physical = 3444 ; free virtual = 6466
INFO: [HLS 200-111] Finished Checking Pragmas Time (s): cpu = 00:01:11 ; elapsed = 00:01:20 . Memory (MB): peak = 435.098 ; gain = 0.125 ; free physical = 3444 ; free virtual = 6466
INFO: [HLS 200-10] Starting code transformations ...
INFO: [HLS 200-489] Unrolling loop 'PixelLoop' (firmware/nnet_utils/nnet_conv2d_latency.h:41) in function 'void nnet::conv_2d_latency_cl<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, config2>(FORWARD_REFERENCE*, FORWARD_REFERENCE*, FORWARD_REFERENCE::weight_t*, FORWARD_REFERENCE::bias_t*)' completely with a factor of 1.
INFO: [XFORM 203-603] Inlining function 'nnet::product::mult<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0> >::product' into 'nnet::conv_2d_latency_cl<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, config2>' (firmware/nnet_utils/nnet_conv2d_latency.h:55).
INFO: [XFORM 203-603] Inlining function 'nnet::product::mult<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0> >::product' into 'nnet::dense_latency<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, config6>' (firmware/nnet_utils/nnet_dense_latency.h:42).
INFO: [XFORM 203-603] Inlining function 'nnet::conv_2d_latency_cl<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, config2>' into 'nnet::conv_2d_cl<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, config2>' (firmware/nnet_utils/nnet_conv2d.h:50).
INFO: [XFORM 203-603] Inlining function 'nnet::dense<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, config6>' into 'myproject' (firmware/myproject.cpp:50).
INFO: [XFORM 203-603] Inlining function 'nnet::softmax<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, softmax_config7>' into 'myproject' (firmware/myproject.cpp:52).
INFO: [HLS 200-111] Finished Standard Transforms Time (s): cpu = 00:04:36 ; elapsed = 00:04:45 . Memory (MB): peak = 962.977 ; gain = 528.004 ; free physical = 2970 ; free virtual = 6012
INFO: [HLS 200-10] Checking synthesizability ...
INFO: [XFORM 203-602] Inlining function 'nnet::cast<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, config2_mult>' into 'nnet::conv_2d_cl<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, config2>' (firmware/nnet_utils/nnet_conv2d_latency.h:82->firmware/nnet_utils/nnet_conv2d.h:50) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::max<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, 4>' into 'nnet::pool_op<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, 4, (nnet::Pool_Op)0>' (firmware/nnet_utils/nnet_pooling.h:57) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::pool_op_limit<config4>' into 'nnet::pooling2d_cl<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, config4>' (firmware/nnet_utils/nnet_pooling.h:205) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::pad_val<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, (nnet::Pool_Op)0>' into 'nnet::pooling2d_cl<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, config4>' (firmware/nnet_utils/nnet_pooling.h:231) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::pool_op<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, 4, (nnet::Pool_Op)0>' into 'nnet::pooling2d_cl<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, config4>' (firmware/nnet_utils/nnet_pooling.h:247) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::cast<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, config6>' into 'nnet::dense_latency<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, config6>' (firmware/nnet_utils/nnet_dense_latency.h:66) automatically.
INFO: [XFORM 203-602] Inlining function 'std::exp' into 'nnet::exp_fcn_float' (firmware/nnet_utils/nnet_activation.h:131) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::exp_fcn_float' into 'nnet::init_exp_table<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, softmax_config7>' (firmware/nnet_utils/nnet_activation.h:154) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::Op_max<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0> >::operator()' into 'nnet::reduce<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, 2, nnet::Op_max<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0> > >' (firmware/nnet_utils/nnet_common.h:43) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::reduce<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, 2, nnet::Op_max<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0> > >' into 'nnet::reduce<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, 4, nnet::Op_max<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0> > >' (firmware/nnet_utils/nnet_common.h:45) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::Op_max<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0> >::operator()' into 'nnet::reduce<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, 4, nnet::Op_max<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0> > >' (firmware/nnet_utils/nnet_common.h:45) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::Op_max<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0> >::operator()' into 'nnet::reduce<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, 8, nnet::Op_max<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0> > >' (firmware/nnet_utils/nnet_common.h:45) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::reduce<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, 8, nnet::Op_max<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0> > >' into 'nnet::reduce<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, 10, nnet::Op_max<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0> > >' (firmware/nnet_utils/nnet_common.h:45) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::reduce<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, 2, nnet::Op_max<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0> > >' into 'nnet::reduce<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, 10, nnet::Op_max<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0> > >' (firmware/nnet_utils/nnet_common.h:45) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::Op_max<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0> >::operator()' into 'nnet::reduce<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, 10, nnet::Op_max<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0> > >' (firmware/nnet_utils/nnet_common.h:45) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::reduce<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0>, 2, nnet::Op_add<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0> > >' into 'nnet::reduce<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0>, 4, nnet::Op_add<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0> > >' (firmware/nnet_utils/nnet_common.h:45) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::reduce<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0>, 4, nnet::Op_add<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0> > >' into 'nnet::reduce<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0>, 8, nnet::Op_add<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0> > >' (firmware/nnet_utils/nnet_common.h:45) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::reduce<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0>, 8, nnet::Op_add<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0> > >' into 'nnet::reduce<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0>, 10, nnet::Op_add<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0> > >' (firmware/nnet_utils/nnet_common.h:45) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::reduce<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0>, 2, nnet::Op_add<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0> > >' into 'nnet::reduce<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0>, 10, nnet::Op_add<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0> > >' (firmware/nnet_utils/nnet_common.h:45) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::reduce<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, 10, nnet::Op_max<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0> > >' into 'nnet::softmax_stable<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, softmax_config7>' (firmware/nnet_utils/nnet_activation.h:239) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::softmax_idx_from_real_val<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, softmax_config7>' into 'nnet::softmax_stable<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, softmax_config7>' (firmware/nnet_utils/nnet_activation.h:254) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::reduce<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0>, 10, nnet::Op_add<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0> > >' into 'nnet::softmax_stable<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, softmax_config7>' (firmware/nnet_utils/nnet_activation.h:262) automatically.
INFO: [XFORM 203-602] Inlining function 'nnet::softmax_idx_from_real_val<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0>, softmax_config7>' into 'nnet::softmax_stable<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, softmax_config7>' (firmware/nnet_utils/nnet_activation.h:265) automatically.
WARNING: [SYNCHK 200-23] firmware/nnet_utils/nnet_activation.h:138: variable-indexed range selection may cause suboptimal QoR.
INFO: [SYNCHK 200-10] 0 error(s), 1 warning(s).
INFO: [HLS 200-111] Finished Checking Synthesizability Time (s): cpu = 00:14:59 ; elapsed = 00:15:08 . Memory (MB): peak = 1010.977 ; gain = 576.004 ; free physical = 2927 ; free virtual = 5989
INFO: [XFORM 203-502] Unrolling all loops for pipelining in function 'nnet::softmax_stable<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, softmax_config7>' (firmware/nnet_utils/nnet_activation.h:217:46).
INFO: [XFORM 203-502] Unrolling all loops for pipelining in function 'nnet::init_invert_table<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0>, softmax_config7>' (firmware/nnet_utils/nnet_activation.h:162:53).
INFO: [XFORM 203-502] Unrolling all loops for pipelining in function 'nnet::init_exp_table<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, softmax_config7>' (firmware/nnet_utils/nnet_activation.h:151:48).
INFO: [XFORM 203-502] Unrolling all loops for pipelining in function 'nnet::dense_latency<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, config6>' (firmware/nnet_utils/nnet_dense_latency.h:17:48).
INFO: [XFORM 203-502] Unrolling all loops for pipelining in function 'nnet::pooling2d_cl<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, config4>' (firmware/nnet_utils/nnet_pooling.h:202:29).
INFO: [XFORM 203-502] Unrolling all loops for pipelining in function 'nnet::max<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, 4>' (firmware/nnet_utils/nnet_pooling.h:10:28).
INFO: [XFORM 203-502] Unrolling all loops for pipelining in function 'nnet::relu<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, relu_config3>' (firmware/nnet_utils/nnet_activation.h:40:43).
INFO: [XFORM 203-502] Unrolling all sub-loops inside loop 'PartitionLoop' (firmware/nnet_utils/nnet_conv2d_latency.h:35) in function 'nnet::conv_2d_cl<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, config2>' for pipelining.
INFO: [HLS 200-489] Unrolling loop 'Loop-1' (firmware/nnet_utils/nnet_activation.h:243) in function 'nnet::softmax_stable<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, softmax_config7>' completely with a factor of 10.
INFO: [HLS 200-489] Unrolling loop 'Loop-2' (firmware/nnet_utils/nnet_activation.h:252) in function 'nnet::softmax_stable<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, softmax_config7>' completely with a factor of 10.
INFO: [HLS 200-489] Unrolling loop 'Loop-3' (firmware/nnet_utils/nnet_activation.h:266) in function 'nnet::softmax_stable<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, softmax_config7>' completely with a factor of 10.
INFO: [HLS 200-489] Unrolling loop 'Loop-1' (firmware/nnet_utils/nnet_activation.h:162) in function 'nnet::init_invert_table<ap_fixed<18, 8, (ap_q_mode)0, (ap_o_mode)0, 0>, softmax_config7>' completely with a factor of 4.
INFO: [HLS 200-489] Unrolling loop 'Loop-1' (firmware/nnet_utils/nnet_activation.h:151) in function 'nnet::init_exp_table<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, softmax_config7>' completely with a factor of 4.
INFO: [HLS 200-489] Unrolling loop 'Product1' (firmware/nnet_utils/nnet_dense_latency.h:37) in function 'nnet::dense_latency<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, config6>' completely with a factor of 7200.
ERROR: [XFORM 203-504] Stop unrolling loop 'Product1' (firmware/nnet_utils/nnet_dense_latency.h:37) in function 'nnet::dense_latency<ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, ap_fixed<2, 1, (ap_q_mode)5, (ap_o_mode)3, 0>, config6>' because it may cause large runtime and excessive memory usage due to increase in code size. Please avoid unrolling the loop or form sub-functions for code in the loop body.
ERROR: [HLS 200-70] Pre-synthesis failed.
command 'ap_source' returned error code
	while executing
"source [lindex $::argv 1] "
	("uplevel" body line 1)
	invoked from within
"uplevel \#0 { source [lindex $::argv 1] } "

INFO: [Common 17-206] Exiting vivado_hls at Tue May 14 10:25:02 2024...
CSynthesis report not found.
Vivado synthesis report not found.
Cosim report not found.
Timing report not found.
Found 1 solution(s) in ../my-hls-test/myproject_prj.
Reports for solution "solution1":

C SIMULATION RESULT:
INFO: [SIM 2] *************** CSIM start ***************
INFO: [SIM 4] CSIM will launch GCC as the compiler.
   Compiling ../../../../myproject_test.cpp in debug mode
   Compiling ../../../../firmware/myproject.cpp in debug mode
   Generating csim.exe
INFO: Unable to open input/predictions file, using default input.
-0.5 -0.5 -0.5 -0.5 -0.5 -0.5 -0.5 -0.5 -0.5 -0.5
INFO: Saved inference results to file: tb_data/csim_results.log
INFO: [SIM 1] CSim done with 0 errors.
INFO: [SIM 3] *************** CSIM finish ***************

Synthesis report not found.
Co-simulation report not found.

Possible fix

I saw someone talking about hls4ml version 0.3.0 who claimed to solve a similar problem but didn't try it because i expect this repository to be stable.

@ZWhimsi ZWhimsi added the bug label May 14, 2024
@bo3z
Copy link
Contributor

bo3z commented May 14, 2024

This is not a bug - when loops are unrolled, operations are parallelised. However, there is a limit to how much can be executed in parallel (due to available resources, critical path and sometimes compiler issues with scheduling).

To address this problem, use the Resource strategy and consider increasing to reuse factor (corresponding to the less parallelism). For more details on the reuse factor see the tutorial: https://github.com/fastmachinelearning/hls4ml-tutorial/blob/main/part2_advanced_config.ipynb.

In general with Vivado HLS (hls4ml's Vivado backend), the limit is 4,096 so the most the reuse factor should be at least that number_of_layer_multiplication / 4,096. number_of_layer_multiplications typically corresponds to the total number of weights.

Alternatively, you could consider also using Vitis HLS which should have better parallelism behaviour (but fully unrolling is still unlikely)

@ZWhimsi
Copy link
Author

ZWhimsi commented May 14, 2024

Hi, thank you for your response.

However, I encountered is this issue with a reuse factor of 1,000,000, using the strategy "Resources" and an approximately 10,000-parameter CNN. If these settings are not viable, what can I do to resolve this issue?

@bo3z
Copy link
Contributor

bo3z commented May 14, 2024

Make sure to use io_stream as the io_type and inspect the config YAML and parameters.h/defines.h files to ensure that the properties you set (reuse factor, strategy) propagated and were set. A 10,000 parameter model should work.

@bo3z
Copy link
Contributor

bo3z commented May 14, 2024

Furthermore, there is a typo - the strategy is Resource not Resources :)

@ZWhimsi
Copy link
Author

ZWhimsi commented May 14, 2024

Yes, my bad. I corrected it. Is there a way to estimate the complexity of the process for a network to avoid attempting 'Latency' mode when it is clearly infeasible? A 10K-parameter CNN is relatively small, and I'd like to understand what kind of network would require caution when unrolling and optimizing for performance.

@bo3z
Copy link
Contributor

bo3z commented May 14, 2024

The HLS will fail automatically when there are more than 4,096 unrolls (i.e. parallel multiplications). This parameter is hard-coded and generally not exposed in hls4ml. It was determined heuristically - the conclusion was that after 4,096 the HLS compiler has issue with scheduling and completing the compilation. Now you can modify this variable in build_project.tcl. There are cases where 8,192 or 16,384 could compile (but take longer).

So as long as every layer has less multiplications than the threshold, you will not get this error. Looking at the error you posted, the last dense layer has 7,200 parameters. In this case you can either:

  1. Try increasing the factor from 4,096 to 8,192 in build_project.tcl OR
  2. Keep Latency for all layers except for the dense layer, dense_11, where you can use Resource and set the reuse factor to 2 for example (this should already be okay, but keep in mind the number of DSPs is limited, if using these). For more information on variable tuning of layers see: https://github.com/fastmachinelearning/hls4ml-tutorial/blob/main/part7a_bitstream.ipynb. You can modify the strategy, precision and reuse factor for each layer individually.

@ZWhimsi
Copy link
Author

ZWhimsi commented May 14, 2024

Thanks for your help !

@ZWhimsi
Copy link
Author

ZWhimsi commented May 14, 2024

The reuse factor is capped at the number of parameters for each layer, correct? I have valid reuse factors ranging from 1, 2, ..., to 18432 for a conv2D layer (15,15,32) in (13,13,64) out. However, why is 1 considered a valid reuse factor if the number of parameters in the layer divided by the reuse factor is greater than 4,096?

@ZWhimsi
Copy link
Author

ZWhimsi commented May 14, 2024

Additionally, I'm curious why even with a large reuse factor like 1000000, the synthesis fails when in "Latency" mode. Wouldn't it be possible to select the maximum reuse factor for each layer instead?"

@bo3z
Copy link
Contributor

bo3z commented May 14, 2024

The reuse factor is capped at the number of parameters for each layer, correct? I have valid reuse factors ranging from 1, 2, ..., to 18432 for a conv2D layer (15,15,32) in (13,13,64) out. However, why is 1 considered a valid reuse factor if the number of parameters in the layer divided by the reuse factor is greater than 4,096?

'valid' means it has to be divide the number of parameters/multiplications in the layer - this is to ensure QoR (otherwise the last clock cycle will have less multiplications than others and this will lead to unequal load, unused hardware, scheduling overhead etc. The idea is that in every clock cycle you do the same number of multiplications as total_layer_mult / layer_reuse_factor).

On the other hand, 4,096 is a Vivado HLS concept (not hls4ml, so if you use a different backend, e.g. Quartus targeting Intel boards it might have a different limit). So when you unroll more than 4,096 it is still valid it just means the HLS compiler (not hls4ml) has issues scheduling such high parallelism.

Additionally, I'm curious why even with a large reuse factor like 1000000, the synthesis fails when in "Latency" mode. Wouldn't it be possible to select the maximum reuse factor for each layer instead?"

Latency has a different implementation. All the loops are unrolled but the multiplications are limited using another pragma. Therefore, the design will be scheduled such that it completes in reuse_factor clock cycles (plus some constant offset), but the loops are treated as unrolled in the HLS code.

@ZWhimsi
Copy link
Author

ZWhimsi commented May 14, 2024

Thank you! Can you provide more details on how the Latency mode works? I read the paper and the documentation but didn't find the information I was looking for. Is there a place where I can find answers without having to delve deeply into the code?

@bo3z
Copy link
Contributor

bo3z commented May 14, 2024

Unfortunately, there is no more documentation than the tutorials, website and papers.

Please have a look here: https://github.com/fastmachinelearning/hls4ml/blob/main/hls4ml/templates/vivado/nnet_utils/nnet_dense_latency.h. The code is short and the most important part of it is the pragma: pragma HLS allocation (https://docs.amd.com/r/en-US/ug1399-vitis-hls/pragma-HLS-allocation).

In the Latency strategy, all model weights are stored in registers and all the loops (there are only 2) are fully unrolled. However, the implementation still takes more than one clock cycle, determined by the reuse factor. Typically the latency is reuse_factor clock cycles + some overhead (invoking the layer, adding bias, typically 2 clock cycles)

In this case, the number of parallel multiplications is limited through the pragma HLS allocation, so that at any time, at most n_in * n_out / reuse_factor multiplications occur in parallel. However, since the loops are unrolled and the weight are in registers, the HLS compiler will aim to balance expressions, to minimise latency/resources. Therefore, there is no guarantee of the order in which the weights will be processed.

Importantly: Both convolutional and recurrent layers are implemented using (require) matrix multiplication, so they build on top of the dense layers (which are essentially matrix multiplication)

@ZWhimsi ZWhimsi closed this as completed May 15, 2024
@ZWhimsi
Copy link
Author

ZWhimsi commented May 16, 2024

I am working on a script designed to forecast FPGA resource usage (specifically DSPs and BRAMs) for HLS4ML. The goal is to estimate resource usage and throughput without synthesizing every model in HLS. My script also checks each layer to see if the number of parameters and multiply-accumulate operations (MACCs) divided by the reuse factor exceeds 4096, issuing a warning if this threshold is surpassed. @vloncar @bo3z I know you guys have already answered similars questions if you have time to look at this, it would be pretty awesome :)

However, when I run this script on CNNs with 10k, 100k, and 1M parameters that fail to synthesize due to extensive unrolling, I do not receive any warnings. The only successful synthesis method I've found is using IO streams and resource constraints.

Can you help me understand what I might be missing in my script that causes it to miss these warnings? Additionally, I would appreciate any suggestions or improvements to enhance my script.

import argparse
import json
import numpy as np
import torch
from tensorflow.keras.models import model_from_json, load_model
from tensorflow.keras.layers import Dense, Conv2D, Flatten, LSTM, BatchNormalization, GRU, Activation, MaxPooling2D
import math
from tabulate import tabulate
from colorama import Fore, Style

def format_table_with_colors(data):
    """
    Formats and colors the data table for better visibility when displayed.
    Uses Colorama library colors to highlight keys.
    """
    formatted_data = []
    for key, value in data.items():
        formatted_data.append([Fore.GREEN + key + Style.RESET_ALL, value])
    return formatted_data

# Configuration dictionary to handle all global settings
config = {
    "freq_mhz": 250,
    "DSP_max": 2520,
    "DSP_used": 0,
    "reuse_factor": 2,
    "mem_bandwidth_gbps": 10,
    "BRAM_capacity_bits": 1024 * 36,
    "BRAM_used": 0,
    "BRAM_max": 32.1 * 10**6,
    "weight_bits": 32,
    "info_layer": {},
    "throughput": float('inf'),
    "batch_size": 32,
    "io_type": "stream"
}

def estimations_keras(model):
    """
    Estimates FPGA resource needs and performance metrics for each layer in a Keras model.
    Calculates DSP and BRAM requirements and adjusts throughput based on resource utilization.
    """
    total_macc = sum(macc_per_layer_keras(layer) for layer in model.layers)
    for layer in model.layers:
        if config["DSP_used"] < config["DSP_max"]:
            DSP_needed = math.ceil(config["info_layer"][layer.name]['macc'] / config["reuse_factor"])
            config["DSP_used"] += DSP_needed
            if config["info_layer"][layer.name]['macc'] != 0:
                new_throughput = math.floor(DSP_needed * config["freq_mhz"] * 10**6 / config["info_layer"][layer.name]['macc'])
                if config["throughput"] > new_throughput:
                    config["throughput"] = new_throughput

        if config["BRAM_used"] < config["BRAM_max"] / config["BRAM_capacity_bits"]:
            config["BRAM_used"] += math.ceil(config["info_layer"][layer.name]['params'] * config["weight_bits"] / config["BRAM_capacity_bits"])
            if config["io_type"] == "stream":
                config["BRAM_used"] += math.ceil(np.prod(layer.output.shape[1:3]) * config["weight_bits"] / (config["BRAM_capacity_bits"] * config["reuse_factor"]))

    formatted_table = format_table_with_colors({'Throughput': config["throughput"], 'DSP': config["DSP_used"], 'BRAM': config["BRAM_used"]})
    print(tabulate(formatted_table))

def macc_per_layer_keras(layer):
    """
    Calculates the multiply-accumulate operations (MACCs) for each type of layer in a Keras model.
    """
    if isinstance(layer, (Conv2D, Dense, LSTM, GRU, BatchNormalization)):
        return layer_specific_maccs(layer)
    elif isinstance(layer, MaxPooling2D):
        config["info_layer"][layer.name] = {'macc': 0, 'params': 0}
        return 0
    elif isinstance(layer, Activation) or "activation" in layer.__class__.__name__.lower():
        return estimate_activation_maccs(layer)
    elif isinstance(layer, Flatten):
        config["info_layer"][layer.name] = {'macc': 0, 'params': 0}
        return 0
    return 0

def layer_specific_maccs(layer):
    """
    Detail specific calculations per layer type, including Conv2D, Dense, LSTM, GRU, BatchNormalization.
    """
    output_area = np.prod(layer.output.shape[1:3])
    if isinstance(layer, Conv2D):
        macc_per_output = np.prod(layer.kernel_size) * layer.filters
    elif isinstance(layer, Dense):
        macc_per_output = layer.units
    elif isinstance(layer, LSTM) or isinstance(layer, GRU):
        input_dims = layer.input.shape[-1]
        num_units = layer.units
        macc_per_output = 4 * num_units if isinstance(layer, LSTM) else 3 * num_units
        macc_per_output *= (input_dims + num_units + 1)  # +1 for bias

    elif isinstance(layer, BatchNormalization):
        return 2 * output_area  # Normalize and scale

    total_macc = output_area * macc_per_output
    if layer.use_bias:
        total_macc += output_area
    config["info_layer"][layer.name] = {'macc': total_macc, 'params': layer.count_params()}
    return total_macc

def estimate_activation_maccs(layer):
    """
    Estimates the MACCs for activation functions based on their complexity.
    """
    output_elements = np.prod(layer.output.shape[1:])
    if layer.activation.__name__ in ['softmax', 'sigmoid', 'tanh']:
        return 5 * output_elements  # Approximation for expensive operations
    elif layer.activation.__name__ == 'relu':
        return 0  # Simple max(0, x)
    return 0  # Direct computation for other activations

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Estimate performance metrics for a Keras or PyTorch model on FPGA")
    parser.add_argument("--model_file", help="Path to the model file (JSON + H5)")
    parser.add_argument("--framework", choices=["keras", "pytorch"], default="keras", help="Choose the framework (Keras or PyTorch)")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size for processing")
    parser.add_argument("--quantization", type=int, default=32, help="Number of bits needed to represent a weight")
    parser.add_argument("--reuse_factor", type=int, default=2, help="Number of DSP reuses")
    parser.add_argument("--frequency", type=int, default=250, help="Operating frequency in MHz")
    parser.add_argument("--io_type", default="stream", help="IO type: stream or parallel")
    args = parser.parse_args()

    # Example of how to load and use the model depending on the framework
    if args.framework == "keras":
        model = model_from_json(open(args.model_file + ".json").read())
        try:
            model.load_weights(args.model_file + ".h5")
        except OSError:
            model.load_weights(args.model_file + "_weights.h5")
        model.summary()
        config.update({
            "batch_size": args.batch_size,
            "reuse_factor": args.reuse_factor,
            "freq_mhz": args.frequency,
            "weight_bits": args.quantization,
            "io_type": args.io_type
        })
        estimations_keras(model)
    else:
        # Placeholder for PyTorch functionality
        print("PyTorch estimation functionality not implemented.")

@ZWhimsi ZWhimsi reopened this May 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants