In [None]:
import numpy as np
from numba import njit
import matplotlib.pyplot as plt

# ──────────────────────────────────────────────
# 1. Global constants & Pauli operators
# ──────────────────────────────────────────────
PI        = np.pi
B0     = 1.0
theta0 = PI / 4
phi0   = PI / 4
h_FD      = 1.0e-4
theta_rot = PI / 3
total_T   = 2.0
N_slices  = 50
tau       = total_T / N_slices
n_epochs  = 100
J         = 0.01

# Pauli matrices
sigma_x = np.array([[0, 1], [1, 0]], dtype=np.complex128)
sigma_y = np.array([[0, -1j], [1j, 0]], dtype=np.complex128)
sigma_z = np.array([[1, 0], [0, -1]], dtype=np.complex128)
I2 = np.eye(2, dtype=np.complex128)

sigma1x = np.kron(sigma_x, I2)
sigma1y = np.kron(sigma_y, I2)
sigma1z = np.kron(sigma_z, I2)
I4 = np.eye(4, dtype=np.complex128)

# Bell-state projectors
bell_states = np.array([
    [1,  0,  0, -1],
    [0,  1,  1,  0],
    [0,  1, -1,  0],
    [1,  0,  0,  1]
], dtype=np.complex128) / np.sqrt(2)

projectors = np.empty((4, 4, 4), dtype=np.complex128)
for k in range(4):
    v = bell_states[k]
    projectors[k] = np.outer(v, v.conj())

psi0 = np.array([1, 0, 0, 1], dtype=np.complex128) / np.sqrt(2)
rho0 = np.outer(psi0, psi0.conj())

# ──────────────────────────────────────────────
# 2. Unitary evolution with full Hamiltonian
# ──────────────────────────────────────────────
@njit(inline="always")
def U4_full(vx, vy, vz, Jcouple, dt):
    hz = np.array([vz + Jcouple, vz - Jcouple])
    U = np.zeros((4, 4), dtype=np.complex128)

    for blk in range(2):
        mag = np.sqrt(vx*vx + vy*vy + hz[blk]*hz[blk])
        if mag < 1e-14:
            cs = 1.0
            ss = 0.0
        else:
            cs = np.cos(mag * dt)
            ss = np.sin(mag * dt) / mag

        a = cs - 1j * ss * hz[blk]
        d = cs + 1j * ss * hz[blk]
        b = -1j * ss * (vx - 1j*vy)
        c = -1j * ss * (vx + 1j*vy)

        if blk == 0:
            U[0, 0] = a; U[0, 2] = b
            U[2, 0] = c; U[2, 2] = d
        else:
            U[1, 1] = a; U[1, 3] = b
            U[3, 1] = c; U[3, 3] = d
    return U

@njit
def probabilities(theta_rot, tau, n_slices,
                  B, theta, phi,
                  Vx, Vy, Vz, Jcouple,
                  projectors):
    rho = rho0.copy()
    Bx = B * np.sin(theta) * np.cos(phi)
    By = B * np.sin(theta) * np.sin(phi)
    Bz = B * np.cos(theta)

    for t in range(n_slices):
        U = U4_full(Bx + Vx[t], By + Vy[t], Bz + Vz[t], Jcouple, tau)
        rho = U @ rho @ U.conj().T

    coeff = 1.0 / np.sqrt(3.0)
    U_rot = U4_full(coeff, coeff, coeff, 0.0, theta_rot)
    rho = U_rot @ rho @ U_rot.conj().T

    p = np.empty(4, dtype=np.float64)
    for k in range(4):
        p[k] = np.real(np.trace(rho @ projectors[k]))
    return p

@njit
def f0_and_trace(Vx, Vy, Vz,
                 theta_rot, tau, n_slices,
                 B0, theta0, phi0,
                 h, Jcouple, projectors):
    P0   = probabilities(theta_rot, tau, n_slices, B0,   theta0,   phi0,   Vx, Vy, Vz, Jcouple, projectors)
    P_Bp = probabilities(theta_rot, tau, n_slices, B0+h, theta0,   phi0,   Vx, Vy, Vz, Jcouple, projectors)
    P_Bm = probabilities(theta_rot, tau, n_slices, B0-h, theta0,   phi0,   Vx, Vy, Vz, Jcouple, projectors)
    P_tp = probabilities(theta_rot, tau, n_slices, B0,   theta0+h, phi0,   Vx, Vy, Vz, Jcouple, projectors)
    P_tm = probabilities(theta_rot, tau, n_slices, B0,   theta0-h, phi0,   Vx, Vy, Vz, Jcouple, projectors)
    P_pp = probabilities(theta_rot, tau, n_slices, B0,   theta0,   phi0+h, Vx, Vy, Vz, Jcouple, projectors)
    P_pm = probabilities(theta_rot, tau, n_slices, B0,   theta0,   phi0-h, Vx, Vy, Vz, Jcouple, projectors)

    grad = np.empty((3, 4))
    grad[0] = (P_Bp - P_Bm) / (2 * h)
    grad[1] = (P_tp - P_tm) / (2 * h)
    grad[2] = (P_pp - P_pm) / (2 * h)

    F = np.zeros((3, 3))
    for i in range(3):
        for j in range(3):
            acc = 0.0
            for k in range(4):
                acc += grad[i, k] * grad[j, k] / P0[k]
            F[i, j] = acc

    Cov = np.linalg.inv(F)
    tr_inv = Cov[0, 0] + Cov[1, 1] + Cov[2, 2]
    f0 = 1.0 / (1.0 / F[0, 0] + 1.0 / F[1, 1] + 1.0 / F[2, 2])
    return f0, tr_inv

# ──────────────────────────────────────────────
# 3. Initial control: Hc = -H0
# ──────────────────────────────────────────────
B     = 1
theta = PI / 4
phi   = PI / 4

vx = -B * np.sin(theta) * np.cos(phi)
vy = -B * np.sin(theta) * np.sin(phi)
vz = -B * np.cos(theta)

V1x = np.full(N_slices, vx)
V1y = np.full(N_slices, vy)
V1z = np.full(N_slices, vz)

# ──────────────────────────────────────────────
# 3. Initial control: Control obtained by GRAPE in the former case
# ──────────────────────────────────────────────
# V1x = np.array([ -7.8623698155e-01, -5.6317828764e-01, -4.8817258475e-01, -4.7085824148e-01, -4.7268058031e-01, -4.7866340454e-01, -4.8393671914e-01, -4.8755301544e-01, -4.8982069247e-01, -4.9126605695e-01, -4.9230593520e-01, -4.9319007797e-01, -4.9404077548e-01, -4.9490174125e-01, -4.9577100521e-01, -4.9663449359e-01, -4.9746925888e-01, -4.9826047533e-01, -4.9899216569e-01, -4.9965666015e-01, -5.0025072434e-01, -5.0077125607e-01, -5.0122106479e-01, -5.0160358917e-01, -5.0192400063e-01, -5.0218854082e-01, -5.0240280532e-01, -5.0256827709e-01, -5.0268257497e-01, -5.0273228030e-01, -5.0268726946e-01, -5.0249217147e-01, -5.0205764479e-01, -5.0125615032e-01, -4.9992141525e-01, -4.9787429940e-01, -4.9498566484e-01, -4.9130938330e-01, -4.8732031006e-01, -4.8427551954e-01, -4.8467186484e-01, -4.9262493793e-01, -5.1361745511e-01, -5.5228377333e-01, -6.0531146341e-01, -6.4352977707e-01, -5.7209926064e-01, -1.5027915293e-01, 1.1441357591e+00, 4.1217965101e+00 ])
# V1y = np.array([ -4.7157906540e-01, -4.8532449125e-01, -5.0031596915e-01, -5.1046387085e-01, -5.1495878691e-01, -5.1544991775e-01, -5.1388715674e-01, -5.1167696759e-01, -5.0957987350e-01, -5.0788373913e-01, -5.0661186127e-01, -5.0567315165e-01, -5.0495374449e-01, -5.0435982542e-01, -5.0383031135e-01, -5.0333177103e-01, -5.0285642281e-01, -5.0240558976e-01, -5.0198632741e-01, -5.0160727992e-01, -5.0127183282e-01, -5.0098245917e-01, -5.0073334801e-01, -5.0051911853e-01, -5.0033214885e-01, -5.0016584726e-01, -5.0002042978e-01, -4.9990834199e-01, -4.9985312803e-01, -4.9990234366e-01, -5.0012073181e-01, -5.0059225430e-01, -5.0140650824e-01, -5.0263037849e-01, -5.0427423518e-01, -5.0622408146e-01, -5.0817615048e-01, -5.0956401946e-01, -5.0953244440e-01, -5.0701156267e-01, -5.0100417814e-01, -4.9121168606e-01, -4.7912968047e-01, -4.6961489086e-01, -4.7254685641e-01, -5.0340707641e-01, -5.8031092610e-01, -7.1433048348e-01, -8.9191981838e-01, -1.0603256471e+00 ])
# V1z = np.array([ -1.1278034584e+00, -9.0767232576e-01, -7.9197764169e-01, -7.3509335521e-01, -7.0959554482e-01, -6.9986061290e-01, -6.9742576915e-01, -6.9795937372e-01, -6.9940744590e-01, -7.0090292018e-01, -7.0215256101e-01, -7.0311395500e-01, -7.0383471211e-01, -7.0438178809e-01, -7.0481090712e-01, -7.0516288936e-01, -7.0546464524e-01, -7.0573154533e-01, -7.0597351154e-01, -7.0619762140e-01, -7.0640592079e-01, -7.0660431588e-01, -7.0679305276e-01, -7.0697630821e-01, -7.0715822170e-01, -7.0733943493e-01, -7.0752393227e-01, -7.0771522271e-01, -7.0791438429e-01, -7.0812329980e-01, -7.0834068907e-01, -7.0856552542e-01, -7.0879534892e-01, -7.0902917351e-01, -7.0926968122e-01, -7.0953271495e-01, -7.0985184778e-01, -7.1029142229e-01, -7.1095338203e-01, -7.1197226045e-01, -7.1348686445e-01, -7.1553162348e-01, -7.1778953617e-01, -7.1901442893e-01, -7.1585969227e-01, -7.0076012856e-01, -6.5824681482e-01, -5.5673372102e-01, -3.1996346532e-01, 2.5755951089e-01 ])

# ──────────────────────────────────────────────
# 4. GRAPE optimization
# ──────────────────────────────────────────────
tr_inv_list = []

for epoch in range(n_epochs + 1):
    if epoch <= 300:
        lr = 1e-8
    else:
        lr = 1e-6

    if epoch > 0:
        f0_base, _ = f0_and_trace(V1x, V1y, V1z,
                                  theta_rot, tau, N_slices,
                                  B0, theta0, phi0,
                                  h_FD, J, projectors)

        grads = []
        eps = 1e-5
        for V in (V1x, V1y, V1z):
            g = np.zeros_like(V)
            for i in range(N_slices):
                V[i] += eps
                f_pert, _ = f0_and_trace(V1x, V1y, V1z,
                                         theta_rot, tau, N_slices,
                                         B0, theta0, phi0,
                                         h_FD, J, projectors)
                g[i] = (f_pert - f0_base) / eps
                V[i] -= eps
            grads.append(g)

        V1x += lr * grads[0]
        V1y += lr * grads[1]
        V1z += lr * grads[2]

    f0_val, tr_inv = f0_and_trace(V1x, V1y, V1z,
                                  theta_rot, tau, N_slices,
                                  B0, theta0, phi0,
                                  h_FD, J, projectors)

    tr_inv_list.append(tr_inv)

    if epoch % 10 == 0:
        print(f"Epoch {epoch:04d}:  f0 = {f0_val:.6e}   Tr[Fcl^-1] = {tr_inv:.12e}")

# ──────────────────────────────────────────────
# 5. Print optimised fields
# ──────────────────────────────────────────────
print("\n# Optimised control fields:")

print("V1x = np.array([", ", ".join(f"{v:.10e}" for v in V1x), "])")
print("V1y = np.array([", ", ".join(f"{v:.10e}" for v in V1y), "])")
print("V1z = np.array([", ", ".join(f"{v:.10e}" for v in V1z), "])")

plt.plot(np.arange(n_epochs + 1), tr_inv_list)
plt.xlabel('Epoch')
plt.ylabel('Tr[F_cl^-1]')
plt.title('Training Progress of Tr[F_cl^-1]')
plt.grid(True)
plt.show()