Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problematic transformations while streamlining of scaled dot-product attention #878

Open
iksnagreb opened this issue Aug 28, 2023 · 6 comments
Labels
bug Something isn't working

Comments

@iksnagreb
Copy link
Contributor

Quick summary

This is not really a bug report, it is more a collection of missing support and minor annoyances I am encountering while working towards streamlining and eventually mapping a scaled dot-product attention operator. This list might grow over time and is rather meant to start a discussion and document these problems for others.

Details

I am currently playing around with with QONNX and FINN graph transformations applied to some dummy single-head scaled dot-product attention operator stripped down to its bare minimum (i.e. it is tiny, has no input/output projections, no weights, no masking, ...). Essentially, this is just comprising a chain of MatMul-Softmax-MatMul operators with some Brevitas quantizers in between. I want to understand the streamlining process and eventually work towards mapping this operator pattern to some custom HLS operators (this is all WIP). Doing so, I have encountered a few, probably small problems in some of the transforms, mostly related to FINN assuming a MatMul operator always to involve one input and one weight initializer, which is not the case for the two-input, no weights MatMuls within scaled dot-product attention, i.e. in queries x keys, both are inputs produced by the preceding layer of input projections. In the following I will list the problematic transformations and some ideas how to fix the problem (I might add to this list over time if I find more or gain some further insights, either via edit or in the comments below):

  • MoveScalarMulPastMatMul always expects a weight matrix initializer as right hand side. As mentioned before, scaled dot-product attention involves MatMuls with two dynamic inputs, both of which may have a scalar multiplication. This transformation is currently skipped due to is_fork_node and is_join_node queries (two-input MatMul is a join node) and due to testing for the presence of weight initializers. Moving any or both of the scalar multiplications past the MatMul should be a valid transformation for two-input MatMul operators as well. This probably requires some refactoring of the transformation, as simply removing these checks seems to lead to detached sub-graphs.

  • Absorb1BitMulIntoMatMul and Absorb1BitMulIntoConv always test for the presence of weight initializers via assertions, causing a program to terminate instead of simply ignoring two-input MatMul operators without weights. In particular, this means the whole Streamline transformation (which among others contains these two) is not applicable when a scaled dot-product attention operator (i.e. a two-input MatMul, but no weights) appears anywhere in a model graph. This can probably be fixed by turning the assertions into simple tests (i.e. if ...:) skipping the application of the transformation.

  • InferShapes fails after FoldTransposeIntoQuantInit: This is probably a bug, but I am not sure whether the transpose of the keys input to the first MatMul of the attention operator should actually be folded, as it is probably just part of the pattern we want to detect and map to our new custom-op. However, as both transforms (in this order) are part of the ConvertQONNXtoFINN transform, this needs to be fixed. I do not really know why this happens or how to fix it, but it fails with ShapeInferenceError somewhere within onnx.shape_inference.infer_shapes (so not even within FINN or QONNX), but the cause of this might be higher up?

Expected behavior

Eventually I want to be able to streamline the scaled dot-product attention operator to only contain MatMul, Softmax and MultiThreshold operators using FINN's ConvertQONNXtoFINN and Streamline transformations "out of the box".

Steps to Reproduce

You can have a look at https://github.com/iksnagreb/attention-dummy for code I use to create my dummy operator and apply some transformations. Note that the code pulls from my fork's feature branch of FINN, but the current dev branch should do as well. I will also attach the ONNX export of the generated dummy operator (without any transforms) here: attention.onnx.zip.

Possible fix

I have already mentioned my current understanding of these problems and ideas to solve them above. It would be nice to get some input/guidance on how to solve it or at least how to work around it. I will continue to work on the problems and might have to adjust/add some transformations anyway to support a scaled dot-product attention operator. I will add any new insights here and will be happy to contribute fixes via PR.

@iksnagreb iksnagreb added the bug Something isn't working label Aug 28, 2023
@iksnagreb
Copy link
Contributor Author

Another problematic transformation is the MoveScalarAddPastMatMul, this one transforms the scalar to be added via dot-product with the matmul weights. Of course this does not work for the dynamic two-input matmul in the attention operator (as we do not know the weights at streamlining time). So the current behavior of simply skipping the join node is indeed correct.

However, we still need a way to deal with these scalar adds to properly streamline the attention pattern. Currently I am trying the following idea: Instead of streamlining the Add nodes downwards through the MatMul, streamline them upwards back into the MultiThreshold via the AbsorbSignBiasIntoMultiThreshold transformation. This currently does not work in cases where, for example, a Transpose gets in the way. I have tracked this down to the order of streamlining transformations, in particular to MoveScalarLinearPastInvariants and AbsorbSignBiasIntoMultiThreshold. Do they strictly have to be in that order or would it be safe to switch the order? Unfortunately I do not know enough about the details of the streamlining and potential side effects to judge this.

If I ran the streamlining unit test (like test_streamline_fc.py), would this be enough to validate that I did not break something?

@iksnagreb
Copy link
Contributor Author

Re-post of insights I gained while looking at the related issue #892, to have this documented here as well:

For me it seems like currently all occurrences of the FoldTransposeIntoQuantInit are dealing with transpose nodes which are inserted by some other transforms like GemmToMatMul or the ChannelsLast/ChannelsFirst conversions. Is this true? In our cases, however, the transpose seems to be inherently part of the model.

@iksnagreb
Copy link
Contributor Author

As we are gradually moving towards more realistic and complete models using the Brevitas quantized multi-head attention, we are seeing even more issues:

  • FoldQuantWeights seems to propagate shapes backwards, which sometimes messes with the next pass of InferShapes causing it to fail with ShapeInferenceError somewhere within onnx.shape_inference.infer_shapes. We think, shapes should never propagate backwards when transforming the model graph. The faulty line seems to be https://github.com/Xilinx/finn/blob/dev/src/finn/transformation/qonnx/fold_quant_weights.py#L191 and we think instead of output_shape it should keep the ishape of the tensor queried before?
  • FoldQuantWeights again, line https://github.com/Xilinx/finn/blob/dev/src/finn/transformation/qonnx/fold_quant_weights.py#L185: I think the initializer needs to be set to the reciprocal of the scale, i.e., 1.0 / scale, for the overall effect of the transformation to be the identity? However, I am not really sure and apparently no one has noticed any issues with this so far. Maybe I have to sketch this out on paper first (the fix would be trivial). At least it prevents us from later using the MoveLinearPastEltwiseAdd transformation as the scales do not match as they are supposed to (I also end up with a Mul scale of about 46519.2148 on an INT32 initializer later on, which seems far too big and matches to be the reciprocal of the scale on the other branch of the Add node). @maltanar: maybe you have some insights on this?
  • AbsorbAddIntoMultiThreshold assumes the initializer input to always be the second input, i.e., at index 1. However, sometimes, e.g., when exporting non-quantized floating-point bias, PyTorch seems to decide to generate the initializer as the first input, causing this transformation to terminate via asserting. At least regarding the Add node, the order of inputs should not matter and this assumption/restriction seems rather artificial.
  • Apparently we need the MoveLinearPastEltwiseAdd transformation, which is not included in the default Streamline steps, but that is fine, we can do custom steps of course. However, there seem to be issues with propagating the shapes and subsequent shape inference passes fail (again somewhere within onnx.shape_inference.infer_shapes). We think it is related to streamlining of quantized bias initializers (as the issues above), which sometimes seem to rely on what looks like broadcasting semantics, e.g., move an element-wise operation involving a tensor of shape (8) past the add operating on a tensor of shape (24, 1, 8), meaning repeating/broadcasting along the two other dimensions. This seems to be transformed correctly, it is just the next shape inference breaking. There are currently two ideas to fix this: Either (locally) delete the shape annotations and redo the shape inference from scratch (not sure if it is possible) or "simulate" the broadcasting in numpy to set the correct shape annotation manually (hopefully the numpy semantics match the onnx ones).
  • Brevitas (by default) uses unsigned quantizers for the attention weights, which are not supported by finn. This is not exactly a streamlining step and finn indeed produces a clear error message, so maybe we should work towards Brevitas setting a more reasonable default?
  • I would like to detect the key input to the attention pattern by finding the input to the first MatMul which is produced by a transpose operation (see attention is softmax(query x key^T) x value). That means, the transpose serves as a kind of marker in this case. However, Brevitas transposes first and then applies a quantizer, which results in a MultiThreshold where I expect my marker transpose. Streamlining the transpose through the MultiThreshold seems not to be supported and to be difficult (please convince me otherwise), so probably we should suggest changing Brevitas' behavior?
  • Moving scalar multiplication of a constant through Reshape, Slice, Squeeze/Unsqueeze and Transpose(?) seems not to be supported? We probably need this for moving scale factors produced by quantizers past the slice operations of the multi-heads and packed input projections.
  • Streamlining of packed input projections (there is a flag in Brevitas to enable/disable these, defaults to True, i.e., enabled) seems to break sometimes with shape inference issues. I am not sure why, this needs more debugging.

@iksnagreb
Copy link
Contributor Author

Hm, the issue regarding the inverse of the scale after the FoldQuantWeights transformation seems to be due to some asymmetry in handling the two inputs to the Add node: The transformation always assumes the quantized initializer feeds the second input to the add, thus the scale needs to be inserted in front of the first. However, this is not always the case, leading to the scale being inserted directly after the just-folded quantized initializer, seemingly doing the operation twice.

For now, this can be solved by detecting this situation (there are just two possible configurations) and flipping the indices if the quantized initializer feeds the first input.

@iksnagreb
Copy link
Contributor Author

Most of the issues seem to boil down to violations of some input order assumptions regarding "dynamic", i.e., produced upstream, vs. initializer inputs. I will try to fix this by a new cleanup transformation in QONNX and a few conditions here and there, not sure whether this will be a sustainable solution.

@iksnagreb
Copy link
Contributor Author

RemoveIdentityOps seems to be broken if the identity operation directly follows a fork-node. It seemingly rewires the fork-node output only into that branch containing the identity op, disconnecting the others. Currently this happens when using packed input projections.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant