Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRMM, remove deprecated inplace trmm to favor outofplace/inplace trmm API #617

Merged
merged 2 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion clients/benchmarks/client.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* ************************************************************************
* Copyright (C) 2016-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -302,6 +302,10 @@ try
value<hipblas_int>(&arg.batch_count)->default_value(1),
"Number of matrices. Only applicable to batched and strided_batched routines")

("inplace",
value<bool>(&arg.inplace)->default_value(false),
"Whether or not to use the in place version of the algorithm. Only applicable to trmm routines")

("verify,v",
value<hipblas_int>(&arg.norm_check)->default_value(0),
"Validate GPU results with CPU? 0 = No, 1 = Yes (default: No)")
Expand Down
182 changes: 133 additions & 49 deletions clients/common/hipblas_template_specialization.cpp

Large diffs are not rendered by default.

66 changes: 59 additions & 7 deletions clients/gtest/trmm_gtest.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* ************************************************************************
* Copyright (C) 2016-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -37,7 +37,8 @@ using ::testing::ValuesIn;

// only GCC/VS 2010 comes with std::tr1::tuple, but it is unnecessary, std::tuple is good enough;

typedef std::tuple<vector<int>, double, vector<char>, double, int, bool> trmm_tuple;
typedef std::tuple<vector<int>, double, vector<char>, double, int, bool, bool> trmm_tuple;
typedef std::tuple<bool, bool> trmm_bad_arg_tuple;

/* =====================================================================
README: This file contains testers to verify the correctness of
Expand All @@ -57,16 +58,16 @@ Yet, the goal of this file is to verify result correctness not argument-checkers
Representative sampling is sufficient, endless brute-force sampling is not necessary
=================================================================== */

// vector of vector, each vector is a {M, N, lda, ldb};
// vector of vector, each vector is a {M, N, lda, ldb, ldc};
// add/delete as a group
const vector<vector<int>> matrix_size_range = {
{-1, -1, 1, 1}, {10, 10, 20, 100},
{-1, -1, 1, 1, 1}, {10, 10, 20, 100, 150},
// {600, 500, 600, 600} ,
// {1024, 1024, 1024, 1024}
};

const vector<vector<int>> full_matrix_size_range = {
{192, 192, 192, 192},
{192, 192, 192, 192, 192},
// {640, 640, 960, 960},
// {1000, 1000, 1000, 1000},
// {2000, 2000, 2000, 2000},
Expand Down Expand Up @@ -118,6 +119,8 @@ const vector<int> batch_count_range = {1, 3};
const bool is_fortran[] = {false, true};
const bool is_fortran_false[] = {false};

const bool is_inplace[] = {false, true};

/* ===============Google Unit Test==================================================== */

/* =====================================================================
Expand All @@ -143,6 +146,7 @@ Arguments setup_trmm_arguments(trmm_tuple tup)
double stride_scale = std::get<3>(tup);
int batch_count = std::get<4>(tup);
bool fortran = std::get<5>(tup);
bool inplace = std::get<6>(tup);

Arguments arg;

Expand All @@ -151,6 +155,7 @@ Arguments setup_trmm_arguments(trmm_tuple tup)
arg.N = matrix_size[1];
arg.lda = matrix_size[2];
arg.ldb = matrix_size[3];
arg.ldc = matrix_size[4];

arg.alpha = alpha;

Expand All @@ -165,10 +170,20 @@ Arguments setup_trmm_arguments(trmm_tuple tup)
arg.batch_count = batch_count;

arg.fortran = fortran;
arg.inplace = inplace;

return arg;
}

class trmm_bad_arg_gtest : public ::TestWithParam<trmm_bad_arg_tuple>
{
protected:
trmm_bad_arg_gtest() {}
virtual ~trmm_bad_arg_gtest() {}
virtual void SetUp() {}
virtual void TearDown() {}
};

class trmm_gtest : public ::TestWithParam<trmm_tuple>
{
protected:
Expand All @@ -178,6 +193,16 @@ class trmm_gtest : public ::TestWithParam<trmm_tuple>
virtual void TearDown() {}
};

TEST_P(trmm_bad_arg_gtest, trmm_bad_arg_gtest_test)
{
Arguments arg;

EXPECT_EQ(testing_trmm_bad_arg<float>(arg), HIPBLAS_STATUS_SUCCESS);
EXPECT_EQ(testing_trmm_bad_arg<double>(arg), HIPBLAS_STATUS_SUCCESS);
EXPECT_EQ(testing_trmm_bad_arg<hipblasComplex>(arg), HIPBLAS_STATUS_SUCCESS);
EXPECT_EQ(testing_trmm_bad_arg<hipblasDoubleComplex>(arg), HIPBLAS_STATUS_SUCCESS);
}

TEST_P(trmm_gtest, trmm_gtest_float)
{
// GetParam return a tuple. Tee setup routine unpack the tuple
Expand Down Expand Up @@ -234,6 +259,27 @@ TEST_P(trmm_gtest, trmm_gtest_double_complex)

#ifndef __HIP_PLATFORM_NVCC__

TEST_P(trmm_bad_arg_gtest, trmm_batched_bad_arg_gtest_test)
{
Arguments arg;

EXPECT_EQ(testing_trmm_batched_bad_arg<float>(arg), HIPBLAS_STATUS_SUCCESS);
EXPECT_EQ(testing_trmm_batched_bad_arg<double>(arg), HIPBLAS_STATUS_SUCCESS);
EXPECT_EQ(testing_trmm_batched_bad_arg<hipblasComplex>(arg), HIPBLAS_STATUS_SUCCESS);
EXPECT_EQ(testing_trmm_batched_bad_arg<hipblasDoubleComplex>(arg), HIPBLAS_STATUS_SUCCESS);
}

TEST_P(trmm_bad_arg_gtest, trmm_strided_batched_bad_arg_gtest_test)
{
Arguments arg;

EXPECT_EQ(testing_trmm_strided_batched_bad_arg<float>(arg), HIPBLAS_STATUS_SUCCESS);
EXPECT_EQ(testing_trmm_strided_batched_bad_arg<double>(arg), HIPBLAS_STATUS_SUCCESS);
EXPECT_EQ(testing_trmm_strided_batched_bad_arg<hipblasComplex>(arg), HIPBLAS_STATUS_SUCCESS);
EXPECT_EQ(testing_trmm_strided_batched_bad_arg<hipblasDoubleComplex>(arg),
HIPBLAS_STATUS_SUCCESS);
}

TEST_P(trmm_gtest, trmm_batched_gtest_float)
{
// GetParam return a tuple. Tee setup routine unpack the tuple
Expand Down Expand Up @@ -360,7 +406,8 @@ INSTANTIATE_TEST_SUITE_P(hipblastrmm_matrix_size,
ValuesIn(side_uplo_transA_diag_range),
ValuesIn(stride_scale_range),
ValuesIn(batch_count_range),
ValuesIn(is_fortran)));
ValuesIn(is_fortran),
ValuesIn(is_inplace)));

// THis function mainly test the scope of full_side_uplo_transA_diag_range,.the scope of
// matrix_size_range is small
Expand All @@ -371,4 +418,9 @@ INSTANTIATE_TEST_SUITE_P(hipblastrmm_scalar_transpose,
ValuesIn(full_side_uplo_transA_diag_range),
ValuesIn(stride_scale_range),
ValuesIn(batch_count_range),
ValuesIn(is_fortran_false)));
ValuesIn(is_fortran_false),
ValuesIn(is_inplace)));

INSTANTIATE_TEST_SUITE_P(hipblasTrmmBadArg,
trmm_bad_arg_gtest,
Combine(ValuesIn(is_fortran), ValuesIn(is_inplace)));
19 changes: 11 additions & 8 deletions clients/include/hipblas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
#ifndef _HIPBLAS_HPP_
#define _HIPBLAS_HPP_

#ifndef HIPBLAS_NO_DEPRECATED_WARNINGS
#define HIPBLAS_NO_DEPRECATED_WARNINGS
#endif
Comment on lines -28 to -30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just comment out for next deprecation use?


/* library headers */
#include "hipblas.h"

Expand Down Expand Up @@ -1798,8 +1794,10 @@ hipblasStatus_t hipblasTrmm(hipblasHandle_t handle,
const T* alpha,
const T* A,
int lda,
T* B,
int ldb);
const T* B,
int ldb,
T* C,
int ldc);

template <typename T, bool FORTRAN = false>
hipblasStatus_t hipblasTrmmBatched(hipblasHandle_t handle,
Expand All @@ -1812,8 +1810,10 @@ hipblasStatus_t hipblasTrmmBatched(hipblasHandle_t handle,
const T* alpha,
const T* const A[],
int lda,
T* const B[],
const T* const B[],
int ldb,
T* const C[],
int ldc,
int batchCount);

template <typename T, bool FORTRAN = false>
Expand All @@ -1828,9 +1828,12 @@ hipblasStatus_t hipblasTrmmStridedBatched(hipblasHandle_t handle,
const T* A,
int lda,
hipblasStride strideA,
T* B,
const T* B,
int ldb,
hipblasStride strideB,
T* C,
int ldc,
hipblasStride strideC,
int batchCount);

// trsm
Expand Down
4 changes: 3 additions & 1 deletion clients/include/hipblas_arguments.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* ************************************************************************
* Copyright (C) 2016-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -123,6 +123,7 @@ struct Arguments
int batch_count = 10;

bool fortran = false;
bool inplace = false; // only for trmm

int norm_check = 0;
int unit_check = 1;
Expand Down Expand Up @@ -186,6 +187,7 @@ struct Arguments
OPER(apiCallCount) SEP \
OPER(batch_count) SEP \
OPER(fortran) SEP \
OPER(inplace) SEP \
OPER(norm_check) SEP \
OPER(unit_check) SEP \
OPER(timing) SEP \
Expand Down
Loading