[Relay] QLinearMatMul allows 1D weight_scale, weight_zero_point inputs#9946
Closed
yuanfz98 wants to merge 66 commits intoapache:mainfrom
Closed
[Relay] QLinearMatMul allows 1D weight_scale, weight_zero_point inputs#9946yuanfz98 wants to merge 66 commits intoapache:mainfrom
yuanfz98 wants to merge 66 commits intoapache:mainfrom
Conversation
AndrewZhaoLuo
requested changes
Jan 24, 2022
| return _op.squeeze(x) | ||
| else: | ||
| if force_assert: | ||
| assert num_elem == 1, "Cannot squeeze tensor shape {} to scalar form.".format(x_shape) |
Contributor
There was a problem hiding this comment.
In this codepath, you will return None
| if force_assert: | ||
| assert num_elem == 1, "Cannot squeeze tensor shape {} to scalar form.".format(x_shape) | ||
| else: | ||
| return x |
Contributor
There was a problem hiding this comment.
This changes the behavior of ensure_scalar_shape -- now we can return non-scalar elements, I would instead in QLinearMatMul's try_resolve_to_const simply, wrap ensure_scalar_shape calls with a check if the number of elements is 1.
| def try_resolve_to_const_scalar(x, dtype_override=None): | ||
| def try_resolve_to_const(x, dtype_override=None, allow1D=False): | ||
| x2 = try_resolve_var_to_const(x, params) | ||
| x3 = ensure_scalar_shape(x2) |
Contributor
There was a problem hiding this comment.
echoing above, instead of adding a new flag to ensure_scalar_shape that breaks invariants, take the check you wrote and push it out here.
Contributor
|
Also, for each file, we use |
* [Caffe Frontend] supporting group > 1 cases for Deconv op - Handling group > 1 cases, assuming group == output channels - Simply decomposed into Relay split, conv2d_transposed, and multi-leveled concatenate ops - Added some test cases Signed-off-by: zotanika <zotanika@gmail.com> * [Caffe Frontend] amending a test case for Deconv op Signed-off-by: zotanika <zotanika@gmail.com> * explicit importing tvm.testing * changing split axis to 0, according to PR apache#9336
…8136) * [Caffe Frontend] adding Reduction op * reformatting Reduction op test script * reformatting Reduction test script * [Caffe frontend] Reduction op - adding more test cases; handling '0 < axis < num_axes - 1' case to give the result equivalent to Caffe framework - skipping Relay multiplication if coeff is 1 Signed-off-by: zotanika <zotanika@gmail.com> * linting test script * linting * [Caffe Frontend] Supporting multiple grouped(channel-wise) Deconv op * Handling group > 1 cases, assuming group == output channels * Decomposed into Relay split, transposed conv, and multi-leveled concatenation. * Added some test cases. Signed-off-by: zotanika <zotanika@gmail.com> * [Caffe Frontend] supporting variable number of inputs for Eltwise * extra handling of rest inputs for PROD, SUM, MAX operations * extra testcases Signed-off-by: zotanika <zotanika@gmail.com> * formatting fix * [Caffe Frontend] reverting codes related Reduction for splitting PR * Revert "[Caffe Frontend] Supporting multiple grouped(channel-wise) Deconv op" This reverts commit 43e25e5. * instant fix against docker format error * instant fix against docker format error * instant fix against docker format error
Co-authored-by: Junru Shao <junrushao1994@gmail.com> Co-authored-by: Xiyou Zhou <xiyou@octoml.ai> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin <wuwei@apache.org> Co-authored-by: Junru Shao <junrushao1994@gmail.com> Co-authored-by: Xiyou Zhou <xiyou@octoml.ai> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin <wuwei@apache.org>
* [microNPU] Remove remaining UnsupportedLayout checks In apache#9508 the decision was made to remove the UnsupportedLayout exception and the checks that throw it, this PR is cleaning up some that remained. Change-Id: I83bfe233381b83af886343c9569db753e33f9059 * fix lint Change-Id: I67c1a5371f0b2e51b6cd39435ef4073d8d17af51
* [microNPU][2c] Initial Performance Model * Added the pre-computed performance modelling per block. * Added the aggregation of cycles given a stripe config. * Implemented the op-specific performance code for conv2d. * Created a DeviceConfig class to hold constant performance related data that is dependent on the accelerator configuration * Added generation of all valid block configs. This is pre-computed and given as an argument when constructing EthosuParts. * Implemented selection of the block config that gives the least amount of data read given a StripeConfig. * Add test guards * Extended block config testing
Co-authored-by: Junru Shao <junrushao1994@gmail.com> Co-authored-by: Xiyou Zhou <xiyou@octoml.ai> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin <wuwei@apache.org> Co-authored-by: Junru Shao <junrushao1994@gmail.com> Co-authored-by: Xiyou Zhou <xiyou@octoml.ai> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin <wuwei@apache.org>
…_shape_from_cudnn` (apache#9948) * Introduce SetConvdescriptors to refactor cudnn/conv_forward.cc * more refactor * remove cudnn get output * cpplint
* [microNPU] Add support for scalar values PR apache#9515 enabled support for scalar constants, but didn't consider the case of a scalar value where the underlying constant data does not have a shape i.e. `constant.shape == []`. See the test case for a visual differece when the scalar value is 1. Change-Id: Id7a238cb5bf999dd5a8428c097202f9fb940a5f0 * Fix failing test by removing constant Before this PR scalar constants were handled differently so this test was able to pass. Now that scalar constants are handled in the same manner as tensor constants, the test fails since unexpected tir is produced in the compilation pipeline. Since the relay used in this test case is not expected to be produced by higher levels of the compiler, removing this constant for now. Change-Id: I4ea5155778809041339e6faac05af3f72c3e3ea5 * clean up finding tensor from inputs Change-Id: Ideccf84f8c9149148ff23e2406229cf637c982a3
These are required for running the demos under ci_qemu in combination with Zephyr
… modules parameters map. (apache#9846) * [Runtime][Pipeline executor] Global parameters group name and runtime modules parameters map. Solution: To support on the fly parameters setting for each runtime module in pipeline executor, we create a feature that use global parameters group name to map the runtime module parameter, after such map relation get created user can do the on the fly parameters setting by using the parameters group name. trigger build. fix ut issue. polish comments. Update python/tvm/contrib/pipeline_executor.py Co-authored-by: Cody Yu <comaniac0422@gmail.com> Update python/tvm/contrib/pipeline_executor.py Co-authored-by: Cody Yu <comaniac0422@gmail.com> Update python/tvm/contrib/pipeline_executor.py Co-authored-by: Cody Yu <comaniac0422@gmail.com> Update python/tvm/contrib/pipeline_executor.py Co-authored-by: Cody Yu <comaniac0422@gmail.com> Update src/runtime/pipeline/pipeline_executor.h Co-authored-by: Cody Yu <comaniac0422@gmail.com> Update src/runtime/pipeline/pipeline_struct.h Co-authored-by: Cody Yu <comaniac0422@gmail.com> Update python/tvm/contrib/pipeline_executor.py Co-authored-by: Cody Yu <comaniac0422@gmail.com> address review comments. * Update python/tvm/contrib/pipeline_executor.py Co-authored-by: Cody Yu <comaniac0422@gmail.com> * fix plint issue. Co-authored-by: Cody Yu <comaniac0422@gmail.com>
* jenkinsfile and one test * formatting * swtich to proper repo for docker * fix missing - with _ * jostle * upgrade to latest images * jenkinsfile and one test * formatting * swtich to proper repo for docker * fix missing - with _ * upgrade to latest images * jostle ci * update with official images * jostle ci
Appended fused operations in cov2d for int8 were computed in a separate
loop from the main conv2d computation:
```
for i in ... parallel
for j in ...
accumulator = 0
for k in ..
vectorized_multiply_add(accumulator, data, kernel)
out = accumulator
for k in ..
out = out + fused subsequent ops
```
This patch moves the fused ops one more loop nesting inwards to get
```
for i in ... parallel
for j in ...
accumulator = 0
for k in ..
vectorized_multiply_add(accumulator, data, kernel)
out = accumulator + fused subsequent ops
```
On quantized mobilenetv2, this results in approximately a 30% speedup.
Type annotations don't do anything, the type conversion needs to be explicit.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Hello,
This PR is to fix #9908. As _qnn.op.dense can already accept 1D vectors (
tvm/python/tvm/relay/qnn/op/qnn.py
Line 552 in 6eb4ed8
In #9908, the expected behaviour is tested by following code:
Output:
Thank reviewers for your time and have a nice day !
cc @alnah005 @cconvey