In [1]:
import numpy as np 
import sys, os
sys.path.append('../Netket/')
import netket as nk
from jax import numpy as jnp
import itertools
from scipy.special import comb
from jax import jit, vmap
import jax


In [2]:
hilbert = nk.hilbert.Spin(s=0.5, N=4)

In [3]:
hstates = hilbert.all_states()
hilbert.states_to_numbers(hstates)

Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15],      dtype=int32)

In [4]:
def return_parity(bitstring):
    '''
    bitstring is a cluster, ex s1s2, or s1s2s3s4 etc
    type of bitstring: jnp.array, dtype=jnp.int8
    returns +1 if even parity, -1 if odd parity
    '''
    par = jnp.prod(bitstring,axis=-1,dtype=jnp.int8) #Since these are spin states +1 and -1
    return par 

def naive_cluster_expansion_mat(hilbert):
    '''Compute cluster expansion coefficients up to a given maximum cluster size.
    for now only works for spin-1/2 systems
    '''
    # n_sites = hilbert.n_sites
    n_sites = hilbert.size
    hstates = hilbert.all_states()
    matsize = 2**n_sites
    mat = jnp.ones((matsize, matsize),dtype=jnp.int8)
    for state_idx, state in enumerate(hstates):
        start_idx = 1 #First column is all ones, so start from second column
        for cluster_size in jnp.arange(1, n_sites + 1):
            clusters = jnp.array(list(itertools.combinations(state, cluster_size)))
            rowvals = return_parity(clusters)
            mat = mat.at[state_idx, start_idx: start_idx + int(comb(n_sites, cluster_size))].set(rowvals)
            start_idx += int(comb(int(n_sites), cluster_size))
    return mat


In [5]:
naive_cluster_expansion_mat(nk.hilbert.Spin(s=0.5, N=2))

Array([[ 1,  1,  1,  1],
       [ 1,  1, -1, -1],
       [ 1, -1,  1, -1],
       [ 1, -1, -1,  1]], dtype=int8)

In [None]:
from jax import jit
from jax import numpy as jnp
import numpy as np
import itertools

@jit
def fwht(x):
    """In-place Fast Walsh–Hadamard Transform (returns transformed vector).
    Supports complex inputs (uses complex128 internally).
    Length of x must be a power of two.
    """
    x = x.astype(jnp.complex128)
    n = x.shape[0]
    h = 1
    while h < n:
        x = x.reshape(-1, 2 * h)
        left = x[:, :h]
        right = x[:, h:2 * h]
        x = jnp.concatenate([left + right, left - right], axis=1)
        x = x.reshape(n)
        h *= 2
    return x

# Helper that maps your hilbert ordering to FWHT ordering, runs FWHT and returns coeffs in cluster-matrix column order
def fwht_coeffs_in_cluster_col_order(logpsi, hilbert):
    """Compute coefficients c solving H c = logpsi using FWHT.
   Usage: fwht_coeffs_in_cluster_col_order(logpsi, hilbert)
   where logpsi is a vector of log wavefunction amplitudes in the canonical netket ordering.   
    """
    n_sites = hilbert.size
    hstates = np.array(hilbert.all_states())  # shape (n_states, n_sites) with values +1/-1

    # Map spins s in {+1,-1} -> bits b in {0,1} with convention b=1 when s==-1
    b = ((1 - hstates) // 2).astype(np.int64)
    powers = (1 << np.arange(n_sites)).astype(np.int64)
    indices = (b * powers).sum(axis=1)
    perm = np.argsort(indices)  # mapping FWHT index -> row

    psi_arr = jnp.array(logpsi)
    psi_by_index = psi_arr[perm]
    n = psi_by_index.shape[0]
    coeffs_by_index = fwht(psi_by_index) / float(n)  # indexed by subset mask
    coeffs_by_index_np = np.array(coeffs_by_index)

    # Build mask list in the SAME ORDER as optim_cluster_expansion_extreme columns
    masks = [0]
    for cluster_size in range(1, n_sites + 1):
        for comb in itertools.combinations(range(n_sites), cluster_size):
            mask = 0
            for bpos in comb:
                mask |= (1 << bpos)
            masks.append(mask)
    masks = np.array(masks, dtype=np.int64)  # length n

    # coeffs in column order: take coeffs_by_index[mask] for each column
    coeffs_col_order = coeffs_by_index_np[masks]
    return coeffs_col_order



In [7]:
n_sites = 16
psi_test = 0.1 * np.ones(2**n_sites)
fwht_coeffs_in_cluster_col_order(psi_test, nk.hilbert.Spin(0.5,n_sites))

array([0.1+0.j, 0. +0.j, 0. +0.j, ..., 0. +0.j, 0. +0.j, 0. +0.j],
      shape=(65536,))

### test if fwht reconstructs arbitrary $\psi$

In [8]:
n_sites = 4
psi_test_rand = np.random.rand(2**n_sites) + np.random.rand(2**n_sites)*1j
# psi_test_rand = 0.1*np.ones(2**n_sites)
cluster_coeffs = fwht_coeffs_in_cluster_col_order(np.log(psi_test_rand), nk.hilbert.Spin(0.5,n_sites))


parity_mat = naive_cluster_expansion_mat(nk.hilbert.Spin(0.5,n_sites))
log_psi_reconstructed = jnp.dot(parity_mat, cluster_coeffs)
psi_reconstructed = np.exp(log_psi_reconstructed)
np.isclose(psi_reconstructed, psi_test_rand).all()

np.True_