You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Using size() twice in a row on the same tensor causes the interpreter to hit an out of sync error with thunder jit.
Minimal repro:
import torch
import thunder
def func(a):
b = a.size()
c = a.size() # expect error here
return a
a = torch.randn(100, 100, device='cuda')
jfunc = thunder.jit(func)
jfunc(a)
Full error log:
Traceback (most recent call last):
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6683, in fn_
interpretation_result: Any = _interpret_call(wrapped_fn_2, args, kwargs)
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6009, in _interpret_call
rval = _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs) # type: ignore
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6239, in _call_dispatch
return _setup_frame_and_run_python_function(compilectx, runtimectx, wrapped_fn, *args, **kwargs)
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6368, in _setup_frame_and_run_python_function
raise e.with_traceback(tb)
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6669, in fn_2
return fn(*args, **kwargs)
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6363, in _setup_frame_and_run_python_function
res, status = _run_frame(frame, compilectx, runtimectx)
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6409, in _run_frame
interpretation_result: None | int | INTERPRETER_SIGNALS = compilectx.interpret(
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 412, in interpret
return self._opcode_interpreter(inst, **interpreter_state)
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1231, in default_opcode_interpreter
return handler(inst, **interpreter_state)
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 3631, in _call_function_ex_handler
return check_and_append(stack, _interpret_call(func, *args, **kwargs))
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6009, in _interpret_call
rval = _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs) # type: ignore
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6239, in _call_dispatch
return _setup_frame_and_run_python_function(compilectx, runtimectx, wrapped_fn, *args, **kwargs)
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6368, in _setup_frame_and_run_python_function
raise e.with_traceback(tb)
File "/patwang-space/thunder_transform_pass/min_repro.py", line 6, in func
c = a.size() # expect error here
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6363, in _setup_frame_and_run_python_function
res, status = _run_frame(frame, compilectx, runtimectx)
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6409, in _run_frame
interpretation_result: None | int | INTERPRETER_SIGNALS = compilectx.interpret(
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 412, in interpret
return self._opcode_interpreter(inst, **interpreter_state)
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1231, in default_opcode_interpreter
return handler(inst, **interpreter_state)
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 4624, in _load_method_handler
meth = _interpret_call(getattr, obj, name)
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6009, in _interpret_call
rval = _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs) # type: ignore
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6170, in _call_dispatch
res = lookaside_fn(*args, **kwargs)
File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 693, in _general_jit_getattr_lookaside
value = getattr_lookaside(obj, name, *maybe_default)
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1684, in _getattr_lookaside
result = wrap_attribute(result, obj, name)
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1638, in wrap_attribute
assert plausibly_wrapper_of(
AssertionError: attribute size of TensorProxy object out of sync: <function prop_lookaside_wrap.<locals>.fn.<locals>.fn_ at 0x7f8b2ae6fac0> vs. <function prop_lookaside_wrap.<locals>.fn.<locals>.fn_ at 0x7f8b2ae6ff40>
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/patwang-space/thunder_transform_pass/min_repro.py", line 11, in <module>
jfunc(a)
File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 661, in fn_
cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 277, in cache_info_wrapper
res = fn(*args, **kwargs)
File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 538, in get_computation_and_inputs
jit_results: TraceResults = interpreter(
File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 190, in _general_frontend
return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 1510, in thunder_general_jit
result = jfn(*args, **kwargs)
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6692, in fn_
raise InterpreterError(msg) from e
thunder.core.interpreter.InterpreterError: Encountered exception AssertionError: attribute size of TensorProxy object out of sync: <function prop_lookaside_wrap.<locals>.fn.<locals>.fn_ at 0x7f8b2ae6fac0> vs. <function prop_lookaside_wrap.<locals>.fn.<locals>.fn_ at 0x7f8b2ae6ff40> while tracing <function func at 0x7f8bf9163d90>:
The text was updated successfully, but these errors were encountered:
Fuzzkatt
changed the title
using size() twice in a row causes interpreter out of sync error with thunder jit
using size() twice in a row on same tensor causes interpreter out of sync error with thunder jit
May 3, 2024
Note that calling size twice in the same trace but on different tensors works fine:
import torch
import thunder
def func(a, b):
c = a.size()
d = b.size() # works
return a, b
a = torch.randn(100, 100, device='cuda')
b = torch.randn(100, 100, device='cuda')
jfunc = thunder.jit(func)
jfunc(a, b)
Additionally, calling size once in the same trace but invoking the trace twice is also fine:
import torch
import thunder
def func(a):
b = a.size()
return a
a = torch.randn(100, 100, device='cuda')
jfunc = thunder.jit(func)
jfunc(a)
jfunc(a) # works
Using size() twice in a row on the same tensor causes the interpreter to hit an out of sync error with thunder jit.
Minimal repro:
Full error log:
The text was updated successfully, but these errors were encountered: