Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim][PIR] support abs, instance_norm op backward in prim pir #60444

Merged
merged 14 commits into from
Jan 4, 2024
2 changes: 2 additions & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

# prim op with one input and one output, with no attribute
UNARY_PRIM_VJP_OPS = [
'abs_grad',
'erf_grad',
'exp_grad',
'floor_grad',
Expand Down Expand Up @@ -103,6 +104,7 @@
'dropout_grad',
'gelu_grad',
'hardswish_grad',
'instance_norm_grad',
'layer_norm_grad',
'leaky_relu_grad',
'relu_grad',
Expand Down
100 changes: 100 additions & 0 deletions paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ namespace paddle {
namespace primitive {
namespace details {

template <typename T>
void abs_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
auto sign_tmp = sign<T>(x);
set_output<T>(out_grad * sign_tmp, x_grad);
}
}

template <typename T>
void assign_grad(const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
Expand Down Expand Up @@ -930,6 +938,98 @@ void gather_nd_grad(const Tensor& x,
}
}

template <typename T>
void instance_norm_grad(const Tensor& x,
const paddle::optional<Tensor>& scale,
const Tensor& saved_mean,
const Tensor& saved_variance,
const Tensor& y_grad,
float epsilon,
Tensor* x_grad,
Tensor* scale_grad,
Tensor* bias_grad) {
const int n = x.dims()[0];
const int c = x.dims()[1];
const int h = x.dims()[2];
const int w = x.dims()[3];

auto promoted_y_grad = y_grad;
if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
promoted_y_grad = cast<T>(y_grad, phi::DataType::FLOAT32);
}

Tensor x_hat;
Tensor std_inv;
if (scale_grad || x_grad) {
auto promoted_x = x;
auto promoted_saved_mean = saved_mean;
auto promoted_saved_var = saved_variance;
if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
promoted_x = cast<T>(x, phi::DataType::FLOAT32);
promoted_saved_mean = cast<T>(saved_mean, phi::DataType::FLOAT32);
promoted_saved_var = cast<T>(saved_variance, phi::DataType::FLOAT32);
}
auto mean = reshape<T>(promoted_saved_mean, IntArray({n, c, 1, 1}))
.tile(IntArray({1, 1, h, w}));
std_inv = reshape<T>(promoted_saved_var, IntArray({n, c, 1, 1}))
.tile(IntArray({1, 1, h, w}));
x_hat = (promoted_x - mean) * std_inv;
}

// x_grad = scale * inv_var * (y_grad - y_grad.mean(2,3) - x_hat * (y_grad *
// x_hat).mean((h,w)))
if (x_grad) {
auto scale_data =
reshape<T>(scale.get_ptr() ? scale.get()
: full<T>(IntArray({c}), 1., x.dtype()),
IntArray({1, c, 1, 1}))
.tile(IntArray({n, 1, h, w}));
auto promoted_scale = scale_data;
if (scale_data.dtype() == phi::DataType::FLOAT16 ||
scale_data.dtype() == phi::DataType::BFLOAT16) {
promoted_scale = cast<T>(scale_data, phi::DataType::FLOAT32);
}
auto result =
(promoted_scale * std_inv) *
(promoted_y_grad -
promoted_y_grad.sum(IntArray({2, 3}), promoted_y_grad.dtype(), true) /
(h * w) -
(x_hat * ((promoted_y_grad * x_hat)
.sum(IntArray({2, 3}), promoted_y_grad.dtype(), true) /
(h * w))));
if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
set_output<T>(cast<T>(result, x.dtype()), x_grad);
} else {
set_output<T>(result, x_grad);
}
}
// scale_grad = x_hat * y_grad.sum(n, h, w)
if (scale_grad) {
auto result = (promoted_y_grad * x_hat).sum(IntArray({0, 2, 3}));
auto scale_dtype = scale.get_ptr() ? scale.get().dtype() : x.dtype();
if (scale_dtype == phi::DataType::FLOAT16 ||
scale_dtype == phi::DataType::BFLOAT16) {
set_output<T>(cast<T>(result, scale_dtype), scale_grad);
} else {
set_output<T>(result, scale_grad);
}
}
// d_bias = y_grad.sum(n, h, w)
if (bias_grad) {
auto result = promoted_y_grad.sum(IntArray({0, 2, 3}));
auto scale_dtype = scale.get_ptr() ? scale.get().dtype() : x.dtype();
if (scale_dtype == phi::DataType::FLOAT16 ||
scale_dtype == phi::DataType::BFLOAT16) {
set_output<T>(cast<T>(result, scale_dtype), bias_grad);
} else {
set_output<T>(result, bias_grad);
}
}
}

template <typename T>
void pad_grad(const Tensor& input,
const Tensor& out_grad,
Expand Down
21 changes: 17 additions & 4 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,12 @@ def test_check_output(self):
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', check_prim=True, check_pir=True
place,
['X'],
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down Expand Up @@ -1803,7 +1808,9 @@ def test_check_output(self):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', check_prim=True, check_pir=True)
self.check_grad(
['X'], 'Out', check_prim=True, check_pir=True, check_prim_pir=True
)


class TestAbs_ZeroDim(TestAbs):
Expand Down Expand Up @@ -4852,7 +4859,11 @@ def test_check_grad(self):
check_prim_pir=True,
)
create_test_act_fp16_class(
TestAbs, check_prim=True, enable_cinn=True, check_pir=True
TestAbs,
check_prim=True,
enable_cinn=True,
check_pir=True,
check_prim_pir=True,
)
create_test_act_fp16_class(TestCeil, grad_check=False, check_pir=True)
create_test_act_fp16_class(
Expand Down Expand Up @@ -5019,7 +5030,9 @@ def test_check_grad(self):
create_test_act_bf16_class(
TestSqrtComp, check_prim=True, check_pir=True, check_prim_pir=True
)
create_test_act_bf16_class(TestAbs, check_prim=True, check_pir=True)
create_test_act_bf16_class(
TestAbs, check_prim=True, check_pir=True, check_prim_pir=True
)
create_test_act_bf16_class(TestCeil, grad_check=False, check_pir=True)
create_test_act_bf16_class(
TestFloor,
Expand Down
6 changes: 5 additions & 1 deletion test/legacy_test/test_instance_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,11 @@ def test_check_output(self):

def test_check_grad(self):
self.check_grad(
['X', 'Scale', 'Bias'], 'Y', check_prim=True, check_pir=True
['X', 'Scale', 'Bias'],
'Y',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)

def init_test_case(self):
Expand Down
9 changes: 9 additions & 0 deletions test/legacy_test/test_instance_norm_op_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ def test_check_grad(self):
'Y',
check_prim=self.check_prim,
check_pir=True,
check_prim_pir=False
if os.getenv("FLAGS_enable_pir_in_executor")
else True,
)

def init_dtype(self):
Expand Down Expand Up @@ -284,6 +287,9 @@ def test_check_grad(self):
max_relative_error=self.max_relative_error,
check_prim=self.check_prim,
check_pir=True,
check_prim_pir=False
if os.getenv("FLAGS_enable_pir_in_executor")
else True,
)


Expand Down Expand Up @@ -356,6 +362,9 @@ def test_check_grad(self):
user_defined_grads=self.user_defined_grads,
check_prim=self.check_prim,
check_pir=True,
check_prim_pir=False
if os.getenv("FLAGS_enable_pir_in_executor")
else True,
)


Expand Down