## Bridge tutorial with real p-tVMC states

We perform the dynamics of the Heisenberg Hamiltonian $(J, h) = (0.1, 1)$ on a $4 \times 4$ lattice. The example states have been optimized with p-tVMC to approximate the dynamics of the state $| + \ldots + \rangle$ at discrete time step $\delta=0.05$. We are now going to perform bridge on these states and see the improvement to the precision of the dynamics that it provides in less that five minutes.

In [None]:
import glob

import numpy as np
import matplotlib.pyplot as plt
import jax
import qutip as qt

from tqdm.auto import tqdm

import netket as nk
import nqxpack

import bridge

In [None]:
# Sampling parameters
n_samples = 3000
sweep_size = 40
n_chains = 20
n_discard_per_chain = 200

# Times
times = np.linspace(0, 4.75, 96)

# Hamiltonian parameters
J = 0.1
h = 1.

# Graph parameters
grid_width = 4
grid_height = 4
n_qubits = grid_width * grid_height
dims_vec = [n_qubits*[2], n_qubits*[1]]

# Random seed
rng_jax = jax.random.PRNGKey(0)

# Hilbert space
graph = nk.graph.Grid([grid_width, grid_height], pbc=True)
hilbert_space = nk.hilbert.Spin(0.5, n_qubits, inverted_ordering=True)

# Operators
hamiltonian = sum([-J * nk.operator.spin.sigmaz(hilbert_space, i) * nk.operator.spin.sigmaz(hilbert_space, j) for i,j in graph.edges()])
hamiltonian += sum([-h * nk.operator.spin.sigmax(hilbert_space, i) for i in graph.nodes()])
hamiltonian = hamiltonian.to_jax_operator()
total_x = sum([nk.operator.spin.sigmax(hilbert_space, k, dtype=complex).to_pauli_strings().to_jax_operator() for k in range(n_qubits)]) * (1 / n_qubits)

In [None]:
# Load states
state_list = []
for file_name in sorted(glob.glob("example_states/state*")):
    state_list.append(nqxpack.load(file_name)['state'])

m_states = len(state_list)

if m_states == 0:
    raise Exception("Failed to load any state")
print(f"Loaded {m_states} states")

In [None]:
# Exact dynamics
psi_0 = qt.Qobj(np.ones(2**n_qubits), dims=dims_vec)
psi_0 /= psi_0.norm()
hamiltonian_qt = hamiltonian.to_qobj()
total_x_qt = total_x.to_qobj()
options = {'atol': 1e-8, 'rtol': 1e-8, 'store_states': True, 'progress_bar': 'tqdm'}

exact_dynamics_res = qt.sesolve(hamiltonian_qt, psi_0, times, e_ops=[total_x_qt], options=options)
exact_dynamics_states = exact_dynamics_res.states
exact_dynamics_states_np = np.array([state.full()[:,0] for state in exact_dynamics_states])
exact_dynamics_observables = exact_dynamics_res.expect[0]

In [None]:
# Bridge
sampling_rule = nk.sampler.rules.LocalRule()
with nk.utils.timing.timed_scope(force=True) as timer:
    bridge_states, rayleigh_matrix_estimate, info_dict = bridge.bridge(state_list, hamiltonian, times, n_samples, sweep_size, n_chains, n_discard_per_chain, sampling_rule=sampling_rule, chunk_size=1, decimal_precision_solver=None)

for key in info_dict:
    print(key, info_dict[key])

In [None]:
# Infidelity of the basis states
base_states_np = np.array([state.to_array(normalize=False) for state in state_list])
base_states_infidelities = np.real(1 - bridge.numpy_tools.fidelity(base_states_np, exact_dynamics_states_np))

# Observable of the basis states
base_states_observables_exact = np.einsum('kn,nk->k', base_states_np.conj(), (total_x.to_sparse() @ base_states_np.T)) / np.linalg.norm(base_states_np, axis=1)**2

# Infidelity of the bridge states
bridge_states_full = np.inner(bridge_states, base_states_np.T)
bridge_states_infidelity = np.real(1 - bridge.numpy_tools.fidelity(exact_dynamics_states_np, bridge_states_full))

# Observable of the bridge states
bridge_states_observables_exact = np.einsum('kn,nk->k', bridge_states_full.conj(), (total_x.to_sparse() @ bridge_states_full.T)) / np.linalg.norm(bridge_states_full, axis=1)**2

# Optimal infidelity that can be obtained within the subspace spanned by the basis states
optimal_infidelity = bridge.distance_to_subspace(exact_dynamics_states_np, base_states_np, decimal_precision=100)

There are two ways of estimating observables:

- Either we construct the linear combination state $\sum_k \alpha_k | \phi_k \rangle$ using `bridge.construct_linear_state`. Then we only have to call the method `expect` as usual.
- Or we can rely on the following formula for linear combination of states:
$$ \langle A \rangle = \frac{\alpha^\dagger G^{(A)} \alpha}{\alpha^\dagger G \alpha} $$
and compute $G$ and $G^{(A)}$ using `bridge.estimate_projected_operators_sum_of_states`.

In [None]:
# First method
bridge_observables_1 = []
bridge_observables_1_eom = []

sampler = nk.sampler.MetropolisLocal(hilbert_space)

for coefficients in tqdm(bridge_states):
    linear_combination_state = bridge.construct_linear_state(state_list, coefficients, cls=nk.vqs.MCState, sampler=sampler)
    observable_stats = linear_combination_state.expect(total_x)
    bridge_observables_1.append(observable_stats.mean.item())
    bridge_observables_1_eom.append(observable_stats.error_of_mean.item())

In [None]:
# Second method
%time gram_matrix, projected_total_x = bridge.estimate_projected_operators_sum_of_states(state_list, [total_x], sampling_rule=sampler.rule, n_samples=2**17, n_chains=2**13, n_discard_per_chain=5, chunk_size=1)

numerator = np.einsum('ki,ij,kj->k', bridge_states.conj(), projected_total_x, bridge_states)
denominator = np.einsum('ki,ij,kj->k', bridge_states.conj(), gram_matrix, bridge_states)
bridge_observables_2 = numerator / denominator

In [None]:
# Plot
fig, ax = plt.subplots(2, 1, figsize=(16,16))

ax[0].plot(times, base_states_infidelities, ls='--', marker='o', c='m', label='Basis states')
ax[0].plot(times, bridge_states_infidelity, ls='-', marker='o', label='Bridge')
ax[0].plot(times, optimal_infidelity, ls=':', marker='o', label='Optimal subspace states')

ax[0].set_yscale('log')
ax[0].set_xlabel('Time')
ax[0].set_ylabel('Infidelity error')
ax[0].grid()
ax[0].legend()


ax[1].plot(times, exact_dynamics_observables, ls='-', c='k', label='Exact')
ax[1].plot(times, base_states_observables_exact, ls='--', label='Basis states (without sampling)')
ax[1].plot(times, bridge_states_observables_exact, ls='--', label='Bridge (without sampling)')
ax[1].errorbar(times, np.real(bridge_observables_1), yerr=bridge_observables_1_eom, ls='--', marker='o', label='Bridge with method 1')
ax[1].plot(times, bridge_observables_2, ls='--', label='Bridge with method 2')

ax[1].set_xlabel('Time')
ax[1].set_ylabel('Magnetization error')
ax[1].grid()
ax[1].legend()