diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index 1b6b99a8b0cbd..61cafbcbda2cb 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -14,6 +14,7 @@ from typing import Any, Optional, Tuple import torch +import torch._dynamo.config import torch._dynamo.test_case import torch._dynamo.testing import torch._functorch.config diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index b374533f44d71..e92ebd5219663 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -99,7 +99,6 @@ ConstDictKeySource, DefaultsSource, DictGetItemSource, - DictSubclassGetItemSource, FlattenScriptObjectSource, FloatTensorSource, FSDPNNModuleSource, @@ -1080,7 +1079,7 @@ def get_guard_manager_from_source(self, source): example_value=example_value, guard_manager_enum=guard_manager_enum, ) - elif istype(source, (DictGetItemSource, DictSubclassGetItemSource)): + elif istype(source, DictGetItemSource): assert base_guard_manager # to make mypy happy assert isinstance(base_example_value, (dict, collections.OrderedDict)) if isinstance(base_guard_manager, DictGuardManager): diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 511a8c17272de..55600277b280e 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -601,44 +601,7 @@ def __post_init__(self): def guard_source(self): return self.base.guard_source() - def reconstruct(self, codegen: "PyCodegen"): - # Load dict - codegen(self.base) - - # Load key - if isinstance(self.index, Source): - codegen(self.index) - else: - codegen.append_output(codegen.create_load_const(self.index)) - codegen.append_output(create_instruction("BINARY_SUBSCR")) - - def name(self): - if isinstance(self.index, ConstDictKeySource): - return f"{self.base.name()}[{self.index.name()}]" - else: - return f"{self.base.name()}[{self.index!r}]" - - -# Same as DictGetItemSource but used for dict.__getitem__ calls to ensure that -# torch.compile does not run the overridden __getitem__ method -@dataclasses.dataclass(frozen=True) -class DictSubclassGetItemSource(ChainedSource): - # Key to access in the dictionary. It can be one of the the following types - # 1) ConstDictKeySource - # 2) constant - like string, integer - index: Any - - def __post_init__(self): - from .variables import ConstantVariable - - assert isinstance( - self.index, ConstDictKeySource - ) or ConstantVariable.is_literal(self.index) - - def guard_source(self): - return self.base.guard_source() - - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen): # reconstruct dict.__getitem__(dct, key) # Load dict.__getitem__ diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index cb979e804edcf..9187a580fad0b 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -92,7 +92,6 @@ ConstDictKeySource, ConvertIntSource, DictGetItemSource, - DictSubclassGetItemSource, FloatTensorSource, GetItemSource, GradSource, @@ -1275,8 +1274,8 @@ def build_key_value(i, k, v): source_key = ConstDictKeySource(self.get_source(), i) key = LazyVariableTracker.create(k, source_key) - source_value = DictSubclassGetItemSource(base, source_key) - res_value = LazyVariableTracker.create(v, source_value) + source_value = DictGetItemSource(self.get_source(), source_key) + value = LazyVariableTracker.create(v, source_value) return key, value