In [None]:
import deepqmc
import haiku as hk
import jax

In [None]:
from deepqmc.molecule import Molecule

mol = Molecule(  # LiH
    coords=[[0.0, 0.0, 0.0], [3.015, 0.0, 0.0]],
    charges=[3, 1],
    charge=0,
    spin=0,
    unit='bohr',
)

In [None]:
from deepqmc.hamil import MolecularHamiltonian

H = MolecularHamiltonian(mol=mol)

In [None]:
import os

import haiku as hk
from hydra import compose, initialize_config_dir
from hydra.utils import instantiate

import deepqmc
from deepqmc.app import instantiate_ansatz


deepqmc_dir = os.path.dirname(deepqmc.__file__)
config_dir = os.path.join(deepqmc_dir, 'conf/ansatz')

with initialize_config_dir(version_base=None, config_dir=config_dir):
    cfg = compose(config_name='psiformer')

_ansatz = instantiate(cfg, _recursive_=True, _convert_='all')

psiformer_ansatz = instantiate_ansatz(H, _ansatz)

In [None]:
from deepqmc.wf.nn_wave_function import eval_log_slater, Psi
from deepqmc.wf.env import ExponentialEnvelopes
from deepqmc.physics import pairwise_diffs
from deepqmc.app import instantiate_ansatz
from deepqmc.types import PhysicalConfiguration
import jax.numpy as jnp

class MyWF(hk.Module):
    def __init__(
        self,
        hamil,
    ):
        super().__init__()
        self.mol = hamil.mol
        self.n_up, self.n_down = hamil.n_up, hamil.n_down
        self.charges = hamil.mol.charges
        self.env = ExponentialEnvelopes(hamil,1,isotropic=False, per_shell=False, per_orbital_exponent=False, spin_restricted=False, init_to_ones=False, softplus_zeta=False)

    @property
    def spin_slices(self):
        return slice(None, self.n_up), slice(self.n_up, None)

    def __call__(self, phys_conf:PhysicalConfiguration, _):
        n_elec = self.n_up + self.n_down
        orb = self.env(phys_conf, None)
        elec_nuc_diffs = pairwise_diffs(phys_conf.r, phys_conf.R).reshape(n_elec, -1)
        elec_emebeddings = jnp.concatenate((elec_nuc_diffs, jnp.concatenate((jnp.ones(self.n_up), -jnp.ones(self.n_down)))[...,None]),axis=-1) 
        # TRANSFORMER BLOCK
        f = hk.Linear(self.n_up+self.n_down)(elec_emebeddings)
        orb *= f[None]
        sign_psi, log_psi = eval_log_slater(orb)
        sign_psi = jax.lax.stop_gradient(sign_psi)
        return Psi(sign_psi.squeeze(), log_psi.squeeze())


In [None]:
my_ansatz = instantiate_ansatz(H, MyWF)

In [None]:
from deepqmc.sampling import initialize_sampling, MetropolisSampler, DecorrSampler, combine_samplers
from functools import partial

elec_sampler = partial(combine_samplers, samplers=[DecorrSampler(length=20), partial(MetropolisSampler)])
sampler_factory = partial(initialize_sampling, elec_sampler=elec_sampler)

In [None]:
import os

from hydra import compose, initialize_config_dir
from hydra.utils import instantiate

deepqmc_dir = os.path.dirname(deepqmc.__file__)
config_dir = os.path.join(deepqmc_dir, 'conf/task/opt')

with initialize_config_dir(version_base=None, config_dir=config_dir):
    cfg = compose(config_name='kfac')

kfac = instantiate(cfg, _recursive_=True, _convert_='all')

In [None]:
from deepqmc.train import train
train(H, psiformer_ansatz, kfac, sampler_factory, steps=1000, electron_batch_size=2000, seed=42, workdir='tmp2')

In [None]:
import h5py
with h5py.File('tmp/training/result.h5', 'r', swmr=True) as f:
    energy = f['local_energy']['mean'][:]

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

plt.plot(energy[:,0,0])
plt.plot(pd.DataFrame(energy[:,0,0]).ewm(halflife=5).mean())
plt.xlabel('Training iteration')
plt.ylabel('Energy')
plt.show()