Skip to content

Commit

Permalink
[Dynamo] Treat integers stored on nn.Modules as dynamic (pytorch#126466)
Browse files Browse the repository at this point in the history
Fixes pytorch#115711

Pull Request resolved: pytorch#126466
Approved by: https://github.com/jansel
  • Loading branch information
yanboliang authored and pytorchmergebot committed May 18, 2024
1 parent a44d0cf commit 6bb9d60
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 4 deletions.
57 changes: 57 additions & 0 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch._dynamo.eval_frame import unsupported
from torch._dynamo.mutation_guard import GenerationTracker
from torch._dynamo.testing import expectedFailureDynamic, same
from torch._dynamo.utils import ifdynstaticdefault
from torch.nn.modules.lazy import LazyModuleMixin
from torch.nn.parameter import Parameter, UninitializedParameter

Expand Down Expand Up @@ -1104,6 +1105,37 @@ def forward(self, x):
return self.m(x)


class ModuleWithIntAttr(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(4, 4)
self.step = 10

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + 1
self.step += 1
return self.layer(x) + self.step


class UnspecInlinableModule(torch.nn.Module):
torchdynamo_force_dynamic = True # forced to be a UnspecializedNNModule

def forward(self, x):
return torch.sin(x)


class UnspecModuleWithIntAttr(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = UnspecInlinableModule()
self.step = 10

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + 1
self.step += 1
return self.layer(x) + self.step


def make_test(fn, expected_ops=None):
def test_fn(self):
return torch._dynamo.testing.standard_test(
Expand Down Expand Up @@ -1357,6 +1389,31 @@ def forward(self, x):
self.assertTrue(torch._dynamo.testing.same(pre, opt_pre))
self.assertTrue(torch._dynamo.testing.same(out1, out_post))

def test_nn_module_unspec_int_attr(self):
for module_class in [ModuleWithIntAttr, UnspecModuleWithIntAttr]:
mod = module_class()
cnt = torch._dynamo.testing.CompileCounter()
opt_mod = torch.compile(backend=cnt)(copy.deepcopy(mod))
x = torch.randn(3, 4)

# Compiling self.step as static.
ref1 = mod(x)
res1 = opt_mod(x)
self.assertTrue(torch.allclose(ref1, res1))
self.assertEqual(cnt.frame_count, 1)

# Compiling self.step as dynamic.
ref2 = mod(x)
res2 = opt_mod(x)
self.assertTrue(torch.allclose(ref2, res2))
self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1))

# No re-compilation!
ref3 = mod(x)
res3 = opt_mod(x)
self.assertTrue(torch.allclose(ref3, res3))
self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1))

# RuntimeError: SymIntArrayRef expected to contain only concrete integers
@expectedFailureDynamic
def test_lazy_module1(self):
Expand Down
4 changes: 0 additions & 4 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,10 +1162,6 @@ def wrap_literal(self, value):
value in self._common_constants()
# Assume integers from global variables want to be specialized
or not self.source.guard_source().is_local()
# Assume that integers that came from NN modules want to be
# specialized (as we don't expect users to be changing the
# NN modules on the fly)
or self.source.guard_source().is_nn_module()
or is_from_defaults(self.source)
or is_cell_contents(self.source)
):
Expand Down

0 comments on commit 6bb9d60

Please sign in to comment.