Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] ConvertLayout pass. #4335

Merged
merged 1 commit into from Dec 26, 2019
Merged

[Relay] ConvertLayout pass. #4335

merged 1 commit into from Dec 26, 2019

Conversation

@anijain2305
Copy link
Contributor

anijain2305 commented Nov 14, 2019

RFC - https://discuss.tvm.ai/t/layout-conversion-pass/4009

Refactored to share code between AlterOpLayout and ConvertLayout pass.

There is considerable refactoring. Following description might help

  • Older alter_op_layout.h is renames as infer_layout_util.h that contains only the FInferCorrect Related methods.
  • transform_layout.h has now several methods from older alter_op_layout.cc. The idea is to put the shared code among AlterOpLayout and ConvertLayout pass in transform_layout.h.
  • Both AlterOpLayout and ConvertLayout use FForwardRewrite. FForwardRewrite accepts a context. This context now has a transform function which is specialized for ConvertOpLayout and AlterOpLayout separately. For example, ConvertOpLayout context-transform function extracts op's "FTVMConvertOpLayout" attribute, and and sends the user-supplied-desired layout to the packed function.
@anijain2305 anijain2305 force-pushed the anijain2305:convert_layout branch from 1fb5cc4 to a066107 Nov 14, 2019
@yzhliu

This comment has been minimized.

Copy link
Member

yzhliu commented Nov 14, 2019

I'm ok with deprecating plevel. @tqchen what's your thought?

@anijain2305 anijain2305 force-pushed the anijain2305:convert_layout branch 2 times, most recently from 1a764d8 to 76d2092 Nov 14, 2019
@tqchen

This comment has been minimized.

Copy link
Member

tqchen commented Nov 15, 2019

I feel live patching the OpAttr is not a pattern we are looking for in here. It would be great if we can factor out some common logic between AlterOpLayout and ConvertLayout(e.g. a base class), so ConvertLayout can directly inject the necessary functions

@anijain2305

This comment has been minimized.

Copy link
Contributor Author

anijain2305 commented Nov 15, 2019

Ok, let me spend some more time thinking about this :)

@anijain2305 anijain2305 force-pushed the anijain2305:convert_layout branch from 76d2092 to b356a13 Nov 19, 2019
@anijain2305

This comment has been minimized.

Copy link
Contributor Author

anijain2305 commented Nov 19, 2019

@tqchen Refactored to share code between AlterOpLayout and ConvertLayout pass.

There is considerable refactoring. Following description might help

  • Older alter_op_layout.h is broken into 2 files - infer_layout_util.h and transform_layout.h.
  • infer_layout_util.h contains the FInferCorrect related methods.
  • transform_layout.h has now several methods from older alter_op_layout.cc. The idea is to put the shared code among AlterOpLayout and ConvertLayout pass in transform_layout.h.
  • Both AlterOpLayout and ConvertLayout use FForwardRewrite. FForwardRewrite accepts a context. This context now has a transform function which is specialized for ConvertOpLayout and AlterOpLayout separately. For example, ConvertOpLayout context-transform function extracts op's "FTVMConvertOpLayout" attribute, and and sends the user-supplied-desired layout to the packed function.

@yzhliu

@anijain2305

This comment has been minimized.

Copy link
Contributor Author

anijain2305 commented Nov 21, 2019

@tqchen @yzhliu @zhiics Let me know if you have some cycles to review this. It will be good to know if you think the high-level design makes sense.

@anijain2305

This comment has been minimized.

Copy link
Contributor Author

anijain2305 commented Nov 25, 2019

Ping for review :)

@yzhliu

This comment has been minimized.

Copy link
Member

yzhliu commented Nov 25, 2019

will do shortly

Copy link
Member

yzhliu left a comment

high level looks good to me.

* \param new_args The traversed/recursed args to the call.
* \return The new Call after calling the packed func.
*/
Call GetCallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) const {

This comment has been minimized.

Copy link
@yzhliu

yzhliu Nov 26, 2019

Member

how about name it AlterCall ?

This comment has been minimized.

Copy link
@anijain2305

anijain2305 Nov 26, 2019

Author Contributor

This function name is shared between Alter and Convert. But, I will try to come up with a better name.

*/

/*!
* Copyright (c) 2019 by Contributors

This comment has been minimized.

Copy link
@yzhliu

yzhliu Nov 26, 2019

Member

remember to remove all copyright


/*!
* Copyright (c) 2019 by Contributors
* \file transform_layout.h

This comment has been minimized.

Copy link
@yzhliu

yzhliu Nov 26, 2019

Member

would be better to move implement into .cc

This comment has been minimized.

Copy link
@anijain2305

anijain2305 Nov 26, 2019

Author Contributor

The impl is shared between AlterOp and ConvertLayout. So, it has to be in .h file.

This comment has been minimized.

Copy link
@yzhliu

yzhliu Dec 10, 2019

Member

well, I feel .h can contain just the data structure and function signature, while large implement of the function can be moved into corresponding transform_layout.cc

This comment has been minimized.

Copy link
@anijain2305

anijain2305 Dec 23, 2019

Author Contributor

I checked again, most of the functions are templated, requiring definition in the header file. What do you think?

@anijain2305 anijain2305 force-pushed the anijain2305:convert_layout branch from b356a13 to ad222fc Nov 27, 2019
@anijain2305 anijain2305 changed the title [Relay][WIP] ConvertLayout pass. [Relay] ConvertLayout pass. Nov 27, 2019
@anijain2305 anijain2305 force-pushed the anijain2305:convert_layout branch 3 times, most recently from 3b6010b to 0bbee0b Nov 27, 2019
@anijain2305

This comment has been minimized.

Copy link
Contributor Author

anijain2305 commented Nov 28, 2019

@yzhliu This is ready for review. I made a couple of more refactoring that still keeps the high level same, and also added docs.

@anijain2305 anijain2305 force-pushed the anijain2305:convert_layout branch from 0bbee0b to b45ae89 Nov 29, 2019
@anijain2305

This comment has been minimized.

Copy link
Contributor Author

anijain2305 commented Dec 1, 2019

Ping @yzhliu

Copy link
Member

yzhliu left a comment

overall LGTM. sorry I missed the reminding last week.


/*!
* Copyright (c) 2019 by Contributors
* \file transform_layout.h

This comment has been minimized.

Copy link
@yzhliu

yzhliu Dec 10, 2019

Member

well, I feel .h can contain just the data structure and function signature, while large implement of the function can be moved into corresponding transform_layout.cc

@kice

This comment has been minimized.

Copy link
Contributor

kice commented Dec 14, 2019

I have many transpose when using tensor core due to NCHW. About 20% of the time is doing transpose.

So, I want a NCHW to NHWC converter.

@anijain2305

This comment has been minimized.

Copy link
Contributor Author

anijain2305 commented Dec 23, 2019

I have many transpose when using tensor core due to NCHW. About 20% of the time is doing transpose.

So, I want a NCHW to NHWC converter.

That is doable. I will first work on getting this PR in. This will setup the whole infrastructure. The NCHW to NHWC layout changes then can go as a second PR.

@anijain2305 anijain2305 force-pushed the anijain2305:convert_layout branch from b45ae89 to f930b23 Dec 23, 2019
@yzhliu
yzhliu approved these changes Dec 26, 2019
@yzhliu yzhliu merged commit 73dda6b into apache:master Dec 26, 2019
1 check passed
1 check passed
continuous-integration/jenkins/pr-merge This commit looks good
Details
@yzhliu

This comment has been minimized.

Copy link
Member

yzhliu commented Dec 26, 2019

Thanks @anijain2305 this is merged. Sorry for the delay.

@anijain2305

This comment has been minimized.

Copy link
Contributor Author

anijain2305 commented Dec 26, 2019

Thanks @yzhliu
No worries about the delay :)

zhiics added a commit to zhiics/tvm that referenced this pull request Dec 31, 2019
zhiics added a commit to neo-ai/tvm that referenced this pull request Jan 9, 2020
* Change upstream url

* Fix bias_add gradient (apache#4516)

* Fix bias_add gradient

A change caused collapse_sum_like to reject implicit dimension
broadcasting for bias_add gradient, so switch to explicit sum reduction
on the non-bias axis dimensions.

* Lint fix

* [Bugfix][Frontend][TFlite] Fix wrong function call in TANH tests (apache#4517)

* Replace sigmoid() with tanh() in tests for TANH

* Fixed extra reshape parameter bug. (apache#4524)

* Use the best tuner possible (apache#4397)

* Use the best tuner possible

* Add comment denoting availability of better tuners

* Fix typos and wording

* [ir] use DataType instead of Type for readability because Type has been deprecated (apache#4513)

* add bfloat16 typeflag support (apache#4525)

* fix empty config caused KeyError (apache#4520)

* fix onnx shape dtype (apache#4528)

* fix crash issue in tsim backend (apache#4527)

* PIL is depreciated and should be replaced with pillow (a fork of PIL) (apache#4533)

Change-Id: If2075df5475505f2da87dae7145af5a7ab83d8a4

* [Relay] External codegen (apache#4482)

* Update legacy places from nnvm to relay. (apache#4535)

* Update legacy places from nnvm to relay.

This PR prepares the current mainline to remove nnvm compiler dep.

* remove legacy stage

* Implement 1d deconvolution (apache#4476)

* [relay][op] add expand op (from ONNX) to relay frontend (apache#4483)

* Add Expand to onnx.py

* add test function for expand

* Fix a onnx frontend test

* Add tests for the value itself instead of shape only on test_expand

* Cleaned up some unnecessary modifications.

* [TOPI] Allow batch matmul to be fused into injective ops (apache#4537)

* [TOPI] Fixed nms max_output_size loop (apache#4541)

One of the loops in hybrid_nms used for
performing the max_output_size reordering
was incorrectly designated as parallel
resulting in incorrect behaviour. This patch
changes that loop to a serial loop.

Change-Id: I97184f5887f5f028d8ab339fa2808eb7630a4017

* [DOCS] Mention Ninja build system in install/from_source.rst (apache#4554)

* [DOCS] Mention Ninja build system in install/from_source.rst

* Address comments

* [PYTHON][FFI] Cythonize NDArray.copyto (apache#4549)

* [PYTHON][FFI] Cythonize NDArray.copyto

* Cythonize the shape property

* vm external codegen (apache#4544)

* [COMMUNITY] @cchung100m -> reviewer (apache#4557)

* [VTA] improved virtual memory mapping (apache#4545)

* [VTA] improved virtual memory mapping

* Update virtual_memory.cc

* [IR] fix style in ir_mutator and ir_visitor (apache#4561)

* [RUNTIME][VULKAN] Fix compiler warning (apache#4559)

* [REFACTOR][DTYPE] Isolate dtype to runtime (apache#4560)

dtype.h -> runtime/data_type.h

Changes:
- Rename all old reference of tvm::Type to DataType
- ExprNode.type -> ExprNode.dtype
- Expr.type() -> Expr.dtype()
- Change Expr related functions to expr_operator.
  - DataType::min() -> min_value(DataType)
  - DataType::max() -> max_value(DataType)
- Move type constructor Int, UInt, Float, Handle, Bool into DataType.
  - Int(bits) -> DataType::Int(bits)
  - UInt(bits) -> DataType::UInt(bits)

* Support standardize runtime module (apache#4532)

* [Relay][Frontend][ONNX] Support auto_pad in Conv and ConvTranspose (apache#4563)

* [TEST] Remove nnvm related code in topi and test script (apache#4562)

* [TEST] Remove nnvm related code in topi and test script

* Remove docs dep

* [Relay] add max_pool3d in relay and TF converter (apache#4551)

* [Relay] add max_pool3d in relay and TF converter

* fix comments

* Remove nnvm (apache#4565)

* [VTA][Chisel] End-to-end Inference with Chisel VTA (apache#4574)

* [VTA][Chisel] End-to-end Inference with Chisel VTA

* Update TensorAlu.scala

* remove unnecessary cast to int32 (apache#4573)

* Fix llvm-enabled build by adding missing intrinsics headers (apache#4575)

* [DEPRECATION] Remove NNVM compiler (apache#4571)

* Remove NNVM compiler

* [Relay/Topi][Op] Added native DepthToSpace and SpaceToDepth Operators (apache#4566)

* Added tvm function stencil for subpixel operations to topi.

* Topi subpixel operators added and tested.

* Added subpixel attrs.

* Added depth_to_space relay attributes.

* depth_to_space fully working.

* Fixed NHWC shape bug.

* SpaceToDepth in and all tests passing.

* lint fixes.

* Added string include

* Fixed topi formatting.

* Added DCR/CDR mode to depthtospace operator.

* [DOC] fix doc in api.py (apache#4580)

* [DEPRECATION] Cleanup legacy verilog support (apache#4576)

This PR cleans up the left over code for legacy verilog support which was experimental.
The new hardware backend path is now support by VTA via TSIM.

* [RUNTIME] Remove Extension VTable in favor of Unified Object system. (apache#4578)

Before the unified object protocol, we support pass
additional extension objects around by declaring a type as an extension type.
The old extension mechanism requires the types to register their
constructor and deleter to a VTable and does not enjoy the benefit of the
self-contained deletion property of the new Object system.

This PR upgrades the extension example to make use of the new object system
and removed the old Extension VTable.

Note that the register_extension funtion in the python side continues to work
when the passed argument does not require explicit container copy/deletion,
which covers the current usecases of the extension mechanism.

* Some Windows and MSVC fixes (apache#4569)

* fix python exception creation in Windows

* better string conversion for msvc

* fix cpp style issue

* [NEWS] add v0.6 release (apache#4558)

* [NEWS] add v0.6 release

* remove link prefix

* fix issue number

* [DOCS]fix typos in autotvm tutorial (apache#4585)

* [Quantization, Calibrate] Fix context creation when current_target is explicity set (apache#4582)

* [Container] Fix NDArray SaveDLTensor declaration and implementation signature different (apache#4586)

* [TOPI][AutoTVM] NHWC conv2d templates for ARM (apache#3859)

* [AutoTVM][TOPI] NHWC conv2d templates (spatial pack) for ARM

As some frontends (tflite for example) are using NHWC as the default
layout, we are enabling NHWC schedule templates in TOPI and AutoTVM.

* some comments fix

* [FIX][TOPI][X86] schedule dense pack (apache#4539)

* [Relay] Convert Layout Pass. (apache#4335)

* [Relay][AlterLayout] Broadcast with scalar shape (apache#4577)

* [TOPI] add 3D upsampling Op. (apache#4584)

* [TOPI] add 3D upsampling Op.

* fix lint issues

* change align_corners to coordinate_transformation_mode

* fix resize3d half_pixel

* make a simple function and clean up trilinear_resize3d_python

* fix doc

* [Runtime] add necessary const qualifier for NDArray container of parameters (apache#4590)

* [autotvm] fix typos in comment (apache#4591)

* fix tf.compat.v1 issue for tf verison <=1.12 (apache#4593)

* [FRONTEND][TF] conv2d_transpose 'SAME' support kernel more than 1x1 (apache#4484)

* [FRONTEND][TF] conv3d_transpose 'SAME' support kernel more than 1x1

* revised per as review comments

* add more fallback wolkaround to make all tests pass

* [GraphRuntime] Support parameter out in the graph runtime debug (apache#4598)

* [GraphRuntime] Support parameter out in the graph runtime debug

* Dummy commit to trigger build

* [Perf] Add CublasLt extern support for better Igemm performance (apache#4550)

* cublaslt added

* fix lint

* address comments

* address more comments

* Trigger CI

* Trigger CI

* fix codegenc (apache#4597)

* [REFACTOR][RUNTIME] Update NDArray use the Unified Object System (apache#4581)

* [REFACTOR][RUNTIME] Move NDArray to Object System.

Previously NDArray has its own object reference counting mechanism.
This PR migrates NDArray to the unified object protocol.

The calling convention of NDArray remained intact.
That means NDArray still has its own type_code and
its handle is still DLTensor compatible.

In order to do so, this PR added a few minimum runtime type
detection in TVMArgValue and RetValue only when the corresponding
type is a base type(ObjectRef) that could also refer to NDArray.

This means that even if we return a base reference object ObjectRef
which refers to the NDArray. The type_code will still be translated
correctly as kNDArrayContainer.
If we assign a non-base type(say Expr) that we know is not compatible
with NDArray during compile time, no runtime type detection will be performed.

This PR also adopts the object protocol for NDArray sub-classing and
removed the legacy NDArray subclass protocol.
Examples in apps/extension are now updated to reflect that.

Making NDArray as an Object brings all the benefits of the object system.
For example, we can now use the Array container to store NDArrays.

* Address review comments

* [Relay][Convert Layout] Handling batch norm layout change. (apache#4600)

* [relay][refactor] Cache Op::Get in passes to reduce lookup overhead (apache#4594)

* Refactor to use IsOp utility

* retrigger CI

* Update dmlc_tvm_commit_id.txt

* disable one test_batch_norm unit test for now to check CI

* enable test_batch_norm

Co-authored-by: SWu <SWu@users.noreply.github.com>
Co-authored-by: Ina Dobreva <55383260+inadob@users.noreply.github.com>
Co-authored-by: Josh Fromm <jwfromm@uw.edu>
Co-authored-by: miheer vaidya <v.miheer@gmail.com>
Co-authored-by: Liang ZOU <liang.d.zou@gmail.com>
Co-authored-by: YixinBao <yixin.bao@intel.com>
Co-authored-by: Cody Yu <comaniac0422@gmail.com>
Co-authored-by: masahi <masahi129@gmail.com>
Co-authored-by: Liangfu Chen <liangfu.chen@icloud.com>
Co-authored-by: lhutton1 <35535092+lhutton1@users.noreply.github.com>
Co-authored-by: Tianqi Chen <tqchen@users.noreply.github.com>
Co-authored-by: Alex Gladkov <gladkov_alex@yahoo.com>
Co-authored-by: Takato Yamada <tkclimb0911@gmail.com>
Co-authored-by: Haichen Shen <shenhaichen@gmail.com>
Co-authored-by: mbarrett97 <55580676+mbarrett97@users.noreply.github.com>
Co-authored-by: Hideto Ueno <uenoku.tokotoko@gmail.com>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Zhao Wu <wuzhaozju@gmail.com>
Co-authored-by: Neo Chien <cchung100m@cs.ccu.edu.tw>
Co-authored-by: Yong Wu <55wuyong@163.com>
Co-authored-by: Dmitri Makarov <dmakarov@users.noreply.github.com>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: kice <wslikerqs@gmail.com>
Co-authored-by: Yizhi Liu <liuyizhi@apache.org>
Co-authored-by: Wang Yucheng <wyc91543@163.com>
Co-authored-by: 王振华(Zhenhua WANG) <i@jackwish.net>
Co-authored-by: deepIgnorance <zhengsizemax@outlook.com>
Co-authored-by: Animesh Jain <anijain@umich.edu>
Co-authored-by: optima2005 <56945758+optima2005@users.noreply.github.com>
Co-authored-by: zhuochen <zhuochen@outlook.com>
Co-authored-by: Leyuan Wang <laurawly@gmail.com>
zhiics added a commit to neo-ai/tvm that referenced this pull request Jan 11, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked issues

Successfully merging this pull request may close these issues.

None yet

5 participants
You can’t perform that action at this time.