Skip to content

Commit

Permalink
Merge pull request #617 from amcamd/TRMM_rm_inplace
Browse files Browse the repository at this point in the history
TRMM, remove deprecated inplace trmm to favor outofplace/inplace trmm API
  • Loading branch information
amcamd committed Jul 10, 2023
2 parents fef4067 + 0b16009 commit 215876e
Show file tree
Hide file tree
Showing 18 changed files with 2,041 additions and 1,948 deletions.
6 changes: 5 additions & 1 deletion clients/benchmarks/client.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* ************************************************************************
* Copyright (C) 2016-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -302,6 +302,10 @@ try
value<hipblas_int>(&arg.batch_count)->default_value(1),
"Number of matrices. Only applicable to batched and strided_batched routines")

("inplace",
value<bool>(&arg.inplace)->default_value(false),
"Whether or not to use the in place version of the algorithm. Only applicable to trmm routines")

("verify,v",
value<hipblas_int>(&arg.norm_check)->default_value(0),
"Validate GPU results with CPU? 0 = No, 1 = Yes (default: No)")
Expand Down
182 changes: 133 additions & 49 deletions clients/common/hipblas_template_specialization.cpp

Large diffs are not rendered by default.

66 changes: 59 additions & 7 deletions clients/gtest/trmm_gtest.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* ************************************************************************
* Copyright (C) 2016-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -37,7 +37,8 @@ using ::testing::ValuesIn;

// only GCC/VS 2010 comes with std::tr1::tuple, but it is unnecessary, std::tuple is good enough;

typedef std::tuple<vector<int>, double, vector<char>, double, int, bool> trmm_tuple;
typedef std::tuple<vector<int>, double, vector<char>, double, int, bool, bool> trmm_tuple;
typedef std::tuple<bool, bool> trmm_bad_arg_tuple;

/* =====================================================================
README: This file contains testers to verify the correctness of
Expand All @@ -57,16 +58,16 @@ Yet, the goal of this file is to verify result correctness not argument-checkers
Representative sampling is sufficient, endless brute-force sampling is not necessary
=================================================================== */

// vector of vector, each vector is a {M, N, lda, ldb};
// vector of vector, each vector is a {M, N, lda, ldb, ldc};
// add/delete as a group
const vector<vector<int>> matrix_size_range = {
{-1, -1, 1, 1}, {10, 10, 20, 100},
{-1, -1, 1, 1, 1}, {10, 10, 20, 100, 150},
// {600, 500, 600, 600} ,
// {1024, 1024, 1024, 1024}
};

const vector<vector<int>> full_matrix_size_range = {
{192, 192, 192, 192},
{192, 192, 192, 192, 192},
// {640, 640, 960, 960},
// {1000, 1000, 1000, 1000},
// {2000, 2000, 2000, 2000},
Expand Down Expand Up @@ -118,6 +119,8 @@ const vector<int> batch_count_range = {1, 3};
const bool is_fortran[] = {false, true};
const bool is_fortran_false[] = {false};

const bool is_inplace[] = {false, true};

/* ===============Google Unit Test==================================================== */

/* =====================================================================
Expand All @@ -143,6 +146,7 @@ Arguments setup_trmm_arguments(trmm_tuple tup)
double stride_scale = std::get<3>(tup);
int batch_count = std::get<4>(tup);
bool fortran = std::get<5>(tup);
bool inplace = std::get<6>(tup);

Arguments arg;

Expand All @@ -151,6 +155,7 @@ Arguments setup_trmm_arguments(trmm_tuple tup)
arg.N = matrix_size[1];
arg.lda = matrix_size[2];
arg.ldb = matrix_size[3];
arg.ldc = matrix_size[4];

arg.alpha = alpha;

Expand All @@ -165,10 +170,20 @@ Arguments setup_trmm_arguments(trmm_tuple tup)
arg.batch_count = batch_count;

arg.fortran = fortran;
arg.inplace = inplace;

return arg;
}

class trmm_bad_arg_gtest : public ::TestWithParam<trmm_bad_arg_tuple>
{
protected:
trmm_bad_arg_gtest() {}
virtual ~trmm_bad_arg_gtest() {}
virtual void SetUp() {}
virtual void TearDown() {}
};

class trmm_gtest : public ::TestWithParam<trmm_tuple>
{
protected:
Expand All @@ -178,6 +193,16 @@ class trmm_gtest : public ::TestWithParam<trmm_tuple>
virtual void TearDown() {}
};

TEST_P(trmm_bad_arg_gtest, trmm_bad_arg_gtest_test)
{
Arguments arg;

EXPECT_EQ(testing_trmm_bad_arg<float>(arg), HIPBLAS_STATUS_SUCCESS);
EXPECT_EQ(testing_trmm_bad_arg<double>(arg), HIPBLAS_STATUS_SUCCESS);
EXPECT_EQ(testing_trmm_bad_arg<hipblasComplex>(arg), HIPBLAS_STATUS_SUCCESS);
EXPECT_EQ(testing_trmm_bad_arg<hipblasDoubleComplex>(arg), HIPBLAS_STATUS_SUCCESS);
}

TEST_P(trmm_gtest, trmm_gtest_float)
{
// GetParam return a tuple. Tee setup routine unpack the tuple
Expand Down Expand Up @@ -234,6 +259,27 @@ TEST_P(trmm_gtest, trmm_gtest_double_complex)

#ifndef __HIP_PLATFORM_NVCC__

TEST_P(trmm_bad_arg_gtest, trmm_batched_bad_arg_gtest_test)
{
Arguments arg;

EXPECT_EQ(testing_trmm_batched_bad_arg<float>(arg), HIPBLAS_STATUS_SUCCESS);
EXPECT_EQ(testing_trmm_batched_bad_arg<double>(arg), HIPBLAS_STATUS_SUCCESS);
EXPECT_EQ(testing_trmm_batched_bad_arg<hipblasComplex>(arg), HIPBLAS_STATUS_SUCCESS);
EXPECT_EQ(testing_trmm_batched_bad_arg<hipblasDoubleComplex>(arg), HIPBLAS_STATUS_SUCCESS);
}

TEST_P(trmm_bad_arg_gtest, trmm_strided_batched_bad_arg_gtest_test)
{
Arguments arg;

EXPECT_EQ(testing_trmm_strided_batched_bad_arg<float>(arg), HIPBLAS_STATUS_SUCCESS);
EXPECT_EQ(testing_trmm_strided_batched_bad_arg<double>(arg), HIPBLAS_STATUS_SUCCESS);
EXPECT_EQ(testing_trmm_strided_batched_bad_arg<hipblasComplex>(arg), HIPBLAS_STATUS_SUCCESS);
EXPECT_EQ(testing_trmm_strided_batched_bad_arg<hipblasDoubleComplex>(arg),
HIPBLAS_STATUS_SUCCESS);
}

TEST_P(trmm_gtest, trmm_batched_gtest_float)
{
// GetParam return a tuple. Tee setup routine unpack the tuple
Expand Down Expand Up @@ -360,7 +406,8 @@ INSTANTIATE_TEST_SUITE_P(hipblastrmm_matrix_size,
ValuesIn(side_uplo_transA_diag_range),
ValuesIn(stride_scale_range),
ValuesIn(batch_count_range),
ValuesIn(is_fortran)));
ValuesIn(is_fortran),
ValuesIn(is_inplace)));

// THis function mainly test the scope of full_side_uplo_transA_diag_range,.the scope of
// matrix_size_range is small
Expand All @@ -371,4 +418,9 @@ INSTANTIATE_TEST_SUITE_P(hipblastrmm_scalar_transpose,
ValuesIn(full_side_uplo_transA_diag_range),
ValuesIn(stride_scale_range),
ValuesIn(batch_count_range),
ValuesIn(is_fortran_false)));
ValuesIn(is_fortran_false),
ValuesIn(is_inplace)));

INSTANTIATE_TEST_SUITE_P(hipblasTrmmBadArg,
trmm_bad_arg_gtest,
Combine(ValuesIn(is_fortran), ValuesIn(is_inplace)));
19 changes: 11 additions & 8 deletions clients/include/hipblas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
#ifndef _HIPBLAS_HPP_
#define _HIPBLAS_HPP_

#ifndef HIPBLAS_NO_DEPRECATED_WARNINGS
#define HIPBLAS_NO_DEPRECATED_WARNINGS
#endif

/* library headers */
#include "hipblas.h"

Expand Down Expand Up @@ -1798,8 +1794,10 @@ hipblasStatus_t hipblasTrmm(hipblasHandle_t handle,
const T* alpha,
const T* A,
int lda,
T* B,
int ldb);
const T* B,
int ldb,
T* C,
int ldc);

template <typename T, bool FORTRAN = false>
hipblasStatus_t hipblasTrmmBatched(hipblasHandle_t handle,
Expand All @@ -1812,8 +1810,10 @@ hipblasStatus_t hipblasTrmmBatched(hipblasHandle_t handle,
const T* alpha,
const T* const A[],
int lda,
T* const B[],
const T* const B[],
int ldb,
T* const C[],
int ldc,
int batchCount);

template <typename T, bool FORTRAN = false>
Expand All @@ -1828,9 +1828,12 @@ hipblasStatus_t hipblasTrmmStridedBatched(hipblasHandle_t handle,
const T* A,
int lda,
hipblasStride strideA,
T* B,
const T* B,
int ldb,
hipblasStride strideB,
T* C,
int ldc,
hipblasStride strideC,
int batchCount);

// trsm
Expand Down
4 changes: 3 additions & 1 deletion clients/include/hipblas_arguments.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* ************************************************************************
* Copyright (C) 2016-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -123,6 +123,7 @@ struct Arguments
int batch_count = 10;

bool fortran = false;
bool inplace = false; // only for trmm

int norm_check = 0;
int unit_check = 1;
Expand Down Expand Up @@ -186,6 +187,7 @@ struct Arguments
OPER(apiCallCount) SEP \
OPER(batch_count) SEP \
OPER(fortran) SEP \
OPER(inplace) SEP \
OPER(norm_check) SEP \
OPER(unit_check) SEP \
OPER(timing) SEP \
Expand Down
Loading

0 comments on commit 215876e

Please sign in to comment.