In [14]:
import numpy as np 
import sys, os
import matplotlib.pyplot as plt
sys.path.append('../Netket/')
import netket as nk
from jax import numpy as jnp
import itertools
from scipy.special import comb
from jax import jit, vmap
import jax


In [3]:
hilbert = nk.hilbert.Spin(s=0.5, N=4)

In [4]:
hstates = hilbert.all_states()
hilbert.states_to_numbers(hstates)
# hstates[0], hilbert.states_to_numbers(hstates[0])
# type(hilbert.states_to_numbers(hstates[0]))

# type(jnp.max(np.array([2,3])))
# print(len(hstates))
# hilbert.size

Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15],      dtype=int32)

In [5]:
def return_parity(bitstring):
    '''
    bitstring is a cluster, ex s1s2, or s1s2s3s4 etc
    type of bitstring: jnp.array, dtype=jnp.int8
    returns +1 if even parity, -1 if odd parity
    '''
    par = jnp.prod(bitstring,axis=-1,dtype=jnp.int8) #Since these are spin states +1 and -1
    return par 

def naive_cluster_expansion_mat(hilbert):
    '''Compute cluster expansion coefficients up to a given maximum cluster size.
    for now only works for spin-1/2 systems
    '''
    # n_sites = hilbert.n_sites
    n_sites = hilbert.size
    hstates = hilbert.all_states()
    matsize = 2**n_sites
    mat = jnp.ones((matsize, matsize),dtype=jnp.int8)
    for state_idx, state in enumerate(hstates):
        start_idx = 1 #First column is all ones, so start from second column
        for cluster_size in jnp.arange(1, n_sites + 1):
            clusters = jnp.array(list(itertools.combinations(state, cluster_size)))
            rowvals = return_parity(clusters)
            mat = mat.at[state_idx, start_idx: start_idx + int(comb(n_sites, cluster_size))].set(rowvals)
            start_idx += int(comb(int(n_sites), cluster_size))
    return mat


In [6]:
def optim_cluster_expansion(hilbert):
    '''Maximum performance cluster expansion for large systems (N=16).
    Uses vmap for automatic vectorization and JIT compilation.
    '''
    n_sites = hilbert.size
    hstates = hilbert.all_states()
    matsize = 2**n_sites
    
    # Pre-compute all cluster indices outside JIT (Python operations)
    cluster_indices_list = []
    for cluster_size in range(1, n_sites + 1):
        clusters = list(itertools.combinations(range(n_sites), cluster_size))
        cluster_indices_list.append(jnp.array(clusters, dtype=jnp.int32))
    
    total_clusters = sum(len(c) for c in cluster_indices_list)
    
    # Vectorized parity computation
    @jit
    def compute_parities_vectorized(state, clusters):
        """Compute parities for a single state across all clusters of a given size"""
        # Shape: (n_clusters, cluster_size)
        state_values = state[clusters]
        # Compute parity: product along cluster axis
        return jnp.prod(state_values, axis=-1, dtype=jnp.int8)
    
    # vmap over all states for each cluster size
    @jit
    def compute_all_parities(hstates, clusters):
        """Vectorized over all states"""
        return vmap(lambda state: compute_parities_vectorized(state, clusters))(hstates)
    
    # Build matrix: move Python loop outside JIT
    columns = [jnp.ones((matsize, 1), dtype=jnp.int8)]
    
    for clusters in cluster_indices_list:
        parities = compute_all_parities(hstates, clusters)
        columns.append(parities)
    
    # Concatenate all columns
    mat = jnp.concatenate(columns, axis=1)
    return mat


In [7]:
naive_cluster_expansion_mat(nk.hilbert.Spin(s=0.5, N=2))

Array([[ 1,  1,  1,  1],
       [ 1,  1, -1, -1],
       [ 1, -1,  1, -1],
       [ 1, -1, -1,  1]], dtype=int8)

In [8]:
def optim_cluster_expansion_extreme(hilbert):
    '''Extreme optimization: just use the vmap version but with proper JIT wrapping.
    The "naive" nested approach was causing issues with variable-sized clusters.
    This sticks with the working vmap pattern but adds extra optimization.
    '''
    n_sites = hilbert.size
    hstates = hilbert.all_states()
    matsize = 2**n_sites
    
    # Pre-compute all cluster indices outside JIT (Python operations)
    cluster_indices_list = []
    for cluster_size in range(1, n_sites + 1):
        clusters = list(itertools.combinations(range(n_sites), cluster_size))
        cluster_indices_list.append(jnp.array(clusters, dtype=jnp.int32))
    
    # Double vmap: over all clusters at once per state
    @jit
    def compute_all_parities_for_state(state, all_cluster_arrays):
        """Compute all parities for a single state across all cluster sizes"""
        # Process each cluster size
        results = []
        for clusters in all_cluster_arrays:
            state_values = state[clusters]
            parities = jnp.prod(state_values, axis=-1, dtype=jnp.int8)
            results.append(parities)
        return jnp.concatenate(results)
    
    # vmap over all states
    @jit
    def compute_all_states(hstates, all_cluster_arrays):
        return vmap(lambda state: compute_all_parities_for_state(state, all_cluster_arrays))(hstates)
    
    # Compute all parities at once
    parities = compute_all_states(hstates, cluster_indices_list)
    
    # Prepend identity column
    identity_col = jnp.ones((matsize, 1), dtype=jnp.int8)
    mat = jnp.concatenate([identity_col, parities], axis=1)
    
    return mat


In [9]:
optim_cluster_expansion_extreme(nk.hilbert.Spin(s=0.5, N=12))

Array([[ 1,  1,  1, ...,  1,  1,  1],
       [ 1,  1,  1, ..., -1, -1, -1],
       [ 1,  1,  1, ..., -1, -1, -1],
       ...,
       [ 1, -1, -1, ...,  1,  1, -1],
       [ 1, -1, -1, ...,  1,  1, -1],
       [ 1, -1, -1, ..., -1, -1,  1]], dtype=int8)

In [10]:
optim_cluster_expansion(nk.hilbert.Spin(s=0.5, N=12))

Array([[ 1,  1,  1, ...,  1,  1,  1],
       [ 1,  1,  1, ..., -1, -1, -1],
       [ 1,  1,  1, ..., -1, -1, -1],
       ...,
       [ 1, -1, -1, ...,  1,  1, -1],
       [ 1, -1, -1, ...,  1,  1, -1],
       [ 1, -1, -1, ..., -1, -1,  1]], dtype=int8)

### Tests for correctness

In [36]:
# Tests for correctness: compare naive vs extreme optimized implementations
spin_size = 6
hil = nk.hilbert.Spin(s=0.5, N=spin_size)

res_naive = naive_cluster_expansion_mat(hil)
res_extreme = optim_cluster_expansion_extreme(hil)
res_vec = optim_cluster_expansion(hil)

np.isclose(res_naive, res_extreme).all()


np.True_

In [12]:
import time

def benchmark_cluster_expansion(n_sites=8):
    """Benchmark all three implementations"""
    hilbert = nk.hilbert.Spin(s=0.5, N=n_sites)
    print(f"\n{'='*60}")
    print(f"Benchmarking N={n_sites} (matrix size: {2**n_sites} x ~{sum(comb(n_sites, k) for k in range(1, n_sites+1)) + 1})")
    print(f"{'='*60}\n")
    
    # Warmup
    try:
        optim_cluster_expansion_extreme(nk.hilbert.Spin(s=0.5, N=4))
        print("JAX/GPU warmup complete")
    except:
        print("GPU not available, using CPU")
    
    # Test 1: Optimized with vmap
    print("\n[1] optim_cluster_expansion (vmap + JIT):")
    start = time.time()
    result_vmap = optim_cluster_expansion(hilbert)
    elapsed_vmap = time.time() - start
    print(f"    Time: {elapsed_vmap:.3f}s")
    print(f"    Output shape: {result_vmap.shape}")
    
    # Test 2: Extreme optimization
    print("\n[2] optim_cluster_expansion_extreme (full vmap + JIT):")
    start = time.time()
    result_extreme = optim_cluster_expansion_extreme(hilbert)
    elapsed_extreme = time.time() - start
    print(f"    Time: {elapsed_extreme:.3f}s")
    print(f"    Output shape: {result_extreme.shape}")
    
    # Verify correctness
    print(f"\nResults match: {jnp.allclose(result_vmap, result_extreme)}")
    
    print(f"\n{'='*60}")
    print(f"Speedup: {elapsed_vmap/elapsed_extreme:.2f}x")
    print(f"{'='*60}\n")
    
    return result_extreme

# Run benchmark for N=16
mat_n16 = benchmark_cluster_expansion(n_sites=10)



Benchmarking N=10 (matrix size: 1024 x ~1024.0)

JAX/GPU warmup complete

[1] optim_cluster_expansion (vmap + JIT):
    Time: 0.706s
    Output shape: (1024, 1024)

[2] optim_cluster_expansion_extreme (full vmap + JIT):
    Time: 0.157s
    Output shape: (1024, 1024)

Results match: True

Speedup: 4.49x

    Time: 0.706s
    Output shape: (1024, 1024)

[2] optim_cluster_expansion_extreme (full vmap + JIT):
    Time: 0.157s
    Output shape: (1024, 1024)

Results match: True

Speedup: 4.49x



In [15]:
optim_cluster_expansion_extreme(nk.hilbert.Spin(s=0.5, N=2))

Array([[ 1,  1,  1,  1],
       [ 1,  1, -1, -1],
       [ 1, -1,  1, -1],
       [ 1, -1, -1,  1]], dtype=int8)

In [57]:
psi_test = 0.1 * np.ones(16)
cluster_mat = optim_cluster_expansion_extreme(nk.hilbert.Spin(s=0.5, N=4)) 
cluster_coeffs = jnp.linalg.solve(cluster_mat, np.log(psi_test))
print(1./np.linalg.norm(psi_test))
np.exp(cluster_coeffs) 

2.5


array([0.1, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,
       1. , 1. , 1. ])

In [58]:
from jax import jit
from jax import numpy as jnp
import numpy as np
import itertools

@jit
def fwht(x):
    """In-place Fast Walshâ€“Hadamard Transform (returns transformed vector).
    Supports complex inputs (uses complex128 internally).
    Length of x must be a power of two.
    """
    x = x.astype(jnp.complex128)
    n = x.shape[0]
    h = 1
    while h < n:
        x = x.reshape(-1, 2 * h)
        left = x[:, :h]
        right = x[:, h:2 * h]
        x = jnp.concatenate([left + right, left - right], axis=1)
        x = x.reshape(n)
        h *= 2
    return x

# Helper that maps your hilbert ordering to FWHT ordering, runs FWHT and returns coeffs in cluster-matrix column order
def fwht_coeffs_in_cluster_col_order(psi, hilbert):
    """Compute coefficients c solving H c = psi using FWHT.
    Returns c in the same column ordering as optim_cluster_expansion_extreme produces.
    """
    n_sites = hilbert.size
    hstates = np.array(hilbert.all_states())  # shape (n_states, n_sites) with values +1/-1

    # Map spins s in {+1,-1} -> bits b in {0,1} with convention b=1 when s==-1
    b = ((1 - hstates) // 2).astype(np.int64)
    powers = (1 << np.arange(n_sites)).astype(np.int64)
    indices = (b * powers).sum(axis=1)
    perm = np.argsort(indices)  # mapping FWHT index -> row

    psi_arr = jnp.array(psi)
    psi_by_index = psi_arr[perm]
    n = psi_by_index.shape[0]
    coeffs_by_index = fwht(psi_by_index) / float(n)  # indexed by subset mask
    coeffs_by_index_np = np.array(coeffs_by_index)

    # Build mask list in the SAME ORDER as optim_cluster_expansion_extreme columns
    masks = [0]
    for cluster_size in range(1, n_sites + 1):
        for comb in itertools.combinations(range(n_sites), cluster_size):
            mask = 0
            for bpos in comb:
                mask |= (1 << bpos)
            masks.append(mask)
    masks = np.array(masks, dtype=np.int64)  # length n

    # coeffs in column order: take coeffs_by_index[mask] for each column
    coeffs_col_order = coeffs_by_index_np[masks]
    return coeffs_col_order

# Verification test for N=4 with complex psi
N = 4
hil = nk.hilbert.Spin(s=0.5, N=N)
print(f"Running FWHT verification for N={N}")

# random complex psi in hilbert ordering
psi = (np.random.rand(2**N) + 1j * np.random.rand(2**N)).astype(np.complex128)
coeffs_col = fwht_coeffs_in_cluster_col_order(psi, hil)

# Reconstruct psi using the explicit matrix (small N)
cluster_mat = optim_cluster_expansion_extreme(hil)
cluster_mat_c = jnp.array(cluster_mat, dtype=jnp.complex128)
recon = cluster_mat_c @ jnp.array(coeffs_col)

max_err = float(jnp.max(jnp.abs(recon - psi)))
print(f"Max absolute reconstruction error: {max_err:.2e}")
assert jnp.allclose(recon, jnp.array(psi), atol=1e-8), "FWHT reconstruction failed"
print("FWHT verified (complex psi) for N=4")


Running FWHT verification for N=4
Max absolute reconstruction error: 3.51e-16
FWHT verified (complex psi) for N=4


In [64]:
n_sites = 16
psi_test = 0.1 * np.ones(2**n_sites)
fwht_coeffs_in_cluster_col_order(psi_test, nk.hilbert.Spin(0.5,n_sites))

array([0.1+0.j, 0. +0.j, 0. +0.j, ..., 0. +0.j, 0. +0.j, 0. +0.j],
      shape=(65536,))