# Toulouse School on Machine Learning in Quantum Many-Body Physics

## Tutorial: Dynamics with neural quantum states

Damian Hofmann

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PhilipVinc/Lectures/blob/main/2204_Toulouse-jax-netket/3-nqs-dynamics.ipynb) 

In this short tutorial session, we will follow up on the previous sessions and lectures on neural quantum states and demonstrate how to compute quantum dynamics using time-dependent variational Monte Carlo in NetKet.

## 0. Setup

To run this notebook, please install the following packages:
```
jax==0.3.4
jaxlib==0.3.2
numpy==1.21.5
netket==3.4.0
matplotlib==3.5.1
```
 
You can run this notebook in Google Colab during the class using the link above.

To install the packages in Colab you need to run the cell below. (If you are running in your own environment, you do not need to run that cell.)

In [None]:
!pip install jax==0.3.4 jaxlib==0.3.2 numpy==1.21.5 netket==3.4.0 matplotlib==3.5.1

## 1. System

As in the previous tutorial, we take the transverse-field Ising model as an example:
$$
    \hat H = \sum_{ij} \hat\sigma^z_i\hat\sigma^z_j - h \sum_i \sigma^x_i.
$$

In [None]:
# import some modules
import netket as nk
import flax.linen as nn

import jax
import jax.numpy as jnp
import numpy as np

import matplotlib.pyplot as plt
from tqdm import tqdm

from functools import partial

rng = nk.jax.PRNGSeq(123)

First, let's setup the model and a simple variational ansatz. For demonstration purposes, let's use a very small spin system with an Ising Hamiltonian:

In [None]:
L = 8
hilbert = nk.hilbert.Spin(1/2, N=L)

In [None]:
print(hilbert)

In [None]:
lat = nk.graph.Chain(length=L, pbc=True)
lat.draw()

Define two Ising Hamiltonians (to perform quenches later) as well as an observable, the magnetization along x.

In [None]:
ham = nk.operator.Ising(hilbert, lat, h=1.0)
ham

In [None]:
ham1 = nk.operator.Ising(hilbert, lat, h=0.5)
ham1

In [None]:
mag_x = sum(nk.operator.spin.sigmax(hilbert, i) for i in range(lat.n_nodes))
mag_x

We use the Jastrow ansatz from yesterday as a first example:
$$ \langle \sigma^{z}_1,\dots \sigma^{z}_N| \Psi_{\mathrm{jas}} \rangle = \cdot \exp \left( \sum_i J_1 \sigma^{z}_i\sigma^{z}_{i+1} + J_2 \sigma^{z}_i\sigma^{z}_{i+2} \right).$$

In [None]:
class JasShort(nn.Module):
    @nn.compact
    def __call__(self, x):
        
        # Define the two variational parameters J1 and J2
        j1 = self.param(
            "j1", nn.initializers.normal(), (1,), complex
        )
        j2 =self.param(
            "j2", nn.initializers.normal(), (1,), complex
        )

        # compute the nearest-neighbor correlations
        corr1=x*jnp.roll(x,-1,axis=-1)
        corr2=x*jnp.roll(x,-2,axis=-1)

        # sum the output
        return jnp.sum(j1*corr1+j2*corr2,axis=-1)

In [None]:
# Create MC state from ansatz
sampler = nk.sampler.MetropolisLocal(hilbert, n_chains=32)
ansatz = JasShort()
vstate = nk.vqs.MCState(sampler, ansatz,
                        n_samples=16000,
                        sampler_seed=rng.next(), seed=rng.next())

## 2. Time dependent variational Monte Carlo

You have heard in the lecture this morning how to do time evolution on a variational ansatz.
Otherwise, helpful references for the derivation of the TDVP equations of motion are, e.g, Yuan et al. (Quantum 3, 191, 2019), and Stokes et al. (arXiv:2203.14824).

We assume to have complex parameters $\theta$ with holomorphic mapping $\theta \mapsto \psi_\theta.$
In order to evolve the variational ansatz $$(\theta, s) \mapsto \psi_\theta(s)$$, we can locally optimize the fidelity
$$
    \max_{\delta\theta} |\langle \mathrm{e}^{-\gamma \hat H \delta t} \psi_{\theta} | \psi_{\theta + \delta\theta} \rangle|^2.
$$
Taylor expanding this condition to second order in $\delta\theta$ and $\delta t$ yields after some steps the equation of motion
$$
G(\theta) \, \dot\theta = -\gamma F(\theta, t)
$$
with the quantum geometric tensor
$$
    G_{ij}(\theta) = \frac{
       \langle\partial_i\psi_\theta | \partial_j\psi_\theta\rangle
    }{
       \langle \psi_\theta | \psi_\theta \rangle
    } - \frac{
       \langle\partial_i\psi_\theta | \psi_\theta \rangle\langle \psi_\theta | \partial_j\psi_\theta\rangle
    }{
       \langle \psi_\theta | \psi_\theta \rangle^2
    }
$$and gradient $$ F_i(\theta, t) = \frac{\partial\langle \hat H \rangle}{\partial\theta_i^*}$$
$\gamma = 1$ results in imaginary time evolution, $\gamma = \mathrm i$ gives real time evolution instead.

$G$ and $F$ can be estimated using Monte Carlo sampling: Given $s \sim |\psi_\theta(s)|^2$, we can estimate those quantities by
$$
    G_{ij} = \operatorname{cov}(o_i, o_j)
    \qquad
    F_i = \operatorname{cov}(o_i, h)
$$
whith the local energy
$$
h(s) = \frac{\langle s | \hat H | \psi_\theta \rangle}{\langle s | \psi_\theta \rangle}
    = \sum_{s'} \frac{\psi_\theta(s')}{\psi_\theta(s)} \langle s | \hat H | s' \rangle
$$
and "quantum score function"
$$
    o_j(s) = \frac{\partial\ln\psi_\theta(s)}{\partial\theta_j}.
$$

In NetKet, the quantum geometric tensor $G(\theta)$ is available from the variational state class:

In [None]:
vstate.quantum_geometric_tensor()

In [None]:
vstate.quantum_geometric_tensor().to_dense()

### 2.1 DIY time stepping loop

Let us build a very simple ODE solver based on the Euler method where, at each time step, we update our state as
$$
    \theta(t + \delta t) = \theta(t) + \dot\theta \delta t.
$$
As an interactive task, I will now give you 10-15 min of time to try and implement a solver loop that, given a Hamiltonian, vstate, initial time t0, (fixed) time step dt, end time t_end, and the factor gamma from above performs t-VMC time propagation.

Some hints:
 * `vstate.expect_and_grad` gives you the expectation value and gradient of an operator.
 * We have seen `vstate.quantum_geomtric_tensor` above.
 * A standard method for solving a linear (least-squares) equation system is `jnp.linalg.lstsq`. (There are others, which you can also use.)
 * Remember to update `vstate.parameters` is a PyTree (and you need to update it in the end).

In [None]:
# version 1, using netket.optimizer.solver.svd
def time_propagation(hamiltonian, vstate, t0, dt, t_end, gamma=1.0j):
    t = t0
    while t < t_end:
        # get energy and gradient
        E, F = vstate.expect_and_grad(hamiltonian)
        # get the QGT object from the variational state
        G = vstate.quantum_geometric_tensor()
        # multiply F by the factor -gamma
        F = jax.tree_map(lambda f: -gamma * f, F)
        # use G.solve and the SVD solver nk.optimizer.solver.svd
        dtheta, _ = G.solve(nk.optimizer.solver.svd, F)
        # apply update theta += dt * dtheta
        vstate.parameters = jax.tree_map(
            lambda x, y: x + dt * y, vstate.parameters, dtheta
        )
        t = t + dt
        yield t, vstate.expect(hamiltonian), vstate.expect(mag_x)
        
# version 2, using jnp.linalg.lstsq (which requires unpacking and repacking
# the parameters into a pytree
def time_propagation2(hamiltonian, vstate, t0, dt, t_end, gamma=1.0j):
    t = t0
    while t < t_end:
        # get energy and gradient
        E, F = vstate.expect_and_grad(hamiltonian)
        
        # convert G and F to arrays to pass to lstsq
        G = vstate.quantum_geometric_tensor().to_dense()
        # convert F to a vector; the second return value
        # is a function that can convert vectors back to the
        # pytree structure of F (which is the same structure as
        # the params
        F, unravel_params = nk.jax.tree_ravel(F)
        F *= -gamma
        
        # lstsq returns dtheta and some other stuff, which we ignore by
        # assigning them to `_*`
        # rcond cuts off very small singular values of G when solving the equation
        dtheta, *_ = jnp.linalg.lstsq(G, F, rcond=1e-14)
        # convert back to a pytree
        dtheta = unravel_params(dtheta)
        
        vstate.parameters = jax.tree_map(
            lambda x, y: x + dt * y, vstate.parameters, dtheta)
        t = t + dt
        yield t, vstate.expect(hamiltonian), vstate.expect(mag_x)

In [None]:
# We want to test it on our example system like this (gamma=-1 gives us imaginary-time propagation):
times = []
energies = []
mag = []
with tqdm(time_propagation2(ham, vstate, t0=0, dt=0.01, t_end=4, gamma=1)) as progress:
    # we make use of time, energy, and magnetization being returned from the solver loop
    for t, E, mx in progress:
        times.append(t)
        energies.append(E)
        mag.append(mx)
        progress.set_postfix(t=t)

In [None]:
# Plot the results

In [None]:
plt.plot(times, [e.mean.real for e in energies])

In [None]:
plt.plot(times, [m.mean.real for m in mag])

Save the optimal parameters we have found:

In [None]:
params0 = jax.tree_map(np.copy, vstate.parameters)

Let's try some real-time propagation, starting from the approximate state we found just now:

In [None]:
vstate.parameters = jax.tree_map(np.copy, params0)

times = []
energies = []
mag = []
with tqdm(time_propagation(ham1, vstate, t0=0, dt=0.001, t_end=1.0, gamma=1j)) as progress:
    for t, E, mx in progress:
        times.append(t)
        energies.append(E)
        mag.append(mx)
        
        progress.set_postfix(t=t, E=E)

In [None]:
plt.plot(times, [e.mean.real for e in energies])

In [None]:
plt.plot(times, [m.mean.real for m in mag])

### 2.2 NetKet TDVP driver

NetKet provides a `TDVP` driver that perform time propagation based on the same ideas we have used above, but includes a lot of feaures beyond that (in particular, Runge-Kutta adaptive and fixed step size integrators of various orders).

In [None]:
import netket.experimental as nkx

In [None]:
vstate.parameters = jax.tree_map(np.copy, params0)

integrator = nkx.dynamics.Euler(dt=0.001)

driver = nkx.TDVP(
    ham1,
    vstate,
    integrator,
    linear_solver=nk.optimizer.solver.svd,
    qgt=nk.optimizer.qgt.QGTJacobianDense(holomorphic=True),
)

log = nk.logging.RuntimeLog()
driver.run(T=1.0, obs={"mx": mag_x}, out=log)

In [None]:
plt.plot(log["Generator"]["iters"], log["Generator"]["Mean"].real)

In [None]:
plt.plot(log["mx"]["iters"], log["mx"]["Mean"].real)

### 2.3 Check result for small system

In [None]:
%pip install qutip

In [None]:
import qutip

In [None]:
hamQ = ham1.to_qobj()
mag_xQ = mag_x.to_qobj()

vstate.parameters = jax.tree_map(np.copy, params0)
psiQ = vstate.to_qobj()

In [None]:
result = qutip.sesolve(hamQ, psiQ, tlist=times, e_ops=[hamQ, mag_xQ])

In [None]:
plt.plot(log["mx"]["iters"], log["mx"]["Mean"].real)
plt.plot(result.times, result.expect[1], "k--")

It seems clear that the two-parameter Jastrow ansatz we have used is not up to the task of representing the Ising quench dynamics.

So, let's use an actual neural quantum state:

In [None]:
ansatz_nqs = nk.models.RBM(alpha=1, dtype=complex)
sampler_nqs = nk.sampler.MetropolisLocal(hilbert, n_chains=32)
vstate_nqs = nk.vqs.MCState(sampler_nqs, ansatz_nqs,
                        n_samples=1024,
                        sampler_seed=rng.next(), seed=rng.next())

In [None]:
integrator = nkx.dynamics.Heun(dt=0.01)

driver = nkx.TDVP(
    ham,
    vstate_nqs,
    integrator,
    linear_solver=nk.optimizer.solver.svd,
    qgt=nk.optimizer.qgt.QGTJacobianDense(holomorphic=True),
    propagation_type="imag",
)

log = nk.logging.RuntimeLog()
driver.run(T=5.0, obs={"mx": mag_x}, out=log)

(Since this takes a moment, I have saved the ground state locally to have the option to load it. You don't need to do this.)

In [None]:
import flax

In [None]:
# with open("NQS_Dyn_RBM1.mpack", "wb") as fp:
#     fp.write(flax.serialization.to_bytes(vstate_nqs.variables))

In [None]:
with open("NQS_Dyn_RBM1.mpack", "rb") as fp:
    vstate_nqs.variables = flax.serialization.from_bytes(vstate_nqs.variables, fp.read())

In [None]:
hamQ = ham1.to_qobj()
mag_xQ = mag_x.to_qobj()
psiQ = vstate_nqs.to_qobj()
result = qutip.sesolve(hamQ, psiQ, tlist=np.linspace(0, 1.0, 100), e_ops=[hamQ, mag_xQ])

In [None]:
integrator = nkx.dynamics.Heun(dt=0.005)
vstate_nqs.n_samples=16000
driver = nkx.TDVP(
    ham1,
    vstate_nqs,
    integrator,
    linear_solver=nk.optimizer.solver.svd,
    qgt=nk.optimizer.qgt.QGTJacobianDense(holomorphic=True),
)

log = nk.logging.RuntimeLog()
driver.run(T=1.0, obs={"mx": mag_x}, out=log)

In [None]:
plt.plot(log["mx"]["iters"], log["mx"]["Mean"].real)
plt.plot(result.times, result.expect[1], "k--")

### 2.4 Quantum geometric tensor

As its name suggests, the quantum geometric tensor $G(\theta)$ has a geometric meaning: It imposes a curvature on space of variational parameters. This curvature accounts for the fact that different directions in parameter space affect the quantum state to different degrees.

The most extreme case is a parameter corresponding to a pure gauge freedom:

In [None]:
class JasShortExtra(nn.Module):
    @nn.compact
    def __call__(self, x):
        j1 = self.param(
            "j1", nn.initializers.normal(), (1,), complex
        )
        j2 = self.param(
            "j2", nn.initializers.normal(), (1,), complex
        )
        extra = self.param(
            "extra", nn.initializers.normal(), (1,), complex
        )

        # compute the nearest-neighbor correlations
        corr1=x*jnp.roll(x,-1,axis=-1)
        corr2=x*jnp.roll(x,-2,axis=-1)

        # sum the output
        return jnp.sum(j1*corr1+j2*corr2,axis=-1) + extra

Our wave function is now
$$ \langle \sigma^{z}_1,\dots \sigma^{z}_N| \Psi_{\mathrm{jas}} \rangle = \mathtt{extra} \cdot \exp \left( \sum_i J_1 \sigma^{z}_i\sigma^{z}_{i+1} + J_2 \sigma^{z}_i\sigma^{z}_{i+2} \right),$$ which only changes norm and global phase of the quantum state.

Let's see how this affects the QGT:

In [None]:
jastrow = JasShortExtra()
vs = nk.vqs.MCState(sampler, jastrow)
G = vs.quantum_geometric_tensor().to_dense()

In [None]:
G

The gauge freedom creates a subspace that is anihilated by the QGT. Since this gauge freedom is exactly along the direction of the parameter `extra`, it is immediately visible in the QGT matrix. Generally, such redundant directions can be seen in the QGT's spectrum:

In [None]:
jnp.linalg.eigvalsh(G)

For a neural quantum state, there is no single parameter that only changes gauge degrees of freedom. The QGT is fully dense:

In [None]:
vstate_nqs.init_parameters()
G = vstate_nqs.quantum_geometric_tensor().to_dense()

In [None]:
G

In [None]:
spectrum = jnp.linalg.eigvalsh(G + 0.01 * np.eye(G.shape[0]))

In [None]:
plt.plot(spectrum)
plt.semilogy()

This is a typical example for an NQS (especially with one that has more hidden units than the shallow $\alpha = 1$ RBM): Eigenvalues of the QGT span several orders of magnitude, making the solution of the t-VMC equation sensitive to noise.

This can make it necessary to use regularization methods (diagonal shift, spectral cutoff for the QGT, or more advanced methods -- references have been given in Giuseppe's lecture.)