# Wave Function Optimization

We present here a complete example on how to use QMCTorch on a H2 molecule.
We first need to import all the relevant modules :

In [None]:
from torch import optim
from qmctorch.scf import Molecule
from qmctorch.wavefunction import SlaterJastrow
from qmctorch.solver import Solver
from qmctorch.sampler import Metropolis
from qmctorch.utils import set_torch_double_precision
from qmctorch.utils.plot_data import plot_energy
set_torch_double_precision()

## Creating the system

The first step is to define a molecule. We here use a H2 molecule with both hydrgen atoms
on the z-axis and separated by 1.38 atomic unit. We choose here to use Slater orbitals that can be otained via `ADF`. We simply here reload calculations to create the molecule

In [None]:
mol = Molecule(load='./hdf5/H2_adf_dzp.hdf5')

We then define the wave function relative to this molecule. We also specify here
the determinants we want to use in the CI expansion. We use here a to include all the single
and double excitation with 2 electrons and 2 orbitals

In [None]:
wf = SlaterJastrow(mol, configs='single_double(2,2)')

As a sampler we use a simple Metropolis Hasting with 1000 walkers. The walkers are initially localized around the atoms.
Each walker will perform 2000 steps of size 0.2 atomic unit and will only keep the last position of each walker (`ntherm=-1`).
During each move all the the electrons are moved simultaneously within a normal distribution centered around their current location.

In [None]:
sampler = Metropolis(nwalkers=5000,
                     nstep=200, step_size=0.2,
                     ntherm=-1, ndecor=100,
                     nelec=wf.nelec, init=mol.domain('atomic'),
                     move={'type': 'all-elec', 'proba': 'normal'})

We will use the ADAM optimizer implemented in pytorch with custom learning rate for each layer.
We also define a linear scheduler that will decrease the learning rate after 100 steps

In [None]:
lr_dict = [{'params': wf.jastrow.parameters(), 'lr': 1E-2},
           {'params': wf.ao.parameters(), 'lr': 1E-6},
           {'params': wf.mo.parameters(), 'lr': 2E-3},
           {'params': wf.fc.parameters(), 'lr': 2E-3}]
opt = optim.Adam(lr_dict, lr=1E-3)


A scheduler can also be used to progressively decrease the value of the learning rate during the optimization.

In [None]:
scheduler = optim.lr_scheduler.StepLR(opt, step_size=100, gamma=0.90)

We can now assemble the solver

In [None]:
solver = Solver(wf=wf, sampler=sampler, optimizer=opt, scheduler=None)

## Comfiguring the solver

Many parameters of the optimization can be controlled. We can specify which observale to track during the optimization. Here only the local energies will be recorded but one can also record the variational parameters

In [None]:
solver.configure(track=['local_energy', 'parameters'])

Some variational parameters can be frozen and therefore not optimized. We here freeze the MO coefficients and the AO parameters
and therefore only the jastrow parametres and the CI coefficients will be optmized

In [None]:
solver.configure(freeze=['ao'])

Either the mean or the variance of local energies can be used as a loss function. We choose here to minimize the energy to optimize the wave function

In [None]:
solver.configure(loss='energy')

The gradients of the wave function w.r.t. the variational parameters can be computed directly via automatic differntiation (`grad='auto'`)or manually (`grad='auto'`) via a reduced noise formula. We pick here a manual calculation

In [None]:
solver.configure(grad='manual')

We also configure the resampling so that the positions of the walkers are updated by performing
25 MC steps from their previous positions after each optimization step.

In [None]:
solver.configure(resampling={'mode': 'update',
                            'resample_every': 1,
                            'nstep_update': 25})

## Running the wave function optimization

We can now run the optimization. We use here 50 optimization steps (epoch), using all the points
in a single mini-batch.

In [None]:
obs = solver.run(50)

In [None]:
plot_energy(obs.local_energy, e0=-1.1645, show_variance=True)