Skip to content

Commit

Permalink
Merge pull request #364 from nirvaank/test_test_value_many
Browse files Browse the repository at this point in the history
Test test value many
  • Loading branch information
lkwagner committed Nov 21, 2022
2 parents dd15a68 + 6b364e8 commit dc36233
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
25 changes: 25 additions & 0 deletions pyqmc/testwf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import time
import numpy as np
import pyqmc.mc as mc


def test_mask(wf, e, epos, mask=None, tolerance=1e-6):
Expand All @@ -14,6 +15,30 @@ def test_mask(wf, e, epos, mask=None, tolerance=1e-6):
assert np.all(error < tolerance)
print("testcase for test_value() with mask passed")

def test_testvalue_many(wf,configs,tol=1e-6):
"""
:parameter wf: a wave function object to be tested
:parameter configs: electron positions
:type configs: (nconf, nelec, 3) array
:returns: max abs errors
:rtype: dictionary
"""
nconf, ne, ndim = configs.configs.shape
val1 = wf.recompute(configs)
wfcopy = copy.copy(wf)

delta=1e-2
tval = np.zeros((nconf,ne))
epos = configs.make_irreducible(0, configs.configs[:, 0, :] + delta)
for e in range(ne):
tval[:,e], savedvals = wf.testvalue(e, epos)

e_all = np.arange(ne)

tmany= wfcopy.testvalue_many(e_all,epos)
terr=tmany-tval
assert np.max(np.abs(terr))<tol

def test_updateinternals(wf, configs):
"""
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_wf_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def run_tests(wf, epos, epsilon):
assert item < epsilon

testwf.test_mask(wf, 0, epos)
#testwf.test_testvalue_many(wf,epos)

for fname, func in zip(
["gradient", "laplacian", "pgradient"],
Expand Down

0 comments on commit dc36233

Please sign in to comment.