Skip to content

Commit

Permalink
[dynamo] Collect cell_and_freevars correctly (pytorch#125097)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#125097
Approved by: https://github.com/Skylion007
  • Loading branch information
anijain2305 authored and andoorve committed May 1, 2024
1 parent f61621c commit ce01671
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
33 changes: 33 additions & 0 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Owner(s): ["oncall: cpu inductor"]
import contextlib
import copy
import functools
import itertools
import math
import platform
Expand Down Expand Up @@ -3048,6 +3049,38 @@ def forward(self, x):
v2 = jit_func(input_tensor)
self.assertEqual(v1, v2)

def test_nn_param_assign_wrapped(self):
class Model2(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=3)
self.batchnorm = nn.BatchNorm2d(num_features=5)
self.conv_weight = torch.randn(5, 3, 3, 3)
self.conv_bias = torch.randn(5)

def forward(self, x):
self.conv.weight = nn.Parameter(self.conv_weight)
self.conv.bias = nn.Parameter(self.conv_bias, requires_grad=False)
self.conv.eval()
x = self.conv(x)
x = self.batchnorm(x)
x = F.relu(x)
return x

input_tensor = torch.randn(1, 3, 10, 10)
func = Model2().to("cpu")

@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

with torch.no_grad():
func.train(False)
v1 = func(input_tensor)
jit_func = torch.compile(wrapper, fullgraph=True)
v2 = jit_func(input_tensor)
self.assertEqual(v1, v2)

@config.patch(inplace_buffers=True)
def test_in_out_buffer(self):
def fn(x, y):
Expand Down
5 changes: 5 additions & 0 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,11 @@ def cell_and_freevars(self):
self._cell_and_freevars = tuple(
self.code_options["co_cellvars"] or []
) + tuple(self.code_options["co_freevars"] or [])

# An inlined function might depend on the freevar of the parent
# function. So, recursively obtain parent cell and freevars.
if isinstance(self, InliningInstructionTranslator):
self._cell_and_freevars += self.parent.cell_and_freevars()
return self._cell_and_freevars

def prune_dead_locals(self):
Expand Down

0 comments on commit ce01671

Please sign in to comment.