<a href="https://colab.research.google.com/github/aaronseymour7/MLQ/blob/main/GFH.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install netket
!pip install qiskit

In [None]:
import netket as nk
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

def get_lowest_data(n_sites_list):

    results = []

    for n in n_sites_list:
        print(f"--- Analyzing lowest e_val for N = {n} ---")

        # Lattice and Hamiltonian

        graph=nk.graph.Chain(n)
        hi = nk.hilbert.Spin(s=0.5, N=graph.n_nodes)
        hamiltonian = nk.operator.Heisenberg(hilbert=hi, graph=graph)
        e_exact = None
        ed_result = nk.exact.lanczos_ed(hamiltonian, k=1)
        e_exact = float(ed_result[0])

        #print(f"Lanczos (exact): {e_exact}")
        #Neural Network (RBM)
        model = nk.models.RBM(alpha=2, param_dtype=complex) #RBM DENSITY COEFFICIENT
        sampler = nk.sampler.MetropolisLocal(hi) #using Metropolis Exchange bc total_sz = 0

        vstate = nk.vqs.MCState(sampler, model, n_samples=1008)

        '''
        # Neel state

        n_chains = vstate.sampler.n_chains
        neel = np.ones(n)
        neel[1::2] = -1
        neel_samples = np.tile(neel, (n_chains, 1)).astype(np.int8)
        vstate.sampler_state = vstate.sampler_state.replace(σ=jnp.array(neel_samples))
        '''

        # Optimizer: Stochastic Reconfiguration (SR)
        optimizer = nk.optimizer.Sgd(learning_rate=0.01)  #LEARNING RATE
        #sr = nk.optimizer.SR(diag_shift=0.1)
        #vmc = nk.driver.VMC(hamiltonian, optimizer, variational_state=vstate, preconditioner=sr)
        vmc = nk.driver.VMC_SR(hamiltonian, optimizer, diag_shift=0.01, variational_state=vstate)
        # Run Optimization
        vmc.run(n_iter=500)
        energy_stats = vstate.expect(hamiltonian)
        e_vmc = float(energy_stats.mean.real)
        e_error = float(energy_stats.error_of_mean)
        results.append({
            "N": n,
            "VMC": e_vmc,
            "Lanczos": e_exact,
            "diff": abs(e_vmc - e_exact),
            "VMC_err": e_error,
        })

    return results

def get_highest_data(n_sites_list):

    results = []

    for n in n_sites_list:
        print(f"--- Analyzing Highest e_val for N = {n} ---")

        # Lattice and Hamiltonian

        graph=nk.graph.Chain(n)
        hi = nk.hilbert.Spin(s=0.5, N=graph.n_nodes)
        hamiltonian = -1 * nk.operator.Heisenberg(hilbert=hi, graph=graph)

        e_exact = None
        ed_result = nk.exact.lanczos_ed(hamiltonian, k=1)
        e_exact = -1 * float(ed_result[0])

        #Neural Network (RBM)
        model = nk.models.RBM(alpha=2, param_dtype=complex) #RBM DENSITY COEFFICIENT
        sampler = nk.sampler.MetropolisLocal(hi)  #using Metropolis Local sampler instead of exchange bc total_sz != 0
        vstate = nk.vqs.MCState(sampler, model, n_samples=1008)

        # Optimizer: Stochastic Reconfiguration (SR)
        optimizer = nk.optimizer.Sgd(learning_rate=0.01) #LEARNING RATE
        sr = nk.optimizer.SR(diag_shift=0.1)
        vmc = nk.driver.VMC_SR(hamiltonian, optimizer, diag_shift=0.01, variational_state=vstate)
        # Run Optimization
        vmc.run(n_iter=500)
        energy_stats = vstate.expect(hamiltonian)
        e_vmc = -1*float(energy_stats.mean.real)
        e_error = float(energy_stats.error_of_mean)
        results.append({
            "N": n,
            "VMC": e_vmc,
            "Lanczos": e_exact,
            "diff": abs(e_vmc - e_exact),
            "VMC_err": e_error,

        })

    return results
def get_1es_data(n_sites_list):

    results = []

    for n in n_sites_list:

        print(f"--- Analyzing First Excited State for N = {n} ---")

        # Lattice and Hamiltonian
        target_sz = 1 if n % 2 == 0 else 1.5
        graph=nk.graph.Chain(n)
        hi = nk.hilbert.Spin(s=0.5,total_sz=target_sz, N=graph.n_nodes)
        lanczos_hi = nk.hilbert.Spin(s=0.5, N=graph.n_nodes)
        hamiltonian = nk.operator.Heisenberg(hilbert=hi, graph=graph)
        lanczos_ham = nk.operator.Heisenberg(hilbert=lanczos_hi, graph=graph)
        e_exact = None
        ed_result = nk.exact.lanczos_ed(lanczos_ham, k=10)
        unique_evals = np.unique(np.round(ed_result, decimals=8))

        if len(unique_evals) > 1:
            e_exact = unique_evals[1]

        #Neural Network (RBM)
        model = nk.models.RBM(alpha=2, param_dtype=complex) #RBM DENSITY COEFFICIENT
        sampler = nk.sampler.MetropolisExchange(hi, graph=graph) #using Metropolis Exchange bc total_sz = 0

        vstate = nk.vqs.MCState(sampler, model, n_samples=1008)

        '''
        # modified Neel state

        n_chains = vstate.sampler.n_chains
        neel = np.ones(n)
        neel[2::2] = -1
        neel_samples = np.tile(neel, (n_chains, 1)).astype(np.int8)
        #print(f"Initial magnetization sum: {jnp.sum(neel_samples[0])}") #total magnetization check
        #vstate.sampler_state = vstate.sampler_state.replace(σ=jnp.array(neel_samples))
        '''

        # Optimizer: Stochastic Reconfiguration (SR)
        optimizer = nk.optimizer.Sgd(learning_rate=0.01)  #LEARNING RATE
        #sr = nk.optimizer.SR(diag_shift=0.1)
        #vmc = nk.driver.VMC(hamiltonian, optimizer, variational_state=vstate, preconditioner=sr)
        vmc = nk.driver.VMC_SR(hamiltonian, optimizer, diag_shift=0.01, variational_state=vstate)
        # Run Optimization
        vmc.run(n_iter=500)
        energy_stats = vstate.expect(hamiltonian)
        e_vmc = float(energy_stats.mean.real)
        e_error = float(energy_stats.error_of_mean)
        results.append({
            "N": n,
            "VMC": e_vmc,
            "Lanczos": e_exact,
            "diff": abs(e_vmc - e_exact),
            "VMC_err": e_error,
        })
    return results


def printer(data):
  for d in data:
    status = f"Diff: {d['diff']:.8f}" if d['diff'] is not None else "Exact not computed"
    print(f"N={d['N']} | VMC: {d['VMC']:.8f} | Lanczos: {d['Lanczos']:.8f} | Abs Difference: {d['diff']:.8f} \n")


def plot_simulation_results(results_list):

    low_data = [results_list[i][0] for i in range(0, len(results_list), 3)]
    es_data  = [results_list[i][0] for i in range(1, len(results_list), 3)]
    hi_data  = [results_list[i][0] for i in range(2, len(results_list), 3)]

    datasets = [low_data, es_data, hi_data]
    titles = ['Ground State', 'First Excited', 'Highest']
    colors = ['red', 'blue', 'green']

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    for i, ax in enumerate(axes):
        data = datasets[i]
        ns = [d['N'] for d in data]
        diffs = [d['diff'] for d in data]
        errors = [d['VMC_err'] for d in data]

        ax.errorbar(ns, diffs, yerr=errors, fmt='-o', color=colors[i],
                    capsize=6, elinewidth=2, markersize=8, label='VMC Stat. Error')

        ax.set_title(titles[i], fontsize=14, fontweight='bold')
        ax.set_xlabel('System Size (N)', fontsize=12)
        ax.set_ylabel('Abs Difference |VMC - Lanczos|', fontsize=12)
        ax.set_xticks(ns)
        ax.grid(True, linestyle='--', alpha=0.6)
        ax.legend()

    plt.suptitle("VMC Accuracy (1D Heisenberg)", fontsize=16, y=1.05)
    plt.tight_layout()
    plt.show()

def get_all_data(n_sizes):
  results = []
  for n in n_sizes:

    data = get_lowest_data([n])
    results.append(data)
    printer(data)
    data = get_1es_data([n])
    results.append(data)
    printer(data)
    data = get_highest_data([n])
    results.append(data)
    printer(data)
  plot_simulation_results(results)





#n_sizes = [6,10,16,20,22]
n_sizes = [7, 8,9]
get_all_data(n_sizes)

#plot_simulation_results(results)





In [None]:
import netket as nk
import numpy as np
import jax
import jax.numpy as jnp

def make_vstate(hi, graph, n_samples=1024, alpha=2):
    model   = nk.models.RBM(alpha=alpha, param_dtype=complex)
    sampler = nk.sampler.MetropolisExchange(hi, graph=graph)
    return nk.vqs.MCState(sampler, model, n_samples=n_samples)

def run_vmc(vstate, H, n_iter=600, lr=0.01, diag_shift=0.01):
    optimizer = nk.optimizer.Sgd(learning_rate=lr)
    driver    = nk.driver.VMC_SR(H, optimizer, diag_shift=diag_shift,
                                  variational_state=vstate)
    driver.run(n_iter=n_iter, show_progress=True)

def get_1es_even(n, graph, n_samples, alpha, n_iter):
    """Even N: 1ES lives in Sz=1 sector. Just minimize energy there."""
    hi = nk.hilbert.Spin(s=0.5, total_sz=1, N=n)
    H  = nk.operator.Heisenberg(hilbert=hi, graph=graph)
    vs = make_vstate(hi, graph, n_samples, alpha)
    run_vmc(vs, H, n_iter=n_iter)
    return vs, H

def get_1es_odd(n, graph, vstate_gs, H_gs, beta, n_samples, alpha, n_iter):
    """
    Odd N: 1ES lives in same Sz=+1/2 sector as GS.
    Minimize E + beta*|<gs|es>|^2 with frozen GS.
    """
    hi = nk.hilbert.Spin(s=0.5, total_sz=0.5, N=n)
    # vstate_gs must share same hi — rebuild gs in this hi if needed
    vs_es     = make_vstate(hi, graph, n_samples, alpha)
    optimizer = nk.optimizer.Sgd(learning_rate=0.005)
    opt_state = optimizer.init(vs_es.parameters)

    for step in range(n_iter):
        E_stats, E_grad = vs_es.expect_and_grad(H_gs)

        # Overlap gradient via ratio trick + JAX autodiff
        samples = vs_es.samples.reshape(-1, vs_es.samples.shape[-1])
        log_gs  = jax.lax.stop_gradient(vstate_gs.log_value(samples))

        def overlap_fn(params):
            log_es = vs_es._apply_fun({"params": params}, samples)
            return jnp.mean(jnp.exp(jnp.conj(log_gs) - log_es))

        overlap      = overlap_fn(vs_es.parameters)
        dO_dtheta    = jax.grad(overlap_fn, holomorphic=True)(vs_es.parameters)
        penalty_grad = jax.tree_util.tree_map(
            lambda g: beta * jnp.conj(overlap) * g, dO_dtheta
        )
        total_grad = jax.tree_util.tree_map(
            lambda eg, pg: eg + pg, E_grad, penalty_grad
        )

        updates, opt_state = optimizer.update(total_grad, opt_state)
        vs_es.parameters   = jax.tree_util.tree_map(
            lambda p, u: p + u, vs_es.parameters, updates
        )

        if step % 100 == 0 or step == n_iter - 1:
            print(f"  step {step:4d} | E={float(E_stats.mean.real):+.5f} | "
                  f"|<gs|es>|={float(abs(overlap)):.4f}")

    return vs_es, H_gs


def analyze(n_sites_list, beta=15.0, n_iter_gs=600, n_iter_es=800,
            n_samples=1024, alpha=2):

    results = []

    for n in n_sites_list:
        print(f"\n{'='*50}  N={n}  ({'odd' if n%2 else 'even'})")

        graph = nk.graph.Chain(n)
        is_odd = (n % 2 == 1)

        # Exact reference
        hi_full = nk.hilbert.Spin(s=0.5, N=n)
        H_full  = nk.operator.Heisenberg(hilbert=hi_full, graph=graph)
        ed      = nk.exact.lanczos_ed(H_full, k=3)
        e0_ex, e1_ex = float(ed[0]), float(ed[1])
        print(f"Lanczos: E0={e0_ex:.5f}  E1={e1_ex:.5f}  gap={e1_ex-e0_ex:.5f}")

        # Ground state (always Sz=0 for even, Sz=0.5 for odd)
        sz_gs = 0.5 if is_odd else 0
        hi_gs = nk.hilbert.Spin(s=0.5, total_sz=sz_gs, N=n)
        H_gs  = nk.operator.Heisenberg(hilbert=hi_gs, graph=graph)
        print("\n-- Ground State --")
        vs_gs = make_vstate(hi_gs, graph, n_samples, alpha)
        run_vmc(vs_gs, H_gs, n_iter=n_iter_gs)
        e0_vmc = float(vs_gs.expect(H_gs).mean.real)
        print(f"VMC E0={e0_vmc:.5f}  (diff={abs(e0_vmc-e0_ex):.5f})")

        # Excited state
        print("\n-- First Excited State --")
        if is_odd:
            vs_es, H_es = get_1es_odd(n, graph, vs_gs, H_gs, beta,
                                       n_samples, alpha, n_iter_es)
        else:
            vs_es, H_es = get_1es_even(n, graph, n_samples, alpha, n_iter_es)

        e1_vmc = float(vs_es.expect(H_es).mean.real)

        # Final overlap check (meaningful for odd; should be ~0 for odd, N/A for even)
        if is_odd:
            samples     = vs_es.samples.reshape(-1, vs_es.samples.shape[-1])
            log_gs_samp = vs_gs.log_value(samples)
            log_es_samp = vs_es.log_value(samples)
            ovlp = float(abs(jnp.mean(jnp.exp(jnp.conj(log_gs_samp) - log_es_samp))))
        else:
            ovlp = 0.0  # different sectors, orthogonal by symmetry

        print(f"VMC E1={e1_vmc:.5f}  (diff={abs(e1_vmc-e1_ex):.5f})  |<gs|es>|={ovlp:.5f}")

        results.append({
            "N": n, "parity": "odd" if is_odd else "even",
            "E0_exact": e0_ex,  "E1_exact": e1_ex,
            "E0_VMC":   e0_vmc, "E1_VMC":   e1_vmc,
            "overlap":  ovlp,
        })

    print("\n\n=== Summary ===")
    print(f"{'N':>3} {'par':>4} {'E0_VMC':>10} {'E0_ex':>10} "
          f"{'E1_VMC':>10} {'E1_ex':>10} {'|<0|1>|':>9}")
    for r in results:
        print(f"{r['N']:>3} {r['parity']:>4} {r['E0_VMC']:>10.5f} {r['E0_exact']:>10.5f} "
              f"{r['E1_VMC']:>10.5f} {r['E1_exact']:>10.5f} {r['overlap']:>9.5f}")

    return results


if __name__ == "__main__":
    analyze([4, 5, 6, 7, 8])