# Projection

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Gopal-Dahale/sqte/blob/main/3_projection.ipynb)

## Setup

In [None]:
# !pip install -q pennylane

In [2]:
from multiprocessing import Pool
from time import time

import ipywidgets as widgets
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pennylane as qml
import scipy
from IPython.display import clear_output, display
from ipywidgets import FloatSlider, interact
from pennylane.devices.qubit import apply_operation, create_initial_state
from scipy.sparse import coo_matrix, spmatrix
from functools import partial

In [3]:
def tfim_hamiltonian(n, J=1.0, h=1.0):
    coeffs = []
    ops = []

    # ZZ interaction terms
    for i in range(n - 1):
        coeffs.append(J)
        ops.append(qml.PauliZ(i) @ qml.PauliZ(i + 1))

    # Transverse field terms (X terms)
    for i in range(n):
        coeffs.append(h)
        ops.append(qml.PauliX(i))

    return qml.Hamiltonian(coeffs, ops)


n_qubits = 10
J = 2.0
h = 0.5
H = tfim_hamiltonian(n_qubits, J=J, h=h)

## Exact

In [4]:
start = time()
t = np.pi / 4
U = jax.scipy.linalg.expm(-1j * t * qml.matrix(H))

# evolve equal superposition state
# psi_0 = np.ones(2**n_qubits)/ np.sqrt(2**n_qubits)

# sv_exact = U @ psi_0

print(f"Time evolution duration: {time() - start:.4f} sec")
# sv_exact.shape

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


Time evolution duration: 0.3793 sec


## Trotter Sampling

In [22]:
n_shots = 1000
key = jax.random.PRNGKey(0)

dev_sampler = qml.device("default.qubit", wires=n_qubits, shots=n_shots, seed=key)
dev = qml.device("default.qubit", wires=n_qubits)

def trot_circ(t, J, h, ntrot):
    for step in range(ntrot):

        # ZZ interaction terms
        for i in range(n_qubits - 1):
            theta = 2 * t * J / ntrot
            qml.PauliRot(theta, "ZZ", [i, i + 1])

        # Transverse field terms (X terms)
        for i in range(n_qubits):
            theta = 2 * t * h / ntrot
            qml.PauliRot(theta, "X", [i])


def circuit(params, ntrot):
    
    # initial state: equal superposition
    for i in range(n_qubits):
        qml.Hadamard(i)
        
    t, J, h = params
    trot_circ(t, J, h, ntrot)

@partial(jax.jit, static_argnums=1)
@qml.qnode(dev, interface='jax')
def sv_circ(params, ntrot):
    circuit(params, ntrot)
    return qml.state()

@partial(jax.jit, static_argnums=1)
@qml.qnode(dev_sampler, interface='jax')
def sampler_circ(params, ntrot):
    circuit(params, ntrot)
    return qml.sample()

state = sv_circ((t, J, h), 2)
samples = sampler_circ((t, J, h), 1)
state.shape, samples.shape

((16,), (1000, 4))

In [23]:
jnp.abs(jnp.vdot(sv_exact, state))

Array(0.99400723, dtype=float32)

In [24]:
# Need to sort the binary rows and remove duplicates
# for large integers, maybe try numpy lexsort

decimal_values = jnp.array([int("".join(map(str, row)), 2) for row in samples])
_, indices = jnp.unique(decimal_values, return_index=True)
samples = samples[indices]
samples.shape

(16, 4)

In [25]:
samples

Array([[0, 0, 0, 0],
       [0, 0, 0, 1],
       [0, 0, 1, 0],
       [0, 0, 1, 1],
       [0, 1, 0, 0],
       [0, 1, 0, 1],
       [0, 1, 1, 0],
       [0, 1, 1, 1],
       [1, 0, 0, 0],
       [1, 0, 0, 1],
       [1, 0, 1, 0],
       [1, 0, 1, 1],
       [1, 1, 0, 0],
       [1, 1, 0, 1],
       [1, 1, 1, 0],
       [1, 1, 1, 1]], dtype=int32)

In [26]:
# directly creating dense operator
H_mat = qml.matrix(H)
H_mat.shape

(16, 16)

In [27]:
decimal_values = jnp.array([int("".join(map(str, row)), 2) for row in samples])

idxs = np.ix_(decimal_values, decimal_values)
H_proj = H_mat[idxs]
H_proj.shape

(16, 16)

In [28]:
H_proj.real

array([[ 6. ,  0.5,  0.5,  0. ,  0.5,  0. ,  0. ,  0. ,  0.5,  0. ,  0. ,
         0. ,  0. ,  0. ,  0. ,  0. ],
       [ 0.5,  2. ,  0. ,  0.5,  0. ,  0.5,  0. ,  0. ,  0. ,  0.5,  0. ,
         0. ,  0. ,  0. ,  0. ,  0. ],
       [ 0.5,  0. , -2. ,  0.5,  0. ,  0. ,  0.5,  0. ,  0. ,  0. ,  0.5,
         0. ,  0. ,  0. ,  0. ,  0. ],
       [ 0. ,  0.5,  0.5,  2. ,  0. ,  0. ,  0. ,  0.5,  0. ,  0. ,  0. ,
         0.5,  0. ,  0. ,  0. ,  0. ],
       [ 0.5,  0. ,  0. ,  0. , -2. ,  0.5,  0.5,  0. ,  0. ,  0. ,  0. ,
         0. ,  0.5,  0. ,  0. ,  0. ],
       [ 0. ,  0.5,  0. ,  0. ,  0.5, -6. ,  0. ,  0.5,  0. ,  0. ,  0. ,
         0. ,  0. ,  0.5,  0. ,  0. ],
       [ 0. ,  0. ,  0.5,  0. ,  0.5,  0. , -2. ,  0.5,  0. ,  0. ,  0. ,
         0. ,  0. ,  0. ,  0.5,  0. ],
       [ 0. ,  0. ,  0. ,  0.5,  0. ,  0.5,  0.5,  2. ,  0. ,  0. ,  0. ,
         0. ,  0. ,  0. ,  0. ,  0.5],
       [ 0.5,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  2. ,  0.5,  0.5,
         0. ,  0.5,  0

In [29]:
# another way to get projected hamiltonian
# use a projector matrix

P = np.zeros((2**n_qubits, samples.shape[0]))
for i, val in enumerate(decimal_values):
    P[val, i] = 1

H_proj2 = P.T @ H_mat @ P
np.allclose(H_proj, H_proj2)

True

In [30]:
P

array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

In [31]:
# initial state projection
# if initial state in original space was |00...0> then
# it requires that |00...0> is present in the subspace
# else we get a vector will all zeros. 

# we start with superposition state
# project it into subspace
psi_0_sub = P.T @ psi_0
print("Norm of psi_0_sub:", jnp.linalg.norm(psi_0_sub))

# normalize if needed
if not jnp.allclose(jnp.linalg.norm(psi_0_sub), 1.0):
    psi_0_sub /= jnp.linalg.norm(psi_0_sub)

# Time-evolved state in subspace
U_sub = jax.scipy.linalg.expm(-1j * t * H_proj) 
sv_sub = U_sub @ psi_0_sub  

# Map back to full space
sv_sub_prime = P @ sv_sub 
print("Norm of sv_sub_prime:", jnp.linalg.norm(sv_sub_prime))

# normalize if need
if not jnp.allclose(jnp.linalg.norm(sv_sub_prime), 1.0):
    sv_sub_prime /= jnp.linalg.norm(sv_sub_prime)

# Compute fidelity
fidelity = jnp.abs(jnp.vdot(sv_exact, sv_sub_prime))**2

print("Fidelity between full-space and subspace evolution:", fidelity)

Norm of psi_0_sub: 1.0
Norm of sv_sub_prime: 0.99999994
Fidelity between full-space and subspace evolution: 1.0


In [20]:
sv_exact

Array([ 0.08967425+0.3915804j , -0.11989825-0.1648028j ,
        0.03397639+0.21295951j, -0.20889324-0.26469016j,
        0.03397638+0.21295948j,  0.01732782-0.1436491j ,
        0.12297151+0.1130721j , -0.11989823-0.16480276j,
       -0.1198982 -0.16480279j,  0.1229715 +0.11307213j,
        0.01732785-0.14364909j,  0.03397642+0.21295948j,
       -0.2088933 -0.26469016j,  0.0339764 +0.21295948j,
       -0.11989821-0.16480277j,  0.08967426+0.39158052j], dtype=complex64)

In [21]:
sv_sub_prime

Array([ 0.08585875+0.45025557j, -0.11999822-0.19680648j,
        0.07696307+0.24587394j, -0.14826289-0.36626562j,
        0.07612782+0.20763102j,  0.        +0.j        ,
        0.18451348+0.23594446j,  0.        +0.j        ,
       -0.12094238-0.22177069j,  0.05088427+0.14680485j,
        0.03292241-0.17371812j,  0.05033822+0.23863302j,
       -0.12718146-0.3011923j ,  0.        +0.j        ,
        0.        +0.j        ,  0.0061716 +0.30002207j], dtype=complex64)