In [None]:
import sys
# Find jVMC package
sys.path.append(sys.path[0] + "/..")

import jax
from jax.config import config
config.update("jax_enable_x64", True)

import jax.random as random
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# install the pip package and import jVMC
!rm -rf vmc_jax
!git clone --branch dev_0.1.0 https://github.com/markusschmitt/vmc_jax.git
%cd vmc_jax
!python setup.py bdist_wheel
!python -m pip install dist/*.whl
%cd ..

sys.path.insert(0,'/content/vmc_jax/')

import jVMC

In [None]:
# DMRG energies produced with the TeNPy library https://github.com/tenpy/tenpy
DMRG_energies = {"10": -1.0545844370449059, "20":-1.0900383739, "100":-1.1194665474274852}

L = 10
g = -0.7

# Initialize net
net = jVMC.nets.CpxRBM(numHidden=8, bias=False)
params = net.init(jax.random.PRNGKey(1234), jnp.zeros((L,), dtype=np.int32))

psi = jVMC.vqs.NQS(net, params)  # Variational wave function

In [None]:
# Set up hamiltonian
hamiltonian = jVMC.operator.BranchFreeOperator()
for l in range(L-1):
    hamiltonian.add(jVMC.operator.scal_opstr(-1., (jVMC.operator.Sz(l), jVMC.operator.Sz((l + 1) % L))))
    hamiltonian.add(jVMC.operator.scal_opstr(g, (jVMC.operator.Sx(l), )))
hamiltonian.add(jVMC.operator.scal_opstr(g, (jVMC.operator.Sx(L-1), )))

In [None]:
# Set up sampler
sampler = jVMC.sampler.MCSampler(psi, (L,), random.PRNGKey(4321), updateProposer=jVMC.sampler.propose_spin_flip_Z2,
                                 numChains=100, sweepSteps=L,
                                 numSamples=5000, thermalizationSweeps=25)

# Set up TDVP
tdvpEquation = jVMC.util.tdvp.TDVP(sampler, rhsPrefactor=1.,
                                   svdTol=1e-8, diagonalShift=10, makeReal='real')

stepper = jVMC.util.stepper.Euler(timeStep=1e-2)  # ODE integrator

In [None]:
res = []
for n in range(300):

    dp, _ = stepper.step(0, tdvpEquation, psi.get_parameters(), hamiltonian=hamiltonian, psi=psi, numSamples=None)
    psi.set_parameters(dp)

    print(n, jax.numpy.real(tdvpEquation.ElocMean0) / L, tdvpEquation.ElocVar0 / L)

    res.append([n, jax.numpy.real(tdvpEquation.ElocMean0) / L, tdvpEquation.ElocVar0 / L])

In [None]:
res = np.array(res)

fig, ax = plt.subplots(2,1, sharex=True, figsize=[4.8,4.8])
if str(L) in DMRG_energies:
    ax[0].semilogy(res[:, 0], res[:, 1] - DMRG_energies[str(L)], '-', label=r"$L="+str(L)+"$")
    ax[0].set_ylabel(r'$(E-E_0)/L$')
else:
    ax[0].plot(res[:, 0], res[:, 1], '-')
    ax[0].set_ylabel(r'$E/L$')

ax[1].semilogy(res[:, 0], res[:, 2], '-')
ax[1].set_ylabel(r'Var$(E)/L$')
ax[0].legend()
plt.xlabel('iteration')
plt.tight_layout()
plt.savefig('gs_search.pdf')