From ce01671c46f125c6baf7fa48f2c625f0d4405f12 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 26 Apr 2024 22:40:48 -0700 Subject: [PATCH] [dynamo] Collect cell_and_freevars correctly (#125097) Pull Request resolved: https://github.com/pytorch/pytorch/pull/125097 Approved by: https://github.com/Skylion007 --- test/inductor/test_cpu_repro.py | 33 +++++++++++++++++++++++++++++++ torch/_dynamo/symbolic_convert.py | 5 +++++ 2 files changed, 38 insertions(+) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index fea142e6f6ea0..004db97b8a461 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: cpu inductor"] import contextlib import copy +import functools import itertools import math import platform @@ -3048,6 +3049,38 @@ def forward(self, x): v2 = jit_func(input_tensor) self.assertEqual(v1, v2) + def test_nn_param_assign_wrapped(self): + class Model2(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=3) + self.batchnorm = nn.BatchNorm2d(num_features=5) + self.conv_weight = torch.randn(5, 3, 3, 3) + self.conv_bias = torch.randn(5) + + def forward(self, x): + self.conv.weight = nn.Parameter(self.conv_weight) + self.conv.bias = nn.Parameter(self.conv_bias, requires_grad=False) + self.conv.eval() + x = self.conv(x) + x = self.batchnorm(x) + x = F.relu(x) + return x + + input_tensor = torch.randn(1, 3, 10, 10) + func = Model2().to("cpu") + + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + with torch.no_grad(): + func.train(False) + v1 = func(input_tensor) + jit_func = torch.compile(wrapper, fullgraph=True) + v2 = jit_func(input_tensor) + self.assertEqual(v1, v2) + @config.patch(inplace_buffers=True) def test_in_out_buffer(self): def fn(x, y): diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 78d5b08e68b4d..4e52c5fdadf34 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -694,6 +694,11 @@ def cell_and_freevars(self): self._cell_and_freevars = tuple( self.code_options["co_cellvars"] or [] ) + tuple(self.code_options["co_freevars"] or []) + + # An inlined function might depend on the freevar of the parent + # function. So, recursively obtain parent cell and freevars. + if isinstance(self, InliningInstructionTranslator): + self._cell_and_freevars += self.parent.cell_and_freevars() return self._cell_and_freevars def prune_dead_locals(self):