In [None]:
import time
import jax
import jax.numpy as jnp
from jax import jit, grad

def mod(a, b):
    return a % b

def site_index(x, y, Ly):
    return x * Ly + y

# -----------------------------
# BdG Hamiltonian 
# -----------------------------
def Hubbard_AFM_dwave_BdG(t, mu, Lx, Ly, m_ord, Delta_ord):
    Ns = Lx * Ly
    N = 4 * Ns

    H = jnp.zeros((N, N), dtype=jnp.float32)

    def idx(i, spin, particle):
        # spin: 0=up, 1=down
        # particle: 0=particle, 1=hole
        return i + spin * Ns + particle * 2 * Ns

    for x in range(Lx):
        for y in range(Ly):
            i = site_index(x, y, Ly)
            # AFM staggering: (-1)^(x+y) 
            eta = 1 - 2 * ((x + y) & 1)

            # -------------------------
            # Onsite AFM terms
            # -------------------------
            H = H.at[idx(i, 0, 0), idx(i, 0, 0)].set(-mu + eta * m_ord)
            H = H.at[idx(i, 1, 0), idx(i, 1, 0)].set(-mu - eta * m_ord)
            H = H.at[idx(i, 0, 1), idx(i, 0, 1)].set( mu - eta * m_ord)
            H = H.at[idx(i, 1, 1), idx(i, 1, 1)].set( mu + eta * m_ord)

            # -------------------------
            # Nearest neighbors
            # -------------------------
            for dx, dy, sign in [(1, 0, +1), (0, 1, -1)]:
                x2 = mod(x + dx, Lx)
                y2 = mod(y + dy, Ly)
                j = site_index(x2, y2, Ly)

                # Hopping terms
                for s in (0, 1):
                    H = H.at[idx(i, s, 0), idx(j, s, 0)].add(-t)
                    H = H.at[idx(j, s, 0), idx(i, s, 0)].add(-t)

                    H = H.at[idx(i, s, 1), idx(j, s, 1)].add( t)
                    H = H.at[idx(j, s, 1), idx(i, s, 1)].add( t)

                # d-wave pairing
                Delta = sign * Delta_ord

                H = H.at[idx(i, 0, 0), idx(j, 1, 1)].add( Delta)
                H = H.at[idx(j, 0, 0), idx(i, 1, 1)].add( Delta)

                H = H.at[idx(i, 1, 0), idx(j, 0, 1)].add(-Delta)
                H = H.at[idx(j, 1, 0), idx(i, 0, 1)].add(-Delta)

                # Hermitian conjugate
                H = H.at[idx(j, 1, 1), idx(i, 0, 0)].add( Delta)
                H = H.at[idx(i, 1, 1), idx(j, 0, 0)].add( Delta)

                H = H.at[idx(j, 0, 1), idx(i, 1, 0)].add(-Delta)
                H = H.at[idx(i, 0, 1), idx(j, 1, 0)].add(-Delta)

    return H
# -----------------------------
# Energy 
# -----------------------------
def e0val(m_ord, Delta_ord, U):
    return (m_ord ** 2 + Delta_ord ** 2) / U


def ground_state_energy(Lx, Ly, eigvals, e0):
    return jnp.sum(jnp.minimum(eigvals, 0.0)) / (Lx * Ly) + e0

# -----------------------------
# Objective function
# -----------------------------
def objective(params, t, mu, U, Lx, Ly):
    m, Delta = params
    H = Hubbard_AFM_dwave_BdG(
        t=t,
        mu=mu,
        Lx=Lx,
        Ly=Ly,
        m_ord=m,
        Delta_ord=Delta,
    )

    eigvals = jnp.linalg.eigvalsh(H)
    e0 = e0val(m, Delta, U)
    return ground_state_energy(Lx, Ly, eigvals, e0)

# -----------------------------
# JIT compilation
# -----------------------------
objective_jit = jit(objective, static_argnums=(4, 5))
grad_objective_jit = jit(grad(objective), static_argnums=(4, 5))


# -----------------------------
# 
# -----------------------------
if __name__ == "__main__":
    t = 1.0
    mu = 0.0
    Lx, Ly = 8, 8

    lr = 0.5
    n_iter = 40

    U_list = [1, 2, 4, 8, 16]
    m_list = []
    Delta_list = []

    for U in U_list:
        print(f"\nU = {U}")
        params = jnp.array([0.1, 0.1])
        # First: compile + run
        t0 = time.time()
        loss = objective_jit(params, t, mu, U, Lx, Ly)
        g = grad_objective_jit(params, t, mu, U, Lx, Ly)
        jax.block_until_ready(loss)
        t1 = time.time()

        print(f"  JIT compile + first run: {t1 - t0:.3f} s")

        # Gradient descent
        t2 = time.time()
        for _ in range(n_iter):
            g = grad_objective_jit(params, t, mu, U, Lx, Ly)
            params = params - lr * g
        jax.block_until_ready(params)
        t3 = time.time()

        print(f"  GD iterations time     : {t3 - t2:.3f} s")

        m_opt, Delta_opt = params
        m_list.append(m_opt)
        Delta_list.append(Delta_opt)

        print(f"  m     = {float(m_opt):.6f}")
        print(f"  Delta = {float(Delta_opt):.6f}")
