Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

thunder.jit fails with nn.Softmax raising got an unexpected keyword argument '_stacklevel' #258

Closed
ptrblck opened this issue Apr 24, 2024 · 3 comments · Fixed by #282
Closed
Labels
bug Something isn't working

Comments

@ptrblck
Copy link
Collaborator

ptrblck commented Apr 24, 2024

🐛 Bug

thunder.jit fails with nn.Softmax raising:

TypeError: softmax() got an unexpected keyword argument '_stacklevel'

pointing to: https://github.com/pytorch/pytorch/blob/a21327e0b03cc18850a0608be2d9c5bd38fd4646/torch/nn/functional.py#L1883

To Reproduce

import torch
from thunder.examine import examine

examine(torch.nn.Softmax(1).cuda(), torch.randn(1, 10, device="cuda"))
#TypeError: softmax() got an unexpected keyword argument '_stacklevel'

Environment

  • PyTorch Version (e.g., 1.0): 2.4.0.dev20240423+cu121
  • thunder build from source using at e0ab648

CC @Fuzzkatt as you've encountered this issue first.

@ptrblck ptrblck added the bug Something isn't working label Apr 24, 2024
@Fuzzkatt
Copy link
Collaborator

Attaching full stacktrace below:

Traceback (most recent call last):
  File "/patwang-space/thunder_transform_pass/explore.py", line 250, in <module>
    explore(model, jfunc_in, fwd_extrace_in, bwd_extrace_in)
  File "/patwang-space/thunder_transform_pass/explore.py", line 8, in explore
    jfunc(*jfunc_in)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 201, in forward
    res = self._forward_fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 648, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 269, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 527, in get_computation_and_inputs
    jit_results: TraceResults = interpreter(fn, args, kwargs, sharp_edges=cd.sharp_edges)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 182, in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)
  File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 1439, in thunder_general_jit
    result = jfn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6673, in fn_
    raise e
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6641, in fn_2
    return fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6040, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6040, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6040, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/activation.py", line 1545, in forward
    return F.softmax(input, self.dim, _stacklevel=5)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1252, in wrapping_wrapper
    res = ufn(*uargs, **ukwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/symbol.py", line 257, in __call__
    result = self.meta(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/langctxs.py", line 124, in _fn
    result = fn(*args, **kwargs)
TypeError: softmax() got an unexpected keyword argument '_stacklevel'

Is it possible that thunder is getting confused between regular torch.nn.Softmax (which doesn't have a _stacklevel arg) and torch.nn.functional.softmax (which does have a _stacklevel arg)?

@crcrpar
Copy link
Collaborator

crcrpar commented Apr 26, 2024

could you try 282?

@Fuzzkatt
Copy link
Collaborator

Just tried 282, hitting
TypeError: _softmax() got an unexpected keyword argument '_stacklevel'
Seems like it's still passing on _stacklevel to _softmax somehow?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants