# Bridge tutorial

This tutorial demonstrates the use of bridge on a small example, using first order discretized dynamics states as a basis to make the presentation simpler. The only source of error in the basis states is therefore *discretization error*, such that we expect bridge to improve a lot. In order to simulate *optimization error*, we can add noise to the basis states, and see that as we increase this noise, the performance of bridge relative to the basis states decreases.

We note two details:
- On a small example, we must be careful to not place ourselves in a trivial case. If we have a lot of symmetries, the dimension of the effective subspace of the dynamics can become smaller than the number of basis states, in which case bridge becomes close to exact. In order to not be in this uninteresting case, we open the boundary conditions of the Ising model to decrease the number of symmetries.
- It might happen that the reported optimal infidelity within the subspace is larger than the actual performance of bridge. This is due to numerical instability in the exact calculation of the optimal infidelity within the subspace, which involves the potentially ill-conditioned inversion of the Gram matrix of the family of basis states.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import qutip as qt

from tqdm.auto import tqdm

import netket as nk
import netket_pro as nkp

import bridge

In [None]:
# Parameters for sampling the Rayleigh matrix
n_samples = 3000
sweep_size = 40
n_chains = 20
n_discard_per_chain = 200

# Noise for the basis states
base_states_noise = 1e-5

# Times
base_times = np.linspace(0, 2, 20)
bridge_times = np.linspace(0, 2, 50)
exact_times = np.sort(np.unique(np.concat((base_times, bridge_times))))

# Hamiltonian parameters
J = 0.2
h = 1.

# Graph parameters
grid_width = 3
grid_height = 3
n_qubits = grid_width * grid_height
dims_vec = [n_qubits*[2], n_qubits*[1]]

# Random seed
rng_jax = jax.random.PRNGKey(0)

# Hilbert space
graph = nk.graph.Grid([grid_width, grid_height], pbc=False)
hilbert_space = nk.hilbert.Spin(0.5, n_qubits, inverted_ordering=True)

# Operators
hamiltonian = sum([-J * nk.operator.spin.sigmaz(hilbert_space, i) * nk.operator.spin.sigmaz(hilbert_space, j) for i,j in graph.edges()])
hamiltonian += sum([-h * nk.operator.spin.sigmax(hilbert_space, i) for i in graph.nodes()])
hamiltonian = hamiltonian.to_jax_operator()
total_x = sum([nk.operator.spin.sigmax(hilbert_space, k, dtype=complex).to_pauli_strings().to_jax_operator() for k in range(n_qubits)]) * (1 / n_qubits)

hamiltonian_qt = hamiltonian.to_qobj()
total_x_qt = total_x.to_qobj()

### Exact dynamics

In [None]:
psi_0 = qt.Qobj(np.ones(2**n_qubits), dims=dims_vec)
psi_0 /= psi_0.norm()
options = {'atol': 1e-8, 'rtol': 1e-8, 'store_states': True}

exact_dynamics_res = qt.sesolve(hamiltonian_qt, psi_0, exact_times, e_ops=[total_x_qt], options=options)
exact_dynamics_states = exact_dynamics_res.states
exact_dynamics_observables = exact_dynamics_res.expect[0]

exact_dynamics_states_np = np.array([state.full()[:,0] for state in exact_dynamics_states])

### Generation of the basis states.

The basis states are constructed with $ | \psi_{k+1} \rangle = (1 - i H \delta) | \psi_k \rangle $. We then potentially add noise on these states.

In [None]:
# Generate basis states
delta = base_times[1]
m_states = len(base_times)

psi = psi_0
base_states = [psi_0]
for _ in range(m_states-1):
    psi += -1.j * delta * hamiltonian_qt * psi
    base_states.append(psi)
base_states_np = np.array([state.full()[:, 0] for state in base_states])
base_states_np += base_states_noise * np.random.random(base_states_np.shape)

sampler = nk.sampler.MetropolisLocal(hilbert_space)
model = nk.models.LogStateVector(hilbert_space)
base_states = [nk.vqs.MCState(sampler, model) for _ in range(m_states)]
for k, state in enumerate(base_states):
    state.variables = {"params": {"logstate": jnp.array(jnp.log(base_states_np[k]))}}

### Bridge

In [None]:
# Bridge
sampling_rule = nk.sampler.rules.LocalRule()
with nk.utils.timing.timed_scope(force=True) as timer:
    bridge_coefficients, rayleigh_matrix_estimate, info_dict = bridge.bridge(base_states, hamiltonian, bridge_times, n_samples, sweep_size, n_chains, n_discard_per_chain, sampling_rule=sampling_rule, chunk_size=1, decimal_precision_solver=None)

# Construct the linear combination states
bridge_vstates = [bridge.construct_linear_state(base_states, coefficient, cls=nk.vqs.MCState, n_samples=2**15) for coefficient in bridge_coefficients]

### Computing the infidelity errors

In [None]:
# Matching the times
_, base_indices = np.nonzero(np.abs(base_times[:, np.newaxis] - exact_times[np.newaxis, :]) < 1e-14)
_, bridge_indices = np.nonzero(np.abs(bridge_times[:, np.newaxis] - exact_times[np.newaxis, :]) < 1e-14)

# Infidelity of the basis states
base_states_infidelities = np.real(1 - bridge.numpy_tools.fidelity(base_states_np, exact_dynamics_states_np[base_indices]))

# Observable of the basis states
base_states_observables_exact = np.real(np.einsum('kn,nk->k', base_states_np.conj(), (total_x.to_sparse() @ base_states_np.T)) / np.linalg.norm(base_states_np, axis=1)**2)

# Infidelity of the bridge states
bridge_states_full = np.inner(bridge_coefficients, base_states_np.T)
bridge_states_infidelity = np.real(1 - bridge.numpy_tools.fidelity(bridge_states_full, exact_dynamics_states_np[bridge_indices]))

# Observable of the bridge states
bridge_states_observables_exact = np.real(np.einsum('kn,nk->k', bridge_states_full.conj(), (total_x.to_sparse() @ bridge_states_full.T)) / np.linalg.norm(bridge_states_full, axis=1)**2)

# Optimal infidelity that can be obtained within the subspace spanned by the basis states
optimal_infidelity = bridge.distance_to_subspace(exact_dynamics_states_np[bridge_indices], base_states_np, decimal_precision=100)

### Computing observables of linear combination states

There are two ways of estimating observables:

- Either we construct the linear combination state $\sum_k \alpha_k | \phi_k \rangle$ using `bridge.construct_linear_state`. Then we only have to call the method `expect` as usual.
- Or we can rely on the following formula for linear combination of states:
$$ \langle A \rangle = \frac{\alpha^\dagger G^{(A)} \alpha}{\alpha^\dagger G \alpha} $$
and compute $G$ and $G^{(A)}$ using `bridge.estimate_projected_operators_sum_of_states`.

In [None]:
# First method
bridge_observables_1 = []
bridge_observables_1_eom = []

for bridge_vstate in tqdm(bridge_vstates):
    observable_stats = bridge_vstate.expect(total_x)
    bridge_observables_1.append(np.real(observable_stats.mean.item()))
    bridge_observables_1_eom.append(observable_stats.error_of_mean.item())

In [None]:
# Second method
%time gram_matrix, projected_total_x = bridge.estimate_projected_operators_sum_of_states(base_states, [total_x], sampling_rule=sampler.rule, n_samples=2**17, n_discard_per_chain=5, chunk_size=1)

numerator = np.einsum('ki,ij,kj->k', bridge_coefficients.conj(), projected_total_x, bridge_coefficients)
denominator = np.einsum('ki,ij,kj->k', bridge_coefficients.conj(), gram_matrix, bridge_coefficients)
bridge_observables_2 = np.real(numerator / denominator)

### Plot

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(16,16))

ax[0].plot(base_times, base_states_infidelities, ls='--', marker='o', c='m', label='Basis states')
ax[0].plot(bridge_times, bridge_states_infidelity, ls='-', marker='o', label='Bridge')
ax[0].plot(bridge_times, optimal_infidelity, ls=':', marker='o', label='Optimal subspace states')

ax[0].set_yscale('log')
ax[0].set_xlabel('Time')
ax[0].set_ylabel('Infidelity error')
ax[0].grid()
ax[0].legend()

ax[1].plot(exact_times, exact_dynamics_observables, ls='-', c='k', label='Exact')
ax[1].plot(base_times, base_states_observables_exact, ls='--', label='Basis states (without sampling)')
ax[1].plot(bridge_times, bridge_states_observables_exact, ls='--', label='Bridge (without sampling)')
ax[1].errorbar(bridge_times, np.real(bridge_observables_1), yerr=bridge_observables_1_eom, ls='--', marker='o', label='Bridge with method 1')
ax[1].plot(bridge_times, bridge_observables_2, ls='--', label='Bridge with method 2')

ax[1].set_xlabel('Time')
ax[1].set_ylabel('Magnetization error')
ax[1].grid()
ax[1].legend()