# Boltzmann Sampling for Small Ising Systems (JAX version)

This notebook reproduces the brute‑force Gibbs sampler originally provided in Julia, translated to **Python/JAX** (`jax.numpy`). It enumerates all $2^n$ configurations of an Ising model, assigns Boltzmann weights
$$w(s)=\exp\bigl(\tfrac12 s^\top J s + h^\top s\bigr),$$
draws independent samples, compresses them into a histogram `[count, s₁, …, sₙ]`, and finally saves the result to CSV.


In [7]:
import jax.numpy as jnp
from jax import random
from typing import Any
import pandas as pd
import numpy as np

In [8]:
def int_to_spin(ints: jnp.ndarray, n: int) -> jnp.ndarray:

    """Convert integers (0…2^n - 1) to {-1,+1} spin vectors of length n."""

    bits = ((ints[:, None] >> jnp.arange(n)) & 1)
    
    return 2 * bits - 1  # 0→‑1, 1→+1

In [9]:
def sample_generation(sample_number: int,
                      J: jnp.ndarray,
                      h: jnp.ndarray,
                      key: Any) -> np.ndarray:
    
    """
    Enumerate configurations, compute Boltzmann weights, sample, return histogram.
    """

    n = J.shape[0]
    configs = jnp.arange(2 ** n, dtype=jnp.uint32)
    spins   = int_to_spin(configs, n)

    energies = 0.5 * jnp.einsum('bi,ij,bj->b', spins, J, spins) + spins @ h
    logw     = energies - jnp.max(energies)
    probs    = jnp.exp(logw) / jnp.sum(jnp.exp(logw))

    samples  = random.choice(key, configs, shape=(sample_number,), p=probs, replace=True)

    counts   = jnp.bincount(samples, length=configs.size)
    
    nonzero  = jnp.nonzero(counts)[0]
    spins_nz = int_to_spin(nonzero, n)

    histogram = jnp.concatenate([counts[nonzero][:, None], spins_nz], axis=1)

    return np.asarray(histogram, dtype=np.int64)


In [10]:
file_adj = "input_adjacency.csv"   # path to weighted adjacency matrix (J with h on diag)
num_samples = 100000               # number of samples to draw
seed = 0                             # PRNG seed

J_df = pd.read_csv(file_adj, header=None)
J = jnp.array(J_df.values, dtype=jnp.float32)
h = jnp.array(jnp.diag(J), dtype=jnp.float32)
key = random.PRNGKey(seed)

In [11]:
hist = sample_generation(num_samples, J, h, key)
hist[:10]  # preview first 10 histogram rows

array([[35629,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1],
       [ 2698,     1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1],
       [ 1205,    -1,     1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1],
       [   52,     1,     1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1],
       [  657,    -1,    -1,     1,    -1,    -1,    -1,    -1,    -1,
           -1],
       [  344,     1,    -1,     1,    -1,    -1,    -1,    -1,    -1,
           -1],
       [  164,    -1,     1,     1,    -1,    -1,    -1,    -1,    -1,
           -1],
       [   37,     1,     1,     1,    -1,    -1,    -1,    -1,    -1,
           -1],
       [ 1161,    -1,    -1,    -1,     1,    -1,    -1,    -1,    -1,
           -1],
       [  645,     1,    -1,    -1,     1,    -1,    -1,    -1,    -1,
           -1]])

In [None]:
output_file = "output_samples.csv"
pd.DataFrame(hist).to_csv(output_file, header=False, index=False)
print(f"histogram saved to {output_file}")

Histogram saved to output_samples.csv
