In [1]:
import pennylane as qml
import numpy as np
from pennylane.devices.qubit import create_initial_state, apply_operation
from time import time
from multiprocessing import Pool
from ipywidgets import interact, FloatSlider
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import scipy
import jax
import jax.numpy as jnp

In [2]:
def tfim_hamiltonian(n, J = 1.0, h = 1.0):
    coeffs = []
    ops = []
    
    # ZZ interaction terms
    for i in range(n - 1):
        coeffs.append(J)
        ops.append(qml.PauliZ(i) @ qml.PauliZ(i + 1))
    
    # Transverse field terms (X terms)
    for i in range(n):
        coeffs.append(h)
        ops.append(qml.PauliX(i))
    
    return qml.Hamiltonian(coeffs, ops)

H = tfim_hamiltonian(4, J=1.0, h=0.5)

In [3]:
def insert_zero_sorted(arr):
    if 0 in arr:
        return arr  # Zero is already present
    
    # Find the insertion index for zero
    idx = np.searchsorted(arr, 0)  # Finds the index where 0 should be inserted
    return np.insert(arr, idx, 0)

In [4]:
J_list = insert_zero_sorted(np.linspace(-8, 8, 40))
h_list = insert_zero_sorted(np.linspace(-8, 8, 40))
t_list = np.linspace(0.1, np.pi, 40)

n_qubits = 4
state = create_initial_state(range(n_qubits))

# Function to generate Hamiltonian for given (J, h)
def generate_hamiltonian(params):
    J, h = params
    return tfim_hamiltonian(n_qubits, J, h)

# Parallelize Hamiltonian creation
start = time()
with Pool() as pool:
    H_res = pool.map(generate_hamiltonian, [(J, h) for J in J_list for h in h_list])

H_list = jnp.array([qml.matrix(H) for H in H_res])
print(f"Hamiltonian precompute time: {time() - start:.4f} sec")

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


Hamiltonian precompute time: 1.5218 sec


## Exact

In [5]:
start = time()

results = []
Ht_list = jnp.array([-1j * t * H_list for t in t_list])
results = jax.scipy.linalg.expm(Ht_list)[:, :, 0] # statevector
results = results.reshape(len(t_list), len(J_list), len(h_list), 2**n_qubits)
print(f"Time evolution duration: {time() - start:.4f} sec")

Time evolution duration: 0.7280 sec


In [6]:
results.shape

(40, 41, 41, 16)

## Trotter

In [7]:
dev = qml.device('default.qubit', wires=n_qubits)

n_trot = 1

def trot_circ(t, J, h):
    for step in range(n_trot):  
        
        # ZZ interaction terms
        for i in range(n_qubits - 1):
            theta = 2 * t * J / n_trot
            qml.PauliRot(theta, "ZZ", [i, i+1])
            
        # Transverse field terms (X terms)
        for i in range(n_qubits):
            theta = 2 * t * h / n_trot
            qml.PauliRot(theta, "X", [i])

@jax.jit
@qml.qnode(dev, interface='jax')
def circuit(params):
    trot_circ(*params)
    return qml.state()

vcircuit = jax.vmap(circuit)

A, B, C = np.meshgrid(t_list, J_list, h_list, indexing='ij')
params = np.vstack([A.ravel(), B.ravel(), C.ravel()]).T
results_trot = vcircuit(params).reshape(len(t_list), len(J_list), len(h_list), 2**n_qubits)
results_trot.shape

(40, 41, 41, 16)

In [8]:
# Define computational basis states
x_vals = np.arange(2**n_qubits) 
cmap = plt.get_cmap("tab10")

# Interactive plot function
def plot_probabilities(t_idx, J_idx, h_idx):
    clear_output(wait=True)  # Speed up rendering by clearing old plots
    
    psi = results[t_idx, J_idx, h_idx]  # Extract state vector
    probabilities = jnp.abs(psi) ** 2  # Compute probabilities

    psi_trot = results_trot[t_idx, J_idx, h_idx]
    probs_trot = jnp.abs(psi_trot) ** 2 

    # Plot
    plt.figure(figsize=(12, 5))
    width = 0.4
    plt.bar(x_vals, probabilities, alpha=0.8, width=width, label='Exact', color=cmap(2), align='center')
    # plt.bar(x_vals  + width, probs_trot, alpha=0.8, width=width, label='Trotterized', color=cmap(3), align='center')
    plt.xlabel("Computational Basis State (x)")
    plt.ylabel("Probability |ψ(x, t)|²")
    plt.title(f"TFIM Evolution: $t={t_list[t_idx]:.2f}, J={J_list[J_idx]:.2f}, h={h_list[h_idx]:.2f}$")
    # plt.ylim( (pow(10,-1),pow(10,0)) )  # Probability range
    plt.xticks(x_vals + width/2, x_vals)  # Show x values
    plt.yscale('log')
    plt.legend()
    plt.show()

In [9]:
# Create sliders with actual values 
t_slider = widgets.SelectionSlider(options=[(f"{t:.2f}", i) for i, t in enumerate(t_list)], description="t")
J_slider = widgets.SelectionSlider(options=[(f"{J:.2f}", i) for i, J in enumerate(J_list)], description="J")
h_slider = widgets.SelectionSlider(options=[(f"{h:.2f}", i) for i, h in enumerate(h_list)], description="h")

# Create interactive UI
ui = widgets.VBox([t_slider, J_slider, h_slider])
out = widgets.interactive_output(plot_probabilities, {'t_idx': t_slider, 'J_idx': J_slider, 'h_idx': h_slider})

# Display widgets
display(ui, out)

VBox(children=(SelectionSlider(description='t', options=(('0.10', 0), ('0.18', 1), ('0.26', 2), ('0.33', 3), (…

Output()

In [10]:
# entropy
probs = jnp.abs(results.reshape(-1, 2**n_qubits))**2
entropy = -jnp.nan_to_num(probs * jnp.log2(probs)).sum(axis=1).reshape(40, 41, 41)

In [11]:
print("entropy uniform state", n_qubits * jnp.log2(2) )
print("entropy basis state", - 1 * jnp.log(1))

entropy uniform state 4.0
entropy basis state -0.0


In [12]:
from mpl_toolkits.axes_grid1.inset_locator import (
    inset_axes,
    mark_inset,
    zoomed_inset_axes,
)

X, Y = np.meshgrid(h_list, J_list[:-1])

def entropy_contour(t_idx):
    clear_output(wait=True)
    fig, ax = plt.subplots(figsize=(14, 8))
    CS = ax.contourf(X, Y, entropy[t_idx, :-1, :], cmap="Spectral_r")
    
    # Dotted contour lines for the curvature
    contour_lines = ax.contour(
        CS, colors="black", linestyles="dashed", linewidths=0.8, alpha=0.5
    )
    plt.xlabel(r"$h$", fontsize=12)
    plt.ylabel(r"$J$", fontsize=12)
    plt.title(f"Entropy Plot: $t={t_list[t_idx]:.2f}$")

    plt.colorbar(CS)

ui = widgets.VBox([t_slider])
out = widgets.interactive_output(entropy_contour, {'t_idx': t_slider})

# Display widgets
display(ui, out)

VBox(children=(SelectionSlider(description='t', options=(('0.10', 0), ('0.18', 1), ('0.26', 2), ('0.33', 3), (…

Output()

- Near $h=0$, $-1 \leq h \leq 1$ the entropy is relatively low implying that the quantum state is highly close to basis state for any time $t$.
- Reflection symmetry across both h = 0 and J = 0 axes.
- As time increases, we see many local minimas and maxima. The states with moderate entropy (between 0 and 5) are favourable because they have a non-uniform distribution. The boundary between low and high entropy regions likely corresponds to phase transitions.