# Task #2

A template code for training an RBM on Rydberg atom data (the full dataset) is provided below. For the first part of this task (determining the minimum number of hidden units), start with 20 hidden units. 

Imports and loadining in data:

In [1]:
import numpy as np
import torch
from RBM_helper import RBM

import Rydberg_energy_calculator

training_data = torch.from_numpy(np.loadtxt("Rydberg_data.txt"))

Define the RBM:

In [2]:
n_vis = training_data.shape[1]
n_hin = 1

rbm = RBM(n_vis, n_hin)

Train the RBM:

In [3]:
epochs = 500
num_samples = 2000

exact_energy = -4.1203519096
print("Exact energy: ",exact_energy)

for e in range(1, epochs+1):
    # do one epoch of training
    rbm.train(training_data)   
 
    # now generate samples and calculate the energy
    if e % 100 == 0:
        print("\nEpoch: ", e)
        print("Sampling...")

        init_state = torch.zeros(num_samples, n_vis)
        RBM_samples = rbm.draw_samples(100, init_state)

        print("Done sampling. Calculating energy...") 
 
        energies = Rydberg_energy_calculator.energy(RBM_samples, rbm.wavefunction) 
        print("Energy from RBM samples: ", energies.item())

Exact energy:  -4.1203519096

Epoch:  100
Sampling...
Done sampling. Calculating energy...




Energy from RBM samples:  -4.12021464486249

Epoch:  200
Sampling...
Done sampling. Calculating energy...
Energy from RBM samples:  -4.119874615698244

Epoch:  300
Sampling...
Done sampling. Calculating energy...
Energy from RBM samples:  -4.120257511619599

Epoch:  400
Sampling...
Done sampling. Calculating energy...
Energy from RBM samples:  -4.120167211682887

Epoch:  500
Sampling...
Done sampling. Calculating energy...
Energy from RBM samples:  -4.120189557945697


 Rydberg Hamiltonian :$H = -\sum_{<i,j>} Vij( \sigma_i^z \sigma_j^z + \sigma_i^z + \sigma_j^z)- \Omega \sum_i \sigma_i^z - h \sum_i \sigma_i^x$

In [4]:
n_vis = training_data.shape[1]
n_hin = 0

epochs = 1000
num_samples = 2000

exact_energy = -4.1203519096
error = 1.0

while error > 0.0001:
    
    n_hin += 1
    print("Hidden Unit number :", n_hin)
    rbm = RBM(n_vis, n_hin)
    
    for e in range(1, epochs+1):
        rbm.train(training_data)
        
#         if e % 100 == 0:
            
    init_state = torch.zeros(num_samples, n_vis)
    RBM_samples = rbm.draw_samples(100, init_state)

    energies = Rydberg_energy_calculator.energy(RBM_samples, rbm.wavefunction) 
    RBM_energy = energies.item()

    error = abs(RBM_energy - exact_energy)
    print("Error = {}   RBM Energy = {}  Exact Energy = {}".format(error, RBM_energy, exact_energy))
            
#         if error <= 0.0001:
#                 break

    
    

print("Minimum number of Hidden Units required to get error < 0.0001 = ", n_hin)

Error = 0.00036067219225977   RBM Energy = -4.11999123740774  Exact Energy = -4.1203519096
Error = 0.0005137099304048931   RBM Energy = -4.119838199669595  Exact Energy = -4.1203519096
Error = 0.0004355051754565409   RBM Energy = -4.1199164044245435  Exact Energy = -4.1203519096
Error = 0.00018923475875087803   RBM Energy = -4.120162674841249  Exact Energy = -4.1203519096
Error = 0.00039305754069207666   RBM Energy = -4.119958852059308  Exact Energy = -4.1203519096
Error = 0.0001958661438914433   RBM Energy = -4.120156043456109  Exact Energy = -4.1203519096
Error = 0.0007007227919588743   RBM Energy = -4.119651186808041  Exact Energy = -4.1203519096
Error = 6.752802804754765e-05   RBM Energy = -4.1202843815719525  Exact Energy = -4.1203519096
Minimum number of Hidden Units required to get error < 0.0001 =  1


In [None]:
# Multiply the number of hidden units by 2
n_hin = n_hin * 2
n_vis = training_data.shape[1]


# Start with 500 data points
n = 400

epochs = 1000
num_samples = 2000

exact_energy = -4.1203519096
error = 1.0

while error > 0.0001:
    
    n += 100
    print("Number of Sample Data :", n)
    trimmed_trainingData = training_data[0:n]
    
    rbm = RBM(n_vis, n_hin)
    
    for e in range(1, epochs+1):
        rbm.train(trimmed_trainingData)
        
        if e % 100 == 0:
            
            init_state = torch.zeros(num_samples, n_vis)
            RBM_samples = rbm.draw_samples(100, init_state)

            energies = Rydberg_energy_calculator.energy(RBM_samples, rbm.wavefunction) 
            RBM_energy = energies.item()

            error = abs(RBM_energy - exact_energy)
            print("Error = {}   RBM Energy = {}  Exact Energy = {}".format(error, RBM_energy, exact_energy))
            
            if error <= 0.0001:
                break

    
    

print("Minimum number of Data required to get error < 0.0001 = ", n)

Error = 0.9575332134862276   RBM Energy = -3.1628186961137725  Exact Energy = -4.1203519096
Error = 0.27636598370534626   RBM Energy = -3.843985925894654  Exact Energy = -4.1203519096
Error = 0.10351384397347108   RBM Energy = -4.016838065626529  Exact Energy = -4.1203519096
Error = 0.05071689181865313   RBM Energy = -4.069635017781347  Exact Energy = -4.1203519096
Error = 0.027600401146030684   RBM Energy = -4.092751508453969  Exact Energy = -4.1203519096
Error = 0.01840841418688477   RBM Energy = -4.101943495413115  Exact Energy = -4.1203519096
Error = 0.013079668331713812   RBM Energy = -4.107272241268286  Exact Energy = -4.1203519096
Error = 0.00842429179298776   RBM Energy = -4.111927617807012  Exact Energy = -4.1203519096
Error = 0.0073431184826331375   RBM Energy = -4.113008791117367  Exact Energy = -4.1203519096
Error = 0.004361022718651242   RBM Energy = -4.115990886881349  Exact Energy = -4.1203519096
Error = 0.7504903481616436   RBM Energy = -3.3698615614383565  Exact Energy

Error = 0.0016883543060126627   RBM Energy = -4.118663555293987  Exact Energy = -4.1203519096
Error = 0.12395301692289618   RBM Energy = -3.996398892677104  Exact Energy = -4.1203519096
Error = 0.020532917084952196   RBM Energy = -4.099818992515048  Exact Energy = -4.1203519096
Error = 0.006473189836432525   RBM Energy = -4.113878719763568  Exact Energy = -4.1203519096
Error = 0.003091310881950804   RBM Energy = -4.117260598718049  Exact Energy = -4.1203519096
Error = 0.0030241136766138155   RBM Energy = -4.117327795923386  Exact Energy = -4.1203519096
Error = 0.0027968179943433924   RBM Energy = -4.117555091605657  Exact Energy = -4.1203519096
Error = 0.0016590663531985683   RBM Energy = -4.1186928432468015  Exact Energy = -4.1203519096
Error = 0.0013550754548630195   RBM Energy = -4.118996834145137  Exact Energy = -4.1203519096
Error = 0.0017211614188585855   RBM Energy = -4.1186307481811415  Exact Energy = -4.1203519096
Error = 0.001831155724969058   RBM Energy = -4.118520753875031 

Error = 0.0013611735820049375   RBM Energy = -4.118990736017995  Exact Energy = -4.1203519096
Error = 0.001978499521432475   RBM Energy = -4.118373410078568  Exact Energy = -4.1203519096
Error = 0.0013173941813588996   RBM Energy = -4.119034515418641  Exact Energy = -4.1203519096
Error = 0.03519283142970586   RBM Energy = -4.085159078170294  Exact Energy = -4.1203519096
Error = 0.0037103881647997383   RBM Energy = -4.1166415214352  Exact Energy = -4.1203519096
Error = 0.001936127556231071   RBM Energy = -4.118415782043769  Exact Energy = -4.1203519096
Error = 0.0010403838962007583   RBM Energy = -4.119311525703799  Exact Energy = -4.1203519096
Error = 0.0012982595006443276   RBM Energy = -4.119053650099356  Exact Energy = -4.1203519096
Error = 0.0008013299772136051   RBM Energy = -4.1195505796227865  Exact Energy = -4.1203519096
Error = 0.0010876865079714193   RBM Energy = -4.119264223092029  Exact Energy = -4.1203519096
Error = 0.0020557698876757513   RBM Energy = -4.118296139712324  

Error = 0.0010907382141285282   RBM Energy = -4.119261171385872  Exact Energy = -4.1203519096
Error = 0.0009450758510451962   RBM Energy = -4.119406833748955  Exact Energy = -4.1203519096
Error = 0.001826197514982475   RBM Energy = -4.118525712085018  Exact Energy = -4.1203519096
Error = 0.0010238234626553933   RBM Energy = -4.119328086137345  Exact Energy = -4.1203519096
Error = 0.0018580622041994843   RBM Energy = -4.118493847395801  Exact Energy = -4.1203519096
Error = 0.013546458411752837   RBM Energy = -4.106805451188247  Exact Energy = -4.1203519096
Error = 0.0018847428797910126   RBM Energy = -4.118467166720209  Exact Energy = -4.1203519096
Error = 0.0008013205426511405   RBM Energy = -4.119550589057349  Exact Energy = -4.1203519096
Error = 0.0007706524605346488   RBM Energy = -4.119581257139465  Exact Energy = -4.1203519096
Error = 0.0014603318583139568   RBM Energy = -4.118891577741686  Exact Energy = -4.1203519096
Error = 0.0009761392134661762   RBM Energy = -4.11937577038653

Error = 0.0011547157512268313   RBM Energy = -4.119197193848773  Exact Energy = -4.1203519096
Error = 0.0012875345694309814   RBM Energy = -4.119064375030569  Exact Energy = -4.1203519096
Error = 0.0008226858241853208   RBM Energy = -4.119529223775815  Exact Energy = -4.1203519096
Error = 0.0007447268770937043   RBM Energy = -4.119607182722906  Exact Energy = -4.1203519096
Error = 0.0010563922340107723   RBM Energy = -4.119295517365989  Exact Energy = -4.1203519096
Error = 0.0011458237129833293   RBM Energy = -4.119206085887017  Exact Energy = -4.1203519096
Error = 0.0005197536101677613   RBM Energy = -4.119832155989832  Exact Energy = -4.1203519096
Error = 0.007339398732002067   RBM Energy = -4.113012510867998  Exact Energy = -4.1203519096
Error = 0.0003606242892484346   RBM Energy = -4.119991285310752  Exact Energy = -4.1203519096
Error = 0.000751468789456311   RBM Energy = -4.119600440810544  Exact Energy = -4.1203519096
Error = 0.0008468790535554049   RBM Energy = -4.11950503054644

Error = 0.0011168258660037012   RBM Energy = -4.119235083733996  Exact Energy = -4.1203519096
Error = 0.00042774760525077227   RBM Energy = -4.119924161994749  Exact Energy = -4.1203519096
Error = 0.0008734108613674962   RBM Energy = -4.119478498738633  Exact Energy = -4.1203519096
Error = 0.0010106217510141846   RBM Energy = -4.119341287848986  Exact Energy = -4.1203519096
Error = 0.001247905937471927   RBM Energy = -4.119104003662528  Exact Energy = -4.1203519096
Error = 0.0004256691880311081   RBM Energy = -4.119926240411969  Exact Energy = -4.1203519096
Error = 0.0010287677880924662   RBM Energy = -4.119323141811908  Exact Energy = -4.1203519096
Error = 0.00033572137414417114   RBM Energy = -4.120016188225856  Exact Energy = -4.1203519096
Error = 0.0003211400997287228   RBM Energy = -4.120030769500271  Exact Energy = -4.1203519096
Error = 0.004538522408875423   RBM Energy = -4.115813387191125  Exact Energy = -4.1203519096
Error = 0.0016878376952664809   RBM Energy = -4.118664071904