# Task #2

A template code for training an RBM on Rydberg atom data (the full dataset) is provided below. For the first part of this task (determining the minimum number of hidden units), start with 1 hidden unit. 

Imports and loading in data:

In [None]:
#from google.colab import drive
#drive.mount('/content/drive', force_remount=True)
#!cp -r "/content/drive/My Drive/Project_1_RBM_and_Tomography" "P_1"

Mounted at /content/drive


In [None]:
import numpy as np
import torch
from RBM_helper import RBM
#from P_1.RBM_helper import RBM
import importlib
import Plotters
#import P_1.Plotters as Plotters
importlib.reload(Plotters)

import Rydberg_energy_calculator
#import P_1.Rydberg_energy_calculator as Rydberg_energy_calculator

training_data = torch.from_numpy(np.loadtxt("Rydberg_data.txt"))
#training_data = torch.from_numpy(np.loadtxt("P_1/Rydberg_data.txt"))
use_n_samples = min(len(training_data), 20000)

Define the RBM:

In [None]:
n_vis = training_data.shape[1]
n_hin = 2

rbm = RBM(n_vis, n_hin)

Train the RBM:

In [None]:
epochs = 500
num_samples = 500

exact_energy = -4.1203519096
print("Exact energy: ",exact_energy)
print("Number of hidden units:", n_hin)

diff_thresh = 1e-4

plot_title = f"{n_hin} hidden units; training with {num_samples} samples"
plotter = Plotters.XYPlotter(href=diff_thresh,
                             title=plot_title,
                             x_label="Epoch",
                             y_label="Energy difference")
%matplotlib inline

for e in range(1, epochs+1):
    # do one epoch of training
    rbm.train(training_data[:num_samples])   
 
    # now generate samples and calculate the energy
    if e % 10 == 0:
        print("\nEpoch: ", e)
        print("Sampling...")

        init_state = torch.zeros(num_samples, n_vis)
        RBM_samples = rbm.draw_samples(100, init_state)

        print("Done sampling. Calculating energy...") 
 
        energies = Rydberg_energy_calculator.energy(RBM_samples, rbm.wavefunction) 
        print("Energy from RBM samples: ", energies.item())

        plotter.update(e, abs(energies.item() - exact_energy))
        print("Difference:", abs(energies.item() - exact_energy))

        # stopping criterion
        if abs(energies.item() - exact_energy) < diff_thresh:
            print(f'Algorithm stopped after {e} epochs with energy estimation {energies.item()}')
            print("Number of hidden units:", n_hin)
            print("RBM Samples:", num_samples)
            #break

Exact energy:  -4.1203519096
Number of hidden units: 2

Epoch:  10
Sampling...
Done sampling. Calculating energy...
Energy from RBM samples:  -4.120131695654462
Difference: 0.00022021394553828344

Epoch:  20
Sampling...
Done sampling. Calculating energy...
Energy from RBM samples:  -4.120106781353415
Difference: 0.0002451282465854021

Epoch:  30
Sampling...
Done sampling. Calculating energy...
Energy from RBM samples:  -4.12001600714296
Difference: 0.00033590245704040456

Epoch:  40
Sampling...
Done sampling. Calculating energy...
Energy from RBM samples:  -4.120092065605315
Difference: 0.00025984399468548247

Epoch:  50
Sampling...
Done sampling. Calculating energy...
Energy from RBM samples:  -4.120368369375167
Difference: 1.6459775166843826e-05
Algorithm stopped after 50 epochs with energy estimation -4.120368369375167
Number of hidden units: 2
RBM Samples: 500
