-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dynamic constraints and NumberProxies #262
Comments
As my current prototype in #250, I realized that
[AIs] We need to allow language-based constraints to be injected and placed in prologue trace. Right now I'm seeing some of those checks showing up in the compute trace and that could be excessive, but we should be able to const fold and merge some of those checks as an optimization later. |
unintentionally walked into dynamic shapes case. lightning-thunder/thunder/core/prims.py Line 3102 in b7154dc
operations inside But for now we can just mark everything as constant shapes. mark down the concrete example here:
In the slice operation, you can see that
|
Note for myself. Some tricky shape usage in the trace that leads to control flow. #260 (comment) |
Will import thunder
import torch
def foo(flag):
if flag > 5:
return torch.ones(1)
return torch.zeros(1)
jfoo = thunder.jit(foo, cache=thunder.CACHE_OPTIONS.SYMBOLIC_VALUES)
# Currently, output is incorrect.
print(jfoo(6)) # tensor([1.])
print(jfoo(0)) # tensor([1.])
# pro_trace = thunder.last_prologue_traces(jfoo)[-1]
# trace = thunder.last_traces(jfoo)[-1]
# print(pro_trace)
# print(trace) Currently, prologue trace doesn't have any check for input. This works fine with the default cache option of |
Thanks for bringing up that. Yes it should have been able to do that. Let me move the example there. |
馃殌 Feature
We'd like to have thunder.jit support dynamic constraints so we are not going to bake every number in a program as compile time constant. This should allow us re-use some compiled program to avoid endless recompilation with dynamic shape.
Pitch
We'd want to:
prototyping PRs:
symbolic values
cache option: Enable symbolic values聽#518symbolic values
cache option by default: Number proxies聽#250Issues:
number proxy is no number: #272executor specific caching rule: #263
prim should be able to insert constraints to bake in static numbers: #463
dynamic shape needs to be modeled in trace: #471
(to be opened) utils.check on NumberProxy needs to be sanitized in prim/clang/torch.
NumberProxy handling in grad transform: #541 (will be addressed in #244)
NumberProxy inconsistency is introduced by grad transform #541
Progress
NumberProxy inheritance PR merged #286.Currently working on enabling caching option
symbolic values
to allow trace handling dynamic scalar inputs (WIP in PR #250). This is going to be the next milestone.A simple prototype is working where
NumberProxy
is used to represent a dynamic scalar input to a trace as a number operand. I tried to test water with settingsymbolic values
as the default cache option and CI exploded. I'm still going through all the failures. Aside from minor code logic patches here and there (since NumberProxy isn't widely used in existing code base), one of the main challenge I'm seeing right now is the accidental exposure of dynamic shape (see issue #471 #463).Ideally we should resolve #471 and model dynamic shape in our trace. That's going to be a longer endeavor to pull through and I should get some help when we decide our plan and start working on that.
Meanwhile, I think in the short term we could push for #463, where we'll just bake in static shape & constraints for NumberProxy to avoid the dynamic shape from dynamic scalar input. I need to further evaluate this solution to figure out if it's enough. Nevertheless, it's still a feature that we might want to have for nvfuser integration. i.e. reduction axes needs to be baked in anyway.
The text was updated successfully, but these errors were encountered: