Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ ElectraForQuestionAnswering,pass,5



GPT2ForSequenceClassification,pass,5
GPT2ForSequenceClassification,pass,6



Expand All @@ -94,7 +94,7 @@ LayoutLMForMaskedLM,pass,5



LayoutLMForSequenceClassification,pass,5
LayoutLMForSequenceClassification,pass,6



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ ElectraForQuestionAnswering,pass,5



GPT2ForSequenceClassification,pass,5
GPT2ForSequenceClassification,pass,6



Expand All @@ -94,7 +94,7 @@ LayoutLMForMaskedLM,pass,5



LayoutLMForSequenceClassification,pass,5
LayoutLMForSequenceClassification,pass,6



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ ElectraForQuestionAnswering,pass,5



GPT2ForSequenceClassification,pass,5
GPT2ForSequenceClassification,pass,6



Expand All @@ -94,7 +94,7 @@ LayoutLMForMaskedLM,pass,5



LayoutLMForSequenceClassification,pass,5
LayoutLMForSequenceClassification,pass,6



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ ElectraForQuestionAnswering,pass,5



GPT2ForSequenceClassification,pass,5
GPT2ForSequenceClassification,pass,6



Expand All @@ -94,7 +94,7 @@ LayoutLMForMaskedLM,pass,5



LayoutLMForSequenceClassification,pass,5
LayoutLMForSequenceClassification,pass,6



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ ElectraForQuestionAnswering,pass,5



GPT2ForSequenceClassification,pass,5
GPT2ForSequenceClassification,pass,6



Expand All @@ -94,7 +94,7 @@ LayoutLMForMaskedLM,pass,5



LayoutLMForSequenceClassification,pass,5
LayoutLMForSequenceClassification,pass,6



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ ElectraForQuestionAnswering,pass,5



GPT2ForSequenceClassification,pass,5
GPT2ForSequenceClassification,pass,6



Expand All @@ -94,7 +94,7 @@ LayoutLMForMaskedLM,pass,5



LayoutLMForSequenceClassification,pass,5
LayoutLMForSequenceClassification,pass,6



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ ElectraForQuestionAnswering,pass,5



GPT2ForSequenceClassification,pass,5
GPT2ForSequenceClassification,pass,6



Expand All @@ -94,7 +94,7 @@ LayoutLMForMaskedLM,pass,5



LayoutLMForSequenceClassification,pass,5
LayoutLMForSequenceClassification,pass,6



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ ElectraForQuestionAnswering,pass,5



GPT2ForSequenceClassification,pass,5
GPT2ForSequenceClassification,pass,6



Expand All @@ -94,7 +94,7 @@ LayoutLMForMaskedLM,pass,5



LayoutLMForSequenceClassification,pass,5
LayoutLMForSequenceClassification,pass,6



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ ElectraForQuestionAnswering,pass,5



GPT2ForSequenceClassification,pass,5
GPT2ForSequenceClassification,pass,6



Expand All @@ -94,7 +94,7 @@ LayoutLMForMaskedLM,pass,5



LayoutLMForSequenceClassification,pass,5
LayoutLMForSequenceClassification,pass,6



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ ElectraForQuestionAnswering,pass,5



GPT2ForSequenceClassification,pass,5
GPT2ForSequenceClassification,pass,6



Expand All @@ -94,7 +94,7 @@ LayoutLMForMaskedLM,pass,5



LayoutLMForSequenceClassification,pass,5
LayoutLMForSequenceClassification,pass,6



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ ElectraForQuestionAnswering,pass,5



GPT2ForSequenceClassification,pass,5
GPT2ForSequenceClassification,pass,6



Expand All @@ -94,7 +94,7 @@ LayoutLMForMaskedLM,pass,5



LayoutLMForSequenceClassification,pass,5
LayoutLMForSequenceClassification,pass,6



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ ElectraForQuestionAnswering,pass,5



GPT2ForSequenceClassification,pass,5
GPT2ForSequenceClassification,pass,6



Expand All @@ -94,7 +94,7 @@ LayoutLMForMaskedLM,pass,5



LayoutLMForSequenceClassification,pass,5
LayoutLMForSequenceClassification,pass,6



Expand Down
9 changes: 8 additions & 1 deletion test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11981,13 +11981,20 @@ def fn(t):
self.assertEqual(y, t.sin())

def test_overridden_getattribute(self):
class Bar:
def __init__(self, v):
self.v = v

class Foo:
attribute_map = {}

def __init__(self):
self.attribute_map = {
"a_premap": "a",
}
# `bar` attribute requires propagating sources correctly through
# object.__getattribute__
self.bar = Bar(5)

def __setattr__(self, key, value):
if key in super().__getattribute__("attribute_map"):
Expand Down Expand Up @@ -12015,7 +12022,7 @@ def get_foo():
return f

def fn(x, f):
return x * f.a_premap * f.a * f.b * f.sentinel
return x * f.a_premap * f.a * f.b * f.sentinel * f.bar.v

x = torch.randn(4)

Expand Down
9 changes: 9 additions & 0 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
FlattenScriptObjectSource,
FloatTensorSource,
FSDPNNModuleSource,
GenericAttrSource,
GetItemSource,
GlobalSource,
GlobalStateSource,
Expand Down Expand Up @@ -1046,6 +1047,14 @@ def get_guard_manager_from_source(self, source):
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, GenericAttrSource):
assert base_guard_manager # to make mypy happy
out = base_guard_manager.generic_getattr_manager(
attr=source.member,
source=source_name,
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, (AttrSource, UnspecializedParamBufferSource)):
assert base_guard_manager # to make mypy happy

Expand Down
24 changes: 24 additions & 0 deletions torch/_dynamo/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,30 @@ def name(self):
return f"{self.base.name()}.{self.member}"


@dataclasses.dataclass(frozen=True)
class GenericAttrSource(ChainedSource):
member: str

def __post_init__(self):
assert self.base, "Can't construct an AttrSource without a valid base source"
if "." in self.member:
member_parts = self.member.split(".")
object.__setattr__(
self, "base", AttrSource(self.base, ".".join(member_parts[:-1]))
)
object.__setattr__(self, "member", member_parts[-1])

def reconstruct(self, codegen):
codegen(self.base)
codegen.extend_output(codegen.create_load_attrs(self.member))

def guard_source(self):
return self.base.guard_source()

def name(self):
return f"object.__getattribute__({self.base.name()}, {self.member!r})"


@dataclasses.dataclass(frozen=True)
class LocalCellSource(Source):
"""
Expand Down
18 changes: 14 additions & 4 deletions torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@
from ..exc import raise_observed_exception, unimplemented
from ..guards import GuardBuilder, install_guard
from ..mutation_guard import unpatched_nn_module_init
from ..source import AttrSource, GetItemSource, TypeSource, WeakRefCallSource
from ..source import (
AttrSource,
GenericAttrSource,
GetItemSource,
TypeSource,
WeakRefCallSource,
)
from ..utils import (
check_unspec_or_constant_args,
cmp_name_to_op_mapping,
Expand Down Expand Up @@ -260,12 +266,16 @@ def call_method(
return result

try:
attr_value = self.objvar.value.__getattribute__(attr_name)
# NB - use object.__getattribute__ to prevent running any user code
attr_value = object.__getattribute__(self.objvar.value, attr_name)
except AttributeError:
raise_observed_exception(AttributeError, tx)

source = self.source and AttrSource(self.source, attr_name)
return VariableTracker.build(tx, attr_value, source)
attr_source = None
if self.objvar.source is not None:
# setup a object.__getattribute__(self.objvar, name) source
attr_source = GenericAttrSource(self.objvar.source, attr_name)
return VariableTracker.build(tx, attr_value, attr_source)

unimplemented(f"non-function or method super: {inner_fn}")

Expand Down
Loading