# Equation 4 Hamiltonian - ZXW Diagram and Time Evolution

This notebook creates a ZXW diagram representing equation 4 from paper 2408:
**Ĥ^jk = F_jk σ̂_+^j σ̂_-^k + F_kj σ̂_-^j σ̂_+^k**

Which decomposes into Pauli strings as: **(F_jk/2) * (X_j X_k - Y_j Y_k)**

We then compute time evolution using **Trotterization** from the "How to Sum and Exponentiate Hamiltonians" paper:
**exp(-iHt) ≈ [exp(-iH₁t/n) exp(-iH₂t/n) ... exp(-iHₘt/n)]^n**


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pyzx as zx

# Import ZXW functions
from pauli_hamiltonian_zx import PauliHamiltonianZX

# Import equation 4 functions
from cor_decay_zxw import (
    compute_F_jk_equation5,
    create_equation4_hamiltonian,
    setup_positions_3d_grid,
    create_equation4_hamiltonian_zxw,
    compute_equation4_time_evolution
)




## Setup: Atom Positions and F_jk Matrix


In [2]:
# Parameters
N = 4 # Number of atoms (small for visualization)
lambda_val = 1.0  # Wavelength
gam = 1.0  # Decay rate Γ
m = 1.5  # Spacing multiplier

print(f"Setting up {N} atoms in 3D grid")
print(f"Wavelength: {lambda_val}, Decay rate: {gam}")

# Setup 3D grid positions
# x, y, z = setup_positions_3d_grid(N, m, lambda_val)
two_dim_positions = setup_positions_2d_grid(N, m, lambda_val)
print(two_dim_positions)
for i in range(len(two_dim_positions[0])):
    print(f"  Atom {i}: ({two_dim_positions[0][i]:.3f}, {two_dim_positions[1][i]:.3f})")
#for i in range(N):
#    print(f"  Atom {i}: ({x[i]:.3f}, {y[i]:.3f}, {z[i]:.3f})")


Setting up 4 atoms in 3D grid
Wavelength: 1.0, Decay rate: 1.0
(array([0.  , 0.  , 0.75, 0.75]), array([0.  , 0.75, 0.  , 0.75]), array([0., 0., 0., 0.]))
  Atom 0: (0.000, 0.000)
  Atom 1: (0.000, 0.750)
  Atom 2: (0.750, 0.000)
  Atom 3: (0.750, 0.750)


In [3]:
# Compute F_jk matrix using equation 5
F_matrix = compute_F_jk_equation5(*two_dim_positions, lambda_val, gam)

print(f"F_jk matrix shape: {F_matrix.shape}")
print(f"\nF_jk matrix (real part):")
print(F_matrix)

print(f"\nSample coupling values:")
for j in range(min(3, N)):
    for k in range(j+1, min(3, N)):
        print(f"  F_{j}{k} = {F_matrix[j, k].real:.6f} + {F_matrix[j, k].imag:.6f}i")


F_jk matrix shape: (4, 4)

F_jk matrix (real part):
[[0.+1.5j        0.+0.15198794j 0.+0.15198794j 0.-0.05659477j]
 [0.+0.15198794j 0.+1.5j        0.-0.05659477j 0.+0.15198794j]
 [0.+0.15198794j 0.-0.05659477j 0.+1.5j        0.+0.15198794j]
 [0.-0.05659477j 0.+0.15198794j 0.+0.15198794j 0.+1.5j       ]]

Sample coupling values:
  F_01 = 0.000000 + 0.151988i
  F_02 = 0.000000 + 0.151988i
  F_12 = 0.000000 + -0.056595i


In [4]:
# Create Pauli string Hamiltonian for equation 4
# This uses: (F_jk/2) * (X_j X_k - Y_j Y_k)
pauli_strings = create_equation4_hamiltonian(N, F_matrix)

print(f"Number of Pauli terms: {len(pauli_strings)}")
print(f"\nFirst 6 Pauli strings (showing X_j X_k and -Y_j Y_k pairs):")
for i, (coeff, gates) in enumerate(pauli_strings[:]):
    print(f"  {i+1}: {coeff:.6f} * {gates}")

# Create ZXW Hamiltonian
hamiltonian = PauliHamiltonianZX(pauli_strings)
print(f"\nTotal qubits: {hamiltonian.total_qubits}")


Number of Pauli terms: 20

First 6 Pauli strings (showing X_j X_k and -Y_j Y_k pairs):
  1: 0.000000+0.750000j * ['X0', 'X0']
  2: -0.000000-0.750000j * ['Y0', 'Y0']
  3: 0.000000+0.075994j * ['X0', 'X1']
  4: -0.000000-0.075994j * ['Y0', 'Y1']
  5: 0.000000+0.075994j * ['X0', 'X2']
  6: -0.000000-0.075994j * ['Y0', 'Y2']
  7: 0.000000-0.028297j * ['X0', 'X3']
  8: 0.000000+0.028297j * ['Y0', 'Y3']
  9: 0.000000+0.750000j * ['X1', 'X1']
  10: -0.000000-0.750000j * ['Y1', 'Y1']
  11: 0.000000-0.028297j * ['X1', 'X2']
  12: 0.000000+0.028297j * ['Y1', 'Y2']
  13: 0.000000+0.075994j * ['X1', 'X3']
  14: -0.000000-0.075994j * ['Y1', 'Y3']
  15: 0.000000+0.750000j * ['X2', 'X2']
  16: -0.000000-0.750000j * ['Y2', 'Y2']
  17: 0.000000+0.075994j * ['X2', 'X3']
  18: -0.000000-0.075994j * ['Y2', 'Y3']
  19: 0.000000+0.750000j * ['X3', 'X3']
  20: -0.000000-0.750000j * ['Y3', 'Y3']

Total qubits: 4


In [5]:
# Build the ZXW diagram
print("Building ZXW diagram...")
graph = hamiltonian.build_graph()
print("Graph built successfully")

# Visualize the diagram
print("\nDisplaying ZXW diagram:")
zx.draw(graph)


Building ZXW diagram...
Graph built successfully

Displaying ZXW diagram:


In [6]:
# Simplify the diagram
print("Simplifying ZXW diagram...")
simplified_graph = hamiltonian.simplify_graph()
print("Graph simplified")

# Visualize simplified diagram
print("\nDisplaying simplified ZXW diagram:")
zx.draw(simplified_graph)


Simplifying ZXW diagram...
Graph simplified

Displaying simplified ZXW diagram:


In [None]:
## Visualize ZXW Diagram

In [None]:
## Time Evolution using Trotterization

In [None]:
# Set up initial state and time points
# Initial state: uniform superposition (matching MATLAB code)
dim = 2 ** hamiltonian.total_qubits
initial_state = np.ones(dim, dtype=complex) / np.sqrt(dim)

# Time points (matching MATLAB: t from 0 to 0.2/gam)
t_initial = 0.0
t_final = 0.2 / gam
t_points = 50
times = np.linspace(t_initial, t_final, t_points)

print(f"Initial state: uniform superposition")
print(f"State dimension: {dim}")
print(f"Time range: {t_initial:.4f} to {t_final:.4f}")
print(f"Number of time points: {t_points}")

# Trotterization parameters
n_trotter = 10
print(f"\nTrotterization: {n_trotter} steps")

TypeError: PauliHamiltonianZX.compute_eigenvalues() got an unexpected keyword argument 'k'

In [None]:
# Compute time evolution using Trotterization
print("Computing time evolution using Trotterization...")

evolved_states = []
psi_amplitude = []

for t in times:
    # Evolve state using Trotterization
    evolved_state = hamiltonian.evolve_state(
        initial_state, 
        t, 
        n_trotter=n_trotter,
        use_tensor_network=True
    )
    evolved_states.append(evolved_state)
    
    # Compute log(|ψ|)/2 (matching MATLAB output)
    psi_amplitude.append(np.log(np.linalg.norm(evolved_state)) / 2)

evolved_states = np.array(evolved_states)
psi_amplitude = np.array(psi_amplitude)

print(f"Computed evolution for {len(times)} time points")
print(f"Final state norm: {np.linalg.norm(evolved_states[-1]):.6f}")


## Plot Results


In [None]:
# Plot time evolution (matching MATLAB output format)
plt.figure(figsize=(12, 5))

# Plot 1: log(|ψ|)/2 vs time (matching MATLAB)
plt.subplot(1, 2, 1)
plt.plot(times * gam, psi_amplitude, '.-b', linewidth=0.5, markersize=3)
plt.xlabel('tΓ', fontsize=12)
plt.ylabel('log(|ψ|)/2', fontsize=12)
plt.title('Collective Decay Amplitude', fontsize=14)
plt.grid(True, alpha=0.3)

# Plot 2: State norm over time
state_norms = [np.linalg.norm(state) for state in evolved_states]
plt.subplot(1, 2, 2)
plt.plot(times * gam, state_norms, '.-r', linewidth=0.5, markersize=3)
plt.xlabel('tΓ', fontsize=12)
plt.ylabel('|ψ|', fontsize=12)
plt.title('State Norm Over Time', fontsize=14)
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nTime evolution statistics:")
print(f"  Initial norm: {np.linalg.norm(initial_state):.6f}")
print(f"  Final norm: {np.linalg.norm(evolved_states[-1]):.6f}")
print(f"  Norm change: {np.linalg.norm(evolved_states[-1]) - np.linalg.norm(initial_state):.6f}")
