Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/dynamo/test_dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Optional, Tuple

import torch
import torch._dynamo.config
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._functorch.config
Expand Down
3 changes: 1 addition & 2 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@
ConstDictKeySource,
DefaultsSource,
DictGetItemSource,
DictSubclassGetItemSource,
FlattenScriptObjectSource,
FloatTensorSource,
FSDPNNModuleSource,
Expand Down Expand Up @@ -1080,7 +1079,7 @@ def get_guard_manager_from_source(self, source):
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, (DictGetItemSource, DictSubclassGetItemSource)):
elif istype(source, DictGetItemSource):
assert base_guard_manager # to make mypy happy
assert isinstance(base_example_value, (dict, collections.OrderedDict))
if isinstance(base_guard_manager, DictGuardManager):
Expand Down
39 changes: 1 addition & 38 deletions torch/_dynamo/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,44 +601,7 @@ def __post_init__(self):
def guard_source(self):
return self.base.guard_source()

def reconstruct(self, codegen: "PyCodegen"):
# Load dict
codegen(self.base)

# Load key
if isinstance(self.index, Source):
codegen(self.index)
else:
codegen.append_output(codegen.create_load_const(self.index))
codegen.append_output(create_instruction("BINARY_SUBSCR"))

def name(self):
if isinstance(self.index, ConstDictKeySource):
return f"{self.base.name()}[{self.index.name()}]"
else:
return f"{self.base.name()}[{self.index!r}]"


# Same as DictGetItemSource but used for dict.__getitem__ calls to ensure that
# torch.compile does not run the overridden __getitem__ method
@dataclasses.dataclass(frozen=True)
class DictSubclassGetItemSource(ChainedSource):
# Key to access in the dictionary. It can be one of the the following types
# 1) ConstDictKeySource
# 2) constant - like string, integer
index: Any

def __post_init__(self):
from .variables import ConstantVariable

assert isinstance(
self.index, ConstDictKeySource
) or ConstantVariable.is_literal(self.index)

def guard_source(self):
return self.base.guard_source()

def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen):
# reconstruct dict.__getitem__(dct, key)

# Load dict.__getitem__
Expand Down
5 changes: 2 additions & 3 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@
ConstDictKeySource,
ConvertIntSource,
DictGetItemSource,
DictSubclassGetItemSource,
FloatTensorSource,
GetItemSource,
GradSource,
Expand Down Expand Up @@ -1275,8 +1274,8 @@ def build_key_value(i, k, v):
source_key = ConstDictKeySource(self.get_source(), i)
key = LazyVariableTracker.create(k, source_key)

source_value = DictSubclassGetItemSource(base, source_key)
res_value = LazyVariableTracker.create(v, source_value)
source_value = DictGetItemSource(self.get_source(), source_key)
value = LazyVariableTracker.create(v, source_value)

return key, value

Expand Down