# Time evolution with PennyLane-JAX

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Gopal-Dahale/sqte/blob/main/2_jax_time_evolve.ipynb)

## Setup

In [1]:
# !pip install -q pennylane

In [1]:
import os

os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"

In [2]:
import pennylane as qml
import numpy as np
from pennylane.devices.qubit import create_initial_state, apply_operation
from time import time
from multiprocessing import Pool
from ipywidgets import interact, FloatSlider
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import scipy
import jax
import jax.numpy as jnp

In [3]:
def tfim_hamiltonian(n, J = 1.0, h = 1.0):
    coeffs = []
    ops = []
    
    # ZZ interaction terms
    for i in range(n - 1):
        coeffs.append(J)
        ops.append(qml.PauliZ(i) @ qml.PauliZ(i + 1))
    
    # Transverse field terms (X terms)
    for i in range(n):
        coeffs.append(h)
        ops.append(qml.PauliX(i))
    
    return qml.Hamiltonian(coeffs, ops)

H = tfim_hamiltonian(4, J=1.0, h=0.5)

In [4]:
def insert_zero_sorted(arr):
    if 0 in arr:
        return arr  # Zero is already present
    
    # Find the insertion index for zero
    idx = np.searchsorted(arr, 0)  # Finds the index where 0 should be inserted
    return np.insert(arr, idx, 0)

In [5]:
J_list = insert_zero_sorted(np.linspace(-8, 8, 40))
h_list = insert_zero_sorted(np.linspace(-8, 8, 40))
t_list = np.linspace(0.1, np.pi, 40)

n_qubits = 4
state = create_initial_state(range(n_qubits))

# Function to generate Hamiltonian for given (J, h)
def generate_hamiltonian(params):
    J, h = params
    return tfim_hamiltonian(n_qubits, J, h)

# Parallelize Hamiltonian creation
start = time()
H_res = [generate_hamiltonian((J, h)) for J in J_list for h in h_list]

H_list = jnp.array([qml.matrix(H) for H in H_res])
print(f"Hamiltonian precompute time: {time() - start:.4f} sec")

Hamiltonian precompute time: 1.8047 sec


## Exact

In [6]:
start = time()

results = []
Ht_list = jnp.array([-1j * t * H_list for t in t_list])
results = jax.scipy.linalg.expm(Ht_list)[:, :, 0] # statevector
results = results.reshape(len(t_list), len(J_list), len(h_list), 2**n_qubits)
print(f"Time evolution duration: {time() - start:.4f} sec")

Time evolution duration: 1.2540 sec


In [7]:
results.shape

(40, 41, 41, 16)

## Trotter

In [8]:
dev = qml.device('default.qubit', wires=n_qubits)

n_trot = 1

def trot_circ(t, J, h):
    for step in range(n_trot):  
        
        # ZZ interaction terms
        for i in range(n_qubits - 1):
            theta = 2 * t * J / n_trot
            qml.PauliRot(theta, "ZZ", [i, i+1])
            
        # Transverse field terms (X terms)
        for i in range(n_qubits):
            theta = 2 * t * h / n_trot
            qml.PauliRot(theta, "X", [i])

@jax.jit
@qml.qnode(dev, interface='jax')
def circuit(params):
    trot_circ(*params)
    return qml.state()

vcircuit = jax.vmap(circuit)

A, B, C = np.meshgrid(t_list, J_list, h_list, indexing='ij')
params = np.vstack([A.ravel(), B.ravel(), C.ravel()]).T
results_trot = vcircuit(params).reshape(len(t_list), len(J_list), len(h_list), 2**n_qubits)
results_trot.shape

(40, 41, 41, 16)

In [9]:
# Define computational basis states
x_vals = np.arange(2**n_qubits) 
cmap = plt.get_cmap("tab10")

# Interactive plot function
def plot_probabilities(t_idx, J_idx, h_idx):
    clear_output(wait=True)  # Speed up rendering by clearing old plots
    
    psi = results[t_idx, J_idx, h_idx]  # Extract state vector
    probabilities = jnp.abs(psi) ** 2  # Compute probabilities

    psi_trot = results_trot[t_idx, J_idx, h_idx]
    probs_trot = jnp.abs(psi_trot) ** 2 

    # Plot
    plt.figure(figsize=(12, 5))
    width = 0.4
    plt.bar(x_vals, probabilities, alpha=0.8, width=width, label='Exact', color=cmap(2), align='center')
    # plt.bar(x_vals  + width, probs_trot, alpha=0.8, width=width, label='Trotterized', color=cmap(3), align='center')
    plt.xlabel("Computational Basis State (x)")
    plt.ylabel("Probability |ψ(x, t)|²")
    plt.title(f"TFIM Evolution: $t={t_list[t_idx]:.2f}, J={J_list[J_idx]:.2f}, h={h_list[h_idx]:.2f}$")
    # plt.ylim( (pow(10,-1),pow(10,0)) )  # Probability range
    plt.xticks(x_vals + width/2, x_vals)  # Show x values
    plt.yscale('log')
    plt.legend()
    plt.show()

In [10]:
# Create sliders with actual values 
t_slider = widgets.SelectionSlider(options=[(f"{t:.2f}", i) for i, t in enumerate(t_list)], description="t")
J_slider = widgets.SelectionSlider(options=[(f"{J:.2f}", i) for i, J in enumerate(J_list)], description="J")
h_slider = widgets.SelectionSlider(options=[(f"{h:.2f}", i) for i, h in enumerate(h_list)], description="h")

# Create interactive UI
ui = widgets.VBox([t_slider, J_slider, h_slider])
out = widgets.interactive_output(plot_probabilities, {'t_idx': t_slider, 'J_idx': J_slider, 'h_idx': h_slider})

# Display widgets
display(ui, out)

VBox(children=(SelectionSlider(description='t', options=(('0.10', 0), ('0.18', 1), ('0.26', 2), ('0.33', 3), (…

Output()

In [11]:
# entropy
probs = jnp.abs(results.reshape(-1, 2**n_qubits))**2
entropy = -jnp.nan_to_num(probs * jnp.log2(probs)).sum(axis=1).reshape(40, 41, 41)

In [12]:
print("entropy uniform state", n_qubits * jnp.log2(2) )
print("entropy basis state", - 1 * jnp.log(1))

entropy uniform state 4.0
entropy basis state -0.0


In [13]:
from mpl_toolkits.axes_grid1.inset_locator import (
    inset_axes,
    mark_inset,
    zoomed_inset_axes,
)

X, Y = np.meshgrid(h_list, J_list[:-1])

def entropy_contour(t_idx):
    clear_output(wait=True)
    fig, ax = plt.subplots(figsize=(14, 8))
    CS = ax.contourf(X, Y, entropy[t_idx, :-1, :], cmap="Spectral_r")
    
    # Dotted contour lines for the curvature
    contour_lines = ax.contour(
        CS, colors="black", linestyles="dashed", linewidths=0.8, alpha=0.5
    )
    plt.xlabel(r"$h$", fontsize=12)
    plt.ylabel(r"$J$", fontsize=12)
    plt.title(f"Entropy Plot: $t={t_list[t_idx]:.2f}$")

    plt.colorbar(CS)

ui = widgets.VBox([t_slider])
out = widgets.interactive_output(entropy_contour, {'t_idx': t_slider})

# Display widgets
display(ui, out)

VBox(children=(SelectionSlider(description='t', options=(('0.10', 0), ('0.18', 1), ('0.26', 2), ('0.33', 3), (…

Output()

- Near $h=0$, $-1 \leq h \leq 1$ the entropy is relatively low implying that the quantum state is highly close to basis state for any time $t$.
- Reflection symmetry across both h = 0 and J = 0 axes.
- As time increases, we see many local minimas and maxima. The states with moderate entropy (between 0 and 5) are favourable because they have a non-uniform distribution. The boundary between low and high entropy regions likely corresponds to phase transitions.

## Time evolution $P_\mu (t)$

See [this](https://arxiv.org/pdf/2412.13839) paper's Section III B and IV A

In [14]:
t_list = np.linspace(0.0001, 5, 100)

n_qubits = 4
state = create_initial_state(range(n_qubits))

# Function to generate Hamiltonian for given (J, h)
def generate_hamiltonian(params):
    J, h = params
    return tfim_hamiltonian(n_qubits, J, h)

# Parallelize Hamiltonian creation
start = time()
H_res = [generate_hamiltonian((J, h)) for J in J_list for h in h_list]

H_list = jnp.array([qml.matrix(H) for H in H_res])
print(f"Hamiltonian precompute time: {time() - start:.4f} sec")

Hamiltonian precompute time: 3.5135 sec


In [15]:
start = time()

results = []
Ht_list = jnp.array([-1j * t * H_list for t in t_list])
results = jax.scipy.linalg.expm(Ht_list)[:, :, 0] # statevector
results = results.reshape(len(t_list), len(J_list), len(h_list), 2**n_qubits)
print(f"Time evolution duration: {time() - start:.4f} sec")
results.shape

Time evolution duration: 2.4886 sec


(100, 41, 41, 16)

In [16]:
flip_levels = range(n_qubits + 1)  # Flip levels from 0 to n
P_mu_t = jnp.abs(results)**2

# Function to compute the number of flips (1s) in the binary representation of mu
def flip_level(mu, n):
    return bin(mu).count('1')

# Group basis state indices by flip level
indices_by_flip = {k: [mu for mu in range(2**n_qubits) if flip_level(mu, n_qubits) == k] for k in flip_levels}

In [17]:
avg_P = jnp.array([P_mu_t[:, :, :, indices_by_flip[k]].mean(axis=-1) for k in flip_levels])
avg_P.shape

(5, 100, 41, 41)

In [18]:
from scipy.optimize import curve_fit

# Define function for least squares fitting
@jax.jit
def fit_least_squares(x, y, k):
    if x.shape[0] == 0:
        return 0.0  # Return 0 if no data

    # Fit log-log to avoid nonlinear optimization
    log_x = k*jnp.log(x)
    log_y = jnp.log(y)
    
    A = jnp.vstack([log_x, jnp.ones_like(log_x)]).T
    params, _, _, _ = jnp.linalg.lstsq(A, log_y, rcond=None)
    
    a_fit = jnp.exp(params[1])  # Convert back to original scale
    return a_fit

# Vectorized function using vmap
@jax.jit
def fit_params_vmap(avg_P_k, t_list, k):
    return fit_least_squares(t_list, avg_P_k, 2*k)

# Apply vmap across (i, j) indices
@jax.jit
def compute_fit_params(J_list, h_list, avg_P, t_list, k):
    return jax.vmap(jax.vmap(lambda avg_P_k: fit_params_vmap(avg_P_k, t_list, k), in_axes=1), in_axes=1)(avg_P)

fit_params = jnp.array([compute_fit_params(J_list, h_list, avg_P[k][t_list<=0.1], t_list[t_list<=0.1], k) for k in flip_levels])
fit_params.shape

(5, 41, 41)

In [19]:
def plot_p_mu_t(J_idx, h_idx):
    clear_output(wait=True)  # Speed up rendering by clearing old plots

    plt.figure(figsize=(10, 6))
    for k in flip_levels:
        t_filtered = t_list[jnp.array(avg_P[k, :, J_idx, h_idx]) > 0]
        avg_P_filtered = jnp.array(avg_P[k, :, J_idx, h_idx])[np.array(avg_P[k, :, J_idx, h_idx]) > 0]
        if len(t_filtered) > 0:
            plt.plot(t_filtered, avg_P_filtered, label=f'{k}-flip')
            prefactor = fit_params[k][J_idx, h_idx]
            plt.plot(t_filtered, prefactor * (t_filtered**(2*k)), ls='--', label=f'fit {k}-flip {prefactor:.2f}. $t^{{{2*k}}}$')
    
    plt.xlabel('log(t)')
    plt.ylabel('log(avg P_mu(t))')
    plt.yscale('log')
    plt.xscale('log')
    plt.title(f'Log-Log Plot of Average Probabilities, $J={J_list[J_idx]:.2f}, h={h_list[h_idx]:.2f}$')
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.grid(True)
    plt.show()

In [20]:
# Create sliders with actual values 
J_slider = widgets.SelectionSlider(options=[(f"{J:.2f}", i) for i, J in enumerate(J_list)], description="J")
h_slider = widgets.SelectionSlider(options=[(f"{h:.2f}", i) for i, h in enumerate(h_list)], description="h")

# Create interactive UI
ui = widgets.VBox([J_slider, h_slider])
out = widgets.interactive_output(plot_p_mu_t, {'J_idx': J_slider, 'h_idx': h_slider})

# Display widgets
display(ui, out)

VBox(children=(SelectionSlider(description='J', options=(('-8.00', 0), ('-7.59', 1), ('-7.18', 2), ('-6.77', 3…

Output()

#### Curves
**Solid lines**: Average probabilities for each flip level:
  - 0-flip ($ |0000\rangle $)
  - 1-flip ($ |0001\rangle, |0010\rangle, \ldots $)
  - 2-flip ($ |0011\rangle, |0101\rangle, \ldots $)
  - 3-flip ($ |0111\rangle, |1011\rangle, \ldots $)
  - 4-flip ($ |1111\rangle $)

**Dashed lines**: Fitted power-law curves $ f t^{2k} $, with coefficients in the legend and $f$ being the prefactor.

#### Observations
**Scaling for Small $ t $**: For $ t < 0.01 $, each flip level’s curve is linear, indicating power-law scaling $ \log(P_\mu(t)) \sim 2k \log(t) + \log(f) $.

**Deviation at Larger $ t $**: Around $ t \approx 0.1 $, the curves deviate from the fitted lines and begin oscillating, consistent with the paper’s observation that higher-order terms in the time evolution series dominate around $ t \sim 1 $.

For a state $ |\mu\rangle $ with $ k $ flips from the initial state $ |0000\rangle $, $ P_\mu(t) \sim t^{2k} $ for small $ t $. This is because the TFIM Hamiltonian $ H = -J \sum_i Z_i Z_{i+1} - h \sum_i X_i $ connects $ |0000\rangle $ to a $ k $-flip state via $ k $ applications of the $ X_i $ terms (each $ X_i $ flips one spin).

- **1-flip**: $ t^2 $, as one $ X_i $ term connects directly (e.g., $ X_1 |0000\rangle = |1000\rangle $).
- **2-flip**: $ t^4 $, requiring two $ X_i $ terms (e.g., $ X_1 X_2 |0000\rangle = |1100\rangle $), and so on.

For $ t > 0.1$, the probabilities oscillate due to interference between eigenstates:
$$
P_\mu(t) = \left| \sum_n e^{-i E_n t} c_n^I c_n^\mu \right|^2.
$$

Large $ h/J $ means the $ X_i $ terms dominate, causing rapid mixing across all flip levels, leading to pronounced oscillations.

The 0-flip state ($ |0000\rangle $) starts at probability 1 and decreases slowly, as expected, since it’s the initial state.

**Sampling Times**:
 For small $ t $ (e.g., $ t < 0.01 $), lower flip levels dominate (1-flip, 2-flip). By sampling at such times, the subspace will mostly include these states, which might miss higher flip contributions important for longer-time dynamics. Around $ t \approx 0.1 $ to 0.5, all flip levels have significant probabilities, making this a good range to sample for a balanced subspace.