Skip to content

Commit

Permalink
Add mixed f32/f64 precision dot_ex (#2418)
Browse files Browse the repository at this point in the history
  • Loading branch information
daineAMD committed Mar 25, 2024
1 parent a7bfe42 commit 1ae1122
Show file tree
Hide file tree
Showing 17 changed files with 157 additions and 81 deletions.
5 changes: 4 additions & 1 deletion clients/benchmarks/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,10 @@ struct perf_blas_dot_ex<
rocblas_half> && std::is_same_v<Tx, Ty> && std::is_same_v<Ty, Tr> && std::is_same_v<Tex, float>)
|| (std::is_same_v<
Tx,
rocblas_bfloat16> && std::is_same_v<Tx, Ty> && std::is_same_v<Ty, Tr> && std::is_same_v<Tex, float>)>>
rocblas_bfloat16> && std::is_same_v<Tx, Ty> && std::is_same_v<Ty, Tr> && std::is_same_v<Tex, float>)
|| (std::is_same_v<
Tx,
float> && std::is_same_v<Tx, Ty> && std::is_same_v<Tr, double> && std::is_same_v<Tr, Tex>)>>
: rocblas_test_valid
{
void operator()(const Arguments& arg)
Expand Down
77 changes: 51 additions & 26 deletions clients/common/cblas_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,30 @@ template void ref_scal<double, rocblas_complex_num<double>*>(int64_t
int64_t incx);

template <>
void ref_dot<rocblas_half>(int64_t n,
const rocblas_half* x,
int64_t incx,
const rocblas_half* y,
int64_t incy,
rocblas_half* result)
void ref_dot<>(
int64_t n, const float* x, int64_t incx, const float* y, int64_t incy, double* result)
{
int64_t ix = incx >= 0 ? 0 : (1 - n) * incx;
int64_t iy = incy >= 0 ? 0 : (1 - n) * incy;

double r = 0.0;
for(int64_t i = 0; i < n; i++)
{
r += double(x[ix]) * y[iy];
ix += incx;
iy += incy;
}

*result = r;
}

template <>
void ref_dot<>(int64_t n,
const rocblas_half* x,
int64_t incx,
const rocblas_half* y,
int64_t incy,
rocblas_half* result)
{
int64_t ix = incx >= 0 ? 0 : (1 - n) * incx;
int64_t iy = incy >= 0 ? 0 : (1 - n) * incy;
Expand All @@ -211,12 +229,12 @@ void ref_dot<rocblas_half>(int64_t n,
}

template <>
void ref_dot<rocblas_bfloat16>(int64_t n,
const rocblas_bfloat16* x,
int64_t incx,
const rocblas_bfloat16* y,
int64_t incy,
rocblas_bfloat16* result)
void ref_dot<>(int64_t n,
const rocblas_bfloat16* x,
int64_t incx,
const rocblas_bfloat16* y,
int64_t incy,
rocblas_bfloat16* result)
{
int64_t ix = incx >= 0 ? 0 : (1 - n) * incx;
int64_t iy = incy >= 0 ? 0 : (1 - n) * incy;
Expand All @@ -233,37 +251,44 @@ void ref_dot<rocblas_bfloat16>(int64_t n,
}

template <>
void ref_dotc<float>(
void ref_dotc<>(
int64_t n, const float* x, int64_t incx, const float* y, int64_t incy, double* result)
{
ref_dot(n, x, incx, y, incy, result);
}

template <>
void ref_dotc<float, float>(
int64_t n, const float* x, int64_t incx, const float* y, int64_t incy, float* result)
{
ref_dot(n, x, incx, y, incy, result);
}

template <>
void ref_dotc<double>(
void ref_dotc<double, double>(
int64_t n, const double* x, int64_t incx, const double* y, int64_t incy, double* result)
{
ref_dot(n, x, incx, y, incy, result);
}

template <>
void ref_dotc<rocblas_half>(int64_t n,
const rocblas_half* x,
int64_t incx,
const rocblas_half* y,
int64_t incy,
rocblas_half* result)
void ref_dotc<rocblas_half, rocblas_half>(int64_t n,
const rocblas_half* x,
int64_t incx,
const rocblas_half* y,
int64_t incy,
rocblas_half* result)
{
ref_dot(n, x, incx, y, incy, result);
}

template <>
void ref_dotc<rocblas_bfloat16>(int64_t n,
const rocblas_bfloat16* x,
int64_t incx,
const rocblas_bfloat16* y,
int64_t incy,
rocblas_bfloat16* result)
void ref_dotc<rocblas_bfloat16, rocblas_bfloat16>(int64_t n,
const rocblas_bfloat16* x,
int64_t incx,
const rocblas_bfloat16* y,
int64_t incy,
rocblas_bfloat16* result)
{
ref_dot(n, x, incx, y, incy, result);
}
Expand Down
67 changes: 34 additions & 33 deletions clients/gtest/blas1_gtest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ Tests:
api: [ FORTRAN_64 ]
function:
- dot: *half_bfloat_single_double_complex_real_precisions
- dot_ex: *half_bfloat_single_double_complex_real_precisions
- dot_ex: *dot_ex_precisions
- dot_batched: *single_precision
- dot_batched_ex: *single_precision
- dot_strided_batched: *single_precision
Expand All @@ -453,6 +453,7 @@ Tests:
function:
- dot: *single_precision
- dot_ex: *single_precision
- dot_ex: *single_in_double_out_precision
- dot_batched: *single_precision
- dot_batched_ex: *single_precision
- dot_strided_batched: *single_precision
Expand Down Expand Up @@ -1056,7 +1057,7 @@ Tests:
- copy: *single_double_precisions_complex_real
- dot: *half_bfloat_single_double_complex_real_precisions
- dotc: *single_double_precisions_complex
- dot_ex: *half_bfloat_single_double_complex_real_precisions
- dot_ex: *dot_ex_precisions
- dotc_ex: *half_bfloat_single_double_complex_real_precisions
- swap: *single_double_precisions_complex_real
- rot: *rot_precisions
Expand Down Expand Up @@ -1139,7 +1140,7 @@ Tests:
- copy: *single_double_precisions_complex_real
- dot: *half_bfloat_single_double_complex_real_precisions
- dotc: *single_double_precisions_complex
- dot_ex: *half_bfloat_single_double_complex_real_precisions
- dot_ex: *dot_ex_precisions
- dotc_ex: *half_bfloat_single_double_complex_real_precisions

- name: blas1_batched
Expand Down Expand Up @@ -1524,12 +1525,12 @@ Tests:
- dotc: *single_double_precisions_complex
- dotc_batched: *half_bfloat_single_double_complex_real_precisions
- dotc_strided_batched: *half_bfloat_single_double_complex_real_precisions
- dot_ex: *half_bfloat_single_double_complex_real_precisions
- dotc_ex: *half_bfloat_single_double_complex_real_precisions
- dot_batched_ex: *half_bfloat_single_double_complex_real_precisions
- dotc_batched_ex: *half_bfloat_single_double_complex_real_precisions
- dot_strided_batched_ex: *half_bfloat_single_double_complex_real_precisions
- dotc_strided_batched_ex: *half_bfloat_single_double_complex_real_precisions
- dot_ex: *dot_ex_precisions
- dotc_ex: *dot_ex_precisions
- dot_batched_ex: *dot_ex_precisions
- dotc_batched_ex: *dot_ex_precisions
- dot_strided_batched_ex: *dot_ex_precisions
- dotc_strided_batched_ex: *dot_ex_precisions
- scal: *single_double_precisions_complex_real
- scal: *scal_single_double_complex_real_alpha_complex_out
- scal_ex: *scal_ex_bfloat_half_single_double_complex_real_precisions
Expand Down Expand Up @@ -1598,12 +1599,12 @@ Tests:
- dotc_bad_arg: *single_double_precisions_complex
- dotc_batched_bad_arg: *half_bfloat_single_double_complex_real_precisions
- dotc_strided_batched_bad_arg: *half_bfloat_single_double_complex_real_precisions
- dot_ex: *half_bfloat_single_double_complex_real_precisions
- dotc_ex: *half_bfloat_single_double_complex_real_precisions
- dot_batched_ex: *half_bfloat_single_double_complex_real_precisions
- dotc_batched_ex: *half_bfloat_single_double_complex_real_precisions
- dot_strided_batched_ex: *half_bfloat_single_double_complex_real_precisions
- dotc_strided_batched_ex: *half_bfloat_single_double_complex_real_precisions
- dot_ex: *dot_ex_precisions
- dotc_ex: *dot_ex_precisions
- dot_batched_ex: *dot_ex_precisions
- dotc_batched_ex: *dot_ex_precisions
- dot_strided_batched_ex: *dot_ex_precisions
- dotc_strided_batched_ex: *dot_ex_precisions
- scal_bad_arg: *single_double_precisions_complex_real
- scal_bad_arg: *scal_single_double_complex_real_alpha_complex_out
- scal_ex_bad_arg: *scal_ex_bfloat_half_single_double_complex_real_precisions
Expand Down Expand Up @@ -1669,12 +1670,12 @@ Tests:
- dotc_bad_arg: *single_double_precisions_complex
- dotc_batched_bad_arg: *half_bfloat_single_double_complex_real_precisions
- dotc_strided_batched_bad_arg: *half_bfloat_single_double_complex_real_precisions
- dot_ex: *half_bfloat_single_double_complex_real_precisions
- dotc_ex: *half_bfloat_single_double_complex_real_precisions
- dot_batched_ex: *half_bfloat_single_double_complex_real_precisions
- dotc_batched_ex: *half_bfloat_single_double_complex_real_precisions
- dot_strided_batched_ex: *half_bfloat_single_double_complex_real_precisions
- dotc_strided_batched_ex: *half_bfloat_single_double_complex_real_precisions
- dot_ex: *dot_ex_precisions
- dotc_ex: *dot_ex_precisions
- dot_batched_ex: *dot_ex_precisions
- dotc_batched_ex: *dot_ex_precisions
- dot_strided_batched_ex: *dot_ex_precisions
- dotc_strided_batched_ex: *dot_ex_precisions
- scal_bad_arg: *single_double_precisions_complex_real
- scal_bad_arg: *scal_single_double_complex_real_alpha_complex_out
- scal_ex_bad_arg: *scal_ex_bfloat_half_single_double_complex_real_precisions
Expand Down Expand Up @@ -1746,12 +1747,12 @@ Tests:
- dotc: *single_double_precisions_complex
- dotc_batched: *half_bfloat_single_double_complex_real_precisions
- dotc_strided_batched: *half_bfloat_single_double_complex_real_precisions
- dot_ex: *half_bfloat_single_double_complex_real_precisions
- dot_batched_ex: *half_bfloat_single_double_complex_real_precisions
- dot_strided_batched_ex: *half_bfloat_single_double_complex_real_precisions
- dotc_ex: *half_bfloat_single_double_complex_real_precisions
- dotc_batched_ex: *half_bfloat_single_double_complex_real_precisions
- dotc_strided_batched_ex: *half_bfloat_single_double_complex_real_precisions
- dot_ex: *dot_ex_precisions
- dot_batched_ex: *dot_ex_precisions
- dot_strided_batched_ex: *dot_ex_precisions
- dotc_ex: *dot_ex_precisions
- dotc_batched_ex: *dot_ex_precisions
- dotc_strided_batched_ex: *dot_ex_precisions
- scal: *single_double_precisions_complex_real
- scal: *scal_single_double_complex_real_alpha_complex_out
- scal_batched: *single_double_precisions_complex_real
Expand Down Expand Up @@ -1822,12 +1823,12 @@ Tests:
# - dotc: *single_double_precisions_complex
# - dotc_batched: *half_bfloat_single_double_complex_real_precisions
# - dotc_strided_batched: *half_bfloat_single_double_complex_real_precisions
# - dot_ex: *half_bfloat_single_double_complex_real_precisions
# - dot_batched_ex: *half_bfloat_single_double_complex_real_precisions
# - dot_strided_batched_ex: *half_bfloat_single_double_complex_real_precisions
# - dotc_ex: *half_bfloat_single_double_complex_real_precisions
# - dotc_batched_ex: *half_bfloat_single_double_complex_real_precisions
# - dotc_strided_batched_ex: *half_bfloat_single_double_complex_real_precisions
# - dot_ex: *dot_ex_precisions
# - dot_batched_ex: *dot_ex_precisions
# - dot_strided_batched_ex: *dot_ex_precisions
# - dotc_ex: *dot_ex_precisions
# - dotc_batched_ex: *dot_ex_precisions
# - dotc_strided_batched_ex: *dot_ex_precisions
- scal: *single_double_precisions_complex_real
- scal: *scal_single_double_complex_real_alpha_complex_out
- scal_batched: *single_double_precisions_complex_real
Expand Down
7 changes: 5 additions & 2 deletions clients/gtest/blas_ex/dot_ex_gtest.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* ************************************************************************
* Copyright (C) 2018-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2018-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -112,7 +112,10 @@ namespace
T2> && std::is_same_v<T2, T3> && std::is_same_v<T1, rocblas_half> && std::is_same_v<T4, float>)
|| (std::is_same_v<
T1,
T2> && std::is_same_v<T2, T3> && std::is_same_v<T1, rocblas_bfloat16> && std::is_same_v<T4, float>)))>;
T2> && std::is_same_v<T2, T3> && std::is_same_v<T1, rocblas_bfloat16> && std::is_same_v<T4, float>)
|| (std::is_same_v<
T1,
T2> && std::is_same_v<T3, T4> && std::is_same_v<T1, float> && std::is_same_v<T3, double>)))>;

// Creates tests for one of the BLAS 1 functions
// ARG passes 1-3 template arguments to the testing_* function
Expand Down
2 changes: 1 addition & 1 deletion clients/include/blas1/testing_dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ void testing_dot(const Arguments& arg)

// CPU BLAS
cpu_time_used = get_time_us_no_sync();
(CONJ ? ref_dotc<T> : ref_dot<T>)(N, hx, incx, hy_ptr, incy, &cpu_result);
(CONJ ? ref_dotc<T, T> : ref_dot<T, T>)(N, hx, incx, hy_ptr, incy, &cpu_result);
cpu_time_used = get_time_us_no_sync() - cpu_time_used;

// For large N, rocblas_half tends to diverge proportional to N
Expand Down
3 changes: 2 additions & 1 deletion clients/include/blas1/testing_dot_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ void testing_dot_batched(const Arguments& arg)
cpu_time_used = get_time_us_no_sync();
for(size_t b = 0; b < batch_count; ++b)
{
(CONJ ? ref_dotc<T> : ref_dot<T>)(N, hx[b], incx, hy_ptr[b], incy, &cpu_result[b]);
(CONJ ? ref_dotc<T, T>
: ref_dot<T, T>)(N, hx[b], incx, hy_ptr[b], incy, &cpu_result[b]);
}
cpu_time_used = get_time_us_no_sync() - cpu_time_used;

Expand Down
4 changes: 2 additions & 2 deletions clients/include/blas1/testing_dot_strided_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ void testing_dot_strided_batched(const Arguments& arg)
cpu_time_used = get_time_us_no_sync();
for(size_t b = 0; b < batch_count; ++b)
{
(CONJ ? ref_dotc<T>
: ref_dot<T>)(N, hx[b], incx, hy_ptr + b * stride_y, incy, &cpu_result[b]);
(CONJ ? ref_dotc<T, T>
: ref_dot<T, T>)(N, hx[b], incx, hy_ptr + b * stride_y, incy, &cpu_result[b]);
}
cpu_time_used = get_time_us_no_sync() - cpu_time_used;

Expand Down
5 changes: 3 additions & 2 deletions clients/include/blas_ex/testing_dot_batched_ex.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* ************************************************************************
* Copyright (C) 2018-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2018-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -279,7 +279,8 @@ void testing_dot_batched_ex(const Arguments& arg)
cpu_time_used = get_time_us_no_sync();
for(int b = 0; b < batch_count; ++b)
{
(CONJ ? ref_dotc<Tx> : ref_dot<Tx>)(N, hx[b], incx, hy_ptr[b], incy, &cpu_result[b]);
(CONJ ? ref_dotc<Tx, Tr>
: ref_dot<Tx, Tr>)(N, hx[b], incx, hy_ptr[b], incy, &cpu_result[b]);
}
cpu_time_used = get_time_us_no_sync() - cpu_time_used;

Expand Down
4 changes: 2 additions & 2 deletions clients/include/blas_ex/testing_dot_ex.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* ************************************************************************
* Copyright (C) 2018-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2018-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -253,7 +253,7 @@ void testing_dot_ex(const Arguments& arg)

// CPU BLAS
cpu_time_used = get_time_us_no_sync();
(CONJ ? ref_dotc<Tx> : ref_dot<Tx>)(N, hx, incx, hy_ptr, incy, cpu_result);
(CONJ ? ref_dotc<Tx, Tr> : ref_dot<Tx, Tr>)(N, hx, incx, hy_ptr, incy, cpu_result);
cpu_time_used = get_time_us_no_sync() - cpu_time_used;

// For large N, rocblas_half tends to diverge proportional to N
Expand Down
6 changes: 3 additions & 3 deletions clients/include/blas_ex/testing_dot_strided_batched_ex.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* ************************************************************************
* Copyright (C) 2018-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2018-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -310,8 +310,8 @@ void testing_dot_strided_batched_ex(const Arguments& arg)
cpu_time_used = get_time_us_no_sync();
for(size_t b = 0; b < batch_count; ++b)
{
(CONJ ? ref_dotc<Tx>
: ref_dot<Tx>)(N, hx[b], incx, hy_ptr + b * stride_y, incy, &cpu_result[b]);
(CONJ ? ref_dotc<Tx, Tr>
: ref_dot<Tx, Tr>)(N, hx[b], incx, hy_ptr + b * stride_y, incy, &cpu_result[b]);
}
cpu_time_used = get_time_us_no_sync() - cpu_time_used;

Expand Down
8 changes: 4 additions & 4 deletions clients/include/cblas_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ inline void ref_copy(
}

// dot
template <typename T>
void ref_dot(int64_t n, const T* x, int64_t incx, const T* y, int64_t incy, T* result);
template <typename T, typename Tr>
void ref_dot(int64_t n, const T* x, int64_t incx, const T* y, int64_t incy, Tr* result);

template <>
inline void
Expand Down Expand Up @@ -199,8 +199,8 @@ inline void ref_dot(int64_t n,
}

// dotc
template <typename T>
void ref_dotc(int64_t n, const T* x, int64_t incx, const T* y, int64_t incy, T* result);
template <typename T, typename Tr>
void ref_dotc(int64_t n, const T* x, int64_t incx, const T* y, int64_t incy, Tr* result);

template <>
inline void ref_dotc(int64_t n,
Expand Down
Loading

0 comments on commit 1ae1122

Please sign in to comment.