In [18]:
import numpy as np 
import sys, os
import matplotlib.pyplot as plt
sys.path.append('../Netket/')
import netket as nk
from jax import numpy as jnp
import itertools
from scipy.special import comb
from jax import jit, vmap
import jax


In [2]:
hilbert = nk.hilbert.Spin(s=0.5, N=4)

In [3]:
hstates = hilbert.all_states()
hilbert.states_to_numbers(hstates)
# hstates[0], hilbert.states_to_numbers(hstates[0])
# type(hilbert.states_to_numbers(hstates[0]))

# type(jnp.max(np.array([2,3])))
# print(len(hstates))
# hilbert.size

Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15],      dtype=int32)

In [None]:
def return_parity(bitstring):
    '''
    bitstring is a cluster, ex s1s2, or s1s2s3s4 etc
    type of bitstring: jnp.array, dtype=jnp.int8
    returns +1 if even parity, -1 if odd parity
    '''
    par = jnp.prod(bitstring,axis=-1,dtype=jnp.int8) #Since these are spin states +1 and -1
    return par 

def naive_cluster_expansion_mat(hilbert):
    '''Compute cluster expansion coefficients up to a given maximum cluster size.
    for now only works for spin-1/2 systems
    '''
    # n_sites = hilbert.n_sites
    n_sites = hilbert.size
    hstates = hilbert.all_states()
    matsize = 2**n_sites
    mat = jnp.ones((matsize, matsize),dtype=jnp.int8)
    for state_idx, state in enumerate(hstates):
        start_idx = 1 #First column is all ones, so start from second column
        for cluster_size in jnp.arange(1, n_sites + 1):
            clusters = jnp.array(list(itertools.combinations(state, cluster_size)))
            rowvals = return_parity(clusters)
            mat = mat.at[state_idx, start_idx: start_idx + int(comb(n_sites, cluster_size))].set(rowvals)
            start_idx += int(comb(int(n_sites), cluster_size))
    return mat


In [33]:
def optim_cluster_expansion(hilbert):
    '''Maximum performance cluster expansion for large systems (N=16).
    Uses vmap for automatic vectorization and JIT compilation.
    '''
    n_sites = hilbert.size
    hstates = hilbert.all_states()
    matsize = 2**n_sites
    
    # Pre-compute all cluster indices outside JIT (Python operations)
    cluster_indices_list = []
    for cluster_size in range(1, n_sites + 1):
        clusters = list(itertools.combinations(range(n_sites), cluster_size))
        cluster_indices_list.append(jnp.array(clusters, dtype=jnp.int32))
    
    total_clusters = sum(len(c) for c in cluster_indices_list)
    
    # Vectorized parity computation
    @jit
    def compute_parities_vectorized(state, clusters):
        """Compute parities for a single state across all clusters of a given size"""
        # Shape: (n_clusters, cluster_size)
        state_values = state[clusters]
        # Compute parity: product along cluster axis
        return jnp.prod(state_values, axis=-1, dtype=jnp.int8)
    
    # vmap over all states for each cluster size
    @jit
    def compute_all_parities(hstates, clusters):
        """Vectorized over all states"""
        return vmap(lambda state: compute_parities_vectorized(state, clusters))(hstates)
    
    # Build matrix: move Python loop outside JIT
    columns = [jnp.ones((matsize, 1), dtype=jnp.int8)]
    
    for clusters in cluster_indices_list:
        parities = compute_all_parities(hstates, clusters)
        columns.append(parities)
    
    # Concatenate all columns
    mat = jnp.concatenate(columns, axis=1)
    return mat


In [118]:
naive_cluster_expansion_mat(nk.hilbert.Spin(s=0.5, N=2))

Array([[ 1,  1,  1,  1],
       [ 1,  1, -1, -1],
       [ 1, -1,  1, -1],
       [ 1, -1, -1,  1]], dtype=int8)

In [20]:
def optim_cluster_expansion_extreme(hilbert):
    '''EXTREME optimization for N=16: trades memory for speed.
    Uses batching and full vmap + JIT + GPU if available.
    '''
    n_sites = hilbert.size
    hstates = hilbert.all_states()
    matsize = 2**n_sites
    
    # Pre-compute all cluster indices
    all_clusters = []
    for cluster_size in range(1, n_sites + 1):
        clusters = list(itertools.combinations(range(n_sites), cluster_size))
        all_clusters.extend(clusters)
    
    all_clusters = jnp.array(all_clusters, dtype=jnp.int32)
    total_clusters = len(all_clusters)
    
    @jit
    def compute_all_parities_extreme(hstates, all_clusters):
        """
        Compute all cluster parities at once.
        Uses full vectorization: vmap over states, then vmap over clusters.
        """
        def compute_single_parity(state, clusters):
            """Compute parity for a single state and cluster"""
            return jnp.prod(state[clusters], dtype=jnp.int8)
        
        # vmap over clusters (for each state)
        def parities_for_state(state):
            return vmap(lambda cluster: compute_single_parity(state, cluster))(all_clusters)
        
        # vmap over states
        return vmap(parities_for_state)(hstates)
    
    # Build final matrix
    parities = compute_all_parities_extreme(hstates, all_clusters)
    
    # Prepend identity column
    identity_col = jnp.ones((matsize, 1), dtype=jnp.int8)
    mat = jnp.concatenate([identity_col, parities], axis=1)
    
    return mat


In [40]:
optim_cluster_expansion(nk.hilbert.Spin(s=0.5, N=10))

Array([[ 1,  1,  1, ...,  1,  1,  1],
       [ 1,  1,  1, ..., -1, -1, -1],
       [ 1,  1,  1, ..., -1, -1, -1],
       ...,
       [ 1, -1, -1, ...,  1,  1, -1],
       [ 1, -1, -1, ...,  1,  1, -1],
       [ 1, -1, -1, ..., -1, -1,  1]], dtype=int8)

### Tests for correctness

In [38]:
result_naive = naive_cluster_expansion_mat(nk.hilbert.Spin(s=0.5, N=4))
result_optim = optim_cluster_expansion(nk.hilbert.Spin(s=0.5, N=4))
jnp.allclose(result_naive, result_optim)  # Should be True

Array(True, dtype=bool)

In [None]:
import time

def benchmark_cluster_expansion(n_sites=16):
    """Benchmark all three implementations"""
    hilbert = nk.hilbert.Spin(s=0.5, N=n_sites)
    print(f"\n{'='*60}")
    print(f"Benchmarking N={n_sites} (matrix size: {2**n_sites} x ~{sum(comb(n_sites, k) for k in range(1, n_sites+1)) + 1})")
    print(f"{'='*60}\n")
    
    # Warmup
    try:
        optim_cluster_expansion_extreme(nk.hilbert.Spin(s=0.5, N=4))
        print("✓ JAX/GPU warmup complete")
    except:
        print("⚠ GPU not available, using CPU")
    
    # Test 1: Optimized with vmap
    print("\n[1] optim_cluster_expansion (vmap + JIT):")
    start = time.time()
    result_vmap = optim_cluster_expansion(hilbert)
    elapsed_vmap = time.time() - start
    print(f"    Time: {elapsed_vmap:.3f}s")
    print(f"    Output shape: {result_vmap.shape}")
    
    # Test 2: Extreme optimization
    print("\n[2] optim_cluster_expansion_extreme (full vmap + JIT):")
    start = time.time()
    result_extreme = optim_cluster_expansion_extreme(hilbert)
    elapsed_extreme = time.time() - start
    print(f"    Time: {elapsed_extreme:.3f}s")
    print(f"    Output shape: {result_extreme.shape}")
    
    # Verify correctness
    print(f"\n✓ Results match: {jnp.allclose(result_vmap, result_extreme)}")
    
    print(f"\n{'='*60}")
    print(f"Speedup: {elapsed_vmap/elapsed_extreme:.2f}x")
    print(f"{'='*60}\n")
    
    return result_extreme

# Run benchmark for N=16
mat_n16 = benchmark_cluster_expansion(n_sites=16)
