Skip to content

Commit

Permalink
Scatter 0D index for gather, 0D index and 0D updates for scatter. (#4…
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu committed Dec 3, 2022
1 parent a3ae080 commit f9815bf
Show file tree
Hide file tree
Showing 12 changed files with 377 additions and 147 deletions.
78 changes: 55 additions & 23 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1268,37 +1268,69 @@ void GatherInferMeta(const MetaTensor& x,
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
index_dims.size(),
1,
index_dims.size() == 1 || index_dims.size() == 0,
true,
phi::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d",
"The index should be 0D or 1D, when it is not 2D, but we get %d",
index_dims.size()));
}

auto input_dim = x.dims();
auto axis_v = axis.to<int>();
if (axis.FromTensor() || axis_v == 0) {
// if axis.FromTensor(), we can not obtain correct shape of output
int batch_size = index_dims[0];
phi::DDim output_dims(input_dim);
output_dims[0] = batch_size;
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
} else {
int index_size = index_dims[0];
std::vector<int> out_dim_vec;
for (int i = 0; i < axis_v; i++) {
out_dim_vec.push_back(input_dim[i]);
if (index_dims.size() == 0) {
// 0D index will decrease the dimension
if (input_dim.size() == 1) {
// the index is a 0D tensor and the x is a 1D tensor
out->set_dims(phi::DDim(phi::Dim<0>()));
} else {
if (axis.FromTensor() || axis_v == 0) {
// decrease the output dimension
std::vector<int> out_dim_vec;
for (int i = 1; i < input_dim.size(); ++i) {
out_dim_vec.emplace_back(input_dim[i]);
}
auto output_dims = phi::make_ddim(out_dim_vec);
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
} else {
std::vector<int> out_dim_vec;
for (int i = 0; i < axis_v; i++) {
out_dim_vec.push_back(input_dim[i]);
}
for (int i = axis_v + 1; i < input_dim.size(); i++) {
out_dim_vec.push_back(input_dim[i]);
}
auto output_dims = phi::make_ddim(out_dim_vec);
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}
}
out_dim_vec.push_back(index_size);
for (int i = axis_v + 1; i < input_dim.size(); i++) {
out_dim_vec.push_back(input_dim[i]);
} else {
if (axis.FromTensor() || axis_v == 0) {
// if axis.FromTensor(), we can not obtain correct shape of output
int batch_size = index_dims[0];
phi::DDim output_dims(input_dim);
output_dims[0] = batch_size;
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
} else {
int index_size = index_dims[0];
std::vector<int> out_dim_vec;
for (int i = 0; i < axis_v; i++) {
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
for (int i = axis_v + 1; i < input_dim.size(); i++) {
out_dim_vec.push_back(input_dim[i]);
}
auto output_dims = phi::make_ddim(out_dim_vec);
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}
auto output_dims = phi::make_ddim(out_dim_vec);
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}
}

Expand Down
49 changes: 26 additions & 23 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -995,31 +995,34 @@ void ScatterInferMeta(const MetaTensor& x,
"index is a 2D tensor, but we get %d.",
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(index_dims.size() == 1 || index_dims.size() == 0,
true,
phi::errors::InvalidArgument(
"The index should be a 0D or 1D tensor when the "
"index is not a 2D tensor, but we get %d.",
index_dims.size()));
}
if (index_dims.size() != 0) {
PADDLE_ENFORCE_EQ(
index_dims.size(),
1,
phi::errors::InvalidArgument("The index should be a 1D tensor when the "
"index is not a 2D tensor, but we get %d.",
index_dims.size()));
(ref_dims.size() == updates_dims.size()),
true,
phi::errors::InvalidArgument(
"When the Input(Updates) is not a 0D tensor, the "
"Input(X) and Input(Updates) should have the same shape size, "
"but received the size of Input(x)'s shape is %d, the size of "
"Input(Updates)'s shape is %d.",
ref_dims.size(),
updates_dims.size()));
PADDLE_ENFORCE_EQ(
updates_dims[0],
index_dims[0],
phi::errors::InvalidArgument(
"Input(Updates) and Input(Ids) should have same batch-size, but"
" received Input(Updates)'s batch-size is %d, Input(Ids)'s "
"batch-size is %d.",
updates_dims[0],
index_dims[0]));
}
PADDLE_ENFORCE_EQ(
ref_dims.size(),
updates_dims.size(),
phi::errors::InvalidArgument(
"Input(X) and Input(Updates) should have the same shape size, "
"but received the size of Input(x)'s shape is %d, the size of "
"Input(Updates)'s shape is %d.",
ref_dims.size(),
updates_dims.size()));
PADDLE_ENFORCE_EQ(
updates_dims[0],
index_dims[0],
phi::errors::InvalidArgument(
"Input(Updates) and Input(Ids) should have same batch-size, but"
" received Input(Updates)'s batch-size is %d, Input(Ids)'s "
"batch-size is %d.",
updates_dims[0],
index_dims[0]));
out->set_dims(ref_dims);
out->share_lod(x);
out->set_dtype(x.dtype());
Expand Down
9 changes: 4 additions & 5 deletions paddle/phi/kernels/funcs/gather.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,9 @@ void GPUGather(const phi::GPUContext& ctx,
}

// index size
int64_t index_size = index.dims()[0];
if (index_size == 0) return;
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];

auto src_dims = src.dims();
phi::DDim output_dims(src_dims);
output_dims[0] = index_size;

// slice size
int64_t slice_size = 1;
Expand Down Expand Up @@ -246,7 +243,9 @@ void GatherV2CUDAFunction(const DenseTensor* input,
inner_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
if (index->dims().size() != 0) {
out_dim_vec.push_back(index_size);
}
for (int i = axis_index + 1; i < input_dim.size(); i++) {
outer_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
Expand Down
28 changes: 18 additions & 10 deletions paddle/phi/kernels/funcs/gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ void CPUGather(const phi::CPUContext& ctx,
const DenseTensor& src,
const DenseTensor& index,
DenseTensor* output) {
// check index of shape 1-D
if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(
index.dims()[1],
Expand All @@ -48,14 +47,15 @@ void CPUGather(const phi::CPUContext& ctx,
"in gather_op, but received value is [%d].",
index.dims()[1]));
} else {
PADDLE_ENFORCE_EQ(index.dims().size(),
1,
phi::errors::InvalidArgument(
"index.dims().size() should be 1 or 2 in gather_op,"
"but received shape's size is [%d].",
index.dims().size()));
PADDLE_ENFORCE_EQ(
index.dims().size() == 1 || index.dims().size() == 0,
true,
phi::errors::InvalidArgument(
"The index should be 0D or 1D, when it is not 2D, but we get %d",
index.dims().size()));
}
int64_t index_size = index.dims()[0];

int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];

auto src_dims = src.dims();

Expand Down Expand Up @@ -188,7 +188,9 @@ void GatherV2Function(const phi::CPUContext& ctx,
inner_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
if (index->dims().size() != 0) {
out_dim_vec.push_back(index_size);
}
for (int i = axis_index + 1; i < input_dim.size(); i++) {
outer_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
Expand Down Expand Up @@ -224,7 +226,13 @@ void GatherV2GradFunction(const phi::CPUContext& ctx,

if (input->numel() == 0) return;
int axis_index = axis;
int64_t input_index_dim_size = input_dim[axis_index];
int64_t input_index_dim_size;
if (input_dim.size() == out->dims().size()) {
input_index_dim_size = input_dim[axis_index];
} else {
// 0d index
input_index_dim_size = 1;
}

int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
Expand Down
26 changes: 16 additions & 10 deletions paddle/phi/kernels/funcs/scatter.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ void GPUScatterAssign(const phi::GPUContext& ctx,
const DenseTensor& index,
DenseTensor* output,
bool overwrite = true) {
// check index of shape 1-D
if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(
index.dims()[1],
Expand All @@ -132,26 +131,33 @@ void GPUScatterAssign(const phi::GPUContext& ctx,
"But received value is [%d]",
index.dims()[1]));
} else {
PADDLE_ENFORCE_EQ(index.dims().size(),
1,
phi::errors::InvalidArgument(
"index.dims().size() should be 1 or 2 in scatter_op."
"But received value is [%d]",
index.dims().size()));
PADDLE_ENFORCE_EQ(
index.dims().size() == 1 || index.dims().size() == 0,
true,
phi::errors::InvalidArgument(
"index.dims().size() should be 0, 1 or 2 in scatter_op."
"But received value is [%d]",
index.dims().size()));
}
int64_t index_size = index.dims()[0];

int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];

auto src_dims = src.dims();
phi::DDim output_dims(src_dims);
output_dims[0] = index_size;

// slice size
int64_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
size_t slice_size = 1;
if (index.dims().size() != 0) {
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
} else {
for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
}

const T* p_src = src.data<T>();
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();

const size_t& slice_bytes = slice_size * sizeof(T);

// set block and grid num
Expand Down

0 comments on commit f9815bf

Please sign in to comment.