In [27]:
import jax
import jax.numpy as jnp
from tqdm import tqdm
import netket as nk
import netket.jax as nkjax
# import netket_pro as nkp
import matplotlib.pyplot as plt
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from tqdm import tqdm
import jax.numpy as jnp
import jax
from grad_sample.utils.misc import compute_eloc
from grad_sample.utils.distances import curved_dist, fs_dist, param_overlap
from hydra import compose, initialize
from omegaconf import OmegaConf
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from grad_sample.tasks.fullsum_analysis import FullSumPruning
from grad_sample.tasks.fullsum_train import Trainer
from grad_sample.utils.is_distrib import *
from grad_sample.utils.plotting_setup import *
from grad_sample.is_hpsi.expect import *

In [28]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [29]:
if GlobalHydra().is_initialized():
    GlobalHydra().clear()
with initialize(version_base=None, config_path="config_xxz/.hydra/"):
    cfg = compose(config_name="config")
    OmegaConf.set_struct(cfg, True)
    print(cfg)
    print(cfg.task)
    # cfg = OmegaConf.to_yaml(cfg)
    # take any task from cfg and run it
# analysis = FullSumPruning(cfg)
trainer = Trainer(cfg)

{'device': '2', 'solver_fn': {'_target_': 'netket.optimizer.solver.cholesky'}, 'lr': 0.005, 'n_iter': 2000, 'chunk_size_jac': 529, 'chunk_size_vmap': 100, 'save_every': 10, 'sample_size': 1000, 'base_path': '/scratch/.amisery/grad_sample/', 'model': {'_target_': 'grad_sample.models.ising.TFI', 'h': 1.0, 'L': 3}, 'diag_shift': 1e-10, 'ansatz': {'_target_': 'netket.models.RBM', 'alpha': 1, 'param_dtype': 'complex'}, 'task': {'_target_': 'grad_sample.tasks.fullsum_train.Trainer'}}
{'_target_': 'grad_sample.tasks.fullsum_train.Trainer'}
{'_target_': 'netket.models.RBM', 'alpha': 1, 'param_dtype': 'complex'}
MC state loaded, num samples 99000
The ground state energy is: -12.019894878851488


In [30]:
from grad_sample.is_hpsi.operator import IS_Operator

In [31]:
is_op = IS_Operator(operator = trainer.model.H_jax)

In [32]:
# no is, calculations done with vstate
trainer.vstate.expect(trainer.model.H_jax)

-8.994-0.000j ± 0.013 [σ²=17.995]

In [33]:
opt = nk.optimizer.Sgd(learning_rate=0.05)
gs_is = nk.VMC(is_op, opt, variational_state=trainer.vstate)

In [34]:
gs_is.state.expect(is_op)

calling IS expect function


-9.061+0.000j ± 0.015 [σ²=17.513]

In [35]:
exp, force_psi = trainer.vstate.expect_and_forces(trainer.model.H_jax);

In [36]:
exp, force_hpsi =  gs_is.state.expect_and_forces(is_op)

In [37]:
vstate_fs = nk.vqs.FullSumState(hilbert=trainer.model.hi, model=trainer.ansatz, chunk_size=trainer.chunk_size, seed=0)

In [38]:
fs_e, force_fs = vstate_fs.expect_and_forces(trainer.model.H_jax)

In [39]:
fs_e

-8.999e+00+8.674e-19j ± 0.000e+00 [σ²=1.799e+01]

In [40]:
jax.tree_util.tree_map(lambda x,y: jnp.abs(x/y), force_fs, force_psi)

{'Dense': {'bias': Array([0.80200407, 0.71492321, 0.77877045, 0.72721456, 1.39375299,
         1.18401654, 0.66050903, 1.37631275, 0.87865644], dtype=float64),
  'kernel': Array([[0.98613843, 1.0004671 , 0.99714517, 0.8663101 , 0.99123658,
          1.00329971, 0.998728  , 1.0099765 , 0.94008266],
         [0.99452459, 1.01849637, 1.0151408 , 1.01237536, 0.94941663,
          0.95519367, 0.96369056, 1.00821689, 0.99276784],
         [0.99461843, 1.00263253, 1.00849126, 1.01056971, 1.03552404,
          1.00415478, 0.96224512, 0.96941046, 1.02204975],
         [0.98817727, 0.98867954, 1.06243627, 0.98409622, 1.01084305,
          0.999002  , 0.99635892, 0.98127486, 0.91137643],
         [1.00083866, 0.97943217, 1.00613207, 1.0183052 , 0.96577039,
          0.97998638, 1.08679391, 0.99392406, 1.0197336 ],
         [1.00167082, 1.01285716, 1.00265761, 0.98487206, 0.98487132,
          1.04312933, 1.01845382, 0.99788264, 1.03529333],
         [0.99622686, 0.99414851, 0.9058706 , 0.98321878

In [41]:
jax.tree_util.tree_map(lambda x,y: jnp.abs(x/y), force_fs, force_hpsi)

{'Dense': {'bias': Array([2.02971367, 0.14530042, 0.55961618, 0.58379928, 0.59781475,
         0.93865656, 0.7519104 , 0.3036536 , 0.19896532], dtype=float64),
  'kernel': Array([[0.99004084, 0.99607277, 0.96353237, 0.66489415, 1.09936149,
          1.09574807, 0.87306474, 1.15670332, 0.93469459],
         [1.0028826 , 1.03298743, 0.98477257, 0.99497506, 1.04961052,
          1.17537479, 1.00034849, 1.03625132, 1.00851914],
         [1.05827409, 1.19441705, 0.95470828, 0.98844495, 0.82317456,
          1.03392214, 1.61073845, 1.10819005, 1.00745805],
         [1.07812843, 1.05350461, 0.94263376, 1.10600695, 0.91602707,
          1.0200935 , 0.96340718, 1.07538043, 1.34716478],
         [1.01462202, 0.9989614 , 0.97581719, 1.0107022 , 0.98828183,
          1.0601734 , 0.91289388, 0.96875192, 1.01114081],
         [0.96336258, 1.05626124, 0.96720576, 1.04344203, 0.96493344,
          1.036487  , 1.06161486, 1.03615628, 0.93728672],
         [1.0339755 , 1.08054995, 1.94438036, 0.94715659