# Task #1 

A template code for training an RBM on H$_2$ data for $r = 1.2$ is shown here. Modify this!

Imports and loading in data:

In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt

from RBM_helper import RBM

import H2_energy_calculator

r = np.array(range(20,290,5)) / 100.0 
coeff = np.loadtxt("H2_data/H2_coefficients.txt")[:,:]
training_data = []

for i in range(len(r)):
    training_data.append(torch.from_numpy(np.loadtxt("H2_data/R_"+str(r[i])+"_samples.txt")))

Define the RBM:

In [16]:
n_vis = 2
n_hin = 20

rbm = RBM(n_vis, n_hin)

Train the RBM:

In [17]:
epochs = 10 # number of training steps, normally 500
update_epoch = 5 #the epoch which the energy is recalcualted, normally 100
num_samples = 1000 # number of samples to generate from the RBM to calculate the H2 energy, normally 1000
energiesList = []
true_energiesList = []

for i in range(len(r)):
    
    true_energy = H2_energy_calculator.energy_from_freq(training_data[i], coeff[i])
    print("\n--------------------------RBM-" + str(i) + "---------------------------")
    print("\nTrue Energy: %s for r = %s." % (true_energy, r[i]))

    for e in range(1, epochs+1):
        # do one epoch of training
        rbm.train(training_data[i])   

        # now generate samples and calculate the energy
        if e % update_epoch == 0:
            print("\nEpoch: ", e)
            print("Sampling the RBM...")

            # For sampling the RBM, we need to do Gibbs sampling.
            # Initialize the Gibbs sampling chain with init_state as defined below.
            init_state = torch.zeros(num_samples, n_vis)
            RBM_samples = rbm.draw_samples(15, init_state)

            print("Done sampling. Calculating energy...")       

            energies = H2_energy_calculator.energy(RBM_samples, coeff[i], rbm.wavefunction) 
            print("Energy from RBM samples: ", energies.item())
    print("\nFinal RBM Energy: ", energies.item())
    energiesList.append(energies.item())
    true_energiesList.append(true_energy)


--------------------------RBM-0---------------------------

True Energy: 0.1442108747311382 for r = 0.2.

Epoch:  5
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  0.3310292383600186

Epoch:  10
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  0.25916153703101996

Final RBM Energy:  0.25916153703101996

--------------------------RBM-1---------------------------

True Energy: -0.3239354753254609 for r = 0.25.

Epoch:  5
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  -0.28362351357204796

Epoch:  10
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  -0.2738321510898521

Final RBM Energy:  -0.2738321510898521

--------------------------RBM-2---------------------------

True Energy: -0.6129039934108024 for r = 0.3.

Epoch:  5
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  -0.5744800022883331

Epoch:  10
Sampling the RB


--------------------------RBM-22---------------------------

True Energy: -1.0429909176307628 for r = 1.3.

Epoch:  5
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  -0.9741985499362045

Epoch:  10
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  -0.9917340254775584

Final RBM Energy:  -0.9917340254775584

--------------------------RBM-23---------------------------

True Energy: -1.0329255626795484 for r = 1.35.

Epoch:  5
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  -0.9741731898245597

Epoch:  10
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  -0.980347184711033

Final RBM Energy:  -0.980347184711033

--------------------------RBM-24---------------------------

True Energy: -1.0235793844779455 for r = 1.4.

Epoch:  5
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  -0.9795748810047887

Epoch:  10
Sampling the 


--------------------------RBM-44---------------------------

True Energy: -0.9461560561584954 for r = 2.4.

Epoch:  5
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  -0.9444370479051795

Epoch:  10
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  -0.9354632472195827

Final RBM Energy:  -0.9354632472195827

--------------------------RBM-45---------------------------

True Energy: -0.9454307947133701 for r = 2.45.

Epoch:  5
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  -0.9394401585818217

Epoch:  10
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  -0.9396869294656891

Final RBM Energy:  -0.9396869294656891

--------------------------RBM-46---------------------------

True Energy: -0.9448669090603318 for r = 2.5.

Epoch:  5
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  -0.9351257839216205

Epoch:  10
Sampling th

In [24]:
%matplotlib notebook
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
plt.plot(r,energiesList)

#Define plots
rbm_plot = ax.plot(r, energiesList, label='RBM')
true_plot = ax.plot(r, true_energiesList, dashes=[6, 2], label='True Energy')

#Text formating
params = {'mathtext.default': 'regular' }  
plt.rcParams.update(params)
plt.xlabel('$r$', fontsize=14)
y = plt.ylabel('$E_{bond}$', fontsize=14)
plt.title('H2 Potential Energy Curve using RBMs', fontsize=16)
plt.text(r[-25], energiesList[0] - 0.05, "n = " + str(epochs))
plt.text(r[-25], energiesList[0] - 0.25, "s = " + str(num_samples))
plt.text(r[-25], energiesList[0] - 0.45, "n_vis = " + str(n_vis))
plt.text(r[-25], energiesList[0] - 0.65, "n_hid = " + str(n_hin))

ax.legend()
plt.show()

<IPython.core.display.Javascript object>