In [1]:
import jax
import jax.numpy as jnp
from tqdm import tqdm
import netket as nk
import netket.jax as nkjax
# import netket_pro as nkp
import matplotlib.pyplot as plt
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from tqdm import tqdm
import jax.numpy as jnp
import jax
from grad_sample.utils.misc import compute_eloc
from grad_sample.utils.distances import curved_dist, fs_dist, param_overlap
from hydra import compose, initialize
from omegaconf import OmegaConf
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from grad_sample.tasks.fullsum_analysis import FullSumPruning
from grad_sample.tasks.fullsum_train import Trainer
from grad_sample.utils.is_distrib import *
from grad_sample.utils.plotting_setup import *
from grad_sample.is_hpsi.expect import *
from grad_sample.is_hpsi.qgt import QGTJacobianDenseImportanceSampling
from grad_sample.is_hpsi.operator import IS_Operator

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
if GlobalHydra().is_initialized():
    GlobalHydra().clear()
with initialize(version_base=None, config_path="config_xxz/.hydra/"):
    cfg = compose(config_name="config")
    OmegaConf.set_struct(cfg, True)
    print(cfg)
    print(cfg.task)
    # cfg = OmegaConf.to_yaml(cfg)
    # take any task from cfg and run it
# analysis = FullSumPruning(cfg)
trainer = Trainer(cfg)

{'device': '2', 'solver_fn': {'_target_': 'netket.optimizer.solver.cholesky'}, 'lr': 0.005, 'n_iter': 2000, 'chunk_size_vmap': 100, 'save_every': 10, 'sample_size': 16, 'base_path': '/scratch/.amisery/grad_sample/', 'model': {'_target_': 'grad_sample.models.heisenberg.Heisenberg1d', 'J': 1.0, 'L': 14, 'sign_rule': False, 'acting_on_subspace': 0}, 'diag_shift': 0.0001, 'chunk_size_jac': 572, 'ansatz': {'_target_': 'netket.models.RBM', 'alpha': 3, 'param_dtype': 'complex'}, 'task': {'_target_': 'grad_sample.tasks.fullsum_train.Trainer'}}
{'_target_': 'grad_sample.tasks.fullsum_train.Trainer'}
{'_target_': 'netket.models.RBM', 'alpha': 3, 'param_dtype': 'complex'}
MC state loaded, num samples 10304
The ground state energy is: -25.054198134188105


In [4]:
is_op = IS_Operator(operator = trainer.model.H_jax)

In [5]:
# no is, calculations done with vstate
trainer.vstate.expect(trainer.model.H_jax)

13.9910-0.0023j ± 0.0023 [σ²=0.0550]

In [6]:
qgt1 =QGTJacobianDenseImportanceSampling(
    importance_operator=is_op, chunk_size=trainer.chunk_size_jac
)
sr_is = nk.optimizer.SR(qgt=qgt1, diag_shift=1e-4, solver=nk.optimizer.solver.cholesky, holomorphic=True)

print("ED:", nk.exact.lanczos_ed(is_op.operator))

log = nk.logging.RuntimeLog()


To fix this, construct SR as  `SR(qgt=MyQGTType, {'diag_scale': None, 'diag_shift': 0.0})` .

  obj.__init__(*args, **kwargs)


ED: [-25.05419813]


In [7]:
opt = nk.optimizer.Sgd(learning_rate=0.005)
op = trainer.model.H
sr = nk.optimizer.SR(solver=nk.optimizer.solver.cholesky, diag_shift=1e-4, holomorphic= True)
gs_is = nk.VMC(is_op, opt, variational_state=trainer.vstate, preconditioner=sr_is)
gs_is.run(n_iter=2000)
# trainer.gs.run(n_iter=100)

  0%|          | 0/2000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
gs_is.state.expect(is_op)

calling IS expect function


-12.0098+0.0010j ± 0.0024 [σ²=0.0880]

In [None]:
exp, force_psi = trainer.vstate.expect_and_forces(trainer.model.H_jax);

In [None]:
exp, force_hpsi =  gs_is.state.expect_and_forces(is_op)

In [None]:
vstate_fs = nk.vqs.FullSumState(hilbert=trainer.model.hi, model=trainer.ansatz, chunk_size=trainer.chunk_size, seed=0)

In [None]:
fs_e, force_fs = vstate_fs.expect_and_forces(trainer.model.H_jax)

In [None]:
fs_e

-8.999e+00-1.735e-18j ± 0.000e+00 [σ²=1.799e+01]

In [None]:
jax.tree_util.tree_map(lambda x,y: jnp.abs(x/y), force_fs, force_psi)

{'Dense': {'bias': Array([0.93092402, 0.19469222, 0.60256542, 0.8036909 , 0.34496903,
         0.43552084, 0.17513618, 0.64318799, 3.97186977], dtype=float64),
  'kernel': Array([[ 14.94007571,   6.90408318,  15.42507549,   2.96411414,
           14.23575427,   4.3459032 ,  25.14661862,   7.8418824 ,
            7.68163779],
         [ 13.89436609,  18.29542463,  10.73803377,  17.07277434,
           23.18548465,  11.3972946 ,   9.46632964,  17.70094328,
            7.63903291],
         [  4.87370661,   4.4401314 ,  18.63329821,  60.03650311,
            5.54178775,   8.5246653 ,   8.05146881,   4.54279373,
           41.89261462],
         [ 12.79485815,  19.60852314,   2.85288104,  17.0839438 ,
            3.54469869,   5.41633179,   4.54571301,  23.10802942,
            1.34889913],
         [ 19.29749448,  10.68603671,  25.43181035,   7.30196509,
            7.06289746,  13.45059398,  13.43261433,  24.2749312 ,
           10.8381639 ],
         [  7.87447127,  11.598522  ,   6.868

In [None]:
jax.tree_util.tree_map(lambda x,y: jnp.abs(x/y), force_fs, force_hpsi)

{'Dense': {'bias': Array([0.88612486, 0.13404208, 0.99371087, 0.84443808, 0.31374175,
         0.5260428 , 0.646892  , 0.59339698, 0.70365368], dtype=float64),
  'kernel': Array([[24.2962961 ,  8.64068893, 15.49768011, 17.08338936, 10.67443608,
           6.19249675,  4.83248955, 13.60913754,  2.91133851],
         [23.74601295, 15.53804373, 17.24533432, 16.19886708,  8.9841118 ,
           2.50129764,  9.14147717, 27.39917294, 31.73628948],
         [ 8.83096568, 17.87887555, 12.14997468, 22.53397512,  6.36807376,
          13.37405873,  3.47667187, 35.34127013, 43.8496282 ],
         [ 8.17360671, 12.03829636,  3.6821588 , 14.69951997,  3.80583798,
           8.56187697,  7.27511624, 11.03965761,  1.28822285],
         [20.3140993 ,  8.07311876, 21.15403205,  5.9070311 , 29.27741423,
           8.14823679, 16.81618266,  8.12369981, 19.69084503],
         [14.38473771, 12.96873877,  7.66741867,  8.99278708,  7.43838629,
          11.30984246, 27.4971089 ,  8.31011023,  8.82746784],
  