Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Number proxies #250

Draft
wants to merge 87 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
81abdf9
quick enabling proxy number
jjsjann123 Apr 22, 2024
4760d5a
disabling symbolic shape assert
jjsjann123 Apr 22, 2024
9a2ae70
hacky entry
jjsjann123 Apr 22, 2024
7dd515e
errr
jjsjann123 Apr 22, 2024
eb62a00
does it work for shapes?
jjsjann123 Apr 22, 2024
5b6ff2d
missing elif
jjsjann123 Apr 22, 2024
7ea52db
note
jjsjann123 Apr 23, 2024
9e4062c
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Apr 24, 2024
09f2405
trying out to remove Number type from NumberProxy
jjsjann123 Apr 26, 2024
99269b6
quick fix
jjsjann123 Apr 26, 2024
f1d4bd1
relaxing check_valid_length
jjsjann123 Apr 26, 2024
050e8bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2024
0a2bdd8
type not defined
jjsjann123 Apr 26, 2024
e81ebbf
patching
jjsjann123 Apr 26, 2024
1a2b6d3
fix WIP
jjsjann123 Apr 27, 2024
caa892b
comment out circular import
jjsjann123 Apr 27, 2024
6846ef4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2024
428eebb
patch
jjsjann123 Apr 27, 2024
9205c60
Merge remote-tracking branch 'jiej/number_proxies_is_not_a_number' in…
jjsjann123 Apr 27, 2024
4fd4c2e
fixing more tests
jjsjann123 Apr 27, 2024
894ba4e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2024
c930a11
import
jjsjann123 Apr 27, 2024
b373d8b
import
jjsjann123 Apr 27, 2024
89ba061
quick fixing tests
jjsjann123 Apr 28, 2024
77cb00f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 28, 2024
b142702
fixing grad maybe
jjsjann123 Apr 28, 2024
23ec660
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 28, 2024
f414c21
fix
jjsjann123 Apr 29, 2024
a881996
fixing grad
jjsjann123 Apr 29, 2024
8cd2987
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2024
b36bdf6
moving grad rule
jjsjann123 Apr 29, 2024
06cbd60
questionable patch
jjsjann123 Apr 29, 2024
5b5763a
Merge remote-tracking branch 'origin/main' into number_proxies_is_not…
jjsjann123 May 9, 2024
d154495
quick patching
jjsjann123 May 10, 2024
f3af55c
Merge branch 'main' into number_proxies_is_not_a_number
t-vi May 10, 2024
435c086
cleaning
jjsjann123 May 11, 2024
0f19b66
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2024
a3413fd
more cleaning
jjsjann123 May 11, 2024
c375abf
quick hack to make CI green
jjsjann123 May 11, 2024
4620b39
Merge remote-tracking branch 'jiej/number_proxies_is_not_a_number' in…
jjsjann123 May 11, 2024
6b9ac97
I think I'm fixing it
jjsjann123 May 11, 2024
0a4cf9b
Merge remote-tracking branch 'jiej/number_proxies_is_not_a_number' in…
jjsjann123 May 12, 2024
75a0ba5
Merge branch 'main' into number_proxies
jjsjann123 May 17, 2024
9cf08aa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 17, 2024
fb1b54d
Merge branch 'main' into number_proxies
jjsjann123 May 21, 2024
07dc56a
adding number like checks
jjsjann123 May 21, 2024
ecd7510
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2024
e37a311
quick fix on build
jjsjann123 May 21, 2024
0e607b2
Merge remote-tracking branch 'jiej/number_proxies' into number_proxies
jjsjann123 May 21, 2024
04cc9ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2024
688ce9f
disable implicit evaluation
jjsjann123 May 21, 2024
491ddb3
Merge remote-tracking branch 'jiej/disable_implicit_evaluate_of_numbe…
jjsjann123 May 22, 2024
c20e166
smoke test
jjsjann123 May 22, 2024
0f6cf9b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2024
f1cd374
fixing smoke test
jjsjann123 May 22, 2024
e549dbd
fixing reshape in nvfuserex
jjsjann123 May 22, 2024
4fdd912
Merge remote-tracking branch 'jiej/number_proxies' into number_proxies
jjsjann123 May 22, 2024
5c84029
quick fix
jjsjann123 May 22, 2024
e26417d
fixing nvfuser handling of sequence
jjsjann123 May 22, 2024
206f6a9
unwrap number proxy
jjsjann123 May 22, 2024
a0daeeb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2024
0274f00
clean up
jjsjann123 May 22, 2024
b200382
error message for regex
jjsjann123 May 22, 2024
9bfac26
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2024
a5e4f41
fixing fill_value check
jjsjann123 May 22, 2024
ec369dd
Merge remote-tracking branch 'jiej/number_proxies' into number_proxies
jjsjann123 May 22, 2024
1c00ccd
import fix
jjsjann123 May 22, 2024
0aac6a2
import
jjsjann123 May 22, 2024
1c76716
fixing flip check
jjsjann123 May 22, 2024
b1ca5a1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2024
4cb290b
quick patch
jjsjann123 May 23, 2024
b321d54
Merge remote-tracking branch 'jiej/number_proxies' into number_proxies
jjsjann123 May 23, 2024
3c87a64
quick patch
jjsjann123 May 23, 2024
6cdd428
fixing return
jjsjann123 May 23, 2024
760fc1f
fixing philox (mostly hacky patches)
jjsjann123 May 23, 2024
6cfcda3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2024
fa50267
typo
jjsjann123 May 23, 2024
0abab1b
typo
jjsjann123 May 23, 2024
781329e
fixing tests
jjsjann123 May 23, 2024
a804552
Merge branch 'proxify_patch' into number_proxies
jjsjann123 May 23, 2024
901897f
Merge remote-tracking branch 'jiej/number_proxies' into number_proxies
jjsjann123 May 23, 2024
19a34ea
fixing numberproxy trace mode check
jjsjann123 May 24, 2024
775aa74
Merge remote-tracking branch 'origin/main' into number_proxies
jjsjann123 May 25, 2024
597e873
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Jun 4, 2024
79a9884
fixing typo
jjsjann123 Jun 4, 2024
8fe43a9
back out from dynamic shape
jjsjann123 Jun 4, 2024
fe6f311
Merge branch 'main' into number_proxies
jjsjann123 Jun 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,9 @@ def proxify(self, value: WrappedValue) -> Any:
co: CACHE_OPTIONS = get_cache_option()
if co is CACHE_OPTIONS.CONSTANT_VALUES:
self.add_constraint((clang.check_tensor_shape_and_metadata, p_orig))
elif co is CACHE_OPTIONS.SYMBOLIC_VALUES:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is our medium term plan w.r.t. defaul caching? If we need this for correctly handling #231 , it would seem that symbolic values should be the default but that in turn would mean that we want to have it work for our supported use cases.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm hoping that #231 wouldn't need to the whole enable CACHE_OPTIONS.SYMBOLIC_VALUES thing. Looks like #231 has passed all CI, which feels promising.

The first step of this is to get number proxies to be plumbed through, I think that might be helpful for #231.
In terms of dynamic shape, we can slowly expanding its support.

Start off with allowing scalar input as number proxies, while still requiring tensor to be constant shape. I'm trying to figure out how/where to properly insert prologue_trace guard. Right now considering doing that from executor. i.e. nvfuser would require reduction dim(s) to be baked in as constant, while torchex doesn't care.

I'll try to come up with a design doc for review.

# TODO: establish guarding logic to allow non-broadcast shape change
self.add_constraint((clang.check_tensor_shape_and_metadata, p_orig))
elif co not in (CACHE_OPTIONS.SAME_INPUT, CACHE_OPTIONS.NO_CACHING):
raise NotImplementedError(f"Unsupported cache option {co}")
return p
Expand All @@ -612,6 +615,10 @@ def proxify(self, value: WrappedValue) -> Any:
self.add_constraint((clang.check_string_value, p, uvalue))
else:
self.add_constraint((clang.check_number_type_and_value, p, uvalue))
elif co is CACHE_OPTIONS.SYMBOLIC_VALUES:
# TODO: establish guarding logic
if p is not uvalue:
value.register_proxy(p)
elif co not in (CACHE_OPTIONS.SAME_INPUT, CACHE_OPTIONS.NO_CACHING):
raise NotImplementedError(f"Unsupported cache option {co}")
return p
Expand Down Expand Up @@ -1403,7 +1410,7 @@ def thunder_general_jit(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDG
)

co: CACHE_OPTIONS = get_cache_option()
if co not in {CACHE_OPTIONS.CONSTANT_VALUES, CACHE_OPTIONS.NO_CACHING}:
if co not in {CACHE_OPTIONS.CONSTANT_VALUES, CACHE_OPTIONS.NO_CACHING, CACHE_OPTIONS.SYMBOLIC_VALUES}:
raise NotImplementedError(f"Only constant constraints is supported")

prologue_trace: TraceCtx = TraceCtx(fn)
Expand Down Expand Up @@ -1434,6 +1441,7 @@ def thunder_general_jit(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDG
uncacheable_classes=(torch.Tensor, int, float, str, NoneType),
)

# NOTE(jiej): numbers are baked in as constant here vvv
with general_jit_ctx(ctx):
with tracectx(computation_trace):
result = jfn(*args, **kwargs)
Expand Down Expand Up @@ -1469,6 +1477,7 @@ def thunder_general_jit(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDG
else:
epilogue_trace = None

# NOTE(jiej): prologue trace is produced here vvv
pro_to_comp_proxies, pro_to_epi_proxies = unpack_inputs(
ctx, prologue_trace, pro_to_comp, pro_to_epi, args, kwargs, has_epilogue=epilogue_trace is not None
)
Expand Down
11 changes: 3 additions & 8 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,14 +1031,9 @@ def _infer_tensor_properties(
_requires_grad = False if not dtypes.is_inexact_dtype(_dtype) else _requires_grad
_ddp_type = ddp_type if ddp_type is not None else _ddp_type

# Extracts actual values for shape
# TODO RC1 Enable this
if using_symbolic_values():
raise NotImplementedError(
f"Trying to construct a tensor proxy while using symbolic values, but this is not yet supported"
)

_shape = tuple(pyval(x) for x in _shape)
if not using_symbolic_values():
# Extracts actual values for shape
_shape = tuple(pyval(x) for x in _shape)

# Computes derived properties
_numel = reduce(operator.mul, _shape, 1)
Expand Down
Loading