Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ERROR: [XFORM 203-504] Stop unrolling loop 'MultLoop' #904

Open
behnamarefy opened this issue Oct 31, 2023 · 8 comments
Open

ERROR: [XFORM 203-504] Stop unrolling loop 'MultLoop' #904

behnamarefy opened this issue Oct 31, 2023 · 8 comments
Labels

Comments

@behnamarefy
Copy link

hello i'm working on a cnn network and i tried to implement it using hls4ml0.7.1 and vivado2019.2 but by trying different reuse factor i get the this error :

ERROR: [XFORM 203-504] Stop unrolling loop 'MultLoop' (firmware/nnet_utils/nnet_dense_resource.h:52) in function 'nnet::conv_2d_cl<nnet::array<ap_fixed<22, 6, (ap_q_mode)5, (ap_o_mode)3, 0>, 256u>, nnet::array<ap_fixed<22, 6, (ap_q_mode)5, (ap_o_mode)3, 0>, 128u>, config7>' because it may cause large runtime and excessive memory usage due to increase in code size. Please avoid unrolling the loop or form sub-functions for code in the loop body.

here is my code :

nb_epoch = 200     # number of epochs to train on
batch_size = 1024  # training batch size
input_shape=[2,128]

filters_per_conv_layer = [256, 128, 64 ]
neurons_per_dense_layer = [ 64]

x = x_in = Input(input_shape + [1])

for i, f in enumerate(filters_per_conv_layer):
    print(('Adding convolutional block {} with N={} filters').format(i, f))
    x = Conv2D(
        int(f),
        kernel_size=(2, 8),
        strides=(1, 1),
        padding='same',
        kernel_initializer='lecun_uniform',
        # kernel_regularizer=l1(0.0001),
        # use_bias=False,
        name='conv_{}'.format(i),
    )(x)
    x = BatchNormalization(name='bn_conv_{}'.format(i))(x)
    x = Activation('relu', name='conv_act_%i' % i)(x)
    x = MaxPooling2D(pool_size=(1, 2), name='pool_{}'.format(i))(x)
x = Flatten()(x)

for i, n in enumerate(neurons_per_dense_layer):
    print(('Adding dense block {} with N={} neurons').format(i, n))
    x = Dense(n, kernel_initializer='lecun_uniform',  name='dense_%i' % i)(x)
    x = BatchNormalization(name='bn_dense_{}'.format(i))(x)
    x = Activation('relu', name='dense_act_%i' % i)(x)
x = Dense(11, name='output_dense')(x)
x_out = Activation('softmax', name='output_softmax')(x)
    

model = Model(inputs = x_in,outputs = x_out)
model.compile(loss='categorical_crossentropy', metrics=['accuracy'], optimizer='adam')

model.summary()

i tried to pruned this network and i load the pruned network here is the whole code

import matplotlib.pyplot as plt
import numpy as np
import time
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from tensorflow_model_optimization.sparsity.keras import strip_pruning
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper

from qkeras.utils import _add_supported_quantized_objects
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.regularizers import l1
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import ZeroPadding2D
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import MaxPool2D
import os

from tensorflow.keras.models import Model
os.environ['PATH'] = '/tools/Xilinx/Vivado/2019.2/bin:' + os.environ['PATH']
co = {}
_add_supported_quantized_objects(co)
co['PruneLowMagnitude'] = pruning_wrapper.PruneLowMagnitude

model_pruned = tf.keras.models.load_model('pruned_behnam2.h5', custom_objects=co)
model_pruned = strip_pruning(model_pruned)
LOSS = tf.keras.losses.CategoricalCrossentropy()
OPTIMIZER = tf.keras.optimizers.Adam(learning_rate=3e-3, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=True)

model_pruned.compile(loss=LOSS, optimizer=OPTIMIZER, metrics=["accuracy"])

model_keras = model = tf.keras.models.load_model('behnam2.h5')
model_keras.compile(loss='categorical_crossentropy', metrics=['accuracy'], optimizer='adam')

model_keras.summary()
model_pruned.summary()
for layer in model_pruned.layers:
    if layer.__class__.__name__ in ['Conv2D', 'Dense']:
        w = layer.get_weights()[0]
        layersize = np.prod(w.shape)
        print("{}: {}".format(layer.name, layersize))  # 0 = weights, 1 = biases
        if layersize > 4096:  # assuming that shape[0] is batch, i.e., 'None'
            print("Layer {} is too large ({}), are you sure you want to train?".format(layer.name, layersize))

conv_0: 4096
conv_1: 524288
Layer conv_1 is too large (524288), are you sure you want to train?
conv_2: 131072
Layer conv_2 is too large (131072), are you sure you want to train?
dense_0: 131072
Layer dense_0 is too large (131072), are you sure you want to train?
output_dense: 704

import hls4ml
import plotting

# First, the baseline model
hls_config = hls4ml.utils.config_from_keras_model(model_pruned, granularity='name')

# Set the precision and reuse factor for the full model
hls_config['Model']['Precision'] = 'ap_fixed<22,6>'
hls_config['Model']['ReuseFactor'] = 4
hls_config['Model']['Strategy'] = 'resource'
# Create an entry for each layer, here you can for instance change the strategy for a layer to 'resource'
# or increase the reuse factor individually for large layers.
# In this case, we designed the model to be small enough for a fully parallel implementation
# so we use the latency strategy and reuse factor of 1 for all layers.
for Layer in hls_config['LayerName'].keys():
    hls_config['LayerName'][Layer]['Strategy'] = 'resource'
    hls_config['LayerName'][Layer]['ReuseFactor'] = 4
    hls_config['LayerName'][Layer]['Precision'] = 'ap_fixed<22,6>'
# If you want best numerical performance for high-accuray models, while the default latency strategy is faster but numerically more unstable
hls_config['LayerName']['output_softmax']['Strategy'] = 'Stable'
hls_config['LayerName']['dense_0']['ReuseFactor'] = 64
plotting.print_dict(hls_config)

cfg = hls4ml.converters.create_config(backend='Vivado')
cfg['IOType'] = 'io_stream'  # Must set this if using CNNs!
cfg['HLSConfig'] = hls_config
cfg['KerasModel'] = model
cfg['OutputDir'] = 'pruned_cnn/'
cfg['XilinxPart'] = 'xcu250-figd2104-2L-e'

hls_model = hls4ml.converters.keras_to_hls(cfg)
hls_model.compile()

hls_model.build(csim=False, synth=True, vsynth=False)


INFO: [HLS 200-489] Unrolling loop 'InitAccum' (firmware/nnet_utils/nnet_dense_resource.h:37) in function 'nnet::conv_2d_cl<nnet::array<ap_fixed<22, 6, (ap_q_mode)5, (ap_o_mode)3, 0>, 256u>, nnet::array<ap_fixed<22, 6, (ap_q_mode)5, (ap_o_mode)3, 0>, 128u>, config7>' completely with a factor of 128.
INFO: [HLS 200-489] Unrolling loop 'MultLoop' (firmware/nnet_utils/nnet_dense_resource.h:52) in function 'nnet::conv_2d_cl<nnet::array<ap_fixed<22, 6, (ap_q_mode)5, (ap_o_mode)3, 0>, 256u>, nnet::array<ap_fixed<22, 6, (ap_q_mode)5, (ap_o_mode)3, 0>, 128u>, config7>' completely with a factor of 131072.
ERROR: [XFORM 203-504] Stop unrolling loop 'MultLoop' (firmware/nnet_utils/nnet_dense_resource.h:52) in function 'nnet::conv_2d_cl<nnet::array<ap_fixed<22, 6, (ap_q_mode)5, (ap_o_mode)3, 0>, 256u>, nnet::array<ap_fixed<22, 6, (ap_q_mode)5, (ap_o_mode)3, 0>, 128u>, config7>' because it may cause large runtime and excessive memory usage due to increase in code size. Please avoid unrolling the loop or form sub-functions for code in the loop body.
ERROR: [HLS 200-70] Pre-synthesis failed.
command 'ap_source' returned error code
    while executing
"source build_prj.tcl"
    ("uplevel" body line 1)
    invoked from within
"uplevel \#0 [list source $arg] "

INFO: [Common 17-206] Exiting vivado_hls at Mon Oct 30 13:35:42 2023...
CSynthesis report not found.
Vivado synthesis report not found.
Cosim report not found.
Timing report not found.
@calad0i
Copy link
Contributor

calad0i commented Oct 31, 2023

Vivado has an unroll limit of 4096 hardcoded that cannot be circumvented easily. The warning message (Layer ... is too large (524288), are you sure you want to train?) suggests that your layer is too large to fit on the chip (or passing Vivado check).
If your model is very sparse, you may try using Vitis and see if everything works (not much experience with that). If not, maybe try shrinking the layer sizes, or using the resource strategy.

@behnamarefy
Copy link
Author

thanks for your answer @calad0i as you mentioned and i read the documentation of hls4ml and other answers i found that i should use the resource strategy and and some higher reuse factors so i tried reuse factor of 1024 for all layers and reuse factor of 704 for out put dense as it's the maximum reuse factor for this layer and i passed that error but i face new one :

WARNING: [XFORM 203-104] Completely partitioning array 'data.V' (firmware/nnet_utils/nnet_dense_stream.h:29) accessed through non-constant indices on dimension 1 (firmware/nnet_utils/nnet_dense_resource.h:139:17), which may result in long runtime and suboptimal QoR due to large multiplexers. Please consider wrapping the array access into a function or using a register file core instead.
WARNING: [XFORM 203-104] Completely partitioning array 'data.V' (firmware/nnet_utils/nnet_dense_stream.h:29) accessed through non-constant indices on dimension 1 (firmware/nnet_utils/nnet_dense_resource.h:56:17), which may result in long runtime and suboptimal QoR due to large multiplexers. Please consider wrapping the array access into a function or using a register file core instead.
/tools/Xilinx/Vivado/2019.2/bin/rdiArgs.sh: line 280: 2913665 Killed                  "$RDI_PROG" "$@"
CSynthesis report not found.
Vivado synthesis report not found.
Cosim report not found.
Timing report not found

and also i tried reuse factor of 128 for all layers and 704 for out put dense and i faced this error

error:  [XFORM 203-133] Bitwidth of reshaped elements (90122 bits) exceeds the maximum Bitwidth (65536) for array 'w7.v'

its very important for me to solve this problem and i hope if you could help me guys @vloncar @jmduarte @thesps

@vloncar
Copy link
Contributor

vloncar commented Nov 1, 2023

its very important for me to solve this problem

Is this homework assignment? 😄

Your model is too big, you need to reduce the number of filters, a lot.

There may be something else going on since that loop shouldn't be unrolled, but even with that resolved I would not expect this model to work.

@behnamarefy
Copy link
Author

no it's a part of my project :) :) @vloncar
i tried to decrease the model parameters and i used conv1d instead of conv2d and here is the new network and the parameters of model have been nearly half of previous do you think still it can't be implemented ?

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 128, 2)]          0         
                                                                 
 conv_0 (Conv1D)             (None, 128, 256)          4352      
                                                                 
 bn_conv_0 (BatchNormalizati  (None, 128, 256)         1024      
 on)                                                             
                                                                 
 conv_act_0 (Activation)     (None, 128, 256)          0         
                                                                 
 pool_0 (MaxPooling1D)       (None, 64, 256)           0         
                                                                 
 conv_1 (Conv1D)             (None, 64, 128)           262272    
                                                                 
 bn_conv_1 (BatchNormalizati  (None, 64, 128)          512       
 on)                                                             
                                                                 
 conv_act_1 (Activation)     (None, 64, 128)           0         
                                                                 
 pool_1 (MaxPooling1D)       (None, 32, 128)           0         
                                                                 
 conv_2 (Conv1D)             (None, 32, 64)            65600     
                                                                 
 bn_conv_2 (BatchNormalizati  (None, 32, 64)           256       
 on)                                                             
                                                                 
 conv_act_2 (Activation)     (None, 32, 64)            0         
                                                                 
 pool_2 (MaxPooling1D)       (None, 16, 64)            0         
                                                                 
 flatten (Flatten)           (None, 1024)              0         
                                                                 
 dense_0 (Dense)             (None, 128)               131200    
                                                                 
 bn_dense_0 (BatchNormalizat  (None, 128)              512       
 ion)                                                            
                                                                 
 dense_act_0 (Activation)    (None, 128)               0         
                                                                 
 output_dense (Dense)        (None, 11)                1419      
                                                                 
 output_softmax (Activation)  (None, 11)               0         
                                                                 
=================================================================
Total params: 467,147
Trainable params: 465,995
Non-trainable params: 1,152
or layer in model_pruned.layers:
    if layer.__class__.__name__ in ['Conv1D', 'Dense']:
        w = layer.get_weights()[0]
        layersize = np.prod(w.shape)
        print("{}: {}".format(layer.name, layersize))  # 0 = weights, 1 = biases
        if layersize > 4096:  # assuming that shape[0] is batch, i.e., 'None'
            print("Layer {} is too large ({}), are you sure you want to train?".format(layer.name, layersize))
conv_0: 4096
conv_1: 262144
Layer conv_1 is too large (262144), are you sure you want to train?
conv_2: 65536
Layer conv_2 is too large (65536), are you sure you want to train?
dense_0: 131072
Layer dense_0 is too large (131072), are you sure you want to train?
output_dense: 1408

@vloncar
Copy link
Contributor

vloncar commented Nov 1, 2023

Try with 10-50k parameters, not half a million. All weights will be stored on chip, so you can't really go large

@behnamarefy
Copy link
Author

thank you so much @vloncar i changed the model architecture and as you said i tried to reduce model parameters to less than 50K here is my new model:

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 128, 2)]          0         
                                                                 
 conv_0 (Conv1D)             (None, 128, 128)          896       
                                                                 
 bn_conv_0 (BatchNormalizati  (None, 128, 128)         512       
 on)                                                             
                                                                 
 conv_act_0 (Activation)     (None, 128, 128)          0         
                                                                 
 pool_0 (MaxPooling1D)       (None, 64, 128)           0         
                                                                 
 conv_1 (Conv1D)             (None, 64, 64)            24640     
                                                                 
 bn_conv_1 (BatchNormalizati  (None, 64, 64)           256       
 on)                                                             
                                                                 
 conv_act_1 (Activation)     (None, 64, 64)            0         
                                                                 
 pool_1 (MaxPooling1D)       (None, 32, 64)            0         
                                                                 
 conv_2 (Conv1D)             (None, 32, 16)            3088      
                                                                 
 bn_conv_2 (BatchNormalizati  (None, 32, 16)           64        
 on)                                                             
                                                                 
 conv_act_2 (Activation)     (None, 32, 16)            0         
                                                                 
 pool_2 (MaxPooling1D)       (None, 16, 16)            0         
                                                                 
 flatten_1 (Flatten)         (None, 256)               0         
                                                                 
 dense_0 (Dense)             (None, 64)                16448     
                                                                 
 bn_dense_0 (BatchNormalizat  (None, 64)               256       
 ion)                                                            
                                                                 
 dense_act_0 (Activation)    (None, 64)                0         
                                                                 
 dense_1 (Dense)             (None, 16)                1040      
                                                                 
 bn_dense_1 (BatchNormalizat  (None, 16)               64        
 ion)                                                            
                                                                 
 dense_act_1 (Activation)    (None, 16)                0         
                                                                 
 output_dense (Dense)        (None, 11)                187       
                                                                 
 output_softmax (Activation)  (None, 11)               0         
                                                                 
=================================================================
Total params: 47,451
Trainable params: 46,875
Non-trainable params: 576

and the configuration that i used is :

import hls4ml
import plotting

# First, the baseline model
hls_config = hls4ml.utils.config_from_keras_model(model_pruned, granularity='name')

# Set the precision and reuse factor for the full model
hls_config['Model']['Precision'] = 'ap_fixed<22,6>'
hls_config['Model']['ReuseFactor'] = 96
hls_config['Model']['Strategy'] = 'resource'
# Create an entry for each layer, here you can for instance change the strategy for a layer to 'resource'
# or increase the reuse factor individually for large layers.
# In this case, we designed the model to be small enough for a fully parallel implementation
# so we use the latency strategy and reuse factor of 1 for all layers.
for Layer in hls_config['LayerName'].keys():
    hls_config['LayerName'][Layer]['Strategy'] = 'resource'
    hls_config['LayerName'][Layer]['ReuseFactor'] = 96
    hls_config['LayerName'][Layer]['Precision'] = 'ap_fixed<22,6>'
    #hls_config['LayerName'][layer]['Trace'] = True
# If you want best numerical performance for high-accuray models, while the default latency strategy is faster but numerically more unstable
hls_config['LayerName']['output_softmax']['Strategy'] = 'Stable'
hls_config['LayerName']['dense_0']['ReuseFactor'] = 64
hls_config['LayerName']['dense_1']['ReuseFactor'] = 64
hls_config['LayerName']['output_dense']['ReuseFactor'] = 16
plotting.print_dict(hls_config)

cfg = hls4ml.converters.create_config(backend='Vivado')
cfg['IOType'] = 'io_stream'  # Must set this if using CNNs!
cfg['HLSConfig'] = hls_config
cfg['KerasModel'] = model
cfg['OutputDir'] = 'pruned_cnn/'
cfg['XilinxPart'] = 'xcu250-figd2104-2L-e'
cfg['strategy']   = 'resource'
#cfg['Interface'] = 'axi_stream'

hls_model = hls4ml.converters.keras_to_hls(cfg)
hls_model.compile()

the code passed the pre-synthesis part and all the errors disappeared but after two days the code is still running on synthesis part and it can't pass the following warning :

INFO: [HLS 200-10] ----------------------------------------------------------------
INFO: [HLS 200-42] -- Implementing module 'conv_1d_cl_array_ap_fixed_128u_array_ap_fixed_22_6_5_3_0_64u_config7_s' 
INFO: [HLS 200-10] ----------------------------------------------------------------
INFO: [SCHED 204-11] Starting scheduling ...
INFO: [SCHED 204-61] Pipelining loop 'KernelShiftWidth'.
INFO: [SCHED 204-61] Pipelining result : Target II = 1, Final II = 1, Depth = 1.
WARNING: [SCHED 204-65] Unable to satisfy pipeline directive: Loop's control-flow is too complicated to be pipelined.
INFO: [SCHED 204-11] Finished scheduling.

should i stop the code and run with different reuse factor for conv layer ? the code is running on a 12th Gen Intel® Core™ i9-12900K × 24 cpu .

@behnamarefy
Copy link
Author

thank you @vloncar i could pass the synthesis by reducing the model parameters. 👍 :):):)

@abd0smaali
Copy link

abd0smaali commented Nov 12, 2023

@vloncar regarding your suggestionto @behnamarefy , I also reduced the number of parameters in my model. However, I still encounter the same issue, as indicated in solution1.log. Additionally, here is the content of yml configuration that i used is :


KerasJson: LeNet5_MNIST_model.json
KerasH5:   LeNet5_MNIST_weights.h5
OutputDir: lenet5-hls-test
ProjectName: lenet5
Part: xc7z020clg400-1
ClockPeriod: 5
Backend: Vivado
IOType: io_stream
HLSConfig:
  Model:
    Precision: ap_fixed<16,6>
    ReuseFactor: 1 

Model: "LeNet5_MNIST"


Model: "LeNet5_MNIST"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 28, 28, 1)]       0
_________________________________________________________________
conv_0 (Conv2D)              (None, 26, 26, 16)        144
_________________________________________________________________
bn_conv_0 (BatchNormalizatio (None, 26, 26, 16)        64
_________________________________________________________________
conv_act_0 (Activation)      (None, 26, 26, 16)        0
_________________________________________________________________
pool_0 (MaxPooling2D)        (None, 13, 13, 16)        0
_________________________________________________________________
conv_1 (Conv2D)              (None, 11, 11, 16)        2304
_________________________________________________________________
bn_conv_1 (BatchNormalizatio (None, 11, 11, 16)        64
_________________________________________________________________
conv_act_1 (Activation)      (None, 11, 11, 16)        0
_________________________________________________________________
pool_1 (MaxPooling2D)        (None, 5, 5, 16)          0
_________________________________________________________________
conv_2 (Conv2D)              (None, 3, 3, 24)          3456
_________________________________________________________________
bn_conv_2 (BatchNormalizatio (None, 3, 3, 24)          96
_________________________________________________________________
conv_act_2 (Activation)      (None, 3, 3, 24)          0
_________________________________________________________________
pool_2 (MaxPooling2D)        (None, 1, 1, 24)          0
_________________________________________________________________
flatten (Flatten)            (None, 24)                0
_________________________________________________________________
dense_0 (Dense)              (None, 42)                1008
_________________________________________________________________
bn_dense_0 (BatchNormalizati (None, 42)                168
_________________________________________________________________
dense_act_0 (Activation)     (None, 42)                0
_________________________________________________________________
dense_1 (Dense)              (None, 64)                2688
_________________________________________________________________
bn_dense_1 (BatchNormalizati (None, 64)                256
_________________________________________________________________
dense_act_1 (Activation)     (None, 64)                0
_________________________________________________________________
output_dense (Dense)         (None, 10)                650
_________________________________________________________________
output_softmax (Activation)  (None, 10)                0
=================================================================
Total params: 11,848
Trainable params: 11,560
Non-trainable params: 288
_________________________________________________________________

the cmd used in terminal :
1:
/hls4ml convert -c co_nf.yml
2:
/hls4ml build -p lenet5-hls-test -a

lenet5-hls-test.tar.gz

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants