In [1]:
from hls4ml.converters.keras_to_hls import parse_default_keras_layer
from hls4ml.converters.keras_to_hls import keras_handler



In [2]:
@keras_handler('MultiHeadAttention')
def parse_conv1d_layer(keras_layer, input_names, input_shapes, data_reader, config): # why we need data_reader and config here
    # what is the input_input_shape looks like, since it has three arguments? (value, key, query)
    # assume input_shapes is: [[None, seq, dim]]
    assert('MultiHeadAttention' in keras_layer['class_name'])
    assert(input_shapes[0]==keras_layer['config']['query_shape'])
    
    layer = parse_default_keras_layer(keras_layer, input_names)
    
    layer['num_heads'] = keras_layer['config']['num_heads']
    layer['head_dim_key'] = keras_layer['config']['key_dim']
    layer['head_dim_value'] = keras_layer['config']['value_dim']
    layer['query_shape'] = keras_layer['config']['query_shape']
    layer['key_shape'] = keras_layer['config']['key_shape']
    layer['value_shape'] = keras_layer['config']['value_shape']
    layer['feature_dim'] = layer['query_shape'][-1]
    # seq_length might not be a constant, not including?


    # below lines needs to discuss. Should we include those features or not?
    layer['dtype'] = keras_layer['config']['dtype'] 

    if keras_layer['config']['output_shape']: # shouse we include assigning the output shape?
        out_shape = keras_layer['config']['output_shape'] # the config output_shape does not include batch and seq, only dim
        layer['output_shape'] = (layer['query_shape'][:2]).extend(out_shape)
    else:
        layer['output_shape'] = layer['query_shape']
        
    output_shape = layer['output_shape']
    
    layer['attention_axes'] = keras_layer['config']['attention_axes'] if (keras_layer['config']['attention_axes'][0]==1) else False

    # should we support attention_axes?
    if layer['attention_axes'] is False: 
        raise Exception('assigning the attention_axe is not currently supported by hls4ml'.format(layer['class_name']))

    # should we support muti-dimension of freature?
    if not((len(layer['query_shape'])) == 3 and (len(layer['query_shape'])) == 3 and (len(layer['query_shape'])) == 3):
        raise Exception('muti-dimension of feature dim is not currently supported by hls4ml'.format(layer['class_name']))

    attn_scores_rank = 4 # filter matrix has shape (Batch,num_head,seq_query,seq_key),   
                         # should we '-1', since Batch will not included in the weight matrix

    layer['softmax_axis'] = tuple(range(attn_scores_rank - len(layer['attention_axes']), attn_scores_rank ))

    return layer, output_shape