In [1]:
import jax
import jax.numpy as jnp
from tqdm import tqdm
import netket as nk
import optax
import netket.jax as nkjax
# import netket_pro as nkp
import matplotlib.pyplot as plt
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

from hydra import compose, initialize
from omegaconf import OmegaConf
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate

from grad_sample.tasks.fullsum_train import Trainer
from grad_sample.utils.is_distrib import *
from grad_sample.utils.plotting_setup import *
from grad_sample.is_hpsi.expect import *
from grad_sample.is_hpsi.qgt import QGTJacobianDenseImportanceSampling
from grad_sample.is_hpsi.operator import IS_Operator
from netket.jax._jacobian.logic import _multiply_by_pdf
from grad_sample.utils.tree_op import dagger_pytree, vjp_pytree, mul_pytree, shape_tree, pytree_mean
from grad_sample.tasks.fullsum_snr_is import _compute_S_F

In [2]:
%load_ext autoreload
%autoreload 2

In [16]:
if GlobalHydra().is_initialized():
    GlobalHydra().clear()
with initialize(version_base=None, config_path="../conf"):
    cfg = compose(config_name="base")
    OmegaConf.set_struct(cfg, True)
    print(cfg)
    print(cfg.task)
    # cfg = OmegaConf.to_yaml(cfg)
    # take any task from cfg and run it
# analysis = FullSumPruning(cfg)
trainer = Trainer(cfg)

{'device': '5', 'is_mode': None, 'solver_fn': {'_target_': 'netket.optimizer.solver.cholesky'}, 'lr': 0.0022, 'diag_shift': 'schedule', 'n_iter': 6000, 'sample_size': 9, 'chunk_size_jac': 1024, 'chunk_size_vmap': 100, 'save_every': 10, 'run_index': 0, 'base_path': '/scratch/.amisery/grad_sample_fullsum/', 'model': {'_target_': 'grad_sample.models.heisenberg.XXZ', 'h': 1.5, 'L': 16}, 'ansatz': {'_target_': 'netket.models.RBM', 'alpha': 3, 'param_dtype': 'complex'}, 'task': {'_target_': 'grad_sample.tasks.fullsum_snr_is.FullSumIS'}}
{'_target_': 'grad_sample.tasks.fullsum_snr_is.FullSumIS'}
[CudaDevice(id=0)]
{'_target_': 'netket.models.RBM', 'alpha': 3, 'param_dtype': 'complex'}
MC state loaded, num samples 512
/scratch/.amisery/grad_sample_fullsum//xxz_1.5/L16/RBM/alpha3/MC_9/0.0022_schedule/run_0
The ground state energy is: -33.711056040864676


In [8]:
trainer.sample_size

10

: 

In [9]:
alpha = 0.5
num_resample = 1000
chunk_size_resample = 100
Nsample = 2**trainer.sample_size
is_op = IS_Operator(operator = trainer.model.H_jax, is_mode=alpha)
log_q, log_q_vars = is_op.get_log_importance(trainer.vstate)

compute_S_F = jax.jit(nkjax.vmap_chunked(lambda s : _compute_S_F(s, trainer.vstate._apply_fun, trainer.vstate.parameters, trainer.vstate.model_state, log_q, log_q_vars, trainer.chunk_size_jac, is_op, trainer.solver_fn, 1e-3), in_axes=0, chunk_size = chunk_size_resample))

: 

In [10]:
# with jax.checking_leaks():
# samples = trainer.vstate.sample_distribution(
# log_q,
# variables=log_q_vars, n_samples = Nsample
# )
# samples = trainer.vstate.samples

# test compared to netket when alpha is 2
# O_exp, O_grad, ng  = _compute_S_F(samples, trainer.vstate._apply_fun, trainer.vstate.parameters, trainer.vstate.model_state, log_q, log_q_vars, trainer.chunk_size_jac//2, is_op, trainer.solver_fn, trainer.diag_shift(0))
# exp_netket, grad_netket = trainer.vstate.expect_and_grad(trainer.model.H_jax)
# jax.tree_util.tree_map(lambda x,y: x/y, O_grad, grad_netket)

# batch_sample = samples.reshape((num_resample, 1, Nsample, -1))
# e, grad_e , ng = compute_S_F(batch_sample)

: 

In [11]:
samples = trainer.vstate.sample_distribution(
log_q,
variables=log_q_vars, n_samples = Nsample
)
O_exp, O_grad, ng  = _compute_S_F(samples, trainer.vstate._apply_fun, trainer.vstate.parameters, trainer.vstate.model_state, log_q, log_q_vars, trainer.chunk_size_jac//2, is_op, trainer.solver_fn, trainer.diag_shift(0))


: 

In [12]:
# estimate gradient variance with resampling
samples = trainer.vstate.sample_distribution(
log_q,
variables=log_q_vars, n_samples = Nsample * num_resample
)
batch_sample = samples.reshape((num_resample, 1, Nsample, -1))
e, grad_e , ng = compute_S_F(batch_sample)

: 

In [13]:
var_resampling = tree_map(lambda g : jnp.var(g, axis=0), grad_e)

: 

In [11]:
# test compared to theoretical variance formula
def expect_grad_var(
    force, log_psi, parameters, model_state, log_q, q_vars, operator, sigma_psi_sq, sigma_alpha, chunk_size
):
    O = operator.operator
    parameters = {"params": parameters}

    sigma_psi_sq = sigma_psi_sq.reshape(sigma_psi_sq.shape[0]*sigma_psi_sq.shape[1], -1)
    sigma_alpha = sigma_alpha.reshape(sigma_alpha.shape[0]*sigma_alpha.shape[1], -1)

    n_samples = sigma_psi_sq.shape[0]

    # Compute standard Expectation value
    log_psi_sigma_psi_sq = nkjax.apply_chunked(lambda x: log_psi(parameters, x), chunk_size=chunk_size)(sigma_psi_sq)
    log_q_sigma_psi_sq = nkjax.apply_chunked(lambda x: log_q(q_vars, x), chunk_size=chunk_size)(sigma_psi_sq)

    log_psi_sigma_alpha = nkjax.apply_chunked(lambda x: log_psi(parameters, x), chunk_size=chunk_size)(sigma_alpha)
    log_q_sigma_alpha = nkjax.apply_chunked(lambda x: log_q(q_vars, x), chunk_size=chunk_size)(sigma_alpha)

    # estimate local forces for new sample
    eta, etap_mels = O.get_conn_padded(sigma_psi_sq)
    _eta = eta.reshape(-1, eta.shape[-1])
    log_psi_eta = nkjax.apply_chunked(lambda x: log_psi(parameters, x), chunk_size=chunk_size)(
        _eta
    )

    # del _eta_Hpsi
    log_psi_eta = log_psi_eta.reshape(eta.shape[:-1])
    w_is_sigma_psi_sq = jnp.abs(jnp.exp(log_psi_sigma_psi_sq - log_q_sigma_psi_sq))**2

    w_is_sigma_alpha = jnp.abs(jnp.exp(log_psi_sigma_alpha - log_q_sigma_alpha))**2
    Z_ratio = 1/nkstats.mean(w_is_sigma_alpha)

    op_loc = jnp.sum(etap_mels * jnp.exp(log_psi_eta- jnp.expand_dims(log_psi_sigma_psi_sq, axis=-1)), axis=-1)
    O_mean = nkstats.mean(op_loc)
    op_loc_c = op_loc - O_mean

    jac_mode  = operator.mode
    # compute centered jacobian with psi squared samples
    jacobian_pytree_c = nkjax.jacobian(
        lambda w, sigma: log_psi(w, sigma),
        parameters["params"],
        sigma_psi_sq,
        model_state,
        mode = jac_mode,
        chunk_size=chunk_size,
        dense=False,
        center=True
    )
    force_pytree_unrolled = mul_pytree(dagger_pytree(jacobian_pytree_c), op_loc_c)
    loc_var = tree_map(lambda x,y: jnp.mean(jnp.broadcast_to(w_is_sigma_psi_sq * Z_ratio, y.shape) * jnp.abs(jnp.expand_dims(x.T,-1)-y)**2, axis=-1), force, force_pytree_unrolled)
    
    log_modulus_sigma = nkjax.apply_chunked(lambda x: jnp.log(jnp.abs(jnp.exp(log_psi(parameters, x)))), chunk_size=chunk_size)(sigma_psi_sq)
    log_modulus_sigma -= jnp.mean(log_modulus_sigma)
    grad_var = tree_map(lambda x,y: -jnp.mean(jnp.broadcast_to(w_is_sigma_psi_sq * Z_ratio * log_modulus_sigma, y.shape) * jnp.abs(jnp.expand_dims(x.T,-1)-y)**2, axis=-1), force, force_pytree_unrolled)

    return loc_var, grad_var

In [15]:
samples_alpha = trainer.vstate.sample_distribution(
log_q,
variables=log_q_vars, n_samples = Nsample
)
var_exact, grad_var = expect_grad_var(O_grad, trainer.vstate._apply_fun, trainer.vstate.parameters, trainer.vstate.model_state, log_q, log_q_vars, is_op, trainer.vstate.samples, samples_alpha, trainer.chunk_size_jac)

: 

In [16]:
# # try optimizing alpha starting at alpha =2
# lr=2400
# alpha_s = 2.0
# n_steps = 20
# al = []
# varl = []
# gradvarl = []
# for n in tqdm(range(n_steps)):
#     is_ops = IS_Operator(operator = trainer.model.H_jax, is_mode=alpha_s)
#     log_qs, log_qs_vars = is_ops.get_log_importance(trainer.vstate)
#     samples_alphas = trainer.vstate.sample_distribution(
#                                             log_qs,
#                                             variables=log_qs_vars, n_samples = Nsample
#                                             )

#     O_exp, O_grad, ng  = _compute_S_F(samples_alphas, trainer.vstate._apply_fun, trainer.vstate.parameters, trainer.vstate.model_state, log_q, log_q_vars, trainer.chunk_size_jac//2, is_op, trainer.solver_fn, trainer.diag_shift(0))

#     var_exact, grad_var = expect_grad_var(O_grad, trainer.vstate._apply_fun, trainer.vstate.parameters, trainer.vstate.model_state, log_qs, log_qs_vars, is_ops, trainer.vstate.samples, samples_alphas, trainer.chunk_size_jac)
#     varl.append(pytree_mean(var_exact))
#     gradvarl.append(pytree_mean(grad_var))
#     al.append(alpha_s)
#     alpha_s -= lr*pytree_mean(grad_var)

: 

In [6]:
def apply_gradient(optimizer_fun, optimizer_state, dp, params):
    updates, new_optimizer_state = optimizer_fun(dp, optimizer_state, params)

    new_params = optax.apply_updates(params, updates)

    return new_optimizer_state, new_params

In [None]:
trainer.opt.update

<function optax.schedules._inject.inject_hyperparams.<locals>.wrapped_transform.<locals>.update_fn(updates, state, params=None, **extra_args)>

: 

In [None]:
import numpy as np
from tqdm import tqdm

# Initialize variables
alpha_s = 2.0
n_iter = 1000
epsilon = 1e-8  # Small constant to prevent division by zero
grad_accum = 0.0  # Accumulated squared gradients
al = []
varl = []
gradvarl = []
is_ops = IS_Operator(operator=trainer.model.H_jax, is_mode=alpha_s)
opt_state = trainer.opt.init(trainer.vstate.parameters)
with tqdm(
                total=n_iter,
                disable=False,
                dynamic_ncols=True,
            ) as pbar:
    # Optimization loop
    for n in tqdm(range(n_iter)):
        # Create IS operator and get log importance weights
        log_qs, log_qs_vars = is_ops.get_log_importance(trainer.vstate)
        
        # Sample from the distribution
        samples_alphas = trainer.vstate.sample_distribution(
            log_qs,
            variables=log_qs_vars,
            n_samples=trainer.Nsample
        )
        
        # Compute the observable and its gradients
        O_exp, O_grad, ng = _compute_S_F(
            samples_alphas,
            trainer.vstate._apply_fun,
            trainer.vstate.parameters,
            trainer.vstate.model_state,
            log_qs,
            log_qs_vars,
            trainer.chunk_size_jac // 2,
            is_ops,
            trainer.solver_fn,
            1e-4
        )
        opt_state, trainer.vstate.parameters = apply_gradient(trainer.opt.update, opt_state, ng[0], trainer.vstate.parameters)
        pbar.set_postfix_str(
                            "E = %s ; alpha = %.2f"%(O_exp, alpha_s)
                    )
        pbar.update(1)
        if n%10 == 1:
            # Compute variance and gradient variance
            var_exact, grad_var = expect_grad_var(
                O_grad,
                trainer.vstate._apply_fun,
                trainer.vstate.parameters,
                trainer.vstate.model_state,
                log_qs,
                log_qs_vars,
                is_ops,
                trainer.vstate.samples,
                samples_alphas,
                trainer.chunk_size_jac
            )
            
            # Compute mean of variance and gradient variance
            varl.append(pytree_mean(var_exact))
            grad_var_mean = pytree_mean(grad_var)
            gradvarl.append(grad_var_mean)
            al.append(alpha_s)
            
            # Adagrad update
            grad_accum += grad_var_mean**2  # Accumulate squared gradient
            step_size = 1.0 / (np.sqrt(grad_accum) + epsilon)  # Compute adaptive step size
            alpha_s -= step_size * grad_var_mean  # Update alpha_s
            is_ops._is_mode = alpha_s

 73%|███████▎  | 732/1000 [03:41<00:56,  4.71it/s, E = -33.70668437197831 ; alpha = 0.74] 

In [1]:
fig, axes = plt.subplots()
axes.plot(al)
axes1 = axes.twinx()
axes1.plot(varl, color='red')

NameError: name 'plt' is not defined

In [None]:
import optax
from tqdm import tqdm

# Initialize variables
alpha_s = jnp.array(2.0)  # alpha_s as a JAX array for compatibility
n_steps = 20

# Define the optimizer
optimizer = optax.adagrad(learning_rate=1.0, eps=1e-8)  # Adagrad optimizer
opt_state = optimizer.init(alpha_s)  # Initialize optimizer state

al = []
varl = []
gradvarl = []

# Optimization loop
for n in tqdm(range(n_steps)):
    # Create IS operator and get log importance weights
    is_ops = IS_Operator(operator=trainer.model.H_jax, is_mode=alpha_s)
    log_qs, log_qs_vars = is_ops.get_log_importance(trainer.vstate)
    
    # Sample from the distribution
    samples_alphas = trainer.vstate.sample_distribution(
        log_qs,
        variables=log_qs_vars,
        n_samples=Nsample
    )
    
    # Compute the observable and its gradients
    O_exp, O_grad, ng = _compute_S_F(
        samples_alphas,
        trainer.vstate._apply_fun,
        trainer.vstate.parameters,
        trainer.vstate.model_state,
        log_qs,
        log_qs_vars,
        trainer.chunk_size_jac // 2,
        is_ops,
        trainer.solver_fn,
        trainer.diag_shift(0)
    )
    
    # Compute variance and gradient variance
    var_exact, grad_var = expect_grad_var(
        O_grad,
        trainer.vstate._apply_fun,
        trainer.vstate.parameters,
        trainer.vstate.model_state,
        log_qs,
        log_qs_vars,
        is_ops,
        trainer.vstate.samples,
        samples_alphas,
        trainer.chunk_size_jac
    )
    
    # Compute mean of variance and gradient variance
    varl.append(pytree_mean(var_exact))
    grad_var_mean = pytree_mean(grad_var)
    gradvarl.append(grad_var_mean)
    al.append(alpha_s)
    
    # Compute the gradient update using Optax
    updates, opt_state = optimizer.update(grad_var_mean, opt_state)
    alpha_s = optax.apply_updates(alpha_s, updates)  # Update alpha_s

  0%|          | 0/20 [00:00<?, ?it/s]

: 

In [None]:
is_ops = IS_Operator(operator = trainer.model.H_jax, is_mode=alpha_s)

: 

In [23]:
n = 2**9

: 

In [11]:
mean_ratio = pytree_mean(var_exact) / pytree_mean(var_resampling)

: 

In [12]:
print(mean_ratio)

239.12911958054127


: 

In [13]:
tree_map(lambda x,y: x/y.T/240, var_exact, var_resampling)

{'Dense': {'bias': Array([0.94798186, 0.93445518, 1.05770642, 0.9177523 , 1.03299275,
         1.03832406, 0.96197603, 1.07071494, 1.03794047, 0.97685504,
         0.96668562, 1.05117939, 0.92989883, 0.97273154, 1.03660041,
         1.03166638, 0.96192518, 0.9931266 , 0.96217681, 0.8604974 ,
         0.98463323, 0.99289561, 0.96640825, 0.99972348, 0.94932979,
         1.07249183, 0.95514303, 0.8939865 , 0.9855491 , 1.01235135,
         0.99162151, 1.1014438 , 0.91654865, 0.99931935, 0.88922758,
         1.0032648 , 0.97594523, 1.02288995, 1.0329321 , 0.96949903,
         0.97110824, 1.03157805, 1.01907588, 1.00319094, 0.99868767,
         1.02461605, 0.99448677, 0.99283738], dtype=float64),
  'kernel': Array([[0.97906245, 0.99513282, 1.02232741, 1.02042556, 0.95849349,
          1.01011261, 1.06894005, 0.96622769, 0.95855709, 1.02836845,
          0.96094697, 1.00877245, 1.0309354 , 1.00383187, 0.96546077,
          1.03172647],
         [0.94768516, 0.93217715, 0.96832394, 0.9687081 ,

: 

In [4]:
is_op = IS_Operator(operator = trainer.model.H_jax)

: 

In [5]:
# no is, calculations done with vstate
trainer.vstate.expect(trainer.model.H_jax)

13.9910-0.0023j ± 0.0023 [σ²=0.0550]

: 

In [6]:
qgt1 =QGTJacobianDenseImportanceSampling(
    importance_operator=is_op, chunk_size=trainer.chunk_size_jac
)
sr_is = nk.optimizer.SR(qgt=qgt1, diag_shift=1e-4, solver=nk.optimizer.solver.cholesky, holomorphic=True)

print("ED:", nk.exact.lanczos_ed(is_op.operator))

log = nk.logging.RuntimeLog()


To fix this, construct SR as  `SR(qgt=MyQGTType, {'diag_scale': None, 'diag_shift': 0.0})` .

  obj.__init__(*args, **kwargs)


ED: [-25.05419813]


: 

In [7]:
opt = nk.optimizer.Sgd(learning_rate=0.005)
op = trainer.model.H
sr = nk.optimizer.SR(solver=nk.optimizer.solver.cholesky, diag_shift=1e-4, holomorphic= True)
gs_is = nk.VMC(is_op, opt, variational_state=trainer.vstate, preconditioner=sr_is)
gs_is.run(n_iter=2000)
# trainer.gs.run(n_iter=100)

  0%|          | 0/2000 [00:00<?, ?it/s]

KeyboardInterrupt: 

: 

In [None]:
gs_is.state.expect(is_op)

calling IS expect function


-12.0098+0.0010j ± 0.0024 [σ²=0.0880]

: 

In [None]:
exp, force_psi = trainer.vstate.expect_and_forces(trainer.model.H_jax);

: 

In [None]:
exp, force_hpsi =  gs_is.state.expect_and_forces(is_op)

: 

In [None]:
vstate_fs = nk.vqs.FullSumState(hilbert=trainer.model.hi, model=trainer.ansatz, chunk_size=trainer.chunk_size, seed=0)

: 

In [None]:
fs_e, force_fs = vstate_fs.expect_and_forces(trainer.model.H_jax)

: 

In [None]:
fs_e

-8.999e+00-1.735e-18j ± 0.000e+00 [σ²=1.799e+01]

: 

In [None]:
jax.tree_util.tree_map(lambda x,y: jnp.abs(x/y), force_fs, force_psi)

{'Dense': {'bias': Array([0.93092402, 0.19469222, 0.60256542, 0.8036909 , 0.34496903,
         0.43552084, 0.17513618, 0.64318799, 3.97186977], dtype=float64),
  'kernel': Array([[ 14.94007571,   6.90408318,  15.42507549,   2.96411414,
           14.23575427,   4.3459032 ,  25.14661862,   7.8418824 ,
            7.68163779],
         [ 13.89436609,  18.29542463,  10.73803377,  17.07277434,
           23.18548465,  11.3972946 ,   9.46632964,  17.70094328,
            7.63903291],
         [  4.87370661,   4.4401314 ,  18.63329821,  60.03650311,
            5.54178775,   8.5246653 ,   8.05146881,   4.54279373,
           41.89261462],
         [ 12.79485815,  19.60852314,   2.85288104,  17.0839438 ,
            3.54469869,   5.41633179,   4.54571301,  23.10802942,
            1.34889913],
         [ 19.29749448,  10.68603671,  25.43181035,   7.30196509,
            7.06289746,  13.45059398,  13.43261433,  24.2749312 ,
           10.8381639 ],
         [  7.87447127,  11.598522  ,   6.868

: 

In [None]:
jax.tree_util.tree_map(lambda x,y: jnp.abs(x/y), force_fs, force_hpsi)

{'Dense': {'bias': Array([0.88612486, 0.13404208, 0.99371087, 0.84443808, 0.31374175,
         0.5260428 , 0.646892  , 0.59339698, 0.70365368], dtype=float64),
  'kernel': Array([[24.2962961 ,  8.64068893, 15.49768011, 17.08338936, 10.67443608,
           6.19249675,  4.83248955, 13.60913754,  2.91133851],
         [23.74601295, 15.53804373, 17.24533432, 16.19886708,  8.9841118 ,
           2.50129764,  9.14147717, 27.39917294, 31.73628948],
         [ 8.83096568, 17.87887555, 12.14997468, 22.53397512,  6.36807376,
          13.37405873,  3.47667187, 35.34127013, 43.8496282 ],
         [ 8.17360671, 12.03829636,  3.6821588 , 14.69951997,  3.80583798,
           8.56187697,  7.27511624, 11.03965761,  1.28822285],
         [20.3140993 ,  8.07311876, 21.15403205,  5.9070311 , 29.27741423,
           8.14823679, 16.81618266,  8.12369981, 19.69084503],
         [14.38473771, 12.96873877,  7.66741867,  8.99278708,  7.43838629,
          11.30984246, 27.4971089 ,  8.31011023,  8.82746784],
  

: 