From 6bb9d6080d33c817fcbf9e5ae8a59b76812a53d2 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 18 May 2024 05:02:14 +0000 Subject: [PATCH] [Dynamo] Treat integers stored on nn.Modules as dynamic (#126466) Fixes #115711 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126466 Approved by: https://github.com/jansel --- test/dynamo/test_modules.py | 57 ++++++++++++++++++++++++++++++ torch/_dynamo/variables/builder.py | 4 --- 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index ceb1521ffe69f..b22f02ee2fcc4 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -19,6 +19,7 @@ from torch._dynamo.eval_frame import unsupported from torch._dynamo.mutation_guard import GenerationTracker from torch._dynamo.testing import expectedFailureDynamic, same +from torch._dynamo.utils import ifdynstaticdefault from torch.nn.modules.lazy import LazyModuleMixin from torch.nn.parameter import Parameter, UninitializedParameter @@ -1104,6 +1105,37 @@ def forward(self, x): return self.m(x) +class ModuleWithIntAttr(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(4, 4) + self.step = 10 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + 1 + self.step += 1 + return self.layer(x) + self.step + + +class UnspecInlinableModule(torch.nn.Module): + torchdynamo_force_dynamic = True # forced to be a UnspecializedNNModule + + def forward(self, x): + return torch.sin(x) + + +class UnspecModuleWithIntAttr(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer = UnspecInlinableModule() + self.step = 10 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + 1 + self.step += 1 + return self.layer(x) + self.step + + def make_test(fn, expected_ops=None): def test_fn(self): return torch._dynamo.testing.standard_test( @@ -1357,6 +1389,31 @@ def forward(self, x): self.assertTrue(torch._dynamo.testing.same(pre, opt_pre)) self.assertTrue(torch._dynamo.testing.same(out1, out_post)) + def test_nn_module_unspec_int_attr(self): + for module_class in [ModuleWithIntAttr, UnspecModuleWithIntAttr]: + mod = module_class() + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch.compile(backend=cnt)(copy.deepcopy(mod)) + x = torch.randn(3, 4) + + # Compiling self.step as static. + ref1 = mod(x) + res1 = opt_mod(x) + self.assertTrue(torch.allclose(ref1, res1)) + self.assertEqual(cnt.frame_count, 1) + + # Compiling self.step as dynamic. + ref2 = mod(x) + res2 = opt_mod(x) + self.assertTrue(torch.allclose(ref2, res2)) + self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1)) + + # No re-compilation! + ref3 = mod(x) + res3 = opt_mod(x) + self.assertTrue(torch.allclose(ref3, res3)) + self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1)) + # RuntimeError: SymIntArrayRef expected to contain only concrete integers @expectedFailureDynamic def test_lazy_module1(self): diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 41b9fbd836ae1..c1b9f68639f53 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1162,10 +1162,6 @@ def wrap_literal(self, value): value in self._common_constants() # Assume integers from global variables want to be specialized or not self.source.guard_source().is_local() - # Assume that integers that came from NN modules want to be - # specialized (as we don't expect users to be changing the - # NN modules on the fly) - or self.source.guard_source().is_nn_module() or is_from_defaults(self.source) or is_cell_contents(self.source) ):