# Variational Monte Carlo and Neural Quantum States from (almost) Scratch


Authors: Filippo Vicentini (École Polytechnique) and Alessandro Sinibaldi (EPFL)

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PhilipVinc/Lectures/blob/master/2406_LesHouches/1-tutorial_vmc.ipynb)

14 June, 2024

The objective of this TuTorial is to get you to write a **modern** Variational Monte Carlo code yourself, that you understand how it is built, such that you can become a great scientist in the future and build upon this new knowledge.
To do that, we will pick a simple problem: the 2D Transverse-Field ising Model and we will try to compute its ground-state, and later on, we will try to simulate a quench. 
For your information, regardless of what people say those days, the 2D Transverse Field Ising model's ground state at the critical point (and around) can be determined with relatively simple VMC ansatze to almost numerical precision on lattices in excess of hundreds of sites with a reasonable cost (~half a day on a GPU for ~8x8, a few extra GPUs on larger lattices). 
For the 16 sites you're playing with today you should expect simulations that run in few seconds for simple networks, and ~minute(s) for more complex networks.

Code wise, I'm a fan of **efficient** research. I think you can write a VMC code in C/++ or Rust and handwritten CUDA (note below), and if you're telling me that, you are surely capable of doing that. However nowdays, I prefer to be able to prototype an algorithm in 10 minutes, and let my GPU do some extra work for me, see that the algorithm does not work, and prototype the next algorithm in 10 minutes. 

(Note: recently Andrej Karphathy reimplemented GPT-2 in handwritten C++ and CUDA, and achieved a factor of 10 speedup over PyTorch... so there's value in handwriting things. But only if you know exactly what you want to implement! That code is an unreadable mess, and it's leveraging architecture-specific optimisations, and you cannot change the network architecture at all https://x.com/karpathy/status/1795484547267834137 )

My goal is to show you that you can write **simple implementations** of many algorithms that are efficient, and you can build upon that.
The main building block will be Automatic Differentiation, which can at times be a bit counter-intuitive, but which is extermely powerful!
To avoid spending time on some 'tough' points that are technically challenging but which have no intrinsic physical interest, we will be building upon the open source package we develop, [NetKet](https://www.netket.org/) .
In particular, the things that will be taken from NetKet are:
 - The Monte Carlo samplers: because writing one is not fun, but it is not hard either. If you want, you can reimplement them over the weekend!
 - The operators: Writing an efficient implementation of (arbitrary) sparse operators that can efficiently be indexed in a way that is relevant to VMC is again not straightfoward, and very error prone. I don't want you to spend time thinking about that today. Instead, we will use them as a 'black box' with a well defined interface. Our implementation is very easy to use and flexible, but again, you're free to reimplement it however you want in the future! It will just take you a few sleepless nights and many cups of coffee.
 - Complex value Automatic Differentiation fixes: AD in Pytorch does not work with complex numbers, while in Jax it does some funny, counter intuitive things (counter intuitive for physicists. Computer scientists will tell you that it makes perfect sense). We will be re-using some AD primitives from netket that make it more straightforward to work with complex numbers and translate the formulas on the blackboard to code, but if you're curious about what happens under the hood and why we need to do that, ask me!

-- 

We will study both the ground state and the time evolution of a paradigmatic quantum system. 

Specifically, we will consider the transverse-field Ising (TFI) model on a 2D square lattice. 
The Hamiltonian is:

$$ 
\mathcal{H}= - \Gamma\sum_{i}\sigma_{i}^{x} - V\sum_{\langle i, j \rangle}\sigma_{i}^{z}\sigma_{j}^{z}, 
$$
where $\sigma_{i}^{x/z}$ represents the $x$ and $z$ Pauli operators on the lattice site $i$. 
$\Gamma$ is the transverse magnetic field and $V$ is the coupling strength.  
The sum $\langle i, j \rangle$ runs over nearest neighbors and we assume periodic boundary conditions. 

## 0. Installing Netket 

First of all, you need to install netket. You can do so by running the following cell:

:::{note}
If you are executing this notebook on **Colab**, you will need to install NetKet. 
You can do so by uncommenting and running the following cell.

Keep in mind that this notebook was designed for NetKet version `3.11.1`, which requires Python >=3.9. If you do not have access to a recent version of Python we strongly recomend to run this notebook on google Colab.
:::

In [None]:
%pip install --quiet netket matplotlib

You can check that the installation was successful doing by importing the library:  

In [None]:
import netket as nk

## 1. Defining The Hamiltonian

The first step in our journey consists in defining the Hamiltonian we are interested in as a Netket operator. 

For this purpose, we first need to define the degrees of freedom we are dealing with (i.e. spins, bosons, fermions etc) by specifying the Hilbert space of the problem. 

For example, let us consider a $4\times4$ square lattice of spins $1/2$.

In Netket, a spin $1/2$ configuration is specified by a string $\{-1, 1\}^N$. 


In [None]:
L = 4  # side of the 2D lattice
N = L * L  # number of spins
hi = nk.hilbert.Spin(s=1 / 2, N=N)  # create the Hilbert space

Netket is largely based on `jax`, a library for accelerator-oriented array computation and program transformation.
It contains a high-performance version of all the functions in in `numpy`.

In [None]:
import jax

print(
    "An Hilbert space configuration looks like: ", hi.random_state(jax.random.key(0), 1)
)  # print a random configuration

To define the Hamiltonian, it is useful to specify an object `nk.graph` which contains all the information on the topology of the physical space where the spin are placed, namely in this case a 2D square lattice. 

In [None]:
graph = nk.graph.Square(length=L, pbc=True)  # pbc = periodic boundary conditions

Now, we can create the Hamiltonian by summing $k$-local operators that in Netket are specified as elements of the class ```LocalOperator``` (see details [here](https://netket.readthedocs.io/en/v3.12.1/api/_generated/operator/netket.operator.LocalOperator.html)). 

In particular, we need only 1-local operators of the type $ \sigma^{x}_i $ and a 2-local operators of the type $ \sigma^{z}_i \sigma^{z}_j $. 
These are already contained in the class by default and so they can be easily imported. 



In [None]:
from netket.operator.spin import sigmax, sigmaz

We first consider the interaction 2-local part. 
We run over all the possible edges of the graph to consider all the possible nearest-neighbor couplings.
Note that NetKet automatically recognizes products of local operators as tensor products. 

In [None]:
V = 1
H = sum([-V * sigmaz(hi, i) * sigmaz(hi, j) for (i, j) in graph.edges()])

We now add the 1-local terms, considering $\Gamma = 3.044 \, V$ which corresponds to the critical point of the phase transition in the 2D TFI model.

In [None]:
Gamma = 3.044 * V
H += sum([-Gamma * sigmax(hi, i) for i in range(N)])

## 2. Ground state
The first problem we want to solve is finding the ground state of the above Hamiltonian.

### 2.1 Exact benchmark 
Since we are playing with a number of spins that is still manageable with exact diagonalization methods, we can compute a benchmark for later. 

This is easily done in Netket by converting the local operator into a sparse $2^N \times 2^N$ matrix, which can be then efficiently diagonalized with methods of standard libraries such as `eigsh` in `scipy.sparse.linalg` which employs the Lanczos method. 

In [None]:
H_sp = H.to_sparse()
print("The Hamiltonian matrix has shape: ", H_sp.shape)

from scipy.sparse.linalg import eigsh

eig_vals, eig_vecs = eigsh(
    H_sp, k=2, which="SA"
)  # k is the number of eigenvalues desired,
E_gs = eig_vals[0]  # "SA" selects the ones with smallest absolute value

print("The ground state energy is:", E_gs)

### 2.2 Variational Monte Carlo

Netket finds the ground state energy of the system by running the Variational Monte Carlo (VMC) method.

The idea is to choose a variational ansatz that approximates the quantum state of the system $|\Psi\rangle \approx |\Psi_{\theta}\rangle$ and to minimize the energy computed on it with respect to its parameters $\theta$. 
Thanks to the Rayleigh-Ritz variational principle, we are guaranteed to converge to a variational approximation of the true ground state since: 
\begin{equation}
E_{\theta} = \frac{\langle \Psi_{\theta} | H | \Psi_{\theta}\rangle}{\langle \Psi_{\theta} | \Psi_{\theta}\rangle} \geq E_{\mathrm{GS}}
\end{equation}
for any choice of $\theta$. 

The Monte Carlo enters in evaluating the energy and its gradient, which is needed for the procedure of minimization. 

The choice of the ansatz is crucial and directly affect the accuracy of the variational approximation. 


#### 2.2.1 Mean-field ansatz

We start by considering the simplest ansatz, namely a mean-field state:

$$ \langle \sigma^{z}_1,\dots \sigma^{z}_N| \Psi_{\mathrm{MF}} \rangle = \Pi_{i=1}^{N} \Phi_i(\sigma^{z}_i), $$

The variational parameters are contained in the single-spin wave functions $\Phi_i(\sigma^{z}_i)$. 

For the single-spin wave functions we take a sigmoid form: 

$$ \Phi_i(\sigma^{z}_i) = 1/(1+\exp(-\lambda \sigma^z_i)), $$

thus depending on the complex-valued variational parameter $\lambda$. 
$\lambda$ can be different for different $\Phi_i$, but here we take it to be the same for simplicity. 

**IMPORTANT**: in NetKet one has to define a variational function approximating the **logarithm of the wave-function amplitudes**, and not the wave function amplitudes themselves.
Moreover, the ansatz is a function evaluated on batches of many-body configurations, namely on a (`jax.numpy`) array of dimension $(N_s, N)$ where $N_s$ is the dimension of the batch (a.k.a. number of samples) and $N$ is the number of local degrees of freedom. 

The model is defined as a class inheriting from `flax.linen.Module`, where `flax` is a library for models based on `jax`. 

In [None]:
import jax.numpy as jnp
import flax.linen as nn


class MF(nn.Module):
    # The __call__(self, x) function should take as
    # input a batch of states x.shape = (n_samples, N)
    # and should return a vector of n_samples log-amplitudes
    @nn.compact
    def __call__(self, x):
        # Extract the system size
        N = x.shape[-1]

        # A tensor of variational parameters is defined by calling
        # the method `self.param` where the arguments will be:
        lam = self.param("lambda",                     # arbitrary name used to refer to this set of parameters
                         nn.initializers.normal(1.0),  # an initializer used to provide the initial values. 1.0 is the std deviation
                         (1,),                         # The shape of the tensor
                         complex)                      # The dtype of the tensor.

        # compute the probabilities
        p = nn.log_sigmoid(lam * x)

        # sum the output
        out = 0.5 * jnp.sum(p, axis=-1)

        return out

This defines the variational ansatz, or variational wave-function.

To use it, you must do the following things:
   1) Create an instance of that class. In this case `MF()`, in other cases you might specify some hyperparameters as well. This corresponds to the parametrised function itself, or the logic, but NOT the parameters.
   2) Call the `.init(random_key, sample_input)` method, which takes as input a `jax.random.key()` rng generator to generate the initial random parameters, and  a sample input. The sample input can be anything, it just needs to have the good shape and datatype that is used to infer the shape of all parameters in the network. The output of this function is a Pytree of parameters, aka a dictionary of variables
   3) To evaluate the function, you must do `MF().apply(variables, input)`

In [None]:
# create an instance of the model
model = MF()

# generate a random sample, and create an initialization for the parameters (stored as pytree dictionary)
sample_input = hi.random_state(jax.random.key(0), 1)
params = model.init(jax.random.key(0), sample_input)

# This is how you evaluate the network for those parameters. 
model.apply(params, sample_input)

# It should work also with many input samples
sample_input = hi.random_state(jax.random.key(0), 10)
model.apply(params, sample_input)

### Sampling

In Netket we operates with Monte Carlo variational states, that are objects from which we can sample from and for which we can easy compute expectation values of operators and correspoding gradients. 
Therefore, besides the wave function, we need to construct a sampler.
While you could write your own MCMC sampling code, writing an efficient one and making sure it can scale to multiple GPUs is not straightforward (though not hard even). 
For that reason, we will be using the samplers from NetKet, of which there are several and can be found in `netket.sampler.*`, and are all instances instance of the class `nk.sampler.Sampler`.

In Netket, the samples are drawn in parallel from many chains, and so they are collected in an array of dimension `(n_chains, chain_length, N)`, such that `n_samples = n_chains * chain_length`. 

We will be using a simple sampler based on the Metropolis-Hastings algorithm, with a local transition rule that consists in flipping a single spin. 

In [None]:
# create the local Metropolis sampler on the Hilbert space
sampler = nk.sampler.MetropolisLocal(
    hi,
    n_chains=16,        # we specify 16 chains
    sweep_size=hi.size, # every sample is obtained by sweeping the sampler for hi.size=16 times, meaning that we are
                        # effectively generating a chain of length chain_length*sweep_size , and taking only 1 ever sweep size samples.
                        # this is done to reduce correlations between two samples.
)

# create the state of the sampler
sampler_state = sampler.init_state(model, params, seed=1)
sampler_state = sampler.reset(model, params, sampler_state)

# sample the configurations. 
σ, sampler_state = sampler.sample(model, params, state=sampler_state, chain_length=100)

In [None]:
print("The shape of the samples is: ", σ.shape)

And you can see that the samples will be returned as a tensor with 3 indices, where the first axis corresponds to the different chains, the second axis corresponds to the length of each chain, and the third axis corresponds to the different degrees of freedom of your system.

You can call `sampler.sample`as many time as you want in a training loop. 
However, every time you change the parameters you should call `sampler.reset`.
Depending on the sampler that you are using, this may do something (like reset some intenral caches) or nothing at all.

### Expectation Values

In Netket, you can compute expectation values of operators and the corresponding gradient using built-in functions. 
However, for educational purposes, we will write it from scratch. 

In VMC, we compute the expectation value of an operator as a statistical average of an estimator over samples drawn from the Born probability distribution of the quantum state. 

For the energy, we can write:
$$
   E_{\theta} = \frac{\langle \Psi_{\theta} | H | \Psi_{\theta} \rangle}{\langle \Psi_{\theta} | \Psi_{\theta} \rangle}  = \sum_\sigma \frac{|\Psi_{\theta}(\sigma)|^2}{\langle \Psi_{\theta} | \Psi_{\theta} \rangle} \frac{\langle \sigma | H | \Psi_{\theta} \rangle}{\langle \sigma | \Psi_{\theta} \rangle} = \mathbb{E}_{|\Psi_{\theta}(\sigma)|^2}[E_{\mathrm{loc}}(\sigma)]
$$
where $E_{\mathrm{loc}}(\sigma) = \langle \sigma | H | \Psi_{\theta} \rangle / \langle \sigma | \Psi_{\theta} \rangle$. 

The statistical mean is approximated as the empirical mean over a set of samples $\sigma^{(i)}$ drawn from $|\Psi_{\theta}(\sigma)|^2 / \langle \Psi_{\theta} | \Psi_{\theta} \rangle$ as: 
$$
E_{\theta} \approx  \frac{1}{N_s} \sum_{i=1}^{N_s} E_{\mathrm{loc}}(\sigma^{(i)}).
$$

For the gradient, we have a similar formula as a statistical average: 
$$
    \partial_k E_{\theta} = \mathbb{E}_{|\Psi_{\theta}(\sigma)|^2} \left[ (\partial_k \log\Psi_{\theta}(\sigma))^* \left( E_\text{loc}(\sigma) - \langle E \rangle\right)\right] \approx \frac{1}{N_s}\sum_i^{N_s} (\partial_k \log\Psi_{\theta}(\sigma_i))^* \left( E_\text{loc}(\sigma_i) - \langle E \rangle\right)
$$ 

To compute both the energy and the gradient, we need a function computing $E_{\mathrm{loc}}(\sigma)$ and $\partial_k \log \Psi_{\theta}(\sigma)$ for some sample batch $\sigma$. In all the following we will consider holomorphic wave functions. 

For the local energies: 
$$
E_{\mathrm{loc}}(\sigma) = \frac{\langle \sigma | H | \Psi_{\theta} \rangle}{\langle \sigma | \Psi_{\theta} \rangle} = \sum_{\eta}  H_{\sigma\eta} \frac{\Psi_{\theta}(\eta)}{\Psi_{\theta}(\sigma)}
$$
we can use the function `nk.get_conn_padded` which, given $\sigma$, computes the connected configurations $|\eta\rangle$ and the matrix elements $H_{\sigma\eta}$ such that $H_{\sigma\eta} \neq 0$.

$\partial_k \log \Psi_{\theta}(\sigma)$ is the jacobian of the function of $\log \Psi_{\theta}(\sigma): \mathbb{C}^{N_{\mathrm{pars}}} \rightarrow \mathbb{C}^{N_{\mathrm{samples}}}$. 
This can be computed using the function `jax.jacrev`. 

In [None]:
σ = σ.reshape(-1, N) # we need to reshape it to (n_chains * chain_lenght, N) = (n_samples, N)
eta, eta_H_sigma = H.get_conn_padded(σ)

print("The shape of the samples is: ", σ.shape) # (n_samples, N)
print("The shape of the connected configurations is: ", eta.shape)  # (n_samples, n_conn, N)
print("The shape of the matrix elements is: ", eta_H_sigma.shape)  # (n_samples, n_conn)

In [None]:
logpsi_fun = lambda pars : model.apply(pars, σ) # we freeze the samples so the functions depend only on the parameters
jacobian = jax.jacrev(logpsi_fun, holomorphic=True)(params)   # compute the jacobian of the log of the (holomorphic) wavefunction
print("The parameters of jastrow have shape:\n" , jax.tree.map(lambda x: x.shape, params))  # (N_par, )
print("The jacobian of jastrow have shape:\n" , jax.tree.map(lambda x: x.shape, jacobian))  # (n_samples, N_par)

In [None]:
from functools import partial 
 
# TODO

@partial(jax.jit, static_argnames='model')   # compile the function and make it faster to execute
def compute_eloc(model, parameters, ham, σ):
    # reshape the samples to have shape (n_samples, N), the samples are divided in different Markov chains
    σ = σ.reshape(-1, σ.shape[-1])

    # compute the connected configurations and the matrix elements
    eta, eta_H_sigma = ...

    # compute the local energies (in log-spacde for numerical stability)
    logpsi_eta = ...    # evaluate the wf on the samples
    logpsi_sigma = ...  # evaluate the wf on the connected configurations
    logpsi_sigma = ...  # add a dimension to match the shape of logpsi_eta
    E_loc = ...         # compute the local energies

    return E_loc

@partial(jax.jit, static_argnames='model')   # compile the function and make it faster to execute
def compute_jacobian(model, parameters, ham, σ):
    # reshape the samples, the samples are divided in different Markov chains
    σ = σ.reshape(-1, σ.shape[-1])

    # compute the dk logpsi
    logpsi_fun = ...    # we freeze the samples so the functions depend only on the parameters
    jacobian = ...  # compute the jacobian of the log of the (holomorphic) wavefunction

    return jacobian

In [None]:
# TODO 

@partial(jax.jit, static_argnames='model')   # compile the function and make it faster to execute
def estimate_energy_and_gradient(model, parameters, ham, σ):
    # reshape the samples, the samples are divided in different Markov chains
    σ = σ.reshape(-1, σ.shape[-1])

    # compute eloc
    E_loc = ...

    # compute jacobian
    jacobian = ...

    # take the number of samples
    n_samples = E_loc.shape[0]

    # compute the energy
    E_average = ...
    E_variance = ...
    E_error = ...
    E = nk.stats.Stats(mean=E_average, error_of_mean=E_error, variance=E_variance)  # create a Netket object containing the statistics

    # center the local energies
    E_loc -= E_average

    # compute the gradient as Ok.conj() @ E_loc / n_samples (operate on pytree with jax.tree.map) 
    E_grad = ...

    return E, E_grad

To work with jax, we make the hamiltonian compatible with it. 

In [None]:
H_jax = H.to_pauli_strings().to_jax_operator()

In [None]:
E, E_grad = estimate_energy_and_gradient(model, params, H_jax, σ)

print("The energy is: ", E)
print("The energy gradient is: ", E_grad)

#### 2.2.2 VMC from scratch

We will now optimize the parameters of the ansatz in order to best approximate the ground state of the Hamiltonian.

For educational purposes, we will write the VMC loop from scratch. 

To minimize the energy we can use the standard Stochastic Gradient Descent (SGD) scheme, which updates the parameters according to: 
\begin{equation}
\theta^{\mathrm{new}} = \theta^{\mathrm{old}} - \eta \nabla E(\theta^\mathrm{old}), 
\end{equation}
where $\eta$ defines the step size and is called learning rate in the Machine Learning community. 

If we repeat this scheme for many iterations we are guaranteed to converge to a local minimum of the energy. 

In [None]:
# TODO 

from tqdm import tqdm  # tqdm is just a progress bar

energies_mf = []
n_steps = ...
eta = ...

# initialize the parameters and the sampler
params = model.init(jax.random.key(0), hi.random_state(jax.random.key(0), 1))
sampler_state = sampler.init_state(model, params, seed=1)

# For every iteration
for i in tqdm(range(n_steps)):
    # sample
    sampler_state = sampler.reset(model, params, sampler_state)
    σ, sampler_state = sampler.sample(model, params, state=sampler_state, chain_length=100)

    # compute energy and gradient of the energy
    E, E_grad = ...

    # log the energy to a list
    energies_mf.append(E.mean.real)
    
    # performs the SGD update function on every element of the dictionary containing the set of parameters
    params = ...


We can now plot the energy during those optimization steps and compare it with the benchmark previously calculated:


In [None]:
import matplotlib.pyplot as plt

plt.plot(energies_mf, label="MF")
plt.ylabel("Energy")
plt.xlabel("Iteration")
plt.hlines(E_gs, 0, len(energies_mf), linestyles="--", color="black", label="ED")
plt.legend()

print(
    "The relative error with the benchmark is: ",
    jnp.absolute((energies_mf[-1] - E_gs) / E_gs),
)

#### 2.2.4. Jastrow ansatz

We can now use a more correlated ansatz than the simple mean-field. 

For instance, we can take a 2-body Jastrow ansatz entangling different spins, namely:

$$ \log \langle \sigma^{z}_1,\dots \sigma^{z}_N| \Psi_{\mathrm{Jas}} \rangle =  \sum_{i, j} J_{ij} \sigma^{z}_i\sigma^{z}_{j} $$

where the variational parameter is $J$. 

Again we can write the model using `flax`.  If you want to look at the official documentation, look [here](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html#defining-your-own-models).
In general, you must write a function that computes the formula above for an arbitrary input $x$.

In general, it is possible to implement the formula above in two ways:
 - Either we do an einstein-summation over i and j above, but the $J$ matrix must be simmetric
 - Or we just do a summation over $i<j$.

The second idea, however, while good for C++, will be very inefficient in Jax which requires vectorised operations. So the 'only' good way to do this is by constructing a J matrix that is simmetric. 
In general, if you have a matrix J of NxN parameters, the stochastic optimisation might break its hermitianity, so you need to find a way to ensure that it's hermitian.
Either you paramtrise only the $N*(N-1)/2$ actually free parameters and you construct J from that, or you parametrise a $NxN$ matrix but you use a symmetric J. Pick one! (The latter is easier than the former approach).

Note: To make your network work with inputs that have an arbitrary number of batch dimensions, like a single bitstring, or a vector of bistrings... you can replace the matrix multiplication $J@ x$ with an einstein summation `jnp.einsum('ij,...j', J, x)`  which is the same thing, but will carry over the extra dimensions in $x$ labelled by `...`.

In [None]:
from flax import linen as nn

class Jastrow(nn.Module):
    @nn.compact
    def __call__(self, x):
        # take the system size 
        N = x.shape[-1]

        # Define the variational parameter 
        Params = ...

        # Create the symmetric J matrix from the parameters
        J = ...

        # x.T@(J@x) but with einsum
        y = ... #jnp.einsum(...)

        return y
    
model = Jastrow()

Now try to run again the optimisation using this ansatz, and see if it improves on the previous results.

In [None]:
# TODO 

from tqdm import tqdm  # tqdm is just a progress bar

energies_jastrow = []

# copy from above...

# For every iteration
for i in tqdm(range(n_steps)):
    # sample

    # compute energy and gradient of the energy
    E, E_grad = ...

    # log the energy to a list
    energies_jastrow.append(E.mean.real)
    


And let's plot the eenergies again and compare

In [None]:
import matplotlib.pyplot as plt

plt.plot(energies_mf, label="MF")
plt.ylabel("Energy")
plt.xlabel("Iteration")
plt.hlines(E_gs, 0, len(energies_mf), linestyles="--", color="black", label="ED")
plt.legend()

print(
    "The relative error with the benchmark is: ",
    jnp.absolute((energies_mf[-1] - E_gs) / E_gs),
)

## Using different optimisers: Adam & Company

Until now you used a very 'simple' (stochastic) gradient descent, but there are much more advanced optimisation schemes available nowdays that rely on the gradient.
For example, we could use adaptive momentum.

To do that, we can rely on a package called [optax](https://optax.readthedocs.io) which implements several optimisers itself. 
Below, I show you a simple example of how to implement the training loop using optax. 
You should use this as a playground to try using [ADAM](https://arxiv.org/abs/1412.6980) to optimise the Jastrow (note: for jastrows it will not improve much, but for later networks, it will).

In [None]:
# TODO 
import optax

from tqdm import tqdm  # tqdm is just a progress bar

energies_jastrow_adam = []
n_steps = ...
eta = ...


# initialize the parameters and the sampler
params = model.init(jax.random.key(0), hi.random_state(jax.random.key(0), 1))
sampler_state = sampler.init_state(model, params, seed=1)

# Initialize the optimizer
optimizer = optax.adam(learning_rate=0.001)
optimizer_state = optimizer.init(params)

# For every iteration
for i in tqdm(range(n_steps)):
    # sample
    sampler_state = sampler.reset(model, params, sampler_state)
    σ, sampler_state = sampler.sample(model, params, state=sampler_state, chain_length=100)

    # compute energy and gradient of the energy
    E, E_grad = ...

    # log the energy to a list
    energies_jastrow_adam.append(E.mean.real)
    
    # performs the optax update
    updates, optimizer_state = optimizer.update(E_grad, optimizer_state)
    params = optax.apply_updates(params, updates)


In [None]:
# Now plot the energies!

###  Neural Quantum State (NQS) : A 1 layer Perceptron (Aka, Carleo's RBM) 

We now want to use a more sophisticated ansatz, based on a neural network representation of the wave function. 

We consider a two-layer fully-connected feed-forward neural network, whose mathematical expression is:

$$ \log \langle \sigma^{z}_1,\dots \sigma^{z}_N| \Psi_{\mathrm{NQS}} \rangle =  g(W_2 \cdot g (W_1 \cdot \sigma^z + b_1) + b_2)$$

where $W_1 \in \mathbb{C}^{N \times M}, b_1 \in \mathbb{C}^M, W_2 \in \mathbb{C}^{M \times 1}$ and $b_2 \in \mathbb{C}$ are the variational parameters, and $g$ is a non-linear function acting element-wise. 
$W_1$ and $b_1$ belongs to the first dense layer, while $W_2$ and $b_2$ belongs to the second.
$M$ is the hidden-unit density and parametrizes the expressibility of the ansatz. 
Here we choose $g$ to be the logcosh activation function. 

Use the functions in `flax.linen` and the activation function `nk.nn.activation.log_cosh`. 


In [None]:
# TODO 

class NQS(nn.Module):
    M: int = 1

    @nn.compact
    def __call__(self, x):
        dense1 = ...
        dense2 = ...

        # We apply the dense layer to the input
        y = dense1(x)

        # We apply the activation function
        y = ...

        # We apply the final dense layer
        y = dense2(y)[..., 0]

        # We apply the activation function
        out = ...

        # sum the output
        return out 

model = NQS(M = N)

And now try to optimise it with SGD and Adam (you can use the same training loop based on optax, and running it with `optax.sgd(learning_rate=0.01)` and `optax.adam(learning_Rate=0.01)` and compare it both with the jastrow, and among them!

## Advanced optimisation: Stochastic Reconfiguration

We then optimize it, however this time adding the Stochastic Reconfiguration (SR) preconditioner (also known as Quantum Natural Gradient). 
SR preconditiones the gradient by the substitution: 
$$
\nabla E(\theta) \longrightarrow S^{-1} \nabla E(\theta)
$$
in the SGD update scheme. 

$S$ is the so-called Quantum Geometric Tensor (QGT), and it corresponds to the metric tensor of the Fubini-Study metrics in the Hilbert space. 
The Fubini-Study metrics is the natural metrics among quantum states and it is defined as: 
$$
d_{\mathrm{FS}}(|\psi\rangle, |\phi\rangle) = \arccos\bigg(\frac{\langle \phi | \psi \rangle \langle \psi | \phi \rangle}{\langle \phi | \phi \rangle \psi \langle \psi \rangle}\bigg).
$$

Indeed, for an infinitesimal parameter change we have that: 
$$
d_{\mathrm{FS}}(|\Psi_{\theta}\rangle, |\Psi_{\theta + \delta \theta}\rangle) =  \arccos\bigg(\frac{\langle \Psi_{\theta + \delta \theta} | \Psi_{\theta} \rangle \langle \Psi_{\theta} | \Psi_{\theta + \delta \theta} \rangle}{\langle \Psi_{\theta} | \Psi_{\theta}  \rangle  \langle  \Psi_{\theta + \delta \theta} | \Psi_{\theta + \delta \theta} \rangle}\bigg) \approx \delta \theta^\dagger S \delta \theta.
$$

The QGT is defined as: 
\begin{equation}
\begin{split}
S_{ij} &= \frac{\langle \partial_i \Psi_{\theta} | \partial_j \Psi_{\theta} \rangle}{\langle \Psi_{\theta} | \Psi_{\theta} \rangle} - \frac{\langle \partial_i \Psi_{\theta} | \Psi_{\theta} \rangle}{\langle \Psi_{\theta} | \Psi_{\theta} \rangle} \frac{\langle \Psi_{\theta} | \partial_j \Psi_{\theta} \rangle}{\langle \Psi_{\theta} | \Psi_{\theta} \rangle} = \\ 
&= \sum_\sigma \frac{|\Psi_{\theta}(\sigma )|^2}{\langle\Psi_{\theta}|\Psi_{\theta}\rangle} (\partial_i \log\Psi_{\theta}(\sigma ) - \langle\partial_i \log\Psi_{\theta} \rangle)^* (\partial_j \log\Psi_{\theta}(\sigma) - \langle\partial_j \log\Psi_{\theta} \rangle) =  \\
&= \frac{1}{N_s} \sum_{m=1}^{N_s} \bigg(\partial_i \log\Psi_{\theta}(\sigma^{(m)}) - \frac{1}{N_s} \sum_{k=1}^{N_s} \partial_i \log\Psi_{\theta}(\sigma^{(k)})\bigg)^* \bigg(\partial_j \log\Psi_{\theta}(\sigma^{(m)}) - \frac{1}{N_s} \sum_{l=1}^{N_s} \partial_j \log\Psi_{\theta}(\sigma^{(l)})\bigg)
\end{split}
\end{equation}

We need a function that computes the QGT. We can reuse $\log\Psi_{\theta}(\sigma )$ from the computation of the gradient of the energy.
So we redefine the function to return also the jacobian. 

In [None]:
@partial(jax.jit, static_argnames='model')   # compile the function and make it faster to execute
def estimate_energy_and_gradient_and_jacobian(model, parameters, ham, σ):
    # reshape the samples, the samples are divided in different Markov chains
    σ = σ.reshape(-1, σ.shape[-1])

    # compute eloc
    E_loc = ...

    # compute jacobian
    jacobian = ...

    # take the number of samples
    n_samples = E_loc.shape[0]

    # compute the energy
    E_average = ...
    E_variance = ...
    E_error = ...
    E = nk.stats.Stats(mean=E_average, error_of_mean=E_error, variance=E_variance)  # create a Netket object containing the statistics

    # center the local energies
    E_loc -= E_average

    # compute the gradient as Ok.conj() @ E_loc / n_samples (operate on pytree with jax.tree.map) 
    E_grad = ...

    return E, E_grad, jacobian

For practical purposes, it is now more easy to work with dense representations of the gradients and of the jacobian: to convert them from `pytree` to `jnp.array` we can use the following function. 

In [None]:
def flatten_jacobian(pytree):
    # Apply reshape operation to each leaf of the pytree
    reshaped_pytree = jax.tree.map(lambda x: jnp.reshape(x, (x.shape[0], -1)), pytree)
    
    # Convert the pytree to a flat list of arrays
    flat_list, _ = jax.tree_util.tree_flatten(reshaped_pytree)
    
    # Stack the arrays in the list along the second dimension
    result = jnp.concatenate(flat_list, axis=-1)
    
    return result

params = model.init(jax.random.key(0), hi.random_state(jax.random.key(0), 1))
sampler_state = sampler.init_state(model, params, seed=1)
sampler_state = sampler.reset(model, params, sampler_state)
σ, sampler_state = sampler.sample(model, params, state=sampler_state, chain_length=100)
E, E_grad, jacobian = estimate_energy_and_gradient_and_jacobian(model, params, H_jax, σ)

print("The flatten jacobian has shape: ", flatten_jacobian(jacobian).shape)
print("The flatten grad has shape: ", nk.jax.tree_ravel(E_grad)[0].shape)

In [None]:
# TODO 

def SR(E_grad, jacobian):

    # convert from pytree to dense array
    E_grad_dense, unravel = nk.jax.tree_ravel(E_grad)
    jacobian_dense = flatten_jacobian(jacobian)

    # take the number of samples
    n_samples = jacobian_dense.shape[0]

    # center the jacobians
    jacobian_centered = ...

    # compute the S matrix
    S = ...

    # condition the S matrix 
    S = S + 1e-4 * jnp.eye(S.shape[0])
    
    # solve the linear system (use the system jax.scipy.sparse.linalg.cg)
    E_grad_nat = ...

    return unravel(E_grad_nat)

In [None]:
# TODO 

energies_jastrow = []
n_steps = ...
eta = ...

# initialize the parameters
model = Jastrow()
params = model.init(jax.random.key(0), hi.random_state(jax.random.key(0), 1))
sampler_state = sampler.init_state(model, params, seed=1)

# Initialize the sample
sampler = nk.sampler.MetropolisLocal(
    hi, 
    n_chains=16,    # we specify 16 chains
)
sampler_state = sampler.init_state(model, params, seed=1)

# For every iteration
for i in tqdm(range(n_steps)):

    # sample
    sampler_state = sampler.reset(model, params, sampler_state)
    σ, sampler_state = sampler.sample(model, params, state=sampler_state, chain_length=100)

    # compute energy and gradient of the energy
    E, E_grad, jacobian = ...

    # so SR 
    E_grad_nat = ...

    # log the energy to a list
    energies_jastrow.append(E.mean.real)

    # performs the SGD update function on every element of the dictionary containing the set of parameters
    params = ...

In [None]:
plt.plot(energies_mf, label="MF")
plt.plot(energies_jastrow, label="Jastrow (SR)")
plt.ylabel("Energy")
plt.xlabel("Iteration")
plt.hlines(E_gs, 0, len(energies_jastrow), linestyles="--", color="black", label="ED")
plt.legend()

print(
    "The relative error of mf is: ",
    jnp.absolute((energies_mf[-1] - E_gs) / E_gs),
)

print(
    "The relative error of Jastrow is: ",
    jnp.absolute((energies_jastrow[-1] - E_gs) / E_gs),
)

###  NQS with SR

Now let's try to compare how well training compares with and without SR, when training a NQS such as Carleo's RBM.


In [None]:
model = NQS(M = N)

We then proceed to the optimization as before. 

In [None]:
energies_nqs = []
n_steps = 200
eta = 0.01

# initialize the parameters
params = model.init(jax.random.key(0), hi.random_state(jax.random.key(0), 1))

# Initialize the sample
sampler = nk.sampler.MetropolisLocal(
    hi, 
    n_chains=16,    # we specify 16 chains
)
sampler_state = sampler.init_state(model, params, seed=1)

# For every iteration
for i in tqdm(range(n_steps)):

    # sample
    sampler_state = sampler.reset(model, params, sampler_state)
    σ, sampler_state = sampler.sample(model, params, state=sampler_state, chain_length=100)

    # compute energy and gradient of the energy
    E, E_grad, jacobian = estimate_energy_and_gradient_and_jacobian(model, params, H_jax, σ)

    # so SR 
    E_grad_nat = SR(E_grad, jacobian)

    # log the energy to a list
    energies_nqs.append(E.mean.real)

    # performs the SGD update function on every element of the dictionary containing the set of parameters
    params = jax.tree.map(lambda x, y: x - eta * y, params, E_grad_nat)

In [None]:
plt.plot(energies_mf, label="MF")
plt.plot(energies_jastrow, label="Jastrow (SR)")
plt.plot(energies_nqs, label="NQS (SR)")
plt.ylabel("Energy")
plt.xlabel("Iteration")
plt.hlines(E_gs, 0, len(energies_nqs), linestyles="--", color="black", label="ED")
plt.legend()

print(
    "The relative error of mf is: ",
    jnp.absolute((energies_mf[-1] - E_gs) / E_gs),
)

print(
    "The relative error of Jastrow is: ",
    jnp.absolute((energies_jastrow[-1] - E_gs) / E_gs),
)

print(
    "The relative error of NQS is: ",
    jnp.absolute((energies_nqs[-1] - E_gs) / E_gs),
)

In [None]:
plt.plot(energies_mf, label="MF")
plt.plot(energies_jastrow, label="Jastrow (SR)")
plt.plot(energies_nqs, label="NQS (SR)")
plt.ylabel("Energy")
plt.xlabel("Iteration")
plt.ylim(E_gs - 0.1, E_gs + 0.1)
plt.hlines(E_gs, 0, len(energies_nqs), linestyles="--", color="black", label="ED")
plt.legend()

### **NOTE**:

You can speed up the calculation of the gradient (and of the $S$ matrix) by implementing directly the vector-Jacobian product (in acronym `vjp`) instead of building the Jacobian first and then apply it to the vector of the values of $E_{\mathrm{loc}}$. 
Indeed, what `jacrev` does for building the full Jacobian matrix is vectorizing the `vjp` for each parameters.
In for the `vjp` you can use the function `nk.jax.vjp` (refinement of `jax.vjp`). 


In [None]:
from functools import partial

# compute the E_loc and center it outside
E_loc = compute_eloc(model, params, H_jax, σ)
E_loc -= jnp.mean(E_loc)

# compute the jacobian and apply it
@partial(jax.jit, static_argnames='model')   # compile the function and make it faster to execute
def build_jacobian_and_apply(model, parameters, σ, E_loc):
    #reshape the samples 
    σ = σ.reshape(-1, σ.shape[-1])

    # compute the dk logpsi
    logpsi_fun = lambda pars : model.apply(pars, σ).conj()            # we freeze the samples so the functions depend only on the parameters
    jacobian = jax.jacrev(logpsi_fun, holomorphic=True)(parameters)   # compute the jacobian of the log of the (holomorphic) wavefunction
    n_samples = E_loc

    # apply the jacobian to the local energies
    E_grad = jax.tree.map(lambda x: jnp.einsum("i..., i -> ...", x.conj(), E_loc) / n_samples, jacobian)

    return E_grad

# vector-Jacobian product
@partial(jax.jit, static_argnames='model')   # compile the function and make it faster to execute
def do_vjp(model, parameters, σ, E_loc):
    #reshape the samples 
    σ = σ.reshape(-1, σ.shape[-1])

    # compute the vjp fun
    logpsi_fun = lambda pars : model.apply(pars, σ).conj()            # we freeze the samples so the functions depend only on the parameters
    _, vjp_fun = nk.jax.vjp(logpsi_fun, parameters, conjugate=True)
    n_samples = E_loc

    # do the vector jacobian product
    E_grad = vjp_fun(E_loc)
    E_grad = jax.tree.map(lambda x: x / n_samples, E_grad)

    return E_grad

print("Build the Jacobian and apply : ")
%timeit build_jacobian_and_apply(model, params, σ, E_loc)

print("vjp : ")
%timeit do_vjp(model, params, σ, E_loc)

#

# 3. Time evolution

The second problem we want to solve is to simulate the unitary dynamics generated by the above Hamiltonian but in 1D for simplicity.

In particular, we consider a quantum quench to the critical point from the paramagnetic phase, namely we prepare the initial state to be the ground state of the TFI Hamiltonian with $V = 0$, and then we switch-on the interaction by quenching with $\{\Gamma = 1, V = 0.5\}$. 

### 3.1 Exact benchmark 


We can compute the benchmark using a second-order Taylor scheme for the unitary propagator, that is: 
\begin{equation}
U(\delta t) = e^{-i \delta t H} = 1 - i \delta t H - \frac{dt^2}{2} H^2 + O(\delta t^3)
\end{equation}

In [None]:
L = 12  # side of the 2D lattice
N = L   # number of spins
hi = nk.hilbert.Spin(s=1 / 2, N=N)  # create the Hilbert space
graph = nk.graph.Chain(length=L, pbc=True)  # pbc = periodic boundary conditions

In [None]:
def exact_dynamics(state, H_sp, ts, obs_sp):
    obs_vals = []
    dt = ts[1] - ts[0]

    for t in tqdm(ts): 
        state /= jnp.linalg.norm(state)
        obs_vals.append((state.conj() @ (obs_sp @ state)))

        Hstate = H_sp @ state
        
        state = state - 1j * dt * Hstate + 0.5 * (dt**2) * (H_sp @ Hstate)

    return state, obs_vals

During the time evolution, we monitor the average $x$-magnetization $\langle \sigma_x \rangle = \frac{1}{N} \sum_i \sigma_x^i$ as observable. 

In [None]:
obs = sum([sigmax(hi, i) for i in range(N)]) / N
obs_jax = obs.to_pauli_strings().to_jax_operator()
obs_sp = obs.to_sparse()

In [None]:
# We prepare the initial state as the ground state of H with V = 0
Gamma = 1.
H0 = sum([- Gamma * sigmax(hi, i) for i in range(N)])
H0_jax = H0.to_pauli_strings().to_jax_operator()
H0_sp = H0.to_sparse()
eig_vals, eig_vecs = eigsh(
    H0_sp, k=2, which="SA"
)
GS = eig_vecs[:, 0]
E_gs = eig_vals[0]

In [None]:
# We define the Hamiltonian for the quench with {Gamma = 6.088, V = 1} and the time parameters
V = 0.5
H = sum([- Gamma * sigmax(hi, i) for i in range(N)]) + sum([- V * sigmaz(hi, i) * sigmaz(hi, j) for (i, j) in graph.edges()])
H_sp = H.to_sparse()
H_jax = H.to_pauli_strings().to_jax_operator()
ts = jnp.linspace(0, 1.0, 1001)
dt = ts[1] - ts[0]
print("dt = ", dt)   

In [None]:
state, obs_exact = exact_dynamics(GS, H_sp, ts, obs_sp)

### 3.2 time-dependent Variational Monte Carlo (t-VMC)

Netket simulates the dynamics of a quantum system by running the time-dependent Variational Monte Carlo (t-VMC) algorithm.

t-VMC projects the time-dependent Schrödinger's equation into the variational manifold of the ansatz by exploiting variational principle for dynamics (McLachlan's, Dirac-Frenkel's or the Time-Dependent Variational Principle). 

This leads to a system of ordinary differential equations for the variational parameters whose solution describes the unitary dynamics on the manifold: 
\begin{equation}
S(\theta_t) \dot{\theta}_t = F(\theta_t), 
\end{equation}
where $S(\theta_t)$ is the Quantum Geometric Tensor (as in SR) and $F(\theta_t)$ are the forces defined as $F(\theta_t) = -i \nabla E(\theta_t)$. 

Once the system is solved for $\dot{\theta}_t$ at time $t$, then the solution is propagated with any numerical integration scheme (Euler, Heun, Runge-Kutta, ...) from $\theta_t$ to $\theta_{t+\delta t}$. 

We first prepare the ground state of the Hamiltonian for $V=0$. Take the RBM ansatz (example of one hidden layer NN) from the standard models in Netket `nk.models.RBM` with hidden density `alpha = 1`. 

In [None]:
# take the model
model = nk.models.RBM(alpha=1, param_dtype=complex)

energies_nqs = []
n_steps = 200
eta = 0.01

# initialize the parameters
params = model.init(jax.random.key(0), hi.random_state(jax.random.key(0), 1))

# Initialize the sample
sampler = nk.sampler.MetropolisLocal(
    hi, 
    n_chains=16,    # we specify 16 chains
)
sampler_state = sampler.init_state(model, params, seed=1)

# For every iteration
for i in tqdm(range(n_steps)):

    # sample
    sampler_state = sampler.reset(model, params, sampler_state)
    σ, sampler_state = sampler.sample(model, params, state=sampler_state, chain_length=50)

    # compute energy and gradient of the energy
    E, E_grad, jacobian = estimate_energy_and_gradient_and_jacobian(model, params, H0_jax, σ)

    # so SR 
    E_grad_nat = SR(E_grad, jacobian)

    # log the energy to a list
    energies_nqs.append(E.mean.real)

    # performs the SGD update function on every element of the dictionary containing the set of parameters
    params = jax.tree.map(lambda x, y: x - eta * y, params, E_grad_nat)

params_v0 = params

print(
    "The relative error of NQS is: ",
    jnp.absolute((energies_nqs[-1] - E_gs) / E_gs),
)

Write a function to estimate the expectation value of the observable with Monte Carlo sampling (same as before for the energy but without the gradient). 

In [None]:
@partial(jax.jit, static_argnames='model')   # compile the function and make it faster to execute
def estimate_observable(model, parameters, obs, σ):
    # reshape the samples, the samples are divided in different Markov chains
    σ = σ.reshape(-1, σ.shape[-1])

    O_loc = compute_eloc(model, parameters, obs, σ)

    n_samples = O_loc.shape[0]

    # compute the energy
    O_average = jnp.mean(O_loc)
    O_variance = jnp.var(O_loc)
    O_error = jnp.std(O_loc) / jnp.sqrt(n_samples)
    O = nk.stats.Stats(mean=O_average, error_of_mean=O_error, variance=O_variance)  # create an object containing the statistics

    return O

We now propagate the state by solving the tVMC equations.

In [None]:
# TODO 

@partial(jax.jit, static_argnames='model')   # compile the function and make it faster to execute
def compute_thetadot(model, params, H_jax, σ):
    # compute grad and jacobians
    E, E_grad, jacobian = ...

    # convert from pytree to dense array
    jacobian_dense = flatten_jacobian(jacobian)

    # take the number of samples
    n_samples = jacobian_dense.shape[0]

    # center the jacobians
    jacobian_centered = ...

    # compute the S matrix
    S = ...

    # convert from pytree to dense array
    F, unravel = nk.jax.tree_ravel(jax.tree.map(lambda x: -1j * x, E_grad))

    # solve the linear system
    thetadot, _ = ...         # nk.optimizer.solver.pinv_smooth

    return unravel(thetadot)

In [None]:
@partial(jax.jit, static_argnames='model')   # compile the function and make it faster to execute
def RK4_scheme(model, params, H_jax, σ):
    thetadot_1 = compute_thetadot(model, params, H_jax, σ)
    params_1 = jax.tree.map(lambda x, y: x + (dt/2) * y, params, thetadot_1)

    thetadot_2 = compute_thetadot(model, params_1, H_jax, σ)
    params_2 = jax.tree.map(lambda x, y: x + (dt/2) * y, params, thetadot_2)

    thetadot_3 = compute_thetadot(model, params_2, H_jax, σ)
    params_3 = jax.tree.map(lambda x, y: x + dt * y, params, thetadot_3)
    
    thetadot_4 = compute_thetadot(model, params_3, H_jax, σ)

    params = jax.tree.map(lambda a, x, y, w, z: a + (dt/6) * (x + 2*y + 2*w + z), params, thetadot_1, thetadot_2, thetadot_3, thetadot_4)   

    return params

In [None]:
observable_nqs = []

# initialize the parameters
params = params_v0

# Initialize the sample
sampler = nk.sampler.MetropolisLocal(
    hi, 
    n_chains=16,    # we specify 16 chains
)
sampler_state = sampler.init_state(model, params, seed=1)

# For every iteration
for t in tqdm(ts):

    # sample
    sampler_state = sampler.reset(model, params, sampler_state)
    σ, sampler_state = sampler.sample(model, params, state=sampler_state, chain_length=50)

    # update the parameters with a numerical scheme (e.g. Euler)
    params = RK4_scheme(model, params, H_jax, σ)

    # log the energy to a list
    observable_nqs.append(estimate_observable(model, params, obs_jax, σ))

In [None]:
plt.plot(ts, jnp.array(obs_exact).real, linestyle='--', color='black', label="ED")
plt.plot(ts, [x.mean.real for x in observable_nqs], label="tVMC")
plt.xlabel(r"$\langle \sigma_x(t) \rangle$")
plt.ylabel(r"$t$")
plt.legend()