In [1]:
import pennylane as qml
import numpy as np
from pennylane.devices.qubit import create_initial_state, apply_operation
from time import time
from multiprocessing import Pool
from ipywidgets import interact, FloatSlider
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import scipy
import jax
import jax.numpy as jnp

In [2]:
def tfim_hamiltonian(n, J = 1.0, h = 1.0):
    coeffs = []
    ops = []
    
    # ZZ interaction terms
    for i in range(n - 1):
        coeffs.append(J)
        ops.append(qml.PauliZ(i) @ qml.PauliZ(i + 1))
    
    # Transverse field terms (X terms)
    for i in range(n):
        coeffs.append(h)
        ops.append(qml.PauliX(i))
    
    return qml.Hamiltonian(coeffs, ops)

H = tfim_hamiltonian(4, J=1.0, h=0.5)

In [3]:
def insert_zero_sorted(arr):
    if 0 in arr:
        return arr  # Zero is already present
    
    # Find the insertion index for zero
    idx = np.searchsorted(arr, 0)  # Finds the index where 0 should be inserted
    return np.insert(arr, idx, 0)

In [28]:
J_list = insert_zero_sorted(np.linspace(-2, 2, 20))
h_list = insert_zero_sorted(np.linspace(0, 4, 20))
t_list = np.linspace(0.1, np.pi, 5)

n_qubits = 4
state = create_initial_state(range(n_qubits))

# Function to generate Hamiltonian for given (J, h)
def generate_hamiltonian(params):
    J, h = params
    return tfim_hamiltonian(n_qubits, J, h)

# Parallelize Hamiltonian creation
start = time()
with Pool() as pool:
    H_res = pool.map(generate_hamiltonian, [(J, h) for J in J_list for h in h_list])

H_list = jnp.array([qml.matrix(H) for H in H_res])
print(f"Hamiltonian precompute time: {time() - start:.4f} sec")

  self.pid = os.fork()


Hamiltonian precompute time: 0.7661 sec


## Exact

In [29]:
start = time()

results = []
Ht_list = jnp.array([-1j * t * H_list for t in t_list])
results = jax.scipy.linalg.expm(Ht_list)[:, :, 0] # statevector
results = results.reshape(len(t_list), len(J_list), len(h_list), 2**n_qubits)
print(f"Time evolution duration: {time() - start:.4f} sec")

Time evolution duration: 0.5201 sec


In [30]:
results.shape

(5, 21, 20, 16)

## Trotter

In [31]:
dev = qml.device('default.qubit', wires=n_qubits)

n_trot = 1

def trot_circ(t, J, h):
    for step in range(n_trot):  
        
        # ZZ interaction terms
        for i in range(n_qubits - 1):
            theta = 2 * t * J / n_trot
            qml.PauliRot(theta, "ZZ", [i, i+1])
            
        # Transverse field terms (X terms)
        for i in range(n_qubits):
            theta = 2 * t * h / n_trot
            qml.PauliRot(theta, "X", [i])

@jax.jit
@qml.qnode(dev, interface='jax')
def circuit(params):
    trot_circ(*params)
    return qml.state()

vcircuit = jax.vmap(circuit)

A, B, C = np.meshgrid(t_list, J_list, h_list, indexing='ij')
params = np.vstack([A.ravel(), B.ravel(), C.ravel()]).T
results_trot = vcircuit(params).reshape(len(t_list), len(J_list), len(h_list), 2**n_qubits)
results_trot.shape

(5, 21, 20, 16)

## Sampled

In [32]:
# n_shots = 100
# dev = qml.device('default.qubit', wires=n_qubits, shots=n_shots)

# @jax.jit
# @qml.qnode(dev, interface='jax')
# def circuit(params):
#     trot_circ(*params)
#     return qml.sample()

# vcircuit = jax.vmap(circuit)

# A, B, C = np.meshgrid(np.array([np.pi/2]), J_list, h_list, indexing='ij')
# params = np.vstack([A.ravel(), B.ravel(), C.ravel()]).T

# samples = vcircuit(params).reshape(len(J_list), len(h_list), n_shots, n_qubits)

# @jax.jit
# def remove_duplicates_with_counts(matrix):
#     """
#     Removes duplicate rows from a 2D binary matrix and counts occurrences.
#     """
#     unique_matrix, idx, counts = jnp.unique(matrix, axis=0, return_inverse=True, size=n_shots, return_counts=True)
#     return unique_matrix, counts

# # Vectorizing across batch dimension
# batch_fn = jax.vmap(remove_duplicates_with_counts, in_axes=(0))
# batch_fn2 = jax.vmap(batch_fn, in_axes=(0))
# bitstring_matrix, probs_arr = batch_fn2(samples)
# probs_arr /= n_shots
# bitstring_matrix.shape, probs_arr.shape

In [33]:
# @jax.jit
# def _int_conversion_from_bts_array(bit_array):
#     """Convert a bit array to an integer representation.
#     NOTE: This can only handle up to 63 qubits. Then the integer will overflow
#     """
#     n_qubits = len(bit_array)
#     bitarray_asint = 0.0
#     for i in range(n_qubits):
#         bitarray_asint = bitarray_asint + bit_array[i] * 2 ** (n_qubits - 1 - i)
#     return bitarray_asint  # type: ignore

# _int_conversion_from_bts_matrix_vmap = jax.jit(jax.vmap(_int_conversion_from_bts_array, 0, 0))

# @jax.jit
# def sort_and_remove_duplicates(bitstring_matrix):
#     """Sort a bitstring matrix and remove duplicate entries.
#     The lowest bitstring values will be placed in the lowest-indexed rows.
#     """
#     bsmat_asints = _int_conversion_from_bts_matrix_vmap(bitstring_matrix)
#     _, indices = jnp.unique(bsmat_asints, size=n_shots, return_index=True)
#     return bitstring_matrix[indices, :]

In [34]:
# bitstring_matrix[3, 4]

In [35]:
# Define computational basis states
x_vals = np.arange(2**n_qubits) 
cmap = plt.get_cmap("tab10")

# Interactive plot function
def plot_probabilities(t_idx, J_idx, h_idx):
    clear_output(wait=True)  # Speed up rendering by clearing old plots
    
    psi = results[t_idx, J_idx, h_idx]  # Extract state vector
    probabilities = jnp.abs(psi) ** 2  # Compute probabilities

    psi_trot = results_trot[t_idx, J_idx, h_idx]
    probs_trot = jnp.abs(psi_trot) ** 2 

    # Plot
    plt.figure(figsize=(12, 5))
    width = 0.4
    plt.bar(x_vals, probabilities, alpha=0.8, width=width, label='Exact', color=cmap(2), align='center')
    plt.bar(x_vals  + width, probs_trot, alpha=0.8, width=width, label='Trotterized', color=cmap(3), align='center')
    plt.xlabel("Computational Basis State (x)")
    plt.ylabel("Probability |ψ(x, t)|²")
    plt.title(f"TFIM Evolution: $t={t_list[t_idx]:.2f}, J={J_list[J_idx]:.2f}, h={h_list[h_idx]:.2f}$")
    plt.ylim(0, 1)  # Probability range
    plt.xticks(x_vals + width/2, x_vals)  # Show x values
    plt.legend()
    plt.show()

In [36]:
# Create sliders with actual values 
t_slider = widgets.SelectionSlider(options=[(f"{t:.2f}", i) for i, t in enumerate(t_list)], description="t")
J_slider = widgets.SelectionSlider(options=[(f"{J:.2f}", i) for i, J in enumerate(J_list)], description="J")
h_slider = widgets.SelectionSlider(options=[(f"{h:.2f}", i) for i, h in enumerate(h_list)], description="h")

# Create interactive UI
ui = widgets.VBox([t_slider, J_slider, h_slider])
out = widgets.interactive_output(plot_probabilities, {'t_idx': t_slider, 'J_idx': J_slider, 'h_idx': h_slider})

# Display widgets
display(ui, out)

VBox(children=(SelectionSlider(description='t', options=(('0.10', 0), ('0.86', 1), ('1.62', 2), ('2.38', 3), (…

Output()