-
Notifications
You must be signed in to change notification settings - Fork 661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rebase #5601
rebase #5601
Conversation
…deduce_consistent_op_interpreter
…-Inc/oneflow into deduce_consistent_op_interpreter
|
||
Maybe<void> CheckIsDeviceSupportedByOp(const ParallelDesc& parallel_desc, | ||
const std::string& op_type_name) { | ||
if (IsCpuOnly(op_type_name)) { CHECK_EQ_OR_RETURN(parallel_desc.device_tag(), "cpu"); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
以后可能会支持其他device(除了cpu和cuda),不能只判断cpu吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
未来支持其他device的时候到时再处理。因为就算这一次漏写了,也不会出事。
{ | ||
// Infer OpArgMutConsistentTensorMeta. | ||
const auto& GetInputTensorMeta = [](int32_t i) { | ||
UNIMPLEMENTED(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
InferLogicalShapeAndDType不会调用这个lambda,所以这里才直接UNIMPLEMENTED()么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的。这是source op
@@ -37,7 +37,10 @@ namespace one { | |||
|
|||
namespace { | |||
|
|||
Maybe<Symbol<Device>> GetDefaultDevice() { return Device::New("cpu", 0); } | |||
Maybe<Symbol<Device>> GetDefaultDevice(const OpExprInterpContext& ctx) { | |||
if (ctx.device.has_value()) { return ctx.device.value(); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接来一个全局的default_device_symbol_应该更好,而且建议python中调用flow.device("type:index")
不应该每次都创建一个device,而是返回一个单例的device,因为我发现python里面flow.device("type:index")
还挺耗时的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已解决
@@ -86,7 +93,9 @@ class TensorInfo final { | |||
private: | |||
std::shared_ptr<const Shape> shape_; | |||
DataType dtype_; | |||
// TODO: Add device info | |||
Maybe<Symbol<Device>> device_; // for local tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里也可以改成Optional
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已解决
@@ -264,7 +264,11 @@ | |||
bind_python: True | |||
|
|||
- name: "constant" | |||
signature: "Tensor Constant(*, Shape shape, Scalar value, DataType dtype)" | |||
signature: "Tensor Constant(*, Shape shape, Scalar value, DataType dtype, Int64 device)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device可以默认为None,按照pytorch的逻辑None device应该就是cpu,
ignature: "Tensor Constant(*, Shape shape, Scalar value, DataType dtype, Int64 device=None)
在functor那里可以用Optional<Int64>
来接它。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的。我正好难以处理None的情形。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已解决
} | ||
const auto& parallel_distribution = JUST(MakeParallelDistribution(sbp_tuple)); | ||
if (!JUST(*Global<Maybe<bool>, MultiClient>::Get())) { | ||
JUST(attrs.SetAttr<std::string>("nd_sbp", parallel_distribution->DebugString())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DebugString感觉怪怪的~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这都是因为user_op_attr那里并没支持sbp基本类型。如果只能序列化的话,我倾向于序列化成txt而不是binary,因为不在乎那点存储,可读性反而很重要。
我尝试把这里改成PbMessage2TxtString
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯,PbMessage2TxtString可以
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已解决
Signed-off-by: daquexian <daquexian566@gmail.com>
…-Inc/oneflow into deduce_consistent_op_interpreter
placement: flow.placement = None, | ||
sbp: Union[ | ||
flow._oneflow_internal.sbp.sbp, List[flow._oneflow_internal.sbp.sbp] | ||
] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
配置placement和sbp。
def test_consistent_naive(test_case): | ||
placement = flow.placement("cpu", {0: [0]}) | ||
sbp = (flow.sbp.broadcast,) | ||
x = flow.ones((16, 16), placement=placement, sbp=sbp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
示例
Speed stats:
|
Maybe<Symbol<cfg::ParallelDistribution>> MakeParallelDistribution( | ||
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple) const { | ||
static thread_local std::map<std::vector<Symbol<cfg::SbpParallel>>, | ||
Symbol<cfg::ParallelDistribution>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.