Skip to content

Commit

Permalink
change geoopt api
Browse files Browse the repository at this point in the history
  • Loading branch information
NicoRenaud committed Aug 31, 2020
1 parent 0d63d56 commit 940ee7c
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 55 deletions.
94 changes: 51 additions & 43 deletions qmctorch/solver/geometry_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from .. import log


class GeoSolver(SolverOrbital):
class GeoSolver():

def __init__(self, wf=None, sampler=None, optimizer=None, scheduler=None, output=None, rank=0):
def __init__(self, solver, opt_geo=None):

SolverOrbital.__init__(self, wf, sampler,
optimizer, scheduler, output, rank)
self.solver = solver
self.opt_geo = opt_geo

def run(self, nepoch, geo_lr=1e-2, batchsize=None,
nepoch_wf_init=100, nepoch_wf_update=50,
Expand All @@ -29,86 +29,94 @@ def run(self, nepoch, geo_lr=1e-2, batchsize=None,
Defaults to half the number of epoch
"""

if not hasattr(self.observable, 'geometry'):
self.observable.geometry = []
if not hasattr(self.solver.observable, 'geometry'):
self.solver.observable.geometry = []

# save the optimizer used for the wf params
opt_wf = deepcopy(self.opt)
opt_wf = deepcopy(self.solver.opt)

# create the optmizier for the geo opt
opt_geo = optim.SGD(self.wf.parameters(), lr=geo_lr)
if self.opt_geo is None:
opt_geo = optim.SGD(
self.solver.wf.parameters(), lr=geo_lr)
else:
opt_geo = self.opt_geo

# save the grad method
eval_grad_wf = self.evaluate_gradient
eval_grad_wf = self.solver.evaluate_gradient

# log data
self.prepare_optimization(batchsize, None)
self.solver.prepare_optimization(batchsize, None)
self.log_data_geo(nepoch)

# init the traj
xyz = [self.wf.geometry(None)]
xyz = [self.solver.wf.geometry(None)]

# initial wf optimization
self.set_params_requires_grad(wf_params=True,
geo_params=False)
self.freeze_parameters(self.freeze_params_list)
self.run_epochs(nepoch_wf_init)
self.solver.set_params_requires_grad(wf_params=True,
geo_params=False)
self.solver.freeze_parameters(self.solver.freeze_params_list)
self.solver.run_epochs(nepoch_wf_init)

# iterations over geo optim
for n in range(nepoch):

# make one step geo optim
self.set_params_requires_grad(wf_params=False,
geo_params=True)
self.opt = opt_geo
self.evaluate_gradient = self.evaluate_grad_auto
self.run_epochs(1)
xyz.append(self.wf.geometry(None))
self.solver.set_params_requires_grad(wf_params=False,
geo_params=True)
self.solver.opt = opt_geo
self.solver.evaluate_gradient = self.solver.evaluate_grad_auto
self.solver.run_epochs(1)
xyz.append(self.solver.wf.geometry(None))

# make a few wf optim
self.set_params_requires_grad(wf_params=True,
geo_params=False)
self.freeze_parameters(self.freeze_params_list)
self.opt = opt_wf
self.evaluate_gradient = eval_grad_wf
self.solver.set_params_requires_grad(wf_params=True,
geo_params=False)
self.solver.freeze_parameters(
self.solver.freeze_params_list)
self.solver.opt = opt_wf
self.solver.evaluate_gradient = eval_grad_wf

cumulative_loss = self.run_epochs(nepoch_wf_update)
cumulative_loss = self.solver.run_epochs(nepoch_wf_update)

# save checkpoint file
if chkpt_every is not None:
if (n > 0) and (n % chkpt_every == 0):
self.save_checkpoint(n, cumulative_loss)
self.solver.save_checkpoint(n, cumulative_loss)

# dump
self.observable.geometry = xyz
self.observable.save(hdf5_group or 'geo_opt', self.hdf5file)
self.solver.observable.geometry = xyz
self.solver.observable.save(
hdf5_group or 'geo_opt', self.solver.hdf5file)

# save traj
filename = self.wf.mol.name + '_go_traj.xyz'
save_trajectory(filename, self.wf.atoms, xyz)
filename = self.solver.wf.mol.name + '_go_traj.xyz'
save_trajectory(filename, self.solver.wf.atoms, xyz)

return self.observable
return self.solver.observable

def log_data_geo(self, nepoch):
"""Log data for the optimization."""
log.info('')
log.info(' Optimization')
log.info(
' Number Parameters : {0}', self.wf.get_number_parameters())
' Number Parameters : {0}', self.solver.wf.get_number_parameters())
log.info(' Number of epoch : {0}', nepoch)
log.info(
' Batch size : {0}', self.sampler.get_sampling_size())
log.info(' Loss function : {0}', self.loss.method)
log.info(' Clip Loss : {0}', self.loss.clip)
log.info(' Gradients : {0}', self.grad_method)
' Batch size : {0}', self.solver.sampler.get_sampling_size())
log.info(
' Resampling mode : {0}', self.resampler.options.mode)
' Loss function : {0}', self.solver.loss.method)
log.info(' Clip Loss : {0}', self.solver.loss.clip)
log.info(
' Resampling every : {0}', self.resampler.options.resample_every)
' Gradients : {0}', self.solver.grad_method)
log.info(
' Resampling steps : {0}', self.resampler.options.nstep_update)
' Resampling mode : {0}', self.solver.resampler.options.mode)
log.info(
' Output file : {0}', self.hdf5file)
' Resampling every : {0}', self.solver.resampler.options.resample_every)
log.info(
' Checkpoint every : {0}', self.chkpt_every)
' Resampling steps : {0}', self.solver.resampler.options.nstep_update)
log.info(
' Output file : {0}', self.solver.hdf5file)
log.info(
' Checkpoint every : {0}', self.solver.chkpt_every)
log.info('')
20 changes: 10 additions & 10 deletions tests/H2_go_traj.xyz
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,26 @@ H 0.00000 0.00000 0.52918

2

H -0.00032 -0.00074 -0.00083
H 0.00154 -0.00053 0.53143
H 0.00048 0.00145 -0.00371
H 0.00075 0.00099 0.53399

2

H -0.00092 -0.00039 -0.00272
H 0.00178 -0.00043 0.53418
H 0.00040 0.00214 -0.00734
H 0.00364 0.00072 0.53632

2

H -0.00114 -0.00066 -0.00527
H 0.00096 -0.00225 0.53525
H -0.00087 0.00233 -0.00980
H 0.00293 0.00162 0.53949

2

H -0.00120 -0.00023 -0.00814
H 0.00050 -0.00236 0.53750
H -0.00092 0.00257 -0.01196
H -0.00009 0.00184 0.54237

2

H -0.00109 0.00000 -0.01155
H 0.00023 -0.00128 0.54020
H -0.00083 0.00241 -0.01513
H -0.00008 0.00144 0.54557

4 changes: 2 additions & 2 deletions tests/test_h2.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def setUp(self):
optimizer=self.opt)

# geo solver
self.geo = GeoSolver(wf=self.wf, sampler=self.sampler,
optimizer=self.opt)
opt_geo = optim.SGD(self.wf.parameters(), lr=1E-2)
self.geo = GeoSolver(self.solver, opt_geo=opt_geo)

# ground state energy
self.ground_state_energy = -1.16
Expand Down

0 comments on commit 940ee7c

Please sign in to comment.