In [7]:
# ising_jax_beta_scan.py
import jax
import jax.numpy as jnp
from jax import jit, lax, random
import matplotlib.pyplot as plt

# -------------------------
# Utilities
# -------------------------
def coord_to_index(coords, shape):
    """Converts N-dimensional coordinates to a 1D index using row-major ordering."""
    # This is a robust implementation of np.ravel_multi_index for JAX.
    # The original function was for column-major indexing, which mismatched
    # the row-major output of jnp.unravel_index.
    shape = jnp.array(shape)
    strides = jnp.concatenate([jnp.cumprod(shape[1:][::-1])[::-1], jnp.array([1])])
    return jnp.dot(coords, strides).astype(int)

def build_nn_J(lattice_shape, coupling=1.0):
    dims = len(lattice_shape)
    N = int(jnp.prod(jnp.array(lattice_shape)))
    # Get coordinates in row-major order
    coords = jnp.stack(jnp.unravel_index(jnp.arange(N), lattice_shape), axis=-1)

    def neighbors_of(i):
        c = coords[i]
        neighs = []
        for d in range(dims):
            plus = c.at[d].set((c[d] + 1) % lattice_shape[d])
            minus = c.at[d].set((c[d] - 1) % lattice_shape[d])
            neighs.append(coord_to_index(plus, lattice_shape))
            neighs.append(coord_to_index(minus, lattice_shape))
        return jnp.array(neighs, dtype=int)

    def build_row(i):
        row = jnp.zeros((N,), dtype=jnp.float32)
        neighs = neighbors_of(i)
        row = row.at[neighs].set(coupling)
        return row

    J = jax.vmap(build_row)(jnp.arange(N))
    J = J.at[jnp.diag_indices(N)].set(0.0)
    return J

@jit
def energy(s, J, h=0.0):
    E_int = -0.5 * jnp.dot(s, jnp.dot(J, s))
    E_field = -h * jnp.sum(s)
    return E_int + E_field

@jax.jit
def metropolis_sweep(key, s, J, beta, num_flips=None):
    N = s.shape[0]
    if num_flips is None:
        num_flips = N

    key_idx, key_u = random.split(key)
    flip_indices = random.randint(key_idx, shape=(num_flips,), minval=0, maxval=N)
    u = random.uniform(key_u, shape=(num_flips,))

    def body_fun(i, state):
        s = state
        idx = flip_indices[i]
        ui = u[i]
        # Local field at site idx
        h_i = jnp.dot(J[idx], s)
        # Change in energy if we flip the spin
        deltaE = 2.0 * s[idx] * h_i
        accept_prob = jnp.exp(-beta * deltaE)
        accept = (deltaE <= 0.0) | (ui < accept_prob)
        # Conditionally flip the spin
        s = lax.cond(accept, lambda x: x.at[idx].set(-x[idx]), lambda x: x, s)
        return s

    return lax.fori_loop(0, num_flips, body_fun, s)



# -------------------------
# Experiment: magnetization vs beta
# -------------------------
if __name__ == "__main__":
    key = random.PRNGKey(0)

    Lx, Ly = 16, 16
    lattice_shape = (Lx, Ly)
    N = Lx * Ly
    J = build_nn_J(lattice_shape, coupling=1.0)

    # range of betas
    betas = jnp.linspace(0.1, 0.8, 30)  # inverse temperature
    n_sweeps = 1000
    measure_every = 20
    thermalization = 500

    mags = []

    # This is a standard Python loop. It's correct but not as fast as vmapping.
    for beta in betas:
        print(f"Running for beta = {beta:.3f}...")
        # random initial state
        key, sub = random.split(key)
        s = random.choice(sub, jnp.array([-1, 1], dtype=jnp.int8), (N,)).astype(jnp.int8)

        # equilibrate
        for _ in range(thermalization):
            key, sub = random.split(key)
            s = metropolis_sweep(sub, s, J, float(beta), num_flips=N)

        # measure magnetization
        m_values = []
        for sweep in range(n_sweeps):
            key, sub = random.split(key)
            s = metropolis_sweep(sub, s, J, float(beta), num_flips=N)
            if (sweep + 1) % measure_every == 0:
                m_values.append(jnp.mean(s))
        # Take the mean of the absolute magnetization over the measurement sweeps
        mags.append(jnp.mean(jnp.abs(jnp.array(m_values))))

    # -------------------------
    # Plot results
    # -------------------------
    plt.figure(figsize=(8, 5))
    plt.plot(betas, mags, "o-", lw=2, label="Simulation Data")
    # Add the known critical temperature for the 2D Ising model
    beta_c = jnp.log(1 + jnp.sqrt(2)) / 2
    plt.axvline(beta_c, color='r', linestyle='--', label=r"Exact $\beta_c$")
    plt.xlabel(r"$\beta = \frac{1}{k_B T}$", fontsize=14)
    plt.ylabel(r"$\langle |M| \rangle$", fontsize=14)
    plt.title(f"2D Ising Model Magnetization ({Lx}x{Ly})", fontsize=16)
    plt.grid(True)
    plt.legend()
    plt.show()

    print("Done.")

Running for beta = 0.100...


TypeError: Shapes must be 1D sequences of concrete values of integer type, got (JitTracer<~int32[]>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function metropolis_sweep at /tmp/ipykernel_31781/3489134959.py:51 for jit. This concrete value was not available in Python because it depends on the value of the argument num_flips.