In [None]:
from optimization import *

In [None]:
# Todo: replace this function with actual file reading
def readHamiltonian_params(file, idx, N):
    Jijalphabeta = np.random.rand(3, 3, N, N)
    h = np.random.rand(3, N)

    # Delete 90% of the elements randomly in Jijalphabeta
    mask_J = np.random.rand(3, 3, N, N) < 0.1
    Jijalphabeta = Jijalphabeta * mask_J

    # Delete 90% of the elements randomly in h
    mask_h = np.random.rand(3, N) < 0.1
    h = h * mask_h

    return Jijalphabeta, h


def generate_tfim_params(N, J=1.0, h=1.0):
    """
    Generates Jij and h arrays for the transverse field Ising model (TFIM):
    H = -J sum_{<i,j>} sigma^z_i sigma^z_j - h sum_i sigma^x_i

    Args:
        N (int): Number of spins (1D chain with periodic boundary).
        J (float): Coupling strength.
        h (float): Transverse field strength.

    Returns:
        Jij (np.ndarray): shape (3, 3, N, N), only Jij[2,2,i,j] nonzero for nearest neighbors.
        h (np.ndarray): shape (3, N), only h[0,i] nonzero (x direction).
    """
    Jij = np.zeros((3, 3, N, N))
    for i in range(N):
        j = (i + 1) % N  # periodic boundary
        Jij[2, 2, i, j] = -J
        Jij[2, 2, j, i] = -J  # symmetric

    h_arr = np.zeros((3, N))
    h_arr[0, :] = -h  # transverse field in x direction

    return Jij, h_arr

In [None]:
Jij, h = readHamiltonian_params("hamiltonian_params.txt", 0, 4)
# Jij, h = generate_tfim_params(10, J=-1.0, h=1.0)
H = construct_hamiltonian(Jij, h)
exact_ground_energy,exact_ground_state = nk.exact.lanczos_ed(H, k = 1, compute_eigenvectors = True)
print("Exact ground state energy:", exact_ground_energy)

In [None]:
Jij_hash = hash(Jij.tobytes())
h_hash = hash(h.tobytes())

params = generate_params(
    alpha=1000,
    seed=1234,
    learning_rate=3e-4,
    n_iter=5000,
    show_progress=True,
    out="data/rbm_optimization_test",
    # Jij_hash=Jij_hash,
    # h_hash=h_hash,
)

out = optimize_rbm(H, params)


In [None]:
# import the data from log file
import matplotlib.pyplot as plt
data_Jastrow = json.load(open(params["out"]+".log"))

iters_Jastrow = data_Jastrow["Energy"]["iters"]
# energy_Jastrow = np.real(data_Jastrow["Energy"]["Mean"])
energy_Jastrow = data_Jastrow["Energy"]["Mean"]["real"]
energy_diff = np.abs(energy_Jastrow - exact_ground_energy[0])

fig, ax2 = plt.subplots()
ax2.plot(iters_Jastrow, energy_diff, color="C1", label="|Energy - Exact|")
ax2.set_yscale("log")
ax2.set_ylabel("Energy Difference (log scale)")
ax2.set_xlabel("Iteration")
ax2.legend()
plt.show()


In [None]:
write_output(H, out, params)