Skip to content

Commit

Permalink
construct dtensor with dtensor_from_local api (PaddlePaddle#60206)
Browse files Browse the repository at this point in the history
  • Loading branch information
pangengzheng authored and Wanglongzhi2001 committed Jan 7, 2024
1 parent c89d3e4 commit d8d2013
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,15 +1309,13 @@ def build_distributed_tensor(local_tensor, dist_attr):
)
else:
raise ValueError(f"dim {dim} is not supported.")
# TODO(pangengzheng): construct dist_tensor with _dtensor_from_local api when it is ready.
global_tensor = paddle.zeros(global_shape, dtype=local_tensor.dtype)
mesh = dist.ProcessMesh(
np.array(dist_attr["process_group"]).reshape(
dist_attr["process_shape"]
)
)
placements = to_placements(dist_attr["dims_mapping"], mesh)
dist_tensor = dist.shard_tensor(global_tensor, mesh, placements)
dist_tensor = dtensor_from_local(local_tensor, mesh, placements)
assert (
dist_tensor._local_value().shape == local_tensor.shape
), f"local tensor shape {dist_tensor._local_value().shape} not equal to local_tensor.shape:{local_tensor.shape}"
Expand Down

0 comments on commit d8d2013

Please sign in to comment.