In [1]:
from hydra import compose, initialize
from omegaconf import OmegaConf
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from grad_sample.tasks.fullsum_train import Trainer
import jax.numpy as jnp
from netket.vqs import FullSumState
import netket as nk
import flax
import optax
from netket_checkpoint._src.serializers.metropolis import serialize_MetropolisSamplerState, deserialize_MetropolisSamplerState
from advanced_drivers.driver import overdispersed_distribution, VMC_NG, statistics

if GlobalHydra().is_initialized():
    GlobalHydra().clear()

with initialize(version_base=None, config_path="./grad_sample/conf"):
    cfg = compose(config_name="qchem")
    OmegaConf.set_struct(cfg, True)
    print(cfg.task)
# trainer = Trainer(cfg)

{'_target_': 'grad_sample.tasks.fullsum_train.Trainer'}


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
base_dir = '/mnt/beegfs/home/CPHT/antoine.misery/workdir_link/qchem_comp/166630_0/L30/NNBF/256/'

ckpt = base_dir + 'MC_12_isauto/schedule_schedule/run_0.mpack'
ckpt_sampler = base_dir + 'MC_12_isauto/schedule_schedule/run_0/sampler_state_default.mpack'

ckpt_sampler = '/mnt/beegfs/home/CPHT/antoine.misery/workdir_link/qchem_comp_new/166630_0/L30/NNBF/256/MC_13_0.8/schedule_schedule/run_0/sampler_state_overdispersed.mpack'
ckpt = '/mnt/beegfs/home/CPHT/antoine.misery/workdir_link/qchem_comp_new/166630_0/L30/NNBF/256/MC_13_0.8/schedule_schedule/run_0.mpack'

In [4]:
model = instantiate(cfg.model)
ansatz = instantiate(cfg.ansatz, hilbert = model.hilbert_space)
chunk_size = cfg.get('chunk_size_vstate', None)
opt = optax.sgd(learning_rate=1)

converged SCF energy = -87.7957664369586
E(RCISD) = -87.8839403630959  E_corr = -0.08817392613732344
E(CCSD) = -87.88571506849657  E_corr = -0.08994863153799405
converged SCF energy = -87.7957664369586




In [20]:
n_s = 2**13

sampler = instantiate(cfg.sampler, hilbert=model.hilbert_space, 
                                                             graph=model.graph, 
                                                             sweep_size=model.hilbert_space.size, 
                                                             n_chains_per_rank=n_s//2,
                                                             )

vstate = nk.vqs.MCState(sampler= sampler, 
                                model=ansatz, 
                                chunk_size=chunk_size, 
                                n_samples= n_s,
                                n_discard_per_chain = 2**6
                            #  seed=0
                            )
with open(ckpt, 'rb') as f:
    vars = nk.experimental.vqs.variables_from_file(ckpt,
                                                   vstate.variables)
    # update the variables of vstate with the loaded data.
    vstate.variables = vars
with open(ckpt_sampler, 'rb') as f:
    state_dict = flax.serialization.msgpack_restore(f.read())
# vstate.sampler_state = deserialize_MetropolisSamplerState(vstate.sampler_state, state_dict)
# vstate.sampler_states['overdispersed'] = deserialize_MetropolisSamplerState(vstate.sampler_state, state_dict)

In [21]:
sampling_distribution = overdispersed_distribution(0.8)
driver = VMC_NG(hamiltonian=model.hamiltonian.to_jax_operator(), 
        optimizer=opt, 
        sampling_distribution=sampling_distribution,
        variational_state=vstate, 
        diag_shift=1, 
        )

In [22]:
driver.state.n_samples = 2**15
driver.state._samples_distributions['overdispersed']=None
grad, loss, w = driver.local_estimators()

In [14]:
driver.state.sampler_states

{'default': MetropolisSamplerState(# accepted = 58937/8847360 (0.6661535192418981%), rng state=[1133021160 3047394276]),
 'overdispersed': MetropolisSamplerState(# accepted = 1967629/8847360 (22.2397302698206%), rng state=[1133021160 3047394276])}

In [23]:
stats = statistics(loss, w)

In [24]:
print(stats.mean.real + stats.error_of_mean, stats.mean.real - stats.error_of_mean)

-87.89265362925502 -87.89265559142254
nan


In [11]:
err = stats.variance*stats.tau_corr_max

In [12]:
err

Array(1.46460856e-06, dtype=float64)