Skip to content

Commit

Permalink
[fp16] fix fp16 support for nn.PairwiseDistance (#50849)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ainavo committed Feb 27, 2023
1 parent ebea088 commit 587120e
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 5 deletions.
91 changes: 91 additions & 0 deletions python/paddle/fluid/tests/unittests/test_pairwise_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,97 @@ def test_pairwise_distance_broadcast_2(self):
dygraph_functional_ret, excepted_value, rtol=1e-05
)

def test_pairwise_distance_fp16(self):
epsilon = 1e-6
all_shape = [[5], [100, 100]]
dtypes = ['float16']
p_list = [-1, 0, 1, 2, np.inf, -np.inf]
places = [paddle.CPUPlace()]
if paddle.device.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
keeps = [False, True]
for place in places:
for shape in all_shape:
for dtype in dtypes:
for p in p_list:
for keepdim in keeps:
x_np = np.random.random(shape).astype(dtype)
y_np = np.random.random(shape).astype(dtype)
# Currently, the CPU does not support float16
if dtype == "float16" and isinstance(
place, paddle.CPUPlace
):
continue
static_ret = test_static(
place,
x_np,
y_np,
p,
epsilon=epsilon,
keepdim=keepdim,
)
dygraph_ret = test_dygraph(
place,
x_np,
y_np,
p,
epsilon=epsilon,
keepdim=keepdim,
)
excepted_value = np_pairwise_distance(
x_np, y_np, p, epsilon=epsilon, keepdim=keepdim
)

self.assertEqual(
static_ret.shape, excepted_value.shape
)
self.assertEqual(
dygraph_ret.shape, excepted_value.shape
)

np.testing.assert_allclose(
static_ret, excepted_value, atol=1e-03
)
np.testing.assert_allclose(
dygraph_ret, excepted_value, atol=1e-03
)
static_functional_ret = test_static(
place,
x_np,
y_np,
p,
epsilon=epsilon,
keepdim=keepdim,
)
dygraph_functional_ret = test_dygraph(
place,
x_np,
y_np,
p,
epsilon=epsilon,
keepdim=keepdim,
)

self.assertEqual(
static_functional_ret.shape,
excepted_value.shape,
)
self.assertEqual(
dygraph_functional_ret.shape,
excepted_value.shape,
)

np.testing.assert_allclose(
static_functional_ret,
excepted_value,
atol=1e-03,
)
np.testing.assert_allclose(
dygraph_functional_ret,
excepted_value,
atol=1e-03,
)


if __name__ == "__main__":
unittest.main()
8 changes: 4 additions & 4 deletions python/paddle/nn/functional/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False, name=None):
Parameters:
x (Tensor): Tensor, shape is :math:`[N, D]` or :math:`[D]`, where :math:`N`
is batch size, :math:`D` is the dimension of vector. Available dtype is
float32, float64.
float16, float32, float64.
y (Tensor): Tensor, shape is :math:`[N, D]` or :math:`[D]`, where :math:`N`
is batch size, :math:`D` is the dimension of vector. Available dtype is
float32, float64.
float16, float32, float64.
p (float, optional): The order of norm. Default: :math:`2.0`.
epsilon (float, optional): Add small value to avoid division by zero.
Default: :math:`1e-6`.
Expand Down Expand Up @@ -84,10 +84,10 @@ def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False, name=None):
check_type(keepdim, 'keepdim', (bool), 'PairwiseDistance')

check_variable_and_dtype(
x, 'x', ['float32', 'float64'], 'PairwiseDistance'
x, 'x', ['float16', 'float32', 'float64'], 'PairwiseDistance'
)
check_variable_and_dtype(
y, 'y', ['float32', 'float64'], 'PairwiseDistance'
y, 'y', ['float16', 'float32', 'float64'], 'PairwiseDistance'
)
sub = paddle.subtract(x, y)
if epsilon != 0.0:
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/nn/layer/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class PairwiseDistance(Layer):
Shape:
- x: :math:`[N, D]` or :math:`[D]`, where :math:`N` is batch size, :math:`D`
is the dimension of the data. Available data type is float32, float64.
is the dimension of the data. Available data type is float16, float32, float64.
- y: :math:`[N, D]` or :math:`[D]`, y have the same dtype as x.
- output: The same dtype as input tensor.
- If :attr:`keepdim` is True, the output shape is :math:`[N, 1]` or :math:`[1]`,
Expand Down

0 comments on commit 587120e

Please sign in to comment.