Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Semi-auto]Add Shard/Replicate/Partial in DistTensor #58930

Merged
merged 13 commits into from
Nov 15, 2023

Conversation

ForFishes
Copy link
Member

@ForFishes ForFishes commented Nov 12, 2023

PR types

New features

PR changes

Others

Description

DistTensor通过Shard/Replicate/Partial构造函数,同时增加SRP转化为dim_mapping。后续将废弃通过TensorDistAttr构造DistTensor。动半中,DistAttr在推导转换出现,SRP在Reshard中出现。

  • Reshard 改造,基于Shard/Replicate/Parital。
  • 改造Python所有设计DTensor的构造API,支持由ProcessMesh和Placements构造。
  • 调用推导规则部分改造,使用Shard/Replicate/Parital转换后的DistAttr。
import paddle
import paddle.distributed as dist
from paddle.base import core

tensor = paddle.rand([2, 10])
mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
d_tensor = paddle.Tensor(
            tensor, process_mesh=mesh, placements=core.Shard(0)
        )

Pcard-73145

Copy link

paddle-bot bot commented Nov 12, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

#else
PADDLE_THROW(platform::errors::Unavailable(
"Placements to PyObject is not supported in the current "
"PaddlePaddle, please recompile and installPaddlePaddle with the option "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

install PaddlePaddle

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

@@ -13,6 +13,7 @@
// limitations under the License.

#pragma once
#include <ostream>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个头文件没用着

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done thx

ReduceType reduce_type_;
};

using Placements = std::vector<std::shared_ptr<Placement>>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是不是可以放到DistTensorMeta类里

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

}

bool operator==(const Placement& other) const override {
const Shard* other_shard = dynamic_cast<const Shard*>(&other);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是在placement里加一个shard_axis方法然后直接访问比较好

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个属于shard特有的,和Partial中的reduce_type_一样。

@ForFishes ForFishes closed this Nov 13, 2023
@ForFishes ForFishes reopened this Nov 13, 2023
.def(py::init([](int64_t dim) {
return std::make_shared<phi::distributed::Shard>(dim);
}))
.def("get_dim", &phi::distributed::Shard::get_dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

张量的维度 和 Mesh 的维度所用的名字是否要做区分? 都叫 dim 的话,感觉容易混淆。 e.g.:axis,dim

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的dim应该就是只tensor的dim,所以对齐一致。

dist_attr.set_dims_mapping(dist_tensor_meta_.dim_mapping());
dist_attr.mark_annotated("process_mesh");
dist_attr.mark_annotated("dims_mapping");
dist_attr_ = dist_attr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉后面 dist_tensor_meta_ 和 dist_attr_ 只能留一个做数据成员。不然会有数据同步问题,静半踩了挺多坑的。
比如 dist_attr 永远是临时变脸,只能通过函数 get_dist_attr(dist_tensor_meta_) 返回

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,是的。计划后续Disttensor将只会拥有DistTensorMeta和DenseTensor指针,其他的都将删掉。但是改动较大,暂时拆分pr合入。

@@ -121,12 +146,14 @@ class DistTensor final
private:
friend class ReshardFunction;

// The global dimensions(shape)
// The global dimensions(shape), will move to DistTensorMeta
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉 tesnor shape,Mesh, dims mapping 都会用 dims。。。。 这个名字后面可能要讨论讨论

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯,这个名字并没有统一。

private:
std::shared_ptr<const ProcessMesh> process_mesh_;
Placements placements_;
std::shared_ptr<const DenseTensorMeta> tensor_meta_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要存整个DenseTensorMeta吗?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要保存dtype,strides,layerout等信息,后续用于推导的cache,基本上和DenseTensorMeta是相同的。

self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

def run_test_placements(self):
self.placements = [core.Replicate(), core.Replicate()]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

要不要直接把这几个类import到paddle.distributed下面,不用core所谓前缀

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯,目前这个单测是一个临时单测。不涉及API改动,后续要整体改动API,在更新这个单测。

public:
virtual ~Placement() = default;

virtual bool is_shard(std::optional<int> dim = std::nullopt) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Placement基类会实际存在对象吗?要不要直接用纯虚函数

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不会存在实际对象。

Copy link
Contributor

@LiYuRio LiYuRio left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

private:
std::shared_ptr<const ProcessMesh> process_mesh_;
Placements placements_;
std::shared_ptr<const DenseTensorMeta> tensor_meta_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用share_ptr的原因是?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

少一次copy,同时tensormeta可能被其他地方使用。生命周期和dtensor一致

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for set_tests_properties(test_dist_tensor_api PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)

Copy link
Contributor

@jzhang533 jzhang533 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ForFishes ForFishes merged commit bd158d1 into PaddlePaddle:develop Nov 15, 2023
28 checks passed
@ForFishes ForFishes deleted the srp_in_semi_auto branch November 15, 2023 02:22
SecretXV pushed a commit to SecretXV/Paddle that referenced this pull request Nov 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

10 participants