Skip to content

Commit

Permalink
Revert "[WIP][dynamo] simplify module_key creation logic (pytorch#94945
Browse files Browse the repository at this point in the history
…)"

This reverts commit 4d753b5.
  • Loading branch information
pruthvistony committed May 2, 2023
1 parent 75818a5 commit f5933a6
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 12 deletions.
8 changes: 3 additions & 5 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import typing
import weakref
from collections.abc import Sized
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Type
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple
from unittest.mock import patch

import torch
Expand Down Expand Up @@ -1617,10 +1617,8 @@ def __init__(

# Execution record for replaying errors
self.exec_recorder = ExecutionRecorder(code=f_code, code_options=code_options)
# Stack of module being parsed, current nn.module is at the end of ordered dict.
# The first field of tuple is the fully qualified name of current module
# in original hierarchy. The second field is the type of current nn.module
self.nn_module_stack: Dict[str, Tuple[str, Type[Any]]] = {}
# Stack of module being parsed, current nn.module is at the end of ordered dict
self.nn_module_stack: Dict[str, str] = {}
# Flag to indicate whether tracing is used for export.
self.export = export

Expand Down
3 changes: 1 addition & 2 deletions torch/_dynamo/variables/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,8 @@ def call_function(

@contextmanager
def record_nn_module_stack():
fully_qualified_name = self.source.name()
try:
tx.nn_module_stack[self.module_key] = (fully_qualified_name, type(mod))
tx.nn_module_stack[self.module_key] = type(mod)
yield
finally:
del tx.nn_module_stack[self.module_key]
Expand Down
18 changes: 18 additions & 0 deletions torch/ao/quantization/_pt2e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,26 @@
from torch.ao.quantization.fx.prepare import (
_is_activation_post_process_node,
)
from collections import OrderedDict
import operator

# TODO[qihan]: longer term, this should happen in the dynamo stack as well
def _get_renamed_nn_module_stack(nn_module_stack):
# initialize with top level parent scope
nn_module_stack_renamed = OrderedDict([("", None)])
if nn_module_stack:
# Rename module_key, e.g. "self_layer1_1__conv1" to "self.layer1.1._conv1", for easier downstream parsing
prev_key = ""
for key, value in nn_module_stack.items():
if not prev_key:
if key.startswith("self_"):
new_key = key[5:]
prev_key = new_key
else:
new_key = prev_key + "." + key[len(prev_key) + 6 :]
nn_module_stack_renamed[new_key] = value
prev_key = new_key
return nn_module_stack_renamed

def _get_tensor_constant_from_node(node, m):
if node is None:
Expand Down
8 changes: 3 additions & 5 deletions torch/ao/quantization/_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .fx import prepare
from .quantize_fx import _convert_to_reference_decomposed_fx
from ._pt2e.utils import (
_get_renamed_nn_module_stack,
_fuse_conv_bn_,
_rearrange_weight_observer_for_addmm,
)
Expand All @@ -20,11 +21,8 @@ def prepare_pt2e(
# TODO: move this information to fx node itself
node_name_to_scope: Dict[str, Tuple[str, type]] = {}
for n in model.graph.nodes:
nn_module_stack = n.meta.get("nn_module_stack", None)
current_scope = ("", type(None))
if nn_module_stack:
bt = list(nn_module_stack.values())[-1]
current_scope = (bt[0].split(".")[-1], bt[1])
renamed_stack = _get_renamed_nn_module_stack(n.meta.get("nn_module_stack", None))
current_scope = list(renamed_stack.items())[-1]
node_name_to_scope[n.name] = current_scope

# TODO: check qconfig_mapping to make sure conv and bn are both configured
Expand Down

0 comments on commit f5933a6

Please sign in to comment.