Skip to content

Commit

Permalink
[EinsumOp] Make EinsumOp support bfloat16. (#43085)
Browse files Browse the repository at this point in the history
* change einsum_v2 as default and add new flags: FLAG_einsum_opt=1|0

* make EInsumOP support bf16

* add unittest for BF16

* add condition for test_BF16

* fix bugs

* fix
  • Loading branch information
2742195759 committed May 31, 2022
1 parent 0ae8a2d commit a4bb38c
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 16 deletions.
2 changes: 2 additions & 0 deletions paddle/phi/kernels/funcs/eigen/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"

Expand Down Expand Up @@ -73,6 +74,7 @@ struct EigenBroadcastGrad<Eigen::DefaultDevice, T, Rank> {
template struct FUNCTOR<Eigen::DefaultDevice, T, 6>
INSTANTIATION(EigenBroadcast, bool);
INSTANTIATION(EigenBroadcast, dtype::float16);
INSTANTIATION(EigenBroadcast, dtype::bfloat16);
INSTANTIATION(EigenBroadcast, float);
INSTANTIATION(EigenBroadcast, double);
INSTANTIATION(EigenBroadcast, int);
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/funcs/eigen/broadcast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"

Expand Down Expand Up @@ -73,6 +74,7 @@ struct EigenBroadcastGrad<Eigen::GpuDevice, T, Rank> {
template struct FUNCTOR<Eigen::GpuDevice, T, 6>
INSTANTIATION(EigenBroadcast, bool);
INSTANTIATION(EigenBroadcast, dtype::float16);
INSTANTIATION(EigenBroadcast, dtype::bfloat16);
INSTANTIATION(EigenBroadcast, float);
INSTANTIATION(EigenBroadcast, double);
INSTANTIATION(EigenBroadcast, int);
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/einsum_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ PD_REGISTER_KERNEL(einsum_grad,
phi::EinsumGradKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/tile_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(tile,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
32 changes: 18 additions & 14 deletions paddle/phi/kernels/impl/einsum_grad_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,24 @@ void EinsumGradKernel(const Context& dev_ctx,
// release the cache tensor dTC to save memory right now. they are useless
// now.
cache.clear();
*(x_grad[0]) = PerformTileAndReduction<T, Context>(dev_ctx,
labeltype,
labelshape,
broadcast_dims,
ellipsis_dims[0],
ops[0],
dA);
*(x_grad[1]) = PerformTileAndReduction<T, Context>(dev_ctx,
labeltype,
labelshape,
broadcast_dims,
ellipsis_dims[1],
ops[1],
dB);
if (x_grad[0]) {
*(x_grad[0]) = PerformTileAndReduction<T, Context>(dev_ctx,
labeltype,
labelshape,
broadcast_dims,
ellipsis_dims[0],
ops[0],
dA);
}
if (x_grad[1]) {
*(x_grad[1]) = PerformTileAndReduction<T, Context>(dev_ctx,
labeltype,
labelshape,
broadcast_dims,
ellipsis_dims[1],
ops[1],
dB);
}
}
}
} // namespace phi
18 changes: 18 additions & 0 deletions python/paddle/fluid/tests/unittests/test_einsum_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,5 +478,23 @@ def test_shape(self):
self.assertEqual(C.shape, (-1, 384))


class TestBF16(unittest.TestCase):
"""
EinsumOp support bfloat16 type, add unittest here for the correctness.
"""

def test_shape(self):
cuda_major = paddle.version.cuda().split('.')[0].strip()
if paddle.is_compiled_with_cuda() and int(cuda_major) >= 11:
""" MatmulKernel support bfloat16 only if cuda_major > 11.0.
"""
A = paddle.to_tensor(np.array([1.0, 2.0])).astype(paddle.bfloat16)
A = A.cuda()
B = paddle.to_tensor(np.array([2.0, 3.0])).astype(paddle.bfloat16)
B = B.cuda()
C = paddle.einsum('i,i->', A, B)
self.assertEqual(C.item(), 8.0)


if __name__ == "__main__":
unittest.main()

0 comments on commit a4bb38c

Please sign in to comment.