In [1]:
import numpy as np 
import sys, os
sys.path.append('../Netket/')
import netket as nk
from jax import numpy as jnp
import itertools
from scipy.special import comb
from jax import jit, vmap
import jax
import matplotlib.pyplot as plt 
from cluster_expansion import fwht_coeffs_in_cluster_col_order, prepare_fwht_meta_cached, compress_and_reconstruct_cached, _get_topk_indices_jit
import analysis
from analysis import std_phase, ipr, pca_entropy, renyi_entropy, mean_amplitude, uniform_state_overlap
import pandas as pd
from functools import partial

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
n_sites_test = 4
hilb_test = nk.hilbert.Spin(0.5, n_sites_test)


psi_test = np.random.rand(2**n_sites_test) + 1j * np.random.rand(2**n_sites_test)


cluster_coeffs_test = fwht_coeffs_in_cluster_col_order(np.log(psi_test), hilb_test)

# warm-up / measure compile time for a specific num_kept
k_test = 2**n_sites_test  # you can vary this
# print(f"Preparing cached meta for n_sites={n_sites_test}...", flush=True)
prepare_fwht_meta_cached(hilb_test)  # fill cache (fast)

psi_rec = compress_and_reconstruct_cached(cluster_coeffs_test, k_test, hilb_test)


mask = _get_topk_indices_jit(cluster_coeffs_test, 2) ## correctly returns the mask
coeffs_sel = cluster_coeffs_test[mask] ## correctly retruns bigges coeffs



np.abs(cluster_coeffs_test), np.abs(coeffs_sel), mask, np.isclose(psi_test, psi_rec).all()

(array([0.84193477, 0.0654011 , 0.03888232, 0.08955436, 0.16401743,
        0.1141646 , 0.14511739, 0.16306679, 0.12026936, 0.20501796,
        0.19583379, 0.1511938 , 0.06669498, 0.13212231, 0.13853674,
        0.19989117]),
 array([0.84193477, 0.20501796]),
 Array([0, 9], dtype=int32),
 np.True_)

In [None]:
hypotheses = {
    "std_phase" : std_phase,
    "IPR" : ipr,
    "SPCA" : pca_entropy,
    "Renyi_2" : renyi_entropy,
    "uniform_state_overlap" : uniform_state_overlap,
    "mean_amplitude" : mean_amplitude,
}

data_root = '..'

h5_files_opt = [os.path.join(f"{data_root}/data/data_optimal_basis_rbm", f) for f in os.listdir(f'{data_root}/data/data_optimal_basis_rbm') if f.endswith('.h5')]
df_opt = analysis.load_outputs_to_dataframe(h5_files_opt, load_eigenstates=False)
df_opt = analysis.attach_hypotheses_fields(df_opt, hypotheses)
df_opt["idx"] = df_opt["file"].apply(lambda x: int(os.path.basename(x).split('_')[2]))
print(len(df_opt))

h5_files_raw = [os.path.join(f"{data_root}/data/data_unrotated_basis_rbm,", f) for f in os.listdir(f'{data_root}/data/data_unrotated_basis_rbm,') if f.endswith('.h5')]
df_raw = analysis.load_outputs_to_dataframe(h5_files_raw, load_eigenstates=False)
df_raw = analysis.attach_hypotheses_fields(df_raw, hypotheses)
df_raw["idx"] = df_raw["file"].apply(lambda x: int(os.path.basename(x).split('_')[2]))
print(len(df_raw))
