Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat eager global tensor indexing #9138

Merged
merged 17 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 43 additions & 12 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1271,11 +1271,23 @@ class TensorScatterNdUpdateFunctor {
<< Error::RuntimeError() << "The dtype of tensor and updates must be same.";
std::shared_ptr<Tensor> contiguous_index = JUST(functional::ToContiguous(indices));
if (inplace) {
JUST(CheckInplaceValid(tensor));
auto outputs = std::make_shared<TensorTuple>(1);
outputs->at(0) = tensor;
JUST(OpInterpUtil::Dispatch(*op_, {tensor, contiguous_index, updates}, outputs.get()));
return outputs->at(0);
if (tensor->is_global()) {
// NOTE: global tensor_scatter_nd_update inplace must calculate on another tensor and assign
// back because of input's sbp limited
auto output =
JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {tensor, contiguous_index, updates}));
int64_t ndim = tensor->shape()->NumAxes();
std::vector<int64_t> start(ndim, 0);
std::vector<int64_t> stop(tensor->shape()->begin(), tensor->shape()->end());
std::vector<int64_t> step(ndim, 1);
return functional::SliceUpdate(tensor, output, start, stop, step, /*inplace=*/true);
wyg1997 marked this conversation as resolved.
Show resolved Hide resolved
} else {
JUST(CheckInplaceValid(tensor));
auto outputs = std::make_shared<TensorTuple>(1);
outputs->at(0) = tensor;
JUST(OpInterpUtil::Dispatch(*op_, {tensor, contiguous_index, updates}, outputs.get()));
return outputs->at(0);
}
} else {
return OpInterpUtil::Dispatch<Tensor>(*op_, {tensor, contiguous_index, updates});
}
Expand Down Expand Up @@ -2136,7 +2148,7 @@ class TensorGetItemFunctor {
Shape shape(DimVector(target_dims.begin(), target_dims.end()));
if (shape != *(result->shape())) { result = JUST(Reshape(result, shape)); }
if (!tensor_indices.empty()) {
JUST(UnifyLocalTensorAndIndicesOnDevice(x, tensor_indices));
JUST(UnifyInputAndIndicesOnDevice(x, tensor_indices));
result = JUST(ApplyAdvancedIndexing(result, tensor_indices));
}
return result;
Expand Down Expand Up @@ -2233,16 +2245,35 @@ class TensorSetItemFunctor {
if (is_identity) {
result = expand_input;
} else {
CHECK_OR_RETURN(view::IsViewApplicable(expand_input))
<< "combined slice setitem must enable view, please try to set ONEFLOW_DISABLE_VIEW=0";
result = JUST(Slice(expand_input, start, end, step, /*enable_view_slice=*/true));
if (expand_input->is_local()) {
CHECK_OR_RETURN(view::IsViewApplicable(expand_input))
<< "combined slice setitem must enable view, please try to set "
"ONEFLOW_DISABLE_VIEW=0";
Comment on lines +2232 to +2234
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以考虑在ONEFLOW_DISABLE_VIEW=1的情况下也支持一下setitem,不过这个工作可以放到以后去做

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯,我也想到了,其实和 global 是一样的,一顿处理后再写回去就行

result = JUST(Slice(expand_input, start, end, step, /*enable_view_slice=*/true));
} else {
// global tensor
result = JUST(Slice(expand_input, start, end, step, /*enable_view_slice=*/false));
}
}
if (target_shape != *(result->shape())) {
const Shape& slice_result_shape = *(result->shape());
if (target_shape != slice_result_shape) {
result = JUST(functional::View(result, target_shape));
}

JUST(UnifyLocalTensorAndIndicesOnDevice(expand_input, tensor_indices));
JUST(ApplyAdvancedIndexingUpdate(result, tensor_indices, value));
JUST(UnifyInputAndIndicesOnDevice(result, tensor_indices));
result = JUST(ApplyAdvancedIndexingUpdate(result, tensor_indices, value));

// Write the sliced tensor back to the original tensor.
if (result->is_global()) {
if (*result->shape() != slice_result_shape) {
CHECK_EQ_OR_RETURN(result->shape()->elem_cnt(), slice_result_shape.elem_cnt())
<< Error::RuntimeError()
<< "The global tensor size mismatch. Target sizes: " << slice_result_shape.ToString()
<< ", value sizes: " << result->shape()->ToString();
result = JUST(functional::View(result, slice_result_shape));
}
JUST(SliceUpdate(expand_input, result, start, end, step, /*inplace=*/true));
}
}
return Maybe<void>::Ok();
}
Expand Down
118 changes: 70 additions & 48 deletions oneflow/core/functional/tensor_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,9 @@ Maybe<bool> IsContinuousSubspace(const TensorTuple& indices) {
// NOTE(wyg):
// Move indices subspace to be contiguous and ahead.
// e.g. [:, index0, index1] -> [index0, index1, :]
Maybe<void> TransposeFront(const std::shared_ptr<Tensor>& input, const TensorTuple& indices,
std::shared_ptr<Tensor>* output, TensorTuple* valid_indices) {
Maybe<std::vector<int>> TransposeFront(const std::shared_ptr<Tensor>& input,
const TensorTuple& indices, std::shared_ptr<Tensor>* output,
TensorTuple* valid_indices) {
std::vector<int> permute;
permute.reserve(input->ndim());
for (int i = 0; i < input->ndim(); ++i) {
Expand All @@ -177,7 +178,7 @@ Maybe<void> TransposeFront(const std::shared_ptr<Tensor>& input, const TensorTup
} else {
*output = input;
}
return Maybe<void>::Ok();
return permute;
}

Maybe<Tensor> AdjustSubspace(const std::shared_ptr<Tensor>& input, const TensorTuple& indices,
Expand Down Expand Up @@ -215,6 +216,30 @@ Maybe<bool> HasFalseIndex(const TensorIndex& index) {
});
}

// Permute back for global tensor which transpose dims to front
Maybe<Tensor> PermuteBackForGlobalTensor(const std::shared_ptr<Tensor>& result,
const std::vector<int>& permute) {
CHECK_OR_RETURN(result->is_global()); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(result->ndim(), permute.size()); // NOLINT(maybe-need-error-msg)
std::vector<int> inv_permute(permute.size());
for (int32_t i = 0; i < permute.size(); ++i) { inv_permute[permute.at(i)] = i; }

bool not_permute = true;
{
for (int32_t i = 0; i < permute.size(); ++i) {
if (inv_permute[i] != i) {
not_permute = false;
break;
}
}
}
if (!not_permute) {
return Transpose(result, inv_permute);
} else {
return result;
}
}

} // namespace

Maybe<void> PrepareSliceIndices(const TensorIndex& index, const Shape& shape,
Expand Down Expand Up @@ -362,24 +387,9 @@ Maybe<Tensor> ApplyAdvancedIndexing(const std::shared_ptr<Tensor>& input,
packed_indices = JUST(Transpose(packed_indices, permute))->contiguous();
}

// Align device or placement between input and indices.
if (transposed_input->is_global()) {
if (packed_indices->is_local()) {
const auto& placement = JUST(transposed_input->parallel_desc());
const auto& broadcast_sbp = JUST(MakeBroadcastSbpParallel());
int n = JUST(input->nd_sbp())->sbp_parallel_size();
std::vector<Symbol<SbpParallel>> grad_sbp_tuple;
packed_indices = JUST(ToGlobal(packed_indices, placement,
std::vector<Symbol<SbpParallel>>(n, broadcast_sbp),
grad_sbp_tuple, /* check_meta */ false, /*copy=*/false));
}
} else {
Symbol<Device> device = JUST(transposed_input->device());
if (JUST(packed_indices->device()) != device) {
packed_indices =
JUST(Copy(packed_indices, device->type(), device->device_id(), /*pin_memory=*/false));
}
}
CHECK_EQ_OR_RETURN(transposed_input->is_local(), packed_indices->is_local())
<< Error::RuntimeError() << "The input and indices must be both local or global.";

auto result = JUST(GatherNd(transposed_input, packed_indices));

int required_ndim = input->ndim() - valid_indices.size() + index_ndim;
Expand All @@ -392,9 +402,9 @@ Maybe<Tensor> ApplyAdvancedIndexing(const std::shared_ptr<Tensor>& input,
return result;
}

Maybe<void> ApplyAdvancedIndexingUpdate(const std::shared_ptr<Tensor>& input,
const TensorTuple& indices,
const std::shared_ptr<Tensor>& value) {
Maybe<Tensor> ApplyAdvancedIndexingUpdate(const std::shared_ptr<Tensor>& input,
const TensorTuple& indices,
const std::shared_ptr<Tensor>& value) {
CHECK_GE_OR_RETURN(input->ndim(), indices.size())
<< Error::IndexError() << "Too many indices for tensor of dimension " << input->ndim();
const auto& expanded_indices = JUST(ExpandIndices(indices));
Expand All @@ -404,15 +414,21 @@ Maybe<void> ApplyAdvancedIndexingUpdate(const std::shared_ptr<Tensor>& input,
// transpose the input as long as the first index is null.
std::shared_ptr<Tensor> transposed_input;
TensorTuple valid_indices;
JUST(TransposeFront(input, *expanded_indices, &transposed_input, &valid_indices));
CHECK_EQ_OR_RETURN(JUST(transposed_input->tensor_storage()), JUST(input->tensor_storage()))
<< Error::RuntimeError()
<< "This setitem operator must enable view mechanism, please try to set "
"ONEFLOW_DISABLE_VIEW=0";
const auto& transposed_input_permute =
JUST(TransposeFront(input, *expanded_indices, &transposed_input, &valid_indices));
// NOTE: For local tensor, we make sure that transposed_input is a view of input.
// Therefore we need not transpose it back because we update the value in a same memory
// by tensor_scatter_nd_update operator.
if (input->is_local()) {
CHECK_EQ_OR_RETURN(JUST(transposed_input->tensor_storage()), JUST(input->tensor_storage()))
<< Error::RuntimeError()
<< "This setitem operator must enable view mechanism, please try to set "
"ONEFLOW_DISABLE_VIEW=0";
}

if (valid_indices.empty()) {
CHECK_EQ_OR_RETURN(value->nelement(), 0) << Error::IndexError() << "invalid indices";
return Maybe<void>::Ok();
return input;
}
int index_ndim = valid_indices[0]->ndim();
auto packed_indices = JUST(Stack(valid_indices, 0));
Expand All @@ -426,21 +442,8 @@ Maybe<void> ApplyAdvancedIndexingUpdate(const std::shared_ptr<Tensor>& input,
packed_indices = JUST(Transpose(packed_indices, permute))->contiguous();
}

if (transposed_input->is_global()) {
const auto& placement = JUST(transposed_input->parallel_desc());
const auto& broadcast_sbp = JUST(MakeBroadcastSbpParallel());
int n = JUST(input->nd_sbp())->sbp_parallel_size();
std::vector<Symbol<SbpParallel>> grad_sbp_tuple;
packed_indices =
JUST(ToGlobal(packed_indices, placement, std::vector<Symbol<SbpParallel>>(n, broadcast_sbp),
grad_sbp_tuple, /*check_meta=*/false, /*copy=*/false));
} else {
Symbol<Device> device = JUST(transposed_input->device());
if (JUST(packed_indices->device()) != device) {
packed_indices =
JUST(Copy(packed_indices, device->type(), device->device_id(), /*pin_memory=*/false));
}
}
CHECK_EQ_OR_RETURN(transposed_input->is_local(), packed_indices->is_local())
<< Error::RuntimeError() << "The input and indices must be both local or global.";

Shape expand_shape;
{
Expand Down Expand Up @@ -475,7 +478,11 @@ Maybe<void> ApplyAdvancedIndexingUpdate(const std::shared_ptr<Tensor>& input,
expand_value = JUST(AdjustSubspace(expand_value, indices, index_ndim, /*reverse*/ true));
}
JUST(TensorScatterNdUpdate(transposed_input, packed_indices, expand_value, /*inplace=*/true));
return Maybe<void>::Ok();
// Global tensor is not support view, so we should permute back and copy to origin input if need
if (transposed_input->is_global()) {
return PermuteBackForGlobalTensor(transposed_input, *transposed_input_permute);
}
return transposed_input;
}

Maybe<Tensor> ApplySelectIndexing(const std::shared_ptr<one::Tensor>& input,
Expand All @@ -498,8 +505,8 @@ Maybe<Tensor> ApplySelectIndexing(const std::shared_ptr<one::Tensor>& input,
return functional::AsStrided(input, sizes, strides, storage_offset);
}

Maybe<void> UnifyLocalTensorAndIndicesOnDevice(const std::shared_ptr<Tensor>& x,
TensorTuple& tensor_indices) {
Maybe<void> UnifyInputAndIndicesOnDevice(const std::shared_ptr<Tensor>& x,
TensorTuple& tensor_indices) {
if (x->is_local()) {
const auto x_device = JUST(x->device());
for (int64_t i = 0; i < tensor_indices.size(); ++i) {
Expand All @@ -513,6 +520,21 @@ Maybe<void> UnifyLocalTensorAndIndicesOnDevice(const std::shared_ptr<Tensor>& x,
JUST(Copy(tensor_index, x_device->type(), x_device->device_id(), /*pin_memory=*/false));
}
}
} else {
// global tensor
const auto& placement = JUST(x->parallel_desc());
const auto& broadcast_sbp = JUST(MakeBroadcastSbpParallel());
int n = JUST(x->nd_sbp())->sbp_parallel_size();
std::vector<Symbol<SbpParallel>> grad_sbp_tuple;
for (int64_t i = 0; i < tensor_indices.size(); ++i) {
const auto tensor_index = tensor_indices[i];
if (tensor_index == nullptr) { continue; }
if (tensor_index->is_local()) {
tensor_indices[i] = JUST(ToGlobal(tensor_index, placement,
std::vector<Symbol<SbpParallel>>(n, broadcast_sbp),
grad_sbp_tuple, /*check_meta=*/false, /*copy=*/false));
}
}
}
return Maybe<void>::Ok();
}
Expand Down
10 changes: 5 additions & 5 deletions oneflow/core/functional/tensor_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ Maybe<Tensor> ApplyAdvancedIndexing(const std::shared_ptr<Tensor>& input,
Maybe<Tensor> ApplySelectIndexing(const std::shared_ptr<one::Tensor>& input,
const TensorIndex& index);

Maybe<void> UnifyLocalTensorAndIndicesOnDevice(const std::shared_ptr<Tensor>& x,
TensorTuple& tensor_indices);
Maybe<void> UnifyInputAndIndicesOnDevice(const std::shared_ptr<Tensor>& x,
TensorTuple& tensor_indices);

Maybe<void> ApplyAdvancedIndexingUpdate(const std::shared_ptr<Tensor>& input,
const TensorTuple& indices,
const std::shared_ptr<Tensor>& value);
Maybe<Tensor> ApplyAdvancedIndexingUpdate(const std::shared_ptr<Tensor>& input,
const TensorTuple& indices,
const std::shared_ptr<Tensor>& value);

} // namespace functional
} // namespace one
Expand Down
Loading