-
Notifications
You must be signed in to change notification settings - Fork 756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat lazy tensor indexing #9334
Conversation
Scalar value_scalar = functional::PyUnpackScalar(value); | ||
value_tensor = ASSERT_PTR( | ||
functional::Constant(Shape({}), value_scalar, tensor->dtype(), ASSERT(tensor->device()))); | ||
std::shared_ptr<Tensor> value_tensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里代码太长,只是加了一个作用域
// NOTE: LocalToGlobal should be called in eager mode | ||
LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LocalToGlobal 只能在 eager 模式下调用
/*static*/ Maybe<void> SliceUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { | ||
return InferLogicalTensorDesc(ctx); | ||
const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0); | ||
auto* y_desc = ctx->MutOutputTensorDesc("y", 0); | ||
y_desc->set_shape(ref_desc.shape()); | ||
y_desc->set_is_dynamic(ref_desc.is_dynamic()); | ||
return Maybe<void>::Ok(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
之前 SliceUpdate 的物理 Tensor 推导是错误的,它支持 S + B -> S,是不能和逻辑 shape 推导共用推导函数(逻辑推导函数中有一些 shape 的检察,物理 tensor shape 推导不需要)
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9334/ |
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9334/ |
Speed stats:
|
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9334/ |
* feat(boxing): collective_boxing slice_boxing support 0size tensor * test(Indexing): add lazy tensor basic indexing * add MaskTensor judgement * format code * feat(TensorIndexing): support lazy advance getitem indexing * feat(Indexing): support lazy indexing for lazy_tensor and free_tensor * fix(Indexing): fix indexing test bug * test(Indexing): test all advance indexing * test(GlobalIndexing): fix eager global indexing bug * test(Indexing): support combined indexing * add last test cases * fix merge bug * fix lazy mode guard * test(Indexing): refine set scalar value test * test(Indexing): enable all bool tensor index setitem * decrease test time * refine 0size shape judgement * add comment
lazy tensor indexing 支持。
主要处理了: