Skip to content

Commit 498be7e

Browse files
Revert "Refactor stack_trace preservation for node meta preservation (pytorch#90803)"
This reverts commit 0f1302e. Reverted pytorch#90803 on behalf of https://github.com/DanilBaibak due to Break internal build
1 parent c887837 commit 498be7e

File tree

6 files changed

+48
-40
lines changed

6 files changed

+48
-40
lines changed

test/test_functionalization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def g(x):
178178
from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks
179179
import torch.fx.traceback as fx_traceback
180180
setup_stacktrace_preservation_hooks([loss.grad_fn])
181-
with fx_traceback.preserve_node_meta():
181+
with fx_traceback.override_stack_trace():
182182
loss.backward()
183183
return x.grad
184184

torch/_dynamo/eval_frame.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ def run_node(self, n):
668668
if aten_graph:
669669
# Running graph with interpreter is needed for propagating the stack_trace
670670
def graph_with_interpreter(*args):
671-
with torch.fx.traceback.preserve_node_meta():
671+
with torch.fx.traceback.override_stack_trace():
672672
return torch.fx.Interpreter(graph).run(*args)
673673

674674
graph = make_fx(

torch/_functorch/aot_autograd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ def joint_forward_backward(
790790
backward_out = []
791791
# Call the backwards pass
792792
if grad_primals:
793-
with fx_traceback.preserve_node_meta():
793+
with fx_traceback.override_stack_trace():
794794
backward_out = torch.autograd.grad(
795795
needed_outs,
796796
grad_primals,
@@ -2319,7 +2319,7 @@ def functional_call(*args, **kwargs):
23192319
mod, pytree.tree_unflatten(args[:params_len], params_spec)
23202320
):
23212321
if isinstance(mod, torch.fx.GraphModule):
2322-
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
2322+
with fx_traceback.override_stack_trace(), warnings.catch_warnings():
23232323
warnings.filterwarnings(
23242324
"ignore", "Anomaly Detection has been enabled."
23252325
)

torch/fx/interpreter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_p
153153

154154
@contextmanager
155155
def _set_current_node(self, node):
156-
with fx_traceback.set_current_meta(node.meta):
156+
with fx_traceback.append_stack_trace(node.stack_trace), fx_traceback.set_current_meta(node.meta):
157157
yield
158158

159159
@compatibility(is_backward_compatible=True)
@@ -477,7 +477,7 @@ def transform(self) -> GraphModule:
477477
Transform ``self.module`` and return the transformed
478478
``GraphModule``.
479479
"""
480-
with fx_traceback.preserve_node_meta():
480+
with fx_traceback.override_stack_trace():
481481
result = super().run(enable_io_processing=False)
482482
if result is not None:
483483
def strip_proxy(a : Union[Argument, Proxy]) -> Any:

torch/fx/proxy.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,19 +76,10 @@ def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs:
7676
proxy = proxy_factory_fn(node)
7777

7878
# Optionally set stack trace on the created Node for debugging purposes
79-
if fx_traceback.has_preserved_node_meta():
80-
current_meta: Dict[str, Any] = fx_traceback.get_current_meta()
81-
82-
# Explicitly set the stack_trace and nn_module_stack on the node.meta
83-
# If other meta fields are needed, they can be added here
84-
stack_trace = current_meta.get("stack_trace")
85-
if stack_trace:
86-
proxy.node.stack_trace = stack_trace
87-
88-
nn_module_stack = current_meta.get("nn_module_stack")
89-
if nn_module_stack:
90-
proxy.node.meta["nn_module_stack"] = nn_module_stack
91-
79+
if fx_traceback.is_stack_trace_overridden():
80+
proxy.node.meta = fx_traceback.get_current_meta()
81+
stacks = fx_traceback.format_stack()
82+
proxy.node.stack_trace = '\n'.join(reversed(stacks))
9283
elif self.record_stack_traces:
9384
user_frame = self._find_user_frame()
9485
if user_frame:

torch/fx/traceback.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,80 @@
11
import traceback
22
from contextlib import contextmanager
3-
from typing import List, Any, Dict
3+
from typing import Optional, List, Any, Dict
44
from ._compatibility import compatibility
55

6-
__all__ = ['preserve_node_meta', 'has_preserved_node_meta',
7-
'set_stack_trace', 'format_stack',
8-
'set_current_meta', 'get_current_meta']
6+
__all__ = ['override_stack_trace', 'set_stack_trace', 'append_stack_trace', 'format_stack',
7+
'is_stack_trace_overridden', 'get_current_meta', 'set_current_meta']
98

9+
10+
current_stack: List[str] = []
1011
current_meta: Dict[str, Any] = {}
11-
should_preserve_node_meta = False
12+
is_overridden = False
1213

1314

1415
@compatibility(is_backward_compatible=False)
1516
@contextmanager
16-
def preserve_node_meta():
17-
global should_preserve_node_meta
17+
def override_stack_trace():
18+
global is_overridden
1819

19-
saved_should_preserve_node_meta = should_preserve_node_meta
20+
saved_is_overridden = is_overridden
2021
try:
21-
should_preserve_node_meta = True
22+
is_overridden = True
2223
yield
2324
finally:
24-
should_preserve_node_meta = saved_should_preserve_node_meta
25-
25+
is_overridden = saved_is_overridden
2626

2727
@compatibility(is_backward_compatible=False)
2828
def set_stack_trace(stack : List[str]):
29-
global current_meta
29+
global current_stack
30+
31+
if is_overridden and stack:
32+
current_stack = stack
33+
34+
@compatibility(is_backward_compatible=False)
35+
@contextmanager
36+
def append_stack_trace(stack : Optional[str]):
37+
"""
38+
The content of stack here is an entire stacktraces as a string
39+
"""
40+
global current_stack
3041

31-
if should_preserve_node_meta and stack:
32-
current_meta["stack_trace"] = "".join(stack)
42+
if is_overridden and stack:
43+
try:
44+
current_stack.append(stack)
45+
yield
46+
finally:
47+
current_stack.pop()
48+
else:
49+
yield
3350

3451

3552
@compatibility(is_backward_compatible=False)
3653
def format_stack() -> List[str]:
37-
if should_preserve_node_meta:
38-
return [current_meta.get("stack_trace", "")]
54+
if is_overridden:
55+
return current_stack.copy()
3956
else:
4057
# fallback to traceback.format_stack()
4158
return traceback.format_list(traceback.extract_stack()[:-1])
4259

4360

4461
@compatibility(is_backward_compatible=False)
45-
def has_preserved_node_meta() -> bool:
46-
return should_preserve_node_meta
62+
def is_stack_trace_overridden() -> bool:
63+
return is_overridden
4764

4865

4966
@compatibility(is_backward_compatible=False)
5067
@contextmanager
5168
def set_current_meta(meta : Dict[str, Any]):
5269
global current_meta
5370

54-
if should_preserve_node_meta and meta:
55-
saved_meta = current_meta
71+
old_meta = current_meta
72+
if is_overridden and meta:
5673
try:
5774
current_meta = meta
5875
yield
5976
finally:
60-
current_meta = saved_meta
77+
current_meta = old_meta
6178
else:
6279
yield
6380

0 commit comments

Comments
 (0)