In [1]:
from multiprocessing import Pool
from time import time

import ipywidgets as widgets
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pennylane as qml
import scipy
from IPython.display import clear_output, display
from ipywidgets import FloatSlider, interact
from pennylane.devices.qubit import apply_operation, create_initial_state
from scipy.sparse import coo_matrix, spmatrix

In [77]:
def tfim_hamiltonian(n, J=1.0, h=1.0):
    coeffs = []
    ops = []

    # ZZ interaction terms
    for i in range(n - 1):
        coeffs.append(J)
        ops.append(qml.PauliZ(i) @ qml.PauliZ(i + 1))

    # Transverse field terms (X terms)
    # for i in range(n):
    #     coeffs.append(h)
    #     ops.append(qml.PauliX(i))

    return qml.Hamiltonian(coeffs, ops)


n_qubits = 12
J = 2.0
h = 0.5
H = tfim_hamiltonian(n_qubits, J=J, h=h)

## Exact

In [78]:
start = time()
t = np.pi / 2
U = jax.scipy.linalg.expm(-1j * t * qml.matrix(H))

# evolve equal superposition state
psi_0 = np.ones(2**n_qubits)/np.sqrt(2**n_qubits)
sv_exact = U @ psi_0

print(f"Time evolution duration: {time() - start:.4f} sec")
sv_exact.shape

Time evolution duration: 0.3685 sec


(4096,)

## Trotter Sampling

In [104]:
n_shots = 3000
key = jax.random.PRNGKey(0)
dev = qml.device("default.qubit", wires=n_qubits, shots=n_shots, seed=key)

n_trot = 1


def trot_circ(t, J, h):
    for step in range(n_trot):

        # ZZ interaction terms
        for i in range(n_qubits - 1):
            theta = 2 * t * J / n_trot
            qml.PauliRot(theta, "ZZ", [i, i + 1])

        # Transverse field terms (X terms)
        for i in range(n_qubits):
            theta = 2 * t * h / n_trot
            qml.PauliRot(theta, "X", [i])


@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(params):
    trot_circ(*params)
    return qml.sample()


samples = circuit((t, J, h))
samples.shape

(3000, 12)

In [105]:
# Need to sort the binary rows and remove duplicates
# for large integers, maybe try numpy lexsort

decimal_values = jnp.array([int("".join(map(str, row)), 2) for row in samples])
_, indices = jnp.unique(decimal_values, return_index=True)
samples = samples[indices]
samples.shape

(2164, 12)

In [106]:
# directly creating dense operator
H_mat = qml.matrix(H)
H_mat.shape

(4096, 4096)

In [107]:
decimal_values = jnp.array([int("".join(map(str, row)), 2) for row in samples])

idxs = np.ix_(decimal_values, decimal_values)
H_proj = H_mat[idxs]
H_proj.shape

(2164, 2164)

In [108]:
# another way to get projected hamiltonian
# use a projector matrix

P = np.zeros((2**n_qubits, samples.shape[0]))
for i, val in enumerate(decimal_values):
    P[val, i] = 1

H_proj2 = P.T @ H_mat @ P
np.allclose(H_proj, H_proj2)

True

In [109]:
# initial state projection
# if initial state in original space was |00...0> then
# it requires that |00...0> is present in the subspace
# else we get a vector will all zeros. 

# we start with superposition state

# project it into subspace
psi_0_sub = P.T @ psi_0
psi_0_sub /= jnp.linalg.norm(psi_0_sub)

# Time-evolved state in subspace
U_sub = jax.scipy.linalg.expm(-1j * t * H_proj) 
sv_sub = U_sub @ psi_0_sub  

# Map back to full space
sv_sub_prime = P @ sv_sub  # (8,)
print(jnp.allclose(jnp.linalg.norm(sv_sub_prime), 1.0))

# Compute fidelity
fidelity = jnp.abs(jnp.vdot(sv_exact, sv_sub_prime))**2
fidelity

True


Array(0.52829397, dtype=float32)

In [None]:
# project to subspace
# below logic isnt working as expected
# TODO: fix logic

# @jax.jit
# def connected_elements_and_amplitudes_bool(bitstring, diag, sign, imag):
#     """Find the connected element to computational basis state |X>."""
#     bitstring_mask = (bitstring == diag)
#     return bitstring_mask.astype(int), jnp.prod(
#         (-1) ** (jnp.logical_and(bitstring, sign))
#         * jnp.array(1j, dtype="complex64") ** (imag)
#     )

# batch_conn = jax.vmap(connected_elements_and_amplitudes_bool, (0, None, None, None))

# vec_ops = qml.pauli.observables_to_binary_matrix(H.ops)

# d = samples.shape[0]
# operator = coo_matrix((d, d), dtype="complex128")

# for coeff, op, vec_op in zip(H.coeffs, H.ops, vec_ops):
#     d, n = samples.shape
#     row_ids = np.arange(d)

#     # qubit wise representation
#     diag = vec_op[:n]
#     sign = vec_op[n:]
#     imag = np.logical_and(diag, sign).astype(int)

#     # print(op)
#     # print(diag)
#     # print(sign)

#     # convert to int
#     decimal_values = jnp.array([int("".join(map(str, row)), 2) for row in samples])
#     samples_conn, amplitudes = batch_conn(samples, diag, sign, imag)
    
#     decimal_conn = jnp.array([int("".join(map(str, row)), 2) for row in samples_conn])
#     conn_mask = np.isin(decimal_conn, decimal_values, assume_unique=True, kind="sort")

#     # keep samples that are represented both in the original samples and connected elements
#     amplitudes = amplitudes[conn_mask]
#     decimal_conn = decimal_conn[conn_mask]
#     row_ids = row_ids[conn_mask]

#     # Get column indices of non-zero matrix elements
#     col_ids = np.searchsorted(decimal_values, decimal_conn)

#     # print(amplitudes)
#     # print(row_ids)
#     # print(col_ids)
    
#     operator += coeff * coo_matrix((amplitudes, (row_ids, col_ids)), (d, d))
#     # print('-'*100)