Skip to content

Commit

Permalink
[inductor] Clear cache on ctx manager exit (pytorch#126146)
Browse files Browse the repository at this point in the history
FIXES pytorch#126128.

Right now, we only clear the cache on ctx manager enter. So state is bad unless we call fresh_inductor_cache again,  usually fine in tests.

Cue compiled autograd tests when going from TestCompiledAutograd -> TestAutogradWithCompiledAutograd.
TestCompiledAutograd uses the ctx manager, but TestAutogradWithCompiledAutograd don't

Pull Request resolved: pytorch#126146
Approved by: https://github.com/jgong5, https://github.com/oulgen
ghstack dependencies: pytorch#126144
  • Loading branch information
xmfan authored and ZelboK committed May 19, 2024
1 parent 80798a7 commit b2efbae
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
16 changes: 16 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -4971,6 +4971,22 @@ def fn(x):
opt_fn = torch.compile(fn, backend="eager")
opt_fn(np.ones([3, 3]))

def test_issue126128(self):
def fn():
x = torch.randn(1, 10)
y = torch.randn(10, 1)
return torch.mm(x, y).sum()

def fn2():
x = torch.randn(10, 100)
y = torch.randn(100, 10)
return torch.mm(x, y).sum()

with torch._inductor.utils.fresh_inductor_cache():
torch.compile(fn)()

torch.compile(fn2)()


instantiate_parametrized_tests(ReproTests)

Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,8 @@ def fresh_inductor_cache(cache_entries=None):
except Exception:
log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir)
raise
finally:
clear_inductor_caches()


def argsort(seq) -> List[int]:
Expand Down

0 comments on commit b2efbae

Please sign in to comment.