Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lisamhy committed Oct 13, 2023
1 parent 9671c06 commit b7aafba
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 56 deletions.
66 changes: 39 additions & 27 deletions paddle/phi/kernels/cpu/scatter_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ void ScatterGradKernel(const Context &ctx,
reducer = "assign";
}

DenseTensor new_index = index;
DenseTensor new_source = source;
if (index.dims().size() == 0) {
new_index.Resize({1});

if (source.dims().size() == x.dims().size() - 1) {
auto dims = vectorize(source.dims());
dims.insert(dims.begin(), 1);
new_source.Resize(make_ddim(dims));
}
}

if (x_grad) {
ctx.template Alloc<T>(x_grad);
}
Expand Down Expand Up @@ -104,9 +116,10 @@ void ScatterGradKernel(const Context &ctx,
auto ones = Full<T, Context>(ctx, vectorize(out_grad.dims()), 1);
auto counts = include_self ? ones : zeros;

// N = N.index_add(dim, index, ones_like(source));
auto src_ones = Full<T, Context>(ctx, vectorize(source.dims()), 1);
auto src_cnts = IndexAdd<T, Context>(ctx, counts, index, src_ones, axis);
// N = N.index_add(dim, index, ones_like(new_source));
auto src_ones = Full<T, Context>(ctx, vectorize(new_source.dims()), 1);
auto src_cnts =
IndexAdd<T, Context>(ctx, counts, new_index, src_ones, axis);

// N.masked_fill_(N == 0, 1);
auto mask = Equal<T, Context>(ctx, src_cnts, zeros);
Expand Down Expand Up @@ -136,9 +149,9 @@ void ScatterGradKernel(const Context &ctx,
auto masked_self = Where<T, Context>(ctx, mask, ones, x);

// Tensor masked_self_result = masked_self.index_reduce(dim, index,
// source, reducer, include_self);
// new_source, reducer, include_self);
auto masked_self_result = Scatter<T, Context>(
ctx, x, index, source, false, axis, reducer, include_self);
ctx, x, index, new_source, false, axis, reducer, include_self);

// grad_self = grad * masked_self_result / masked_self;
auto grad_mul_masked_self_result =
Expand All @@ -148,16 +161,16 @@ void ScatterGradKernel(const Context &ctx,
}

if (updates_grad) {
// Tensor src_zero = source == 0;
auto src_ones = Full<T, Context>(ctx, vectorize(source.dims()), 1);
auto src_zeros = Full<T, Context>(ctx, vectorize(source.dims()), 1);
auto src_zero = Equal<T, Context>(ctx, source, src_zeros);
// Tensor src_zero = new_source == 0;
auto src_ones = Full<T, Context>(ctx, vectorize(new_source.dims()), 1);
auto src_zeros = Full<T, Context>(ctx, vectorize(new_source.dims()), 1);
auto src_zero = Equal<T, Context>(ctx, new_source, src_zeros);
auto src_zero_t = Cast<bool, Context>(ctx, src_zero, x.dtype());

// Tensor src_num_zeros = zeros_like(self).index_add(dim, index,
// src_zero.to(self.dtype())).index_select(dim, index);
auto src_num_zeros_inner =
IndexAdd<T, Context>(ctx, zeros, index, src_zero_t, axis);
IndexAdd<T, Context>(ctx, zeros, new_index, src_zero_t, axis);

auto src_num_zeros =
IndexSelect<T, Context>(ctx, src_num_zeros_inner, index, axis);
Expand All @@ -170,11 +183,11 @@ void ScatterGradKernel(const Context &ctx,
BitwiseAnd<bool, Context>(ctx, src_zero, src_num_zeros_equal_one);

// // For src positions with src_single_zero, (grad *
// result).index_select(dim,index) / source.masked_fill(src_zero, 1)
// result).index_select(dim,index) / new_source.masked_fill(src_zero, 1)
// // would incorrectly propagate zeros as the gradient
// Tensor masked_src = source.masked_fill(src_single_zero, 1);
// Tensor masked_src = new_source.masked_fill(src_single_zero, 1);
auto masked_src =
Where<T, Context>(ctx, src_single_zero_bool, src_ones, source);
Where<T, Context>(ctx, src_single_zero_bool, src_ones, new_source);

// Tensor masked_src_result = x.index_reduce(dim, index, masked_src,
// reducer, include_self);
Expand All @@ -184,7 +197,7 @@ void ScatterGradKernel(const Context &ctx,
// Tensor grad_src1 = where(src_single_zero,
// (grad * masked_src_result).index_select(dim,
// index), (grad * result).index_select(dim,
// index) / source.masked_fill(src_zero, 1));
// index) / new_source.masked_fill(src_zero, 1));
auto grad_mul_masked_src_result =
Multiply<T, Context>(ctx, out_grad, masked_src_result);
auto grad_mul_masked_src_result_index_select =
Expand All @@ -196,7 +209,7 @@ void ScatterGradKernel(const Context &ctx,
IndexSelect<T, Context>(ctx, grad_mul_out, index, axis);

auto src_masked_fill_one =
Where<T, Context>(ctx, src_zero, src_ones, source);
Where<T, Context>(ctx, src_zero, src_ones, new_source);
auto where_2 = Divide<T, Context>(
ctx, grad_mul_out_index_select, src_masked_fill_one);

Expand All @@ -208,8 +221,8 @@ void ScatterGradKernel(const Context &ctx,

// if ((src_num_zeros > 1).any().item<bool>()) {
// auto node = std::make_shared<DelayedError>(
// "index_reduce(): Double backward is unsupported for source when >1
// zeros in source are scattered to the same position in x",
// "index_reduce(): Double backward is unsupported for new_source when
// >1 zeros in new_source are scattered to the same position in x",
// /* num inputs */ 1);
// auto result = node->apply({ grad_src1 });
// grad_src = result[0];
Expand All @@ -223,13 +236,12 @@ void ScatterGradKernel(const Context &ctx,
auto src_num_zeros_greater_one_any =
Any<bool, Context>(ctx, src_num_zeros_greater_one, {}, false);

// bool *out_data =
// reinterpret_cast<bool*>(src_num_zeros_greater_one_any.data());
bool out_data = src_num_zeros_greater_one_any.template data<bool>()[0];
if (out_data) {
VLOG(3)
<< "index_reduce(): Double backward is unsupported for source when "
">1 zeros in source are scattered to the same position in x";
VLOG(3) << "index_reduce(): Double backward is unsupported for "
"new_source when "
">1 zeros in new_source are scattered to the same position "
"in x";
*updates_grad = grad_src1;
} else {
*updates_grad = grad_src1;
Expand All @@ -244,15 +256,15 @@ void ScatterGradKernel(const Context &ctx,
auto self_is_result = Equal<T, Context>(ctx, x, out);
auto self_is_result_t = Cast<bool, Context>(ctx, self_is_result, x.dtype());

// Tensor source_is_result = (source == value).to(x.scalar_type());
auto source_is_result = Equal<T, Context>(ctx, source, value);
// Tensor source_is_result = (new_source == value).to(x.scalar_type());
auto source_is_result = Equal<T, Context>(ctx, new_source, value);
auto source_is_result_t =
Cast<bool, Context>(ctx, source_is_result, x.dtype());

// Tensor N_to_distribute = self_is_result.index_add(axis, index,
// source_is_result);
auto N_to_distribute = IndexAdd<T, Context>(
ctx, self_is_result_t, index, source_is_result_t, axis);
ctx, self_is_result_t, new_index, source_is_result_t, axis);

// Tensor grad_distributed = grad / N_to_distribute;
auto grad_distributed = Divide<T, Context>(ctx, out_grad, N_to_distribute);
Expand Down Expand Up @@ -287,8 +299,8 @@ void ScatterGradKernel(const Context &ctx,
auto self_dims = out_grad.dims();
auto zeros = Full<T, Context>(ctx, vectorize(self_dims), 0);
// auto ones = Full<T, Context>(ctx, vectorize(self_dims), 1);
auto src_ones = Full<T, Context>(ctx, vectorize(source.dims()), 1);
auto src_cnts = IndexAdd<T, Context>(ctx, zeros, index, src_ones, axis);
auto src_ones = Full<T, Context>(ctx, vectorize(new_source.dims()), 1);
auto src_cnts = IndexAdd<T, Context>(ctx, zeros, new_index, src_ones, axis);
auto mask = Equal<T, Context>(ctx, src_cnts, zeros);
*x_grad = Where<T, Context>(ctx, mask, out_grad, zeros);
}
Expand Down
7 changes: 5 additions & 2 deletions paddle/phi/kernels/cpu/scatter_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,11 @@ void ScatterKernel(const Context &ctx,
new_index.Resize(make_ddim({index_dim}));
} else if (index.dims().size() == 0) {
new_index.Resize(make_ddim({1}));
if (updates.dims().size() == 0) {
new_updates.Resize(make_ddim({1}));

if (updates.dims().size() == x.dims().size() - 1) {
auto dims = vectorize(updates.dims());
dims.insert(dims.begin(), 1);
new_updates.Resize(make_ddim(dims));
}
} else {
PADDLE_ENFORCE_EQ(
Expand Down
64 changes: 39 additions & 25 deletions paddle/phi/kernels/gpu/scatter_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ void ScatterGradKernel(const Context &ctx,
reducer = "assign";
}

DenseTensor new_index = index;
DenseTensor new_source = source;
if (index.dims().size() == 0) {
new_index.Resize({1});

if (new_source.dims().size() == x.dims().size() - 1) {
auto dims = vectorize(new_source.dims());
dims.insert(dims.begin(), 1);
new_source.Resize(make_ddim(dims));
}
}

if (x_grad) {
ctx.template Alloc<T>(x_grad);
}
Expand Down Expand Up @@ -109,9 +121,10 @@ void ScatterGradKernel(const Context &ctx,
auto ones = Full<T, Context>(ctx, vectorize(out_grad.dims()), 1);
auto counts = include_self ? ones : zeros;

// N = N.index_add(dim, index, ones_like(source));
auto src_ones = Full<T, Context>(ctx, vectorize(source.dims()), 1);
auto src_cnts = IndexAdd<T, Context>(ctx, counts, index, src_ones, axis);
// N = N.index_add(dim, index, ones_like(new_source));
auto src_ones = Full<T, Context>(ctx, vectorize(new_source.dims()), 1);
auto src_cnts =
IndexAdd<T, Context>(ctx, counts, new_index, src_ones, axis);

// N.masked_fill_(N == 0, 1);
auto mask = Equal<T, Context>(ctx, src_cnts, zeros);
Expand Down Expand Up @@ -141,9 +154,9 @@ void ScatterGradKernel(const Context &ctx,
auto masked_self = Where<T, Context>(ctx, mask, ones, x);

// Tensor masked_self_result = masked_self.index_reduce(dim, index,
// source, reduce, include_self);
// new_source, reduce, include_self);
auto masked_self_result = Scatter<T, Context>(
ctx, x, index, source, false, axis, reducer, include_self);
ctx, x, index, new_source, false, axis, reducer, include_self);

// grad_self = grad * masked_self_result / masked_self;
auto grad_mul_masked_self_result =
Expand All @@ -153,16 +166,16 @@ void ScatterGradKernel(const Context &ctx,
}

if (updates_grad) {
// Tensor src_zero = source == 0;
auto src_ones = Full<T, Context>(ctx, vectorize(source.dims()), 1);
auto src_zeros = Full<T, Context>(ctx, vectorize(source.dims()), 1);
auto src_zero = Equal<T, Context>(ctx, source, src_zeros);
// Tensor src_zero = new_source == 0;
auto src_ones = Full<T, Context>(ctx, vectorize(new_source.dims()), 1);
auto src_zeros = Full<T, Context>(ctx, vectorize(new_source.dims()), 1);
auto src_zero = Equal<T, Context>(ctx, new_source, src_zeros);
auto src_zero_t = Cast<bool, Context>(ctx, src_zero, x.dtype());

// Tensor src_num_zeros = zeros_like(self).index_add(dim, index,
// src_zero.to(self.dtype())).index_select(dim, index);
auto src_num_zeros_inner =
IndexAdd<T, Context>(ctx, zeros, index, src_zero_t, axis);
IndexAdd<T, Context>(ctx, zeros, new_index, src_zero_t, axis);

auto src_num_zeros =
IndexSelect<T, Context>(ctx, src_num_zeros_inner, index, axis);
Expand All @@ -175,11 +188,11 @@ void ScatterGradKernel(const Context &ctx,
BitwiseAnd<bool, Context>(ctx, src_zero, src_num_zeros_equal_one);

// // For src positions with src_single_zero, (grad *
// result).index_select(dim,index) / source.masked_fill(src_zero, 1)
// result).index_select(dim,index) / new_source.masked_fill(src_zero, 1)
// // would incorrectly propagate zeros as the gradient
// Tensor masked_src = source.masked_fill(src_single_zero, 1);
// Tensor masked_src = new_source.masked_fill(src_single_zero, 1);
auto masked_src =
Where<T, Context>(ctx, src_single_zero_bool, src_ones, source);
Where<T, Context>(ctx, src_single_zero_bool, src_ones, new_source);

// Tensor masked_src_result = x.index_reduce(dim, index, masked_src,
// reduce, include_self);
Expand All @@ -189,7 +202,7 @@ void ScatterGradKernel(const Context &ctx,
// Tensor grad_src1 = where(src_single_zero,
// (grad * masked_src_result).index_select(dim,
// index), (grad * result).index_select(dim,
// index) / source.masked_fill(src_zero, 1));
// index) / new_source.masked_fill(src_zero, 1));
auto grad_mul_masked_src_result =
Multiply<T, Context>(ctx, out_grad, masked_src_result);
auto grad_mul_masked_src_result_index_select =
Expand All @@ -201,7 +214,7 @@ void ScatterGradKernel(const Context &ctx,
IndexSelect<T, Context>(ctx, grad_mul_out, index, axis);

auto src_masked_fill_one =
Where<T, Context>(ctx, src_zero, src_ones, source);
Where<T, Context>(ctx, src_zero, src_ones, new_source);
auto where_2 = Divide<T, Context>(
ctx, grad_mul_out_index_select, src_masked_fill_one);

Expand All @@ -213,8 +226,8 @@ void ScatterGradKernel(const Context &ctx,

// if ((src_num_zeros > 1).any().item<bool>()) {
// auto node = std::make_shared<DelayedError>(
// "index_reduce(): Double backward is unsupported for source when >1
// zeros in source are scattered to the same position in x",
// "index_reduce(): Double backward is unsupported for new_source when
// >1 zeros in new_source are scattered to the same position in x",
// /* num inputs */ 1);
// auto result = node->apply({ grad_src1 });
// grad_src = result[0];
Expand Down Expand Up @@ -242,9 +255,10 @@ void ScatterGradKernel(const Context &ctx,
// src_num_zeros_greater_one_any.template data<T>()[0];

if (out_data) {
VLOG(3)
<< "index_reduce(): Double backward is unsupported for source when "
">1 zeros in source are scattered to the same position in x";
VLOG(3) << "index_reduce(): Double backward is unsupported for "
"new_source when "
">1 zeros in new_source are scattered to the same position "
"in x";
*updates_grad = grad_src1;
} else {
*updates_grad = grad_src1;
Expand All @@ -258,15 +272,15 @@ void ScatterGradKernel(const Context &ctx,
auto self_is_result = Equal<T, Context>(ctx, x, out);
auto self_is_result_t = Cast<bool, Context>(ctx, self_is_result, x.dtype());

// Tensor source_is_result = (source == value).to(x.scalar_type());
auto source_is_result = Equal<T, Context>(ctx, source, value);
// Tensor source_is_result = (new_source == value).to(x.scalar_type());
auto source_is_result = Equal<T, Context>(ctx, new_source, value);
auto source_is_result_t =
Cast<bool, Context>(ctx, source_is_result, x.dtype());

// Tensor N_to_distribute = self_is_result.index_add(axis, index,
// source_is_result);
auto N_to_distribute = IndexAdd<T, Context>(
ctx, self_is_result_t, index, source_is_result_t, axis);
ctx, self_is_result_t, new_index, source_is_result_t, axis);

// Tensor grad_distributed = grad / N_to_distribute;
auto grad_distributed = Divide<T, Context>(ctx, out_grad, N_to_distribute);
Expand Down Expand Up @@ -301,8 +315,8 @@ void ScatterGradKernel(const Context &ctx,
auto self_dims = out_grad.dims();
auto zeros = Full<T, Context>(ctx, vectorize(self_dims), 0);
// auto ones = Full<T, Context>(ctx, vectorize(self_dims), 1);
auto src_ones = Full<T, Context>(ctx, vectorize(source.dims()), 1);
auto src_cnts = IndexAdd<T, Context>(ctx, zeros, index, src_ones, axis);
auto src_ones = Full<T, Context>(ctx, vectorize(new_source.dims()), 1);
auto src_cnts = IndexAdd<T, Context>(ctx, zeros, new_index, src_ones, axis);
auto mask = Equal<T, Context>(ctx, src_cnts, zeros);
*x_grad = Where<T, Context>(ctx, mask, out_grad, zeros);
}
Expand Down
7 changes: 5 additions & 2 deletions paddle/phi/kernels/gpu/scatter_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,11 @@ void ScatterKernel(const Context& ctx,
new_index.Resize(make_ddim({index_dim}));
} else if (index.dims().size() == 0) {
new_index.Resize(make_ddim({1}));
if (updates.dims().size() == 0) {
new_updates.Resize(make_ddim({1}));

if (updates.dims().size() == x.dims().size() - 1) {
auto dims = vectorize(updates.dims());
dims.insert(dims.begin(), 1);
new_updates.Resize(make_ddim(dims));
}
} else {
PADDLE_ENFORCE_EQ(
Expand Down

0 comments on commit b7aafba

Please sign in to comment.