diff --git a/ivy/functional/frontends/torch/miscellaneous_ops.py b/ivy/functional/frontends/torch/miscellaneous_ops.py index 9369885c74745..7bd69abad45d5 100644 --- a/ivy/functional/frontends/torch/miscellaneous_ops.py +++ b/ivy/functional/frontends/torch/miscellaneous_ops.py @@ -92,6 +92,46 @@ def cartesian_prod(*tensors): return ret +@with_unsupported_dtypes({"2.1.2 and below": "float16"}, "torch") +@to_ivy_arrays_and_back +def cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"): + if len(x1.shape) != 3 or len(x2.shape) != 3: + raise ivy.exceptions.IvyError( + "Both ivy arrays need to have 3 dimensions (BxRxM)" + ) + + if ( + compute_mode != "use_mm_for_euclid_dist_if_necessary" + and compute_mode != "use_mm_for_euclid_dist" + and compute_mode != "donot_use_mm_for_euclid_dist" + ): + raise ivy.exceptions.IvyError( + f"{compute_mode} is not a valid value for compute_mode" + ) + if p == 2: + B, P, M = x1.shape + _, R, _ = x2.shape + if ( + compute_mode == "use_mm_for_euclid_dist_if_necessary" + and (P > 25 or R > 25) + or compute_mode == "use_mm_for_euclid_dist" + ): + return ivy.vector_norm( + x1[:, :, None, :] - x2[:, None, :, :], axis=-1, ord=p + ) + else: + distances = ivy.zeros((B, P, R), dtype=x1.dtype) + for b in range(B): + for i in range(P): + for j in range(R): + distances[b, i, j] = ivy.vector_norm( + x1[b, i, :] - x2[b, j, :], ord=p + ) + return distances + else: + return ivy.vector_norm(x1[:, :, None, :] - x2[:, None, :, :], axis=-1, ord=p) + + @to_ivy_arrays_and_back def clone(input, *, memory_format=None): return ivy.copy_array(input) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py index 340fe6bed6930..d0e35c7d98f6b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py @@ -510,6 +510,50 @@ def test_torch_cartesian_prod( ) +@handle_frontend_test( + fn_tree="torch.cdist", + dtypes_and_x=helpers.dtype_and_values( + shape=st.shared(helpers.get_shape(min_num_dims=3, max_num_dims=3), key="shape"), + shared_dtype=True, + num_arrays=2, + allow_inf=False, + available_dtypes=["float32", "float64"], + ), + p=st.integers(min_value=0, max_value=1000000), + compute_mode=st.sampled_from( + [ + "use_mm_for_euclid_dist_if_necessary", + "use_mm_for_euclid_dist", + "donot_use_mm_for_euclid_dist", + ] + ), +) +def test_torch_cdist( + *, + dtypes_and_x, + p, + compute_mode, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtypes, xs = dtypes_and_x + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x1=xs[0], + x2=xs[1], + p=p, + compute_mode=compute_mode, + ) + + # clone @handle_frontend_test( fn_tree="torch.clone",