Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][VM] JIT #4129

Closed
wants to merge 7 commits into from
Closed

[Relay][VM] JIT #4129

wants to merge 7 commits into from

Conversation

MarisaKirisame
Copy link
Contributor

Right now relay vm dont support any/symbolic shape with dense, because dense schedule/compute require input shape known as constant beforehand.

This PR add polymorphic inline JIT, so if a function cannot be lowered at compile time, for each new shape we will encounter, we will compile a tvm kernel at runtime, and invoke it.

@wweic @jroesch @junrushao1994 @icemelon9 @vinx13 please review.

@MarisaKirisame MarisaKirisame changed the title [Relay] [VM] JIT [Relay][VM] JIT Oct 15, 2019
@@ -50,11 +51,13 @@
_reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective)
_reg.register_schedule("split", schedule_injective)
_reg.register_dynamic_compute("split", True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can ops automatically register dynamic compute? It seems like this would be a useful feature for every op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can make it by default true, because the false case will break immediately, then we can add annote.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, I think that's a little cleaner, but we can wait to see what other reviewers think :)

auto new_func = Downcast<Function>(InferType(new_func_untyped, Module(), GlobalVar()));
auto key = CCacheKeyNode::make(new_func, target);
CompileEngine ce = CompileEngine::Global(); // WHY cant I use engine_?
auto jit_pf = ce->JIT(key);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought - would it be worthwhile including an option for users to specify a range for each dimension of a shape that can be Any, then we can pre-compile all of those? It will use more memory, but save latency at runtime.

Alternatively, we can have the user supply a max value for each dimension that can be Any, then use that max value everywhere and pad when necessary. Thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm inclined to leave it as future work - it is indeed possible, however they are more complex, and have different options (range? max value and pad?). Meanwhile, this make it possible to compute for dense, is simple and allow further extension that does those. Sounds good to you?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely, I think we can leave it for a later change.

@wweic
Copy link
Contributor

wweic commented Oct 16, 2019

I'm under the impression that for operators that accepts/returns dynamic shape tensor, we just need to implement the shape function. Please correct me if I'm wrong. @icemelon9

@MarisaKirisame
Copy link
Contributor Author

@wweic yes, however some schedule does not accept dynamic shape tensor, such as nn.dense

@kevinthesun
Copy link
Contributor

How do we pick a descent dense schedule at runtime? My feeling is that JIT compilation done in this PR can be covered by AOT compilation, like in #4118.

@wweic
Copy link
Contributor

wweic commented Oct 16, 2019

@MarisaKirisame Got it. I think it might be better to fix the schedule for nn.dense, since the design for any support is by adding shaping function to each operator. We should try to be consistent in the implementation. What do you think?

@MarisaKirisame
Copy link
Contributor Author

@wweic I had talked with @eqy, and he say generic schedule is slow. This also allow us to tune schedule for each specific input shape (so dense with batch size of 16 vs 32 need not share the same schedule at all).

The shape function is still needed no matter what, I dont get what you mean by that.

@kevinthesun
Copy link
Contributor

@MarisaKirisame I don't think tune a dynamic shape kernel is related to JIT.

@MarisaKirisame
Copy link
Contributor Author

@kevinthesun they both have their own merits.
For JIT, the problem is the compiler need to live in the runtime - so it can never be used on embedded device. Also if new shapes are constantly encounter (which I suspect wont happend), compilation will be a major cost of runtime.
For Dispatch, it is the complexity, and, if the dynamism fall into a few predefined case, JIT will get the utmost precision.

Both approach can definitely coexist and we can have some LoweringStrategy option which pick which to use.

@MarisaKirisame
Copy link
Contributor Author

@kevinthesun we cant tune symbolic shape, but in JIT, there is no symbolic shape/any shape/dynamic shape. All that's left is purely concrete shape, and we can tune them individually.

@kevinthesun
Copy link
Contributor

@kevinthesun we cant tune symbolic shape, but in JIT, there is no symbolic shape/any shape/dynamic shape. All that's left is purely concrete shape, and we can tune them individually.

But how do we know which shapes will come in? In AOT, we can split a symbolic axis into several buckets and AOT tune/compile kernel for each bucket. I don't see how we can achieve this in JIT.

@MarisaKirisame
Copy link
Contributor Author

@kevinthesun there will be two run of the same program.
In the zeroth run, u do not tune and use the default schedule. whenever you does this, tvm will complain that it has to use the default schedule (by outputting to the stdout). You record all of them.

In the first run, you tune the schedule beforehand according to the stdout output.

@kevinthesun
Copy link
Contributor

@MarisaKirisame The difficulty for dynamic shape kernel is that we cannot predict which concrete shape will come in at runtime. In the zero run, only a sample of concrete shape comes in, and in the first run only this sample is tuned. There can be a lot of different shapes coming in and a lot of tuning have to be done in runtime. We cannot tune every concrete shape because both of time and memory complexity. Also if we want to tune kernels, we should do this even before compilation.

@MarisaKirisame
Copy link
Contributor Author

@kevinthesun it depend on the work load. For my current work load, only the batch_size is dynamic, and it is either 32/16, so the problem is easily solved.
One can also imagine converting from torchscript to relay (or something similar), and all of the 'dynamic' shape only have 1 single value at run time.
I agree that this is no a universal solution, however it will always work (instead of the current vm implementation which will simply break on dynamic dense).
And it seems like there is a lot of work in topi/tvm/relay to make graph dispatching work, so I think the best plan is to have this for now (which I need right now).
After the graph dispatching is done, we can figure out(by benchmarking) if hybird make sense, or is it just universally better, then we implement a hybird strategy or swap it out, does it sounds good to you?

@icemelon
Copy link
Member

I have some concern at this PR as it is a temporary solution and is limited to your current workload. It's definitely possible that the workload is fully dynamic, such as batch size, sequence length, output of nms and arange op, etc.
Besides, after this PR, the VM runtime then requires llvm and tvm to be able to jit these kernels. This will increase the runtime size which is not a desirable behavior in many situations.
I think a better solution to your current case is to use cblas in the target. This can support dense with symbolic shape.

@kevinthesun
Copy link
Contributor

@kevinthesun it depend on the work load. For my current work load, only the batch_size is dynamic, and it is either 32/16, so the problem is easily solved.
One can also imagine converting from torchscript to relay (or something similar), and all of the 'dynamic' shape only have 1 single value at run time.
I agree that this is no a universal solution, however it will always work (instead of the current vm implementation which will simply break on dynamic dense).
And it seems like there is a lot of work in topi/tvm/relay to make graph dispatching work, so I think the best plan is to have this for now (which I need right now).
After the graph dispatching is done, we can figure out(by benchmarking) if hybird make sense, or is it just universally better, then we implement a hybird strategy or swap it out, does it sounds good to you?

This method only works for very limited cases, but introduces extra complexity and dependency to VM. I'm not quite sure if this is the way we want to go.

@MarisaKirisame
Copy link
Contributor Author

MarisaKirisame commented Oct 16, 2019

@icemelon9 @kevinthesun the 'your current workload' include all classical vision model (resnet/densenet/vgg), and all classical nlp model (treelstm/lstm, as they already use ADT). That leave only bert/transformer uncovered, which AFAIK did not exist in tvm yet.
batch size is not fully dynamic - for 99% of the time it will be the batch size specified by the user. the remaining 1% is the len(dataset) mod batch_size (which is also static).
I am not sure if 'everything excluding transformers' is very limited. I am not an expert in ML, so I might be wrong. But in this case, please tell me what other popular model beside bert/elmo/transformer am I missing.
It also does not include dependency to the VM. If the compiler does not exist at runtime, it will fail at runtime, meanwhile the current approach fail at compile time. It is pretty much the same - A fail is a fail isnt it?

@kevinthesun
Copy link
Contributor

kevinthesun commented Oct 16, 2019

For my current work load, only the batch_size is dynamic, and it is either 32/16, so the problem is easily solved.

This looks a very limited use case to me.

@MarisaKirisame
Copy link
Contributor Author

MarisaKirisame commented Oct 16, 2019

@kevinthesun batch size is universal in all vision task.
in virtually every vision task, the batch size is almost fixed (because there is validation/test/leftover).
supporting all vision task is by no mean 'very limited'.

My current workload is just a single datapoint of all the workload this PR will bring.

@icemelon
Copy link
Member

current fail doesn't mean we cannot support this in the future. We can have better approach to support symbolic shape. I don't think we should merge a half-bake solution which will be replaced or deprecated soon.

and, we should move the discussion to https://discuss.tvm.ai

@kevinthesun
Copy link
Contributor

kevinthesun commented Oct 16, 2019

@MarisaKirisame

in virtually every vision task, the batch size is almost fixed

It looks to me that your use case is to support some fixed batch sizes. This definitely doesn't cover most actual use cases for dynamic shape in practice.

@kevinthesun
Copy link
Contributor

Agree we should move this discussion to forum or an RFC issue since there are some fundamental issues to be resolved.

@kevinthesun kevinthesun added the status: need RFC need RFC discussion label Oct 16, 2019
@tqchen tqchen closed this Dec 22, 2019
@tqchen
Copy link
Member

tqchen commented Dec 22, 2019

close due to inactive status for now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants