Skip to content

Commit

Permalink
【Prim】Custom softmax grad (#51474)
Browse files Browse the repository at this point in the history
* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

* add unittest

* fix typo

* fix typo

* fix map.at

* fix find

* fix test

* fix cinn cache key structure realize

* using ordered map for attributes

* add test by review advice

---------

Co-authored-by: jiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* Pr 50885 (#7)

* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

* add unittest

* fix typo

* fix typo

* fix map.at

* fix find

* fix test

* fix cinn cache key structure realize

* using ordered map for attributes

* add test by review advice

---------

Co-authored-by: jiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix code in a dy2static-friendly way.

* [dystatic] add hooker for prim

---------

Co-authored-by: Aurelius84 <zhangliujie@baidu.com>
Co-authored-by: jiangcheng <thisjiang@qq.com>
Co-authored-by: cxxly <chenxx_id@163.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix cast prim and vjp dtype mapping error bug

* Cxx prim custom vjp (#8)

* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

---------

Co-authored-by: jiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* Pr 50885 (#7)

* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

---------

Co-authored-by: jiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix code in a dy2static-friendly way.

* [dystatic] add hooker for prim

---------

Co-authored-by: Aurelius84 <zhangliujie@baidu.com>
Co-authored-by: jiangcheng <thisjiang@qq.com>
Co-authored-by: cxxly <chenxx_id@163.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix cast prim and vjp dtype mapping error bug

* [dy2static-ci] fix dy2static ci errors.

---------

Co-authored-by: Aurelius84 <zhangliujie@baidu.com>
Co-authored-by: jiangcheng <thisjiang@qq.com>
Co-authored-by: cxxly <chenxx_id@163.com>

* [Prim] enable whitelist and blacklist for custom_vjp

* support softmax grad

* remove additional code

* add test back

---------

Co-authored-by: Aurelius84 <zhangliujie@baidu.com>
Co-authored-by: jiangcheng <thisjiang@qq.com>
Co-authored-by: cxxly <chenxx_id@163.com>
Co-authored-by: xiongkun <807377414@qq.com>
  • Loading branch information
5 people committed Mar 15, 2023
1 parent 50df017 commit f124c86
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 0 deletions.
21 changes: 21 additions & 0 deletions paddle/fluid/operators/softmax_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
Expand Down Expand Up @@ -156,6 +159,23 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpMaker<T> {
}
};

class SoftmaxCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;

public:
void Apply() override {
paddle::Tensor out = this->GetSingleForwardOutput("Out");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::Tensor dx = this->GetSingleInputGrad("X");
auto* dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx);
int axis = static_cast<int>(this->Attr<int>("axis"));
VLOG(6) << "Runing softmax_grad composite func";
prim::softmax_grad<prim::DescTensor>(out, out_grad, axis, dx_ptr);
this->RecoverOutputName(dx, dx_name);
}
};

DECLARE_INPLACE_OP_INFERER(SoftmaxInplaceInferer, {"X", "Out"});

} // namespace operators
Expand All @@ -172,6 +192,7 @@ REGISTER_OPERATOR(softmax,
ops::SoftmaxOpInferVarType,
ops::SoftmaxOpGradMaker<paddle::framework::OpDesc>,
ops::SoftmaxOpGradMaker<paddle::imperative::OpBase>,
ops::SoftmaxCompositeGradOpMaker,
ops::SoftmaxInplaceInferer,
SoftmaxInferShapeFunctor);
DECLARE_INFER_SHAPE_FUNCTOR(softmax_grad,
Expand Down
29 changes: 29 additions & 0 deletions paddle/fluid/prim/api/composite_backward/composite_backward_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,35 @@ using Tensor = paddle::Tensor;
using IntArray = paddle::experimental::IntArrayBase<paddle::Tensor>;
// This function should have as same signature as phi, which defined in
// paddle/phi/api/backward/backward_api.h
template <typename T>
void softmax_grad(const Tensor& out,
const Tensor& out_grad,
int axis,
Tensor* x_grad) {
if (x_grad) {
if (out_grad.dims().size() > 0) {
if (axis >= 0) {
auto new_out_grad = out_grad * out;
auto tmp_x_grad = new_out_grad -
out * sum<T>(new_out_grad, {axis}, out.dtype(), true);
set_output<T>(tmp_x_grad, x_grad);
} else {
auto new_out_grad = out_grad * out;
auto tmp_x_grad =
new_out_grad - out * sum<T>(new_out_grad,
{out.dims().size() + axis},
out.dtype(),
true);
set_output<T>(tmp_x_grad, x_grad);
}
} else {
set_output<T>(
full<T>(phi::vectorize(out_grad.dims()), 0.0, out_grad.dtype()),
x_grad);
}
}
}

template <typename T>
void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) {
if (x_grad) {
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,7 @@
param : [out]
kernel :
func : softmax_grad
composite : softmax_grad(out, out_grad, axis, x_grad)

- backward_op : spectral_norm_grad
forward : spectral_norm (Tensor weight, Tensor u, Tensor v, int dim, int power_iters, float eps) -> Tensor(out)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from utils import TOLERANCE

import paddle
import paddle.nn.functional as F
from paddle.fluid import core


def generate_data(shape, dtype="float32"):
np_data = np.random.random(shape).astype(dtype)
return np_data


class Attr:
def __init__(self) -> None:
self.dtype = None
self.axis = -1
self.shape = None

def set_dtype(self, dtype) -> None:
self.dtype = dtype
return

def set_axis(self, axis) -> None:
self.axis = axis
return

def set_shape(self, shape) -> None:
self.shape = shape
return

def get_rtol(self, flag):
rtol = TOLERANCE[self.dtype][flag].get("rtol")
return rtol

def get_atol(self, flag):
atol = TOLERANCE[self.dtype][flag].get("atol")
return atol


attrs = Attr()


def fn(x):
return F.softmax(x, axis=attrs.axis, dtype=attrs.dtype)


def expect_grad(inputs):
paddle.disable_static()
inputs.stop_gradient = False
res = fn(inputs)

gradients = paddle.grad(res, inputs)
return gradients


class TestCompositeSoftmax(unittest.TestCase):
def setUp(self):
self.dtypes = ["float32", "float64"]
self.shapes = [[2, 3, 4], [2, 3]]
self.axes = [-1, 0, 1]

def cal_composite_grad(self, inputs):
paddle.enable_static()
core._set_prim_forward_enabled(True)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
'x', shape=inputs.shape, dtype=str(inputs.dtype)
)
x.stop_gradient = False
y = fn(x)
blocks = main_program.blocks

fwd_ops = [op.type for op in blocks[0].ops]
# Ensure that softmax in original block
self.assertTrue('softmax' in fwd_ops)

paddle.incubate.autograd.primapi.to_prim(blocks)

fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops_new)

z = paddle.static.gradients([y], x)
fwd_ops_grad = [op.type for op in blocks[0].ops]
# Ensure that softmax_grad not in grad block

self.assertTrue('softmax_grad' not in fwd_ops_grad)

exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z])
paddle.disable_static()
core._set_prim_forward_enabled(False)
return res

def compare_backward(self):
np_data = generate_data(attrs.shape)
tensor_data = paddle.to_tensor(np_data)

expect = expect_grad(tensor_data)[0].numpy()
actual = self.cal_composite_grad(np_data)[0]

assert expect.dtype == actual.dtype
np.testing.assert_allclose(
expect,
actual,
rtol=attrs.get_rtol("backward"),
atol=attrs.get_atol("backward"),
)

def test_backward(self):
for i in self.axes:
for j in self.dtypes:
for t in self.shapes:
attrs.set_axis(i)
attrs.set_dtype(j)
attrs.set_shape(t)
self.compare_backward()


class TestCompositeSoftmaxPrimBackward(unittest.TestCase):
"test composite softmax and prim backward"

def setUp(self):
core._set_prim_backward_enabled(True)
self.dtypes = ["float32", "float64"]
self.shapes = [[], [2, 3, 4], [2, 3]]
self.axes = [-1, 0, 1]

def cal_composite_grad(self, inputs):
paddle.enable_static()
core._set_prim_all_enabled(True)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
'x', shape=inputs.shape, dtype=str(inputs.dtype)
)
x.stop_gradient = False
y = fn(x)
blocks = main_program.blocks
z = paddle.static.gradients([y], x)
paddle.incubate.autograd.primapi.to_prim(blocks)

exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z])
paddle.disable_static()
core._set_prim_all_enabled(False)
return res

def compare_backward(self):
if not attrs.shape and attrs.axis not in [-1, 0]:
# op softmax does not support both case
return
np_data = generate_data(attrs.shape)
tensor_data = paddle.to_tensor(np_data)

expect = expect_grad(tensor_data)[0].numpy()
actual = self.cal_composite_grad(np_data)[0]

assert expect.dtype == actual.dtype
np.testing.assert_allclose(
expect,
actual,
rtol=attrs.get_rtol("prim_backward"),
atol=attrs.get_rtol("prim_backward"),
)

def test_prim_backward(self):
for i in self.axes:
for j in self.dtypes:
for t in self.shapes:
attrs.set_axis(i)
attrs.set_dtype(j)
attrs.set_shape(t)
self.compare_backward()


if __name__ == '__main__':
unittest.main()

0 comments on commit f124c86

Please sign in to comment.