Skip to content

Commit

Permalink
complex log_args to use tuple format for csv parse (#1392)
Browse files Browse the repository at this point in the history
* complex number output format only changes for rocblas-bench
  • Loading branch information
TorreZuk committed Aug 16, 2022
1 parent 8f2273b commit 79fcb2b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 6 deletions.
6 changes: 4 additions & 2 deletions clients/include/argument_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ template <rocblas_argument... Args>
class ArgumentModel
{
// Whether model has a particular parameter
// TODO: Replace with C++17 fold expression ((Args == param) || ...)
static constexpr bool has(rocblas_argument param)
{
for(auto x : {Args...})
if(x == param)
return true;
return false;
// TODO: Replace with C++17 fold expression, a C++17 extension
// return ((Args == param) || ...);
}

public:
Expand Down Expand Up @@ -142,6 +143,7 @@ class ArgumentModel
{
rocblas_internal_ostream name_list;
rocblas_internal_ostream value_list;
value_list.set_csv(true);

if(ArgumentModel_get_log_function_name())
{
Expand Down Expand Up @@ -177,7 +179,7 @@ class ArgumentModel
// apply is a templated lambda for C++17 and a templated fuctor for C++14
//
// For rocblas_ddot, the following template specialization of apply will be called:
// apply<e_N>(print, arg, T{}), apply<e_incx>(print, arg, T{}),, apply<e_incy>(print, arg, T{})
// apply<e_N>(print, arg, T{}), apply<e_incx>(print, arg, T{}), apply<e_incy>(print, arg, T{})
//
// apply in turn calls print with a string corresponding to the enum, for example "N" and the value of N
//
Expand Down
47 changes: 43 additions & 4 deletions library/src/include/rocblas_ostream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ class ROCBLAS_INTERNAL_EXPORT rocblas_internal_ostream
// Flag indicating whether YAML mode is turned on
bool m_yaml = false;

// Flag for CSV output avoid commas in any value representation
bool m_csv = false;

// Get worker for file descriptor
static std::shared_ptr<worker> get_worker(int fd);

Expand Down Expand Up @@ -246,6 +249,12 @@ class ROCBLAS_INTERNAL_EXPORT rocblas_internal_ostream
// Flush the output
void flush();

// csv friendly output set true
void set_csv(bool flag)
{
m_csv = flag;
}

// Destroy the rocblas_internal_ostream
virtual ~rocblas_internal_ostream();

Expand All @@ -266,6 +275,34 @@ class ROCBLAS_INTERNAL_EXPORT rocblas_internal_ostream
// Abort function which safely flushes all IO
friend void rocblas_abort_once();

// stream output is required to allow csv override of std::forward
// which will use of rocblas_*_complex operators
template <typename T>
void stream_output(rocblas_internal_ostream& os, T&& x)
{
os.m_os << std::forward<T>(x);
}

template <>
void stream_output<rocblas_double_complex>(rocblas_internal_ostream& os,
rocblas_double_complex&& x)
{
if(!m_csv)
os.m_os << x; // complex operator<<
else
os << x; // local override not complex operator
}

template <>
void stream_output<rocblas_float_complex>(rocblas_internal_ostream& os,
rocblas_float_complex&& x)
{
if(!m_csv)
os.m_os << x; // complex operator<<
else
os << x; // local override not complex operator
}

/*************************************************************************
* Non-member friend functions for formatted output *
*************************************************************************/
Expand All @@ -274,7 +311,7 @@ class ROCBLAS_INTERNAL_EXPORT rocblas_internal_ostream
template <typename T, std::enable_if_t<!std::is_enum<std::decay_t<T>>{}, int> = 0>
friend rocblas_internal_ostream& operator<<(rocblas_internal_ostream& os, T&& x)
{
os.m_os << std::forward<T>(x);
os.stream_output(os, x);
return os;
}

Expand All @@ -299,10 +336,12 @@ class ROCBLAS_INTERNAL_EXPORT rocblas_internal_ostream

// Complex output
template <typename T>
friend rocblas_internal_ostream& operator<<(rocblas_internal_ostream& os,
const rocblas_complex_num<T>& x)
friend rocblas_internal_ostream& operator<<(rocblas_internal_ostream& os,
rocblas_complex_num<T> x)
{
if(os.m_yaml)
if(os.m_csv)
os.m_os << "(" << std::real(x) << ": " << std::imag(x) << ")";
else if(os.m_yaml)
os.m_os << "'(" << std::real(x) << "," << std::imag(x) << ")'";
else
os.m_os << x;
Expand Down

0 comments on commit 79fcb2b

Please sign in to comment.