Skip to content

Commit

Permalink
[dynamo] Add more dynamo call_methods and getattr support or Placement (
Browse files Browse the repository at this point in the history
pytorch#117733)

Summary:
Explained by title.
This fix is part of: pytorch#117670

Test Plan:
Unit tetst and CI
- Unit test: `buck2 test mode/dev-nosan //caffe2/test/distributed/_tensor:dtensor_compile -- test_placement_compile`

Differential Revision: D52863073

Pull Request resolved: pytorch#117733
Approved by: https://github.com/yanboliang
  • Loading branch information
yoyoyocmu authored and pytorchmergebot committed Jan 22, 2024
1 parent f612e96 commit 56ef5af
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
27 changes: 27 additions & 0 deletions test/distributed/_tensor/test_dtensor_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Replicate,
Shard,
)
from torch.distributed._tensor.placement_types import _Partial
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
Expand Down Expand Up @@ -98,6 +99,32 @@ def device_type(self) -> str:
def world_size(self) -> int:
return 2

def test_placement_compile(self):
def fn(x):
a = 0
if x.is_replicate():
a += 1
if x.is_shard():
a += 2
if x.dim < 0:
raise RuntimeError("dim < 0")
if x.is_shard(0):
a += 2
if x.is_shard(dim=0):
a += 2
if x.is_shard(dim=None):
a += 2
if x.is_partial():
a += 3
return a

compiled_fn = torch.compile(backend="aot_eager", fullgraph=True)(fn)

for x in [Shard(0), Replicate(), _Partial()]:
opt_fn = fn(x)
compiled_out = compiled_fn(x)
self.assertEqual(opt_fn, compiled_out)

def test_fakify_dtensor(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

Expand Down
23 changes: 19 additions & 4 deletions torch/_dynamo/variables/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ def is_placement(value):
def as_python_constant(self):
return self.value

def var_getattr(self, tx, name: str) -> VariableTracker:
if name == "dim":
return ConstantVariable.create(self.value.dim)
return super().var_getattr(tx, name)

def call_method(
self,
tx,
Expand All @@ -112,10 +117,20 @@ def call_method(
) -> "VariableTracker":
from . import ConstantVariable

allowed_methods = ["__init__", "__setattr__"]
# placement types dynamo tracking allows only __init__
# and __setattr__ methods, the latter is for case like `Shard(dim)`
if name in allowed_methods:
# Placement types dynamo tracking only allows following methods
# and __setattr__ is for case like `Shard(dim)` and methods.
# Methods in the list must satisfy:
# 1. Input arguments are constants and do not need to be guarded on;
# 2. Output is constant with respect to their inputs
constant_fold_functions = [
"__init__",
"__setattr__",
"is_shard",
"is_partial",
"is_replicate",
]

if name in constant_fold_functions:
try:
value_type = type(self.value)
assert (
Expand Down

0 comments on commit 56ef5af

Please sign in to comment.