In [1]:
# Adding src module to path
import sys
import os
module_path = os.path.abspath(os.path.join('../..'))

if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
# general imports
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt
%matplotlib notebook

import src.utils as utils
from src.nqs import RBM, Hamiltonian

In [3]:
#Initializing starting values
hamiltonian = np.array([1, 0, 0, -1])
visible_size = int(np.sqrt(len(hamiltonian)))
hidden = 4
steps = 1000

seed = 342

#Finding true ground state energy and displaying it
np.random.seed(seed)

energy_list = []
error_list = []
gstate_list =[]

In [17]:
H = Hamiltonian(values=hamiltonian)

rbm = RBM(visible_size=visible_size, hidden_size=hidden, hamiltonian=H, walker_steps=0)

eig, states = np.linalg.eigh(H)
print(f"Eig: {eig},\nstates: \n{states}")
print(states.shape)
E_truth = np.min(eig)
e_truth_index = np.argmin(eig)
print(f"Eig index: {e_truth_index}")
gstate = states[:,e_truth_index]
print(f"g state: {gstate}")

energy_list = rbm.train(iterations=steps, lr=0.01, print_energy=False)


Eig: [-1.  0.  0.  1.],
states: 
[[0. 0. 0. 1.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]]
(4, 4)
Eig index: 0
g state: [0. 0. 0. 1.]
func:train args:[(<src.nqs.RBM object at 0x0000021A8144F520>,), {'iterations': 1000, 'lr': 0.01, 'print_energy': False}] took: 0.5493900775909424 sec


In [6]:
# plt.ioff() # uncomment to stop plotting in notebook

plt.plot(energy_list)

plt.axhline(y=E_truth, color='b', linestyle='--')
plt.title(f"Training of RBM with {visible_size} visible nodes")


<IPython.core.display.Javascript object>

Text(0.5, 1.0, 'Training of RBM with 2 visible nodes')

In [18]:
print(f"Ground state energy found: {energy_list[-1]}")
print(f"Accuracy: {np.abs(E_truth-energy_list[-1])}")

Ground state energy found: -0.9999854922243896
Accuracy: 1.450777561040617e-05


In [19]:
print(f"RBM wave function: \n{rbm.wave_function()}")
print(f"True ground state: \n{gstate}")


#print(f"RBM state error: {np.abs(gstate)**2 @ np.abs(rbm.wave_function())**2}")
print(f"RBM state error: {np.abs(gstate.T.conj() @ rbm.wave_function())}") # overlap measure

RBM wave function: 
[[-1.11855713e-03-2.14286675e-03j]
 [ 1.49970946e-03+4.17080297e-04j]
 [ 5.52503745e-04-3.05507537e-04j]
 [ 7.00052815e-01+7.14085003e-01j]]
True ground state: 
[0. 0. 0. 1.]
RBM state error: [0.99999567]
