-
Notifications
You must be signed in to change notification settings - Fork 664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat empty op #5659
Feat empty op #5659
Changes from 18 commits
c94183d
631ed38
75be6a9
f355238
275d4b7
7631138
55e5c83
8221bc6
fd72499
6512197
a86be56
a5dd716
52e3966
f4faf11
a2d2147
fdcaa11
38c3827
a341960
d9a1dd6
b52f857
87a02af
a5f1e73
143d471
dd09067
6a1df24
d00b380
9d9fccf
7429b74
b9dea29
e4e80db
911e765
e210383
3b2d05c
2223c0b
25717bc
a790dfc
3560c43
31184f7
5478eff
476afdc
a88ed9b
6367fb3
b9faece
0da4d90
66b571e
a75bdfb
2fbd207
30d74d6
8713b41
8caeb3f
76a932e
cdb16ad
ae9d3a7
3701588
90e4ad1
fa7f007
6373845
d898283
43ce019
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,6 +52,30 @@ Maybe<EagerMirroredTensorImpl*> TensorImpl4Tensor(const std::shared_ptr<Tensor>& | |
return tensor->mut_eager_mirrored_tensor_impl(); | ||
} | ||
|
||
class MutMirroredTensorMeta : public TensorMeta { | ||
public: | ||
MutMirroredTensorMeta() : TensorMeta(std::make_shared<const Shape>(), kInvalidDataType) {} | ||
MutMirroredTensorMeta(const MutMirroredTensorMeta&) = default; | ||
MutMirroredTensorMeta(MutMirroredTensorMeta&&) = default; | ||
~MutMirroredTensorMeta() override = default; | ||
}; | ||
|
||
std::vector<TensorMeta*>* ThreadLocalDefaultOutputMutTensorMetas(int64_t size) { | ||
static thread_local std::vector<MutMirroredTensorMeta> struct_vec; | ||
static thread_local std::vector<TensorMeta*> ptr_vec; | ||
struct_vec.resize(size); | ||
ptr_vec.resize(size); | ||
if (size == 1) { | ||
ptr_vec.at(0) = &struct_vec.at(0); // unfold loop | ||
} else if (size == 2) { | ||
ptr_vec.at(0) = &struct_vec.at(0); // unfold loop | ||
ptr_vec.at(1) = &struct_vec.at(1); // unfold loop | ||
} else { | ||
for (int i = 0; i < size; ++i) { ptr_vec.at(i) = &struct_vec.at(i); } | ||
} | ||
return &ptr_vec; | ||
} | ||
|
||
} // namespace | ||
|
||
Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, | ||
|
@@ -69,12 +93,16 @@ Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in | |
} | ||
std::shared_ptr<EagerBlobObjectList> output_eager_blob_objects = | ||
std::make_shared<EagerBlobObjectList>(outputs->size()); | ||
auto* output_tensor_metas = ThreadLocalDefaultOutputMutTensorMetas(outputs->size()); | ||
for (int i = 0; i < outputs->size(); i++) { | ||
if (!outputs->at(i)) { | ||
outputs->at(i) = | ||
const auto& tensor_impl = | ||
std::make_shared<MirroredTensor>(std::make_shared<EagerMirroredTensorImpl>()); | ||
} | ||
if (JUST(outputs->at(i)->has_eager_blob_object())) { | ||
outputs->at(i) = tensor_impl; | ||
output_tensor_metas->at(i) = tensor_impl->mut_tensor_meta(); | ||
} else { | ||
bool has_eager_blob_object = JUST(outputs->at(i)->has_eager_blob_object()); | ||
CHECK_OR_RETURN(has_eager_blob_object); | ||
output_eager_blob_objects->at(i) = JUST(outputs->at(i)->eager_blob_object()); | ||
} | ||
lixinqi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
@@ -109,14 +137,21 @@ Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in | |
return CHECK_JUST(TensorImpl4Tensor(inputs.at(i)))->mut_tensor_meta(); | ||
}, | ||
[&](int32_t i) -> TensorMeta* { | ||
return CHECK_JUST(TensorImpl4Tensor(outputs->at(i)))->mut_tensor_meta(); | ||
// using thread_local TensorMeta pointer if inplace. | ||
// using tensor_impl TensorMeta pointer if not inplace. | ||
return output_tensor_metas->at(i); | ||
Comment on lines
+141
to
+143
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 非inplace时正常 infer 到 thread_local TensorMeta 中,inplace 时 infer 到实际的 tensor_impl 中 |
||
})); | ||
|
||
for (int i = 0; i < output_eager_blob_objects->size(); i++) { | ||
auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); | ||
if (!output_eager_blob_objects->at(i)) { | ||
auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); | ||
JUST(tensor_impl->InitEagerBlobObject(JUST(outputs->at(i)->device())->mem_case())); | ||
output_eager_blob_objects->at(i) = JUST(tensor_impl->eager_blob_object()); | ||
} else { | ||
// output i is inplaced. | ||
// check thread_local TensorMeta and tensor_impl TensorMeta. | ||
CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape() == output_tensor_metas->at(i)->shape()); | ||
CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype() == output_tensor_metas->at(i)->dtype()); | ||
Comment on lines
+152
to
+156
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果是inplace则直接检察infer的结果 |
||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,11 +42,7 @@ namespace one { | |
} else { | ||
const auto& impl = | ||
std::make_shared<EagerMirroredTensorImpl>(tensor_meta, requires_grad, is_leaf); | ||
const auto& tensor = std::make_shared<MirroredTensor>(impl); | ||
const auto& outputs = std::make_shared<TensorTuple>(); | ||
outputs->push_back(tensor); | ||
JUST(RunEmptyOp(outputs.get())); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 看起来是这里的删除,使得MakeTensor不再创建blob_object了 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 对,后来这里要完全去掉,eager tensor只能有op的接口创建 |
||
return tensor; | ||
return std::make_shared<MirroredTensor>(impl); | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert之前的改动,不共享shape的内存