Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using size() twice in a row on same tensor causes interpreter out of sync error with thunder jit #349

Closed
Fuzzkatt opened this issue May 2, 2024 · 1 comment · Fixed by #352
Assignees
Labels
bug Something isn't working jit

Comments

@Fuzzkatt
Copy link
Contributor

Fuzzkatt commented May 2, 2024

Using size() twice in a row on the same tensor causes the interpreter to hit an out of sync error with thunder jit.

Minimal repro:

import torch
import thunder

def func(a):
    b = a.size()
    c = a.size() # expect error here
    return a

a = torch.randn(100, 100, device='cuda')
jfunc = thunder.jit(func)
jfunc(a)

Full error log:

Traceback (most recent call last):
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6683, in fn_
    interpretation_result: Any = _interpret_call(wrapped_fn_2, args, kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6009, in _interpret_call
    rval = _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs)  # type: ignore
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6239, in _call_dispatch
    return _setup_frame_and_run_python_function(compilectx, runtimectx, wrapped_fn, *args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6368, in _setup_frame_and_run_python_function
    raise e.with_traceback(tb)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6669, in fn_2
    return fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6363, in _setup_frame_and_run_python_function
    res, status = _run_frame(frame, compilectx, runtimectx)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6409, in _run_frame
    interpretation_result: None | int | INTERPRETER_SIGNALS = compilectx.interpret(
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 412, in interpret
    return self._opcode_interpreter(inst, **interpreter_state)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1231, in default_opcode_interpreter
    return handler(inst, **interpreter_state)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 3631, in _call_function_ex_handler
    return check_and_append(stack, _interpret_call(func, *args, **kwargs))
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6009, in _interpret_call
    rval = _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs)  # type: ignore
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6239, in _call_dispatch
    return _setup_frame_and_run_python_function(compilectx, runtimectx, wrapped_fn, *args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6368, in _setup_frame_and_run_python_function
    raise e.with_traceback(tb)
  File "/patwang-space/thunder_transform_pass/min_repro.py", line 6, in func
    c = a.size() # expect error here
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6363, in _setup_frame_and_run_python_function
    res, status = _run_frame(frame, compilectx, runtimectx)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6409, in _run_frame
    interpretation_result: None | int | INTERPRETER_SIGNALS = compilectx.interpret(
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 412, in interpret
    return self._opcode_interpreter(inst, **interpreter_state)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1231, in default_opcode_interpreter
    return handler(inst, **interpreter_state)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 4624, in _load_method_handler
    meth = _interpret_call(getattr, obj, name)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6009, in _interpret_call
    rval = _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs)  # type: ignore
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6170, in _call_dispatch
    res = lookaside_fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 693, in _general_jit_getattr_lookaside
    value = getattr_lookaside(obj, name, *maybe_default)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1684, in _getattr_lookaside
    result = wrap_attribute(result, obj, name)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1638, in wrap_attribute
    assert plausibly_wrapper_of(
AssertionError: attribute size of TensorProxy object out of sync: <function prop_lookaside_wrap.<locals>.fn.<locals>.fn_ at 0x7f8b2ae6fac0> vs. <function prop_lookaside_wrap.<locals>.fn.<locals>.fn_ at 0x7f8b2ae6ff40>

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/patwang-space/thunder_transform_pass/min_repro.py", line 11, in <module>
    jfunc(a)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 661, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 277, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 538, in get_computation_and_inputs
    jit_results: TraceResults = interpreter(
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 190, in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
  File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 1510, in thunder_general_jit
    result = jfn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6692, in fn_
    raise InterpreterError(msg) from e
thunder.core.interpreter.InterpreterError: Encountered exception AssertionError: attribute size of TensorProxy object out of sync: <function prop_lookaside_wrap.<locals>.fn.<locals>.fn_ at 0x7f8b2ae6fac0> vs. <function prop_lookaside_wrap.<locals>.fn.<locals>.fn_ at 0x7f8b2ae6ff40> while tracing <function func at 0x7f8bf9163d90>:
@Fuzzkatt Fuzzkatt added the bug Something isn't working label May 2, 2024
@Fuzzkatt Fuzzkatt changed the title using size() twice in a row causes interpreter out of sync error with thunder jit using size() twice in a row on same tensor causes interpreter out of sync error with thunder jit May 3, 2024
@Fuzzkatt
Copy link
Contributor Author

Fuzzkatt commented May 3, 2024

Note that calling size twice in the same trace but on different tensors works fine:

import torch
import thunder

def func(a, b):
    c = a.size()
    d = b.size() # works
    return a, b

a = torch.randn(100, 100, device='cuda')
b = torch.randn(100, 100, device='cuda')
jfunc = thunder.jit(func)
jfunc(a, b)

Additionally, calling size once in the same trace but invoking the trace twice is also fine:

import torch
import thunder

def func(a):
    b = a.size()
    return a

a = torch.randn(100, 100, device='cuda')
jfunc = thunder.jit(func)
jfunc(a)
jfunc(a) # works

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working jit
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants