Skip to content

Commit

Permalink
Merge pull request #4341 from jtkrogel/nx_gpu_flags
Browse files Browse the repository at this point in the history
Nexus: add CPU/GPU flags for batched code
  • Loading branch information
ye-luo committed Nov 23, 2022
2 parents 8cc1aa2 + 83f6887 commit 7d401d9
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions nexus/lib/qmcpack_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -1940,9 +1940,9 @@ class radfunc(QIxml):
#end class radfunc

class slaterdeterminant(QIxml):
attributes = ['optimize','delay_rank']
attributes = ['optimize','delay_rank','gpu','matrix_inverter']
elements = ['determinant']
write_types = obj(optimize=yesno)
write_types = obj(optimize=yesno,gpu=yesno)
#end class slaterdeterminant

class determinant(QIxml):
Expand Down Expand Up @@ -4596,7 +4596,8 @@ def generate_bspline_builder(type = 'bspline',
spo_up = 'spo_u',
spo_down = 'spo_d',
sposets = None,
system = None
system = None,
orbitals_cpu = None,
):
tilematrix = identity(3,dtype=int)
if system!=None:
Expand All @@ -4620,7 +4621,7 @@ def generate_bspline_builder(type = 'bspline',
spindatasets = True
)
)
if sort!=None:
if sort is not None:
bsb.sort = sort
#end if
if truncate and buffer!=None:
Expand All @@ -4629,15 +4630,18 @@ def generate_bspline_builder(type = 'bspline',
if hybridrep is not None:
bsb.hybridrep = hybridrep
#end if
if twist!=None:
if twist is not None:
bsb.twistnum = system.structure.select_twist(twist)
elif twistnum!=None:
elif twistnum is not None:
bsb.twistnum = twistnum
elif len(system.structure.kpoints)==1:
bsb.twistnum = 0
else:
bsb.twistnum = None
#end if
if orbitals_cpu is not None and orbitals_cpu:
bsb.gpu = False
#end if
return bsb
#end def generate_bspline_builder

Expand Down Expand Up @@ -4744,6 +4748,7 @@ def generate_determinantset(up = 'u',
spo_down = 'spo_d',
spin_polarized = False,
delay_rank = None,
matrix_inv_cpu = None,
system = None
):
if system is None:
Expand Down Expand Up @@ -4780,6 +4785,9 @@ def generate_determinantset(up = 'u',
if delay_rank is not None:
dset.slaterdeterminant.delay_rank = delay_rank
#end if
if matrix_inv_cpu is not None and matrix_inv_cpu:
dset.slaterdeterminant.matrix_inverter = 'host'
#end if
return dset
#end def generate_determinantset

Expand Down Expand Up @@ -7192,6 +7200,8 @@ def generate_qmcpack_input(**kwargs):
J1_rcut_open = 5.0,
J2_rcut_open = 10.0,
driver = 'legacy', # legacy,batched
orbitals_cpu = None, # place/evaluate orbitals on cpu if on gpu
matrix_inv_cpu = None, # evaluate matrix inverse on cpu if on gpu
qmc = None, # opt,vmc,vmc_test,dmc,dmc_test
)

Expand Down Expand Up @@ -7343,6 +7353,7 @@ def generate_basic_input(**kwargs):
href = kw.orbitals_h5,
spin_polarized = kw.spin_polarized,
system = kw.system,
orbitals_cpu = kw.orbitals_cpu,
)
#end if
if kw.partition is None:
Expand All @@ -7358,6 +7369,7 @@ def generate_basic_input(**kwargs):
dset = generate_determinantset(
spin_polarized = kw.spin_polarized,
delay_rank = kw.delay_rank,
matrix_inv_cpu = kw.matrix_inv_cpu,
system = kw.system,
)
elif kw.det_format=='old':
Expand Down

0 comments on commit 7d401d9

Please sign in to comment.