Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix checkMatrix usage in some tests #4732

Merged
merged 7 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/QMCWaveFunctions/tests/test_ConstantSPOSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ TEST_CASE("ConstantSPOSet", "[wavefunction]")
const int last_index = 2;
sposet->evaluate_notranspose(elec, first_index, last_index, phimat, gphimat, lphimat);

checkMatrix(phimat, spomat);
checkMatrix(lphimat, laplspomat);
auto check = checkMatrix(phimat, spomat);
CHECKED_ELSE(check.result) { FAIL(check.result_message); }
check = checkMatrix(lphimat, laplspomat);
CHECKED_ELSE(check.result) { FAIL(check.result_message); }

//Test of makeClone()
auto sposet_vgl2 = sposet->makeClone();
Expand All @@ -125,8 +127,10 @@ TEST_CASE("ConstantSPOSet", "[wavefunction]")

sposet_vgl2->evaluate_notranspose(elec, first_index, last_index, phimat, gphimat, lphimat);

checkMatrix(phimat, spomat);
checkMatrix(lphimat, laplspomat);
check = checkMatrix(phimat, spomat);
CHECKED_ELSE(check.result) { FAIL(check.result_message); }
check = checkMatrix(lphimat, laplspomat);
CHECKED_ELSE(check.result) { FAIL(check.result_message); }

//Lastly, check if name is correct.
std::string myname = sposet_vgl2->getClassName();
Expand Down
29 changes: 20 additions & 9 deletions src/QMCWaveFunctions/tests/test_DiracDeterminantBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ void test_DiracDeterminantBatched_first()
b(2, 1) = -0.04586322768;
b(2, 2) = 0.3927890292;

checkMatrix(ddb.get_det_engine().get_ref_psiMinv(), b);
auto check = checkMatrix(b, ddb.get_det_engine().get_ref_psiMinv());
CHECKED_ELSE(check.result) { FAIL(check.result_message); }

ParticleSet::GradType grad;
PsiValueType det_ratio = ddb.ratioGrad(elec, 0, grad);
Expand All @@ -85,7 +86,8 @@ void test_DiracDeterminantBatched_first()
b(2, 1) = 0.7119205298;
b(2, 2) = 0.9105960265;

checkMatrix(ddb.get_det_engine().get_ref_psiMinv(), b);
check = checkMatrix(b, ddb.get_det_engine().get_ref_psiMinv());
CHECKED_ELSE(check.result) { FAIL(check.result_message); }

// set virtutal particle position
PosType newpos(0.3, 0.2, 0.5);
Expand Down Expand Up @@ -259,7 +261,8 @@ void test_DiracDeterminantBatched_second()
app_log() << ddb.getPsiMinv() << std::endl;
#endif

checkMatrix(ddb.get_det_engine().get_ref_psiMinv(), orig_a);
auto check = checkMatrix(orig_a, ddb.get_det_engine().get_ref_psiMinv());
CHECKED_ELSE(check.result) { FAIL(check.result_message); }
}

TEST_CASE("DiracDeterminantBatched_second", "[wavefunction][fermion]")
Expand Down Expand Up @@ -356,7 +359,8 @@ void test_DiracDeterminantBatched_delayed_update(int delay_rank, DetMatInvertor
// force update Ainv in ddc using SM-1 code path
ddc.completeUpdates();

checkMatrix(ddc.get_det_engine().get_ref_psiMinv(), a_update1);
auto check = checkMatrix(a_update1, ddc.get_det_engine().get_ref_psiMinv());
CHECKED_ELSE(check.result) { FAIL(check.result_message); }

grad = ddc.evalGrad(elec, 1);

Expand Down Expand Up @@ -411,7 +415,8 @@ void test_DiracDeterminantBatched_delayed_update(int delay_rank, DetMatInvertor
#endif

// compare all the elements of get_ref_psiMinv() in ddc and orig_a
checkMatrix(ddc.get_det_engine().get_ref_psiMinv(), orig_a);
check = checkMatrix(orig_a, ddc.get_det_engine().get_ref_psiMinv());
CHECKED_ELSE(check.result) { FAIL(check.result_message); }

// testing batched interfaces
ResourceCollection pset_res("test_pset_res");
Expand Down Expand Up @@ -446,8 +451,11 @@ void test_DiracDeterminantBatched_delayed_update(int delay_rank, DetMatInvertor
ddc.mw_accept_rejectMove(ddc_ref_list, p_ref_list, 0, isAccepted, true);
ddc.mw_completeUpdates(ddc_ref_list);

checkMatrix(ddc.get_det_engine().get_ref_psiMinv(), a_update1);
checkMatrix(ddc_clone_ref.get_det_engine().get_ref_psiMinv(), a_update1);
check = checkMatrix(a_update1, ddc.get_det_engine().get_ref_psiMinv());
CHECKED_ELSE(check.result) { FAIL(check.result_message); }

check = checkMatrix(a_update1, ddc_clone_ref.get_det_engine().get_ref_psiMinv());
CHECKED_ELSE(check.result) { FAIL(check.result_message); }

ddc.mw_evalGrad(ddc_ref_list, p_ref_list, 1, grad_new);
ddc.mw_ratioGrad(ddc_ref_list, p_ref_list, 1, ratios, grad_new);
Expand All @@ -465,8 +473,11 @@ void test_DiracDeterminantBatched_delayed_update(int delay_rank, DetMatInvertor
ddc.mw_accept_rejectMove(ddc_ref_list, p_ref_list, 2, isAccepted, true);
ddc.mw_completeUpdates(ddc_ref_list);

checkMatrix(ddc.get_det_engine().get_ref_psiMinv(), orig_a);
checkMatrix(ddc_clone_ref.get_det_engine().get_ref_psiMinv(), orig_a);
check = checkMatrix(orig_a, ddc.get_det_engine().get_ref_psiMinv());
CHECKED_ELSE(check.result) { FAIL(check.result_message); }

check = checkMatrix(orig_a, ddc_clone_ref.get_det_engine().get_ref_psiMinv());
CHECKED_ELSE(check.result) { FAIL(check.result_message); }
}

TEST_CASE("DiracDeterminantBatched_delayed_update", "[wavefunction][fermion]")
Expand Down
24 changes: 16 additions & 8 deletions src/QMCWaveFunctions/tests/test_spline_applyrotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ TEST_CASE("Spline applyRotation zero rotation", "[wavefunction]")
for (int k = 0; k < orbitalsetsize; k++)
psiM_rot_manual[i][j] += psiM_bare[i][k] * rot_mat[k][j];
}
checkMatrix(psiM_rot_manual, psiM_rot);
auto check = checkMatrix(psiM_rot_manual, psiM_rot, true);
CHECKED_ELSE(check.result) { FAIL(check.result_message); }

// Check grad
SPOSet::GradMatrix dpsiM_rot_manual(elec_.R.size(), orbitalsetsize);
Expand Down Expand Up @@ -157,7 +158,8 @@ TEST_CASE("Spline applyRotation zero rotation", "[wavefunction]")
d2psiM_rot_manual[i][j] += d2psiM_bare[i][k] * rot_mat[k][j];
}

checkMatrix(d2psiM_rot_manual, d2psiM_rot);
check = checkMatrix(d2psiM_rot_manual, d2psiM_rot, true, 2e-4);
CHECKED_ELSE(check.result) { FAIL(check.result_message); }

} // TEST_CASE

Expand Down Expand Up @@ -329,7 +331,8 @@ TEST_CASE("Spline applyRotation one rotation", "[wavefunction]")
for (int k = 0; k < orbitalsetsize; k++)
psiM_rot_manual[i][j] += psiM_bare[i][k] * rot_mat[k][j];
}
checkMatrix(psiM_rot_manual, psiM_rot);
auto check = checkMatrix(psiM_rot_manual, psiM_rot, true);
CHECKED_ELSE(check.result) { FAIL(check.result_message); }

// Check grad
SPOSet::GradMatrix dpsiM_rot_manual(elec_.R.size(), orbitalsetsize);
Expand Down Expand Up @@ -363,7 +366,8 @@ TEST_CASE("Spline applyRotation one rotation", "[wavefunction]")
d2psiM_rot_manual[i][j] += d2psiM_bare[i][k] * rot_mat[k][j];
}

checkMatrix(d2psiM_rot_manual, d2psiM_rot);
check = checkMatrix(d2psiM_rot_manual, d2psiM_rot, true, 2e-4);
CHECKED_ELSE(check.result) { FAIL(check.result_message); }

} // TEST_CASE

Expand Down Expand Up @@ -630,7 +634,8 @@ TEST_CASE("Spline applyRotation two rotations", "[wavefunction]")
for (int k = 0; k < orbitalsetsize; k++)
psiM_rot_manual[i][j] += psiM_bare[i][k] * rot_mat_tot[k][j];
}
checkMatrix(psiM_rot_manual, psiM_rot);
auto check = checkMatrix(psiM_rot_manual, psiM_rot, true);
CHECKED_ELSE(check.result) { FAIL(check.result_message); }

// Check grad
SPOSet::GradMatrix dpsiM_rot_manual(elec_.R.size(), orbitalsetsize);
Expand Down Expand Up @@ -663,7 +668,8 @@ TEST_CASE("Spline applyRotation two rotations", "[wavefunction]")
for (int k = 0; k < orbitalsetsize; k++)
d2psiM_rot_manual[i][j] += d2psiM_bare[i][k] * rot_mat_tot[k][j];
}
checkMatrix(d2psiM_rot_manual, d2psiM_rot);
check = checkMatrix(d2psiM_rot_manual, d2psiM_rot, true, 2e-4);
CHECKED_ELSE(check.result) { FAIL(check.result_message); }

} // TEST_CASE

Expand Down Expand Up @@ -771,7 +777,8 @@ TEST_CASE("Spline applyRotation complex rotation", "[wavefunction]")
for (int k = 0; k < orbitalsetsize; k++)
psiM_rot_manual[i][j] += psiM_bare[i][k] * rot_mat[k][j];
}
checkMatrix(psiM_rot_manual, psiM_rot);
auto check = checkMatrix(psiM_rot_manual, psiM_rot, true);
CHECKED_ELSE(check.result) { FAIL(check.result_message); }

// Check grad
SPOSet::GradMatrix dpsiM_rot_manual(elec_.R.size(), orbitalsetsize);
Expand Down Expand Up @@ -804,7 +811,8 @@ TEST_CASE("Spline applyRotation complex rotation", "[wavefunction]")
for (int k = 0; k < orbitalsetsize; k++)
d2psiM_rot_manual[i][j] += d2psiM_bare[i][k] * rot_mat[k][j];
}
checkMatrix(d2psiM_rot_manual, d2psiM_rot);
check = checkMatrix(d2psiM_rot_manual, d2psiM_rot, true, 2e-4);
CHECKED_ELSE(check.result) { FAIL(check.result_message); }
} // TEST_CASE
#endif
} // namespace qmcplusplus
12 changes: 8 additions & 4 deletions src/Utilities/for_testing/checkMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@

namespace qmcplusplus
{
template bool approxEquality<float>(float val_a, float val_b);
template bool approxEquality<std::complex<float>>(std::complex<float> val_a, std::complex<float> val_b);
template bool approxEquality<double>(double val_a, double val_b);
template bool approxEquality<std::complex<double>>(std::complex<double> val_a, std::complex<double> val_b);
template bool approxEquality<float>(float val_a, float val_b, std::optional<double> eps);
template bool approxEquality<std::complex<float>>(std::complex<float> val_a,
std::complex<float> val_b,
std::optional<double> eps);
template bool approxEquality<double>(double val_a, double val_b, std::optional<double> eps);
template bool approxEquality<std::complex<double>>(std::complex<double> val_a,
std::complex<double> val_b,
std::optional<double> eps);
} // namespace qmcplusplus
35 changes: 25 additions & 10 deletions src/Utilities/for_testing/checkMatrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,27 @@
#include <string>
#include <complex>
#include <type_traits>
#include <optional>
#include "type_traits/complex_help.hpp"
namespace qmcplusplus
{

template<typename T, IsComplex<T> = true>
bool approxEquality(T val_a, T val_b)
bool approxEquality(T val_a, T val_b, std::optional<double> eps)
{
return val_a == ComplexApprox(val_b);
if (eps)
return val_a == ComplexApprox(val_b).epsilon(eps.value());
else
return val_a == ComplexApprox(val_b);
}

template<typename T, IsReal<T> = true>
bool approxEquality(T val_a, T val_b)
bool approxEquality(T val_a, T val_b, std::optional<double> eps)
{
return val_a == Approx(val_b);
if (eps)
return val_a == Approx(val_b).epsilon(eps.value());
else
return val_a == Approx(val_b);
}

struct CheckMatrixResult
Expand All @@ -50,9 +57,13 @@ struct CheckMatrixResult
* left block of b_mat.
* \param[in] b_mat - the matrix to check
* \param[in] check_all - if true continue to check matrix elements after failure
* \param[in] eps - add a tolerance for Catch Approx checks. Default to same as in Approx.
*/
template<class M1, class M2>
CheckMatrixResult checkMatrix(M1& a_mat, M2& b_mat, const bool check_all = false)
CheckMatrixResult checkMatrix(M1& a_mat,
M2& b_mat,
const bool check_all = false,
std::optional<const double> eps = std::nullopt)
{
// This allows use to check a padded b matrix with a nonpadded a
if (a_mat.rows() > b_mat.rows() || a_mat.cols() > b_mat.cols())
Expand All @@ -66,7 +77,7 @@ CheckMatrixResult checkMatrix(M1& a_mat, M2& b_mat, const bool check_all = false
for (int i = 0; i < a_mat.rows(); i++)
for (int j = 0; j < a_mat.cols(); j++)
{
bool approx_equality = approxEquality<typename M1::value_type>(a_mat(i, j), b_mat(i, j));
bool approx_equality = approxEquality<typename M1::value_type>(a_mat(i, j), b_mat(i, j), eps);
if (!approx_equality)
{
matrixElementError(i, j, a_mat, b_mat);
Expand All @@ -78,9 +89,13 @@ CheckMatrixResult checkMatrix(M1& a_mat, M2& b_mat, const bool check_all = false
return {all_elements_match, error_msg.str()};
}

extern template bool approxEquality<float>(float val_a, float val_b);
extern template bool approxEquality<std::complex<float>>(std::complex<float> val_a, std::complex<float> val_b);
extern template bool approxEquality<double>(double val_a, double val_b);
extern template bool approxEquality<std::complex<double>>(std::complex<double> val_a, std::complex<double> val_b);
extern template bool approxEquality<float>(float val_a, float val_b, std::optional<double> eps);
extern template bool approxEquality<std::complex<float>>(std::complex<float> val_a,
std::complex<float> val_b,
std::optional<double> eps);
extern template bool approxEquality<double>(double val_a, double val_b, std::optional<double> eps);
extern template bool approxEquality<std::complex<double>>(std::complex<double> val_a,
std::complex<double> val_b,
std::optional<double> eps);
} // namespace qmcplusplus
#endif
7 changes: 7 additions & 0 deletions src/Utilities/tests/for_testing/test_checkMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ TEST_CASE("checkMatrix_OhmmsMatrix_real", "[utilities][for_testing]")
// This would be how you would fail and print the information about what element failed.
CHECKED_ELSE(check_matrix_result.result) { FAIL(check_matrix_result.result_message); }

//check with epsilon
b_mat(0,2) = 2.6005;
check_matrix_result = checkMatrix(a_mat, b_mat, false, 1e-4);
REQUIRE(check_matrix_result.result == false);
check_matrix_result = checkMatrix(a_mat, b_mat, false, 1e-3);
REQUIRE(check_matrix_result.result == true);

b_mat.resize(4, 4);
b_mat(0, 0) = 2.3;
b_mat(0, 1) = 4.5;
Expand Down