# Restricted Boltzmann Machine (RBM)

In [18]:
import numpy as np

In [19]:
class RBM():
    
    def __init__(self,num_visible,num_hidden):
        
        
        # defined the variables 
        # weights is defined as a matrix of dimension(num_visible,num_hidden) and is uniform distribution
        
        # weights are defiend with zero mean and std_dev = 0.1
        self.num_visible= num_visible
        self.num_hidden= num_hidden
        self.debug_print = True
        random = np.random.RandomState(1234)
        
        self.weights = np.asarray(random.uniform(low=-0.1 * np.sqrt(6. / (num_hidden + num_visible)),
                       	high=0.1 * np.sqrt(6. / (num_hidden + num_visible)),
                       	size=(num_visible, num_hidden)))
        
        self.weights = np.insert(self.weights, 0, 0, axis = 0)
        self.weights = np.insert(self.weights, 0, 0, axis = 1)
        
        
    def train(self, data, max_epochs = 1000, learning_rate = 0.1):
        num_examples = data.shape[0]

        # Insert bias units of 1 into the first column.
        data = np.insert(data, 0, 1, axis = 1)

        for epoch in range(max_epochs):
            
            # Here for forward pass we use the variable pos
            # for forward pass we compute p(h|v;W) prob of h(any h) given v(the previous one) and weights
            # computing the probability is done by passing the activation value through a sigmoid
            
            pos_hidden_activations = np.dot(data, self.weights)      
            pos_hidden_probs = self._logistic(pos_hidden_activations)
            pos_hidden_probs[:,0] = 1 # Fix the bias unit.
            pos_hidden_states = pos_hidden_probs > np.random.rand(num_examples, self.num_hidden + 1)
      
            pos_associations = np.dot(data.T, pos_hidden_probs)
            
            # Here for Backward pass we use the variable neg
            # for forward pass we compute p(v|h;W) prob of v given h and weights
            # computing the probability is done by passing the activation value through a sigmoid
            
            neg_visible_activations = np.dot(pos_hidden_states, self.weights.T)
            neg_visible_probs = self._logistic(neg_visible_activations)
            neg_visible_probs[:,0] = 1 # Fix the bias unit.
            neg_hidden_activations = np.dot(neg_visible_probs, self.weights)
            neg_hidden_probs = self._logistic(neg_hidden_activations)
      
            neg_associations = np.dot(neg_visible_probs.T, neg_hidden_probs)

            # Update of Weights 
            # with error being (pos_associations - neg_associations)/num_examples
            self.weights += learning_rate * ((pos_associations - neg_associations) / num_examples)

            error = np.sum((data - neg_visible_probs) ** 2)
            if self.debug_print:
                
                print("Epoch %s: error is %s" % (epoch, error))
            
    def _logistic(self, x):
        
        # Sigmoid function
        return 1.0 / (1 + np.exp(-x))
    
    def run_visible(self, data):
        
        # This function helps to return the hidden_states
        # the value becomes 1 if we activation value crosses the threshold value and else 0 
        num_examples = data.shape[0]
    
        # Create a matrix, where each row is to be the hidden units (plus a bias unit)
        hidden_states = np.ones((num_examples, self.num_hidden + 1))
    
    # Insert bias units of 1 into the first column of data.
        data = np.insert(data, 0, 1, axis = 1)

    # Calculate the activations of the hidden units.
        hidden_activations = np.dot(data, self.weights)
    # Calculate the probabilities of turning the hidden units on.
        hidden_probs = self._logistic(hidden_activations)
    # Turn the hidden units on with their specified probabilities.
        hidden_states[:,:] = hidden_probs > np.random.rand(num_examples, self.num_hidden + 1)
    
        hidden_states = hidden_states[:,1:]
        return hidden_states

In [20]:
if __name__ == '__main__':
    r = RBM(num_visible = 6, num_hidden =4)
    training_data = np.array([[1,1,1,0,0,0],[1,0,1,0,0,0],[1,1,1,0,0,0],[0,0,1,1,1,0], [0,0,1,1,0,0],[0,0,1,1,1,0]])
    r.train(training_data, max_epochs = 5000)
    print("-----------------------------------------------------------------------------------------------------------")
    print ("Weights Matrix[num_visible,num_hidden]")
    print(r.weights)
    print("-----------------------------------------------------------------------------------------------------------")
    user = np.array([[0,0,0,1,1,0]])
    print("These are the Hidden states")
    print(r.run_visible(user))

Epoch 0: error is 8.854872195841658
Epoch 1: error is 8.59682688499872
Epoch 2: error is 8.264298666927878
Epoch 3: error is 7.980574989778066
Epoch 4: error is 7.8667137038729535
Epoch 5: error is 7.442123897033919
Epoch 6: error is 7.232223416899338
Epoch 7: error is 7.501563523596398
Epoch 8: error is 6.955381194798115
Epoch 9: error is 7.014749315661615
Epoch 10: error is 6.846173478746524
Epoch 11: error is 6.987021194929353
Epoch 12: error is 6.47896760539359
Epoch 13: error is 6.571306070352661
Epoch 14: error is 6.866512798494496
Epoch 15: error is 6.610469311415429
Epoch 16: error is 6.212674896912996
Epoch 17: error is 6.598546663665742
Epoch 18: error is 6.375841431772711
Epoch 19: error is 6.195299279355087
Epoch 20: error is 6.352998186330957
Epoch 21: error is 6.314425927741947
Epoch 22: error is 6.355638609625673
Epoch 23: error is 6.128688732052034
Epoch 24: error is 6.1408890965125265
Epoch 25: error is 6.039612977322262
Epoch 26: error is 6.076392699112309
Epoch 27: e

Epoch 276: error is 3.431340497336528
Epoch 277: error is 3.3903793896309113
Epoch 278: error is 5.288788147753013
Epoch 279: error is 3.041005284061866
Epoch 280: error is 3.490450440730985
Epoch 281: error is 3.057373987697765
Epoch 282: error is 4.107866756007905
Epoch 283: error is 1.7333391927245587
Epoch 284: error is 2.9717141867591303
Epoch 285: error is 2.557696835583402
Epoch 286: error is 3.0221888614422943
Epoch 287: error is 3.680676512063301
Epoch 288: error is 4.970455179830599
Epoch 289: error is 1.629235787685126
Epoch 290: error is 1.5724829886032317
Epoch 291: error is 2.2454628657957194
Epoch 292: error is 2.0237567680484894
Epoch 293: error is 2.4881987658763767
Epoch 294: error is 1.405317911985679
Epoch 295: error is 2.118012688041832
Epoch 296: error is 1.8505221038713526
Epoch 297: error is 2.289092342109842
Epoch 298: error is 4.175345246890846
Epoch 299: error is 1.8963578194954445
Epoch 300: error is 1.9462793160794323
Epoch 301: error is 2.5545290131455918


Epoch 621: error is 0.8719431522775226
Epoch 622: error is 0.7449342270155924
Epoch 623: error is 1.4743905300918214
Epoch 624: error is 0.8736666750643297
Epoch 625: error is 0.7274377554180516
Epoch 626: error is 0.8780493921304607
Epoch 627: error is 0.37209091727614185
Epoch 628: error is 1.503767964675888
Epoch 629: error is 0.9111961278882197
Epoch 630: error is 1.1936654339824655
Epoch 631: error is 1.4157471871821623
Epoch 632: error is 1.4118070725703191
Epoch 633: error is 1.408029434914185
Epoch 634: error is 1.415020017110481
Epoch 635: error is 1.9285531600101364
Epoch 636: error is 1.9264464193136102
Epoch 637: error is 1.7453904112552316
Epoch 638: error is 0.8656568075943217
Epoch 639: error is 1.9290606932613958
Epoch 640: error is 1.4013965665431072
Epoch 641: error is 0.3395309163959443
Epoch 642: error is 0.87192411394156
Epoch 643: error is 1.4005798943753454
Epoch 644: error is 0.8689373085351061
Epoch 645: error is 0.8976801480742171
Epoch 646: error is 0.8140197

Epoch 991: error is 1.3213361002286046
Epoch 992: error is 0.2574678695654585
Epoch 993: error is 0.8786680188505498
Epoch 994: error is 0.8591920974644554
Epoch 995: error is 0.1853869188086949
Epoch 996: error is 0.18425588447033786
Epoch 997: error is 0.18317189833526948
Epoch 998: error is 0.8228009973125561
Epoch 999: error is 0.8844791359857286
Epoch 1000: error is 0.25469474868993375
Epoch 1001: error is 0.3171654521850471
Epoch 1002: error is 1.508028991506064
Epoch 1003: error is 0.8228707516494904
Epoch 1004: error is 0.822318971800898
Epoch 1005: error is 0.8739433800631962
Epoch 1006: error is 0.8868422837998716
Epoch 1007: error is 0.17647225323944998
Epoch 1008: error is 0.6893215190647869
Epoch 1009: error is 0.8207027178318154
Epoch 1010: error is 1.3497359881774764
Epoch 1011: error is 2.1432163125204235
Epoch 1012: error is 0.6881826627948675
Epoch 1013: error is 0.8135130256306612
Epoch 1014: error is 0.25750608668146735
Epoch 1015: error is 2.1963057902221355
Epoch 

Epoch 1214: error is 0.8680949148454501
Epoch 1215: error is 0.7355108370826619
Epoch 1216: error is 0.11924071495794876
Epoch 1217: error is 0.11888305941163545
Epoch 1218: error is 0.11853356113456408
Epoch 1219: error is 0.4795782020055533
Epoch 1220: error is 0.8678186753717294
Epoch 1221: error is 0.8477783699618108
Epoch 1222: error is 0.11840689762208036
Epoch 1223: error is 0.8650433970230494
Epoch 1224: error is 0.784849036172801
Epoch 1225: error is 0.8500864352720701
Epoch 1226: error is 0.8621624252618268
Epoch 1227: error is 0.847022188446909
Epoch 1228: error is 0.47450235326268303
Epoch 1229: error is 0.8441219618574922
Epoch 1230: error is 0.8396815829730575
Epoch 1231: error is 0.11842545771880755
Epoch 1232: error is 0.11796674257812095
Epoch 1233: error is 1.210329035530692
Epoch 1234: error is 1.8397683864497671
Epoch 1235: error is 0.7987379914158793
Epoch 1236: error is 0.8426892280793808
Epoch 1237: error is 0.11659126737552247
Epoch 1238: error is 0.116163347518

Epoch 1581: error is 0.05970108595792492
Epoch 1582: error is 1.3088347851493958
Epoch 1583: error is 0.8787186385784876
Epoch 1584: error is 0.8081627534629032
Epoch 1585: error is 0.05942195803588821
Epoch 1586: error is 0.059320664531895245
Epoch 1587: error is 0.05921954437298017
Epoch 1588: error is 0.05911859675988025
Epoch 1589: error is 1.621718592407942
Epoch 1590: error is 1.2824872039199144
Epoch 1591: error is 0.059189627945947045
Epoch 1592: error is 0.88515712787443
Epoch 1593: error is 0.9116775806089363
Epoch 1594: error is 0.799656710080218
Epoch 1595: error is 0.9079973953042857
Epoch 1596: error is 0.058348154031597425
Epoch 1597: error is 0.058246917224413344
Epoch 1598: error is 0.058145909613798744
Epoch 1599: error is 0.885373085009175
Epoch 1600: error is 0.527357730654133
Epoch 1601: error is 0.057605121531712995
Epoch 1602: error is 0.057508752868301045
Epoch 1603: error is 1.398631615576757
Epoch 1604: error is 0.05712843545722687
Epoch 1605: error is 1.59224

Epoch 1942: error is 0.03315859099700069
Epoch 1943: error is 0.9204211100520037
Epoch 1944: error is 0.032942996449324795
Epoch 1945: error is 0.8894295494721952
Epoch 1946: error is 0.9171646849244282
Epoch 1947: error is 0.8656558086387505
Epoch 1948: error is 0.8902318903596774
Epoch 1949: error is 0.032858133602011085
Epoch 1950: error is 0.03279536646042437
Epoch 1951: error is 0.03273318609636804
Epoch 1952: error is 0.03267157827997584
Epoch 1953: error is 0.03261052916985496
Epoch 1954: error is 0.42495127888780887
Epoch 1955: error is 0.03249426725440337
Epoch 1956: error is 0.8885767978241059
Epoch 1957: error is 0.03264778009917012
Epoch 1958: error is 0.8685602307493663
Epoch 1959: error is 1.6927025413790167
Epoch 1960: error is 0.031992354607606896
Epoch 1961: error is 0.8502455001555265
Epoch 1962: error is 0.03194232298732668
Epoch 1963: error is 0.6448934653231634
Epoch 1964: error is 0.031827234998584406
Epoch 1965: error is 0.0317815414212264
Epoch 1966: error is 0.

Epoch 2253: error is 0.021643305928424915
Epoch 2254: error is 0.934395537238372
Epoch 2255: error is 0.021445172393413703
Epoch 2256: error is 0.9313389020104326
Epoch 2257: error is 0.021309052498918638
Epoch 2258: error is 0.9232938901927009
Epoch 2259: error is 0.02111730317260991
Epoch 2260: error is 0.8856354554586978
Epoch 2261: error is 0.021150249859524673
Epoch 2262: error is 0.9314932794014983
Epoch 2263: error is 0.8664938714638344
Epoch 2264: error is 0.9285431969886929
Epoch 2265: error is 0.9252956973193537
Epoch 2266: error is 0.02108368329552838
Epoch 2267: error is 0.5493859000557467
Epoch 2268: error is 0.8956701488862232
Epoch 2269: error is 0.020991336632913877
Epoch 2270: error is 0.02096449124295077
Epoch 2271: error is 0.9240465445991906
Epoch 2272: error is 1.8257406795304676
Epoch 2273: error is 0.020696306408946378
Epoch 2274: error is 0.020670429493721137
Epoch 2275: error is 0.8951613339787006
Epoch 2276: error is 0.020599414697404707
Epoch 2277: error is 0

Epoch 2582: error is 0.014448246944146308
Epoch 2583: error is 0.014432766380718885
Epoch 2584: error is 0.014417336855484505
Epoch 2585: error is 0.014401957639385466
Epoch 2586: error is 0.014386628017227395
Epoch 2587: error is 0.014371347287415065
Epoch 2588: error is 0.9250049028418523
Epoch 2589: error is 0.01440307377065139
Epoch 2590: error is 0.014386346316104758
Epoch 2591: error is 1.80960687425431
Epoch 2592: error is 0.014353124395948486
Epoch 2593: error is 0.014336627491747534
Epoch 2594: error is 0.014320205037296054
Epoch 2595: error is 0.014303855869973242
Epoch 2596: error is 0.01428757884904446
Epoch 2597: error is 0.014271372855239673
Epoch 2598: error is 0.014255236790340175
Epoch 2599: error is 0.014239169576773883
Epoch 2600: error is 0.01422317015721861
Epoch 2601: error is 0.014207237494213367
Epoch 2602: error is 0.014191370569777541
Epoch 2603: error is 0.01417556838503773
Epoch 2604: error is 0.014159829959861741
Epoch 2605: error is 0.014144154332500422
Ep

Epoch 2939: error is 0.9352288951457386
Epoch 2940: error is 0.01022916660068753
Epoch 2941: error is 0.010218977736039228
Epoch 2942: error is 0.010208830786999986
Epoch 2943: error is 0.010198725186237666
Epoch 2944: error is 0.01018866037554288
Epoch 2945: error is 0.9160103613980854
Epoch 2946: error is 0.01012466802964299
Epoch 2947: error is 0.01011567139958151
Epoch 2948: error is 0.010106697174712909
Epoch 2949: error is 0.010097745097728341
Epoch 2950: error is 0.0683414030539195
Epoch 2951: error is 0.010076113763407299
Epoch 2952: error is 0.010067239801054567
Epoch 2953: error is 0.010058386636869712
Epoch 2954: error is 0.01004955404033426
Epoch 2955: error is 0.010040741784670631
Epoch 2956: error is 0.010031949646781045
Epoch 2957: error is 0.010023177407188204
Epoch 2958: error is 0.010014424849975884
Epoch 2959: error is 0.9367164820256473
Epoch 2960: error is 0.010032702050983878
Epoch 2961: error is 0.010023181753268526
Epoch 2962: error is 0.010013694700923718
Epoch

Epoch 3309: error is 0.007911037630991288
Epoch 3310: error is 0.00790092176388897
Epoch 3311: error is 0.9460593397096679
Epoch 3312: error is 0.007940058044234597
Epoch 3313: error is 0.007929282176750294
Epoch 3314: error is 0.007918570644977065
Epoch 3315: error is 0.00790792274649177
Epoch 3316: error is 0.007897337787724092
Epoch 3317: error is 0.007886815083834018
Epoch 3318: error is 1.4410015766203255
Epoch 3319: error is 0.007956966699281547
Epoch 3320: error is 0.007945449855226756
Epoch 3321: error is 0.00793400744443394
Epoch 3322: error is 0.007922638632437954
Epoch 3323: error is 0.007911342595468942
Epoch 3324: error is 0.007900118520300746
Epoch 3325: error is 0.9388947124696853
Epoch 3326: error is 0.007802274079733864
Epoch 3327: error is 0.007792267417172424
Epoch 3328: error is 0.007782316750322395
Epoch 3329: error is 0.007772421481267779
Epoch 3330: error is 0.9022863617222103
Epoch 3331: error is 0.9451327925634746
Epoch 3332: error is 0.007976573479602563
Epoch

Epoch 3625: error is 0.006127378502271207
Epoch 3626: error is 0.006121408136748765
Epoch 3627: error is 0.0061154631710398245
Epoch 3628: error is 0.9591970638368703
Epoch 3629: error is 0.006032431393932654
Epoch 3630: error is 0.0060271268800189525
Epoch 3631: error is 0.9572546853803922
Epoch 3632: error is 0.0059689620820346356
Epoch 3633: error is 0.00596404113163721
Epoch 3634: error is 0.005959137371296558
Epoch 3635: error is 0.005954250632645113
Epoch 3636: error is 0.9411759948830767
Epoch 3637: error is 0.9251305072910309
Epoch 3638: error is 0.00594702876029707
Epoch 3639: error is 0.05827616956327753
Epoch 3640: error is 1.889295799968356
Epoch 3641: error is 0.005863231319838028
Epoch 3642: error is 0.005859306460022414
Epoch 3643: error is 0.005855386937295626
Epoch 3644: error is 0.005851472720586555
Epoch 3645: error is 0.005847563779181067
Epoch 3646: error is 0.005843660082717858
Epoch 3647: error is 0.005839761601184168
Epoch 3648: error is 0.9554185514034091
Epoch

Epoch 3945: error is 0.00489052280771132
Epoch 3946: error is 0.9268467239028718
Epoch 3947: error is 0.004953678851395651
Epoch 3948: error is 0.9235709784460353
Epoch 3949: error is 0.005035190431413532
Epoch 3950: error is 0.961336233265033
Epoch 3951: error is 0.5720029428077049
Epoch 3952: error is 0.9656402227543917
Epoch 3953: error is 0.0049112185196158
Epoch 3954: error is 0.004907462410880726
Epoch 3955: error is 0.00490371553228838
Epoch 3956: error is 0.004899977823134973
Epoch 3957: error is 0.004896249223272619
Epoch 3958: error is 0.9383862139291314
Epoch 3959: error is 0.004907293261357408
Epoch 3960: error is 0.004903529544812994
Epoch 3961: error is 0.004899774561124323
Epoch 3962: error is 0.9615208331374441
Epoch 3963: error is 0.0048738609000281066
Epoch 3964: error is 0.004870194427647177
Epoch 3965: error is 0.004866536614031296
Epoch 3966: error is 0.004862887403623686
Epoch 3967: error is 0.004859246741370845
Epoch 3968: error is 0.004855614572717297
Epoch 3969

Epoch 4313: error is 0.003987303909748522
Epoch 4314: error is 0.003983966766051978
Epoch 4315: error is 0.003980645287318467
Epoch 4316: error is 0.003977339326277332
Epoch 4317: error is 0.0039740487371762254
Epoch 4318: error is 0.0039707733757642005
Epoch 4319: error is 0.00396751309927517
Epoch 4320: error is 0.0039642677664114105
Epoch 4321: error is 0.0039610372373272895
Epoch 4322: error is 0.003957821373613246
Epoch 4323: error is 0.003954620038279992
Epoch 4324: error is 0.003951433095742863
Epoch 4325: error is 0.003948260411806315
Epoch 4326: error is 0.003945101853648667
Epoch 4327: error is 0.003941957289807098
Epoch 4328: error is 0.6215113541077064
Epoch 4329: error is 0.003927356497630742
Epoch 4330: error is 0.003924274371265803
Epoch 4331: error is 0.00392120565029002
Epoch 4332: error is 0.003918150210154239
Epoch 4333: error is 0.003915107927579843
Epoch 4334: error is 0.9615431949399152
Epoch 4335: error is 0.003921207047553547
Epoch 4336: error is 0.9543371045817

Epoch 4683: error is 0.0033112819076608637
Epoch 4684: error is 0.0033089186281894736
Epoch 4685: error is 0.003306562812391494
Epoch 4686: error is 0.003304214406954919
Epoch 4687: error is 0.9692993965350033
Epoch 4688: error is 0.003261294292697816
Epoch 4689: error is 0.003259196532446699
Epoch 4690: error is 0.00325710459399729
Epoch 4691: error is 0.0032550184355209545
Epoch 4692: error is 0.003252938015548085
Epoch 4693: error is 0.0032508632929647632
Epoch 4694: error is 0.9415585770063426
Epoch 4695: error is 0.0032847681544667883
Epoch 4696: error is 0.0032824837731630376
Epoch 4697: error is 0.0032802062893233183
Epoch 4698: error is 0.0032779356544439557
Epoch 4699: error is 0.0032756718204260474
Epoch 4700: error is 0.003273414739571903
Epoch 4701: error is 0.003271164364581428
Epoch 4702: error is 0.0032689206485484465
Epoch 4703: error is 0.0032666835449573304
Epoch 4704: error is 0.0032644530076793763
Epoch 4705: error is 0.003262228990969364
Epoch 4706: error is 0.0032

Epoch 4880: error is 0.0029843748244947886
Epoch 4881: error is 0.002982691884985379
Epoch 4882: error is 0.0029810112439072186
Epoch 4883: error is 0.0029793328936037924
Epoch 4884: error is 0.0029776568264600262
Epoch 4885: error is 0.002975983034902063
Epoch 4886: error is 0.002974311511396977
Epoch 4887: error is 0.002972642248452521
Epoch 4888: error is 0.0029709752386168128
Epoch 4889: error is 0.0029693104744781054
Epoch 4890: error is 0.0029676479486644832
Epoch 4891: error is 0.0029659876538436557
Epoch 4892: error is 0.0029643295827225715
Epoch 4893: error is 0.002962673728047351
Epoch 4894: error is 0.0029610200826028273
Epoch 4895: error is 0.002959368639212384
Epoch 4896: error is 0.0029577193907377487
Epoch 4897: error is 0.002956072330078608
Epoch 4898: error is 0.002954427450172473
Epoch 4899: error is 0.0029527847439943904
Epoch 4900: error is 0.0029511442045566136
Epoch 4901: error is 0.0029495058249085493
Epoch 4902: error is 0.0029478695981363014
Epoch 4903: error i