Skip to content

Commit

Permalink
Merge pull request #9143 from kexinzhao/numpy_conv2d_pool2d_fp16
Browse files Browse the repository at this point in the history
Add float16 support for cudnn conv2d
  • Loading branch information
kexinzhao committed Mar 16, 2018
2 parents c0511c3 + e967d19 commit 8e73101
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 14 deletions.
19 changes: 12 additions & 7 deletions paddle/fluid/operators/conv_cudnn_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -133,7 +134,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv forward ---------------------
T alpha = 1.0f, beta = 0.0f;
typename platform::CudnnDataType<T>::ScalingParamType alpha = 1.0f,
beta = 0.0f;
for (int i = 0; i < groups; i++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
Expand Down Expand Up @@ -280,7 +282,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv backward data ---------------------
T alpha = 1.0f, beta = 0.0f;
typename platform::CudnnDataType<T>::ScalingParamType alpha = 1.0f,
beta = 0.0f;
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad.
Expand Down Expand Up @@ -315,16 +318,18 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle

REGISTER_OP_KERNEL(conv2d, CUDNN, ::paddle::platform::CUDAPlace,
namespace plat = paddle::platform;
REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<float>,
paddle::operators::CUDNNConvOpKernel<double>);
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, ::paddle::platform::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<double>,
paddle::operators::CUDNNConvOpKernel<plat::float16>);
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>);

REGISTER_OP_KERNEL(conv3d, CUDNN, ::paddle::platform::CUDAPlace,
REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<float>,
paddle::operators::CUDNNConvOpKernel<double>);
REGISTER_OP_KERNEL(conv3d_grad, CUDNN, ::paddle::platform::CUDAPlace,
REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>);
17 changes: 14 additions & 3 deletions paddle/fluid/operators/conv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,23 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
}
#endif

auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Input")->type());
auto filter_data_type =
framework::ToDataType(ctx.Input<Tensor>("Filter")->type());
PADDLE_ENFORCE_EQ(input_data_type, filter_data_type,
"input and filter data type should be consistent");

if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN,
"float16 can only be used when CUDNN is used");
}

std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout_, library_);
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
library_);
}

Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/platform/cudnn_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/macros.h"

namespace paddle {
Expand Down Expand Up @@ -80,6 +81,22 @@ enum class PoolingMode {
template <typename T>
class CudnnDataType;

template <>
class CudnnDataType<float16> {
public:
static const cudnnDataType_t type = CUDNN_DATA_HALF;
// The scaling param type is float for HALF and FLOAT tensors
typedef const float ScalingParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
return &v;
}
static ScalingParamType* kZero() {
static ScalingParamType v = 0.0;
return &v;
}
};

template <>
class CudnnDataType<float> {
public:
Expand Down
22 changes: 22 additions & 0 deletions python/paddle/fluid/tests/unittests/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,28 @@ def _numpy_to_lod_tensor(np_value, lod, place):
tensor.set_lod(lod)
return tensor

@staticmethod
def np_dtype_to_fluid_dtype(input):
"""Change the dtype of float16 numpy array
numpy float16 is binded to paddle::platform::float16
in tensor_py.h via the help of uint16 data type since
the internal memory representation of float16 is
uint16_t in paddle and np.uint16 in numpy, which are
themselves binded together by pybind.
Args:
input: input numpy array
Returns:
input: if the dtype of input is np.float16, its dtype will be
changed to np.uint16 so that the internal memory will be
reinterpreted input as of dtype np.uint16.
"""
if input.dtype == np.float16:
input.dtype = np.uint16
return input

def _get_gradient(self, input_to_check, place, output_names, no_grad_set):
prog = Program()
block = prog.global_block()
Expand Down
88 changes: 84 additions & 4 deletions python/paddle/fluid/tests/unittests/test_conv2d_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,24 @@ def setUp(self):
self.init_op_type()
self.init_group()
self.init_dilation()
self.init_data_type()
self.init_test_case()

conv2d_param = {
'stride': self.stride,
'pad': self.pad,
'dilation': self.dilations
}
input = np.random.random(self.input_size).astype("float32")
filter = np.random.random(self.filter_size).astype("float32")

input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
output = conv2d_forward_naive(input, filter, self.groups,
conv2d_param).astype('float32')
conv2d_param).astype(self.dtype)

self.inputs = {'Input': input, 'Filter': filter}
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
Expand All @@ -99,6 +104,8 @@ def test_check_output(self):
self.check_output()

def test_check_grad(self):
if self.dtype == np.float16:
return
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
Expand All @@ -111,6 +118,8 @@ def test_check_grad(self):
set(['Input', 'Filter']), 'Output', max_relative_error=0.02)

def test_check_grad_no_filter(self):
if self.dtype == np.float16:
return
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
Expand All @@ -126,6 +135,8 @@ def test_check_grad_no_filter(self):
no_grad_set=set(['Filter']))

def test_check_grad_no_input(self):
if self.dtype == np.float16:
return
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
Expand All @@ -148,6 +159,9 @@ def init_test_case(self):
f_c = self.input_size[1] / self.groups
self.filter_size = [6, f_c, 3, 3]

def init_data_type(self):
self.dtype = np.float32

def init_dilation(self):
self.dilations = [1, 1]

Expand Down Expand Up @@ -232,36 +246,102 @@ def init_op_type(self):
self.op_type = "conv2d"


class TestFP16CUDNN(TestCUDNN):
def init_data_type(self):
self.dtype = np.float16

def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)


class TestCUDNNWithPad(TestWithPad):
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv2d"


class TestFP16CUDNNWithPad(TestCUDNNWithPad):
def init_data_type(self):
self.dtype = np.float16

def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)


class TestCUDNNWithStride(TestWithStride):
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv2d"


class TestFP16CUDNNWithStride(TestCUDNNWithStride):
def init_data_type(self):
self.dtype = np.float16

def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)


class TestCUDNNWithGroup(TestWithGroup):
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv2d"


class TestFP16CUDNNWithGroup(TestCUDNNWithGroup):
def init_data_type(self):
self.dtype = np.float16

def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)


class TestCUDNNWith1x1(TestWith1x1):
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv2d"


class TestFP16CUDNNWith1x1(TestCUDNNWith1x1):
def init_data_type(self):
self.dtype = np.float16

def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)


class TestCUDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1):
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv2d"


class TestFP16CUDNNWithInput1x1Filter1x1(TestCUDNNWithInput1x1Filter1x1):
def init_data_type(self):
self.dtype = np.float16

def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)


class TestDepthwiseConv(TestConv2dOp):
def init_test_case(self):
self.pad = [1, 1]
Expand Down

0 comments on commit 8e73101

Please sign in to comment.