In [64]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import random
import time
import math
import argparse
import matplotlib.pyplot as plt
import pickle
import os

if not os.path.exists('../plot'):
    os.makedirs('../plot')
if not os.path.exists('../dump'):
    os.makedirs('../dump')

seed = 10417617 # do not change or remove

### helper function

In [65]:
def binary_data(inp):
    return (inp > 0.5) * 1.

In [58]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

In [59]:
def shuffle_corpus(data):
    random_idx = np.random.permutation(len(data))
    return data[random_idx]

### model

In [60]:
class RBM:
    """
    The RBM base class
    """
    def __init__(self, n_visible, n_hidden, k, lr=0.01, minibatch_size=1):
        """
        n_visible, n_hidden: dimension of visible and hidden layer
        k: number of gibbs sampling steps
        k: number of gibbs sampling steps
        lr: learning rate
        vbias, hbias: biases for visible and hidden layer, initialized as zeros
            in shape of (n_visible,) and (n_hidden,)
        W: weights between visible and hidden layer, initialized using Xavier,
            same as Assignment1-Problem1, in shape of (n_hidden, n_visible)
        Do np.random.seed(seed) before you call any np.random.xx()
        """
        self.n_visible = n_visible
        self.n_hidden = n_hidden
        self.k = k
        self.lr = lr
        
        self.vbias = np.zeros(n_visible,)
        self.hbias = np.zeros(n_hidden,)
        self.W = np.random.normal(0, np.sqrt(6.0/(self.n_hidden+self.n_visible)), (n_hidden, n_visible))


    def h_v(self, v):
        """
        Calculates hidden vector distribution P(h=1|v)
        v: visible vector in shape (N, n_visible)
        return P(h=1|v) in shape (N, n_hidden)
        N is the batch size
        """
        return sigmoid(self.hbias + v.dot(self.W.T))


    def sample_h(self, v):
        """
        Sample hidden vector given distribution P(h=1|v)
        v: visible vector in shape (N, n_visible)
        return hidden vector and P(h=1|v) both in shape (N, n_hidden)
        Do np.random.seed(seed) before you call any np.random.xx()
        """
        p_h = self.h_v(v)
        sample = np.random.binomial(1,p_h)
        return sample, p_h

    def v_h(self, h):
        """
        Calculates visible vector distribution P(v=1|h)
        h: hidden vector in shape (N, n_hidden)
        return P(v=1|h) in shape (N, n_visible)
        """
        return sigmoid(self.vbias + h.dot(self.W))

    def sample_v(self, h):
        """
        Sample visible vector given distribution P(h=1|v)
        h: hidden vector in shape (N, n_hidden)
        return visible vector and P(v=1|h) both in shape (N, n_visible)
        Do np.random.seed(seed) before you call any np.random.xx()
        """
        p_v = self.v_h(h)
        sample = np.random.binomial(1,p_v)
        return sample, p_v

    def gibbs_k(self, v, k=0):
        """
        The v (CD-k) procedure
        v: visible vector, in (N, n_visible)
        k: number of gibbs sampling steps
        return (h0, v0, h_sample, v_sample, prob_h, prob_v)
        h0: initial hidden vector sample, in (N, n_hidden)
        v0: the input v, in (N, n_visible)
        h_sample: the hidden vector sample after k steps, in (N, n_hidden)
        v_sample: the visible vector samplg after k steps, in (N, n_visible)
        prob_h: P(h=1|v) after k steps, in (N, n_hidden)
        prob_v: P(v=1|h) after k steps, in (N, n_visible)
        (Refer to Fig.1 in the handout if unsure on step counting)
        """
        if (k==0): k=self.k
        v0 = v
        h0, p_h0 = self.sample_h(v)
        vi = v0
        hi = h0
        for i in range(k):
            vi, p_vi = self.sample_v(hi)
            hi, p_hi = self.sample_h(vi)
        return (h0, v0, hi, vi, p_hi, p_vi)
        
    def update(self, X):
        """
        Updates RBM parameters with data X
        X: in (N, n_visible)
        Compute all gradients first before updating(/making change to) the
        parameters(W and biases).
        """
        N = X.shape[0]
        h0, v0, hi, vi, p_hi, p_vi = self.gibbs_k(X)
        grad_W = (self.h_v(v0).T.dot(v0) - p_hi.T.dot(vi)) / N
        grad_b = (self.h_v(v0) - p_hi).sum(0) / N
        grad_c = (v0 - vi).sum(0) / N
        
        self.W += self.lr * grad_W
        self.vbias += self.lr * grad_c
        self.hbias += self.lr * grad_b
        return grad_W, grad_b, grad_c
        

    def eval(self, X):
        """
        Computes reconstruction error, set k=1 for reconstruction.
        X: in (N, n_visible)
        Return the mean reconstruction error as scalar
        """
        N = X.shape[0]
        pred = self.sample_v(self.sample_h(X)[0])[0]
        return np.sqrt(((pred - X) ** 2).sum(1)).sum() / N

### main

In [61]:
class Args():
    def __init__(self):
        self.max_epoch = 1000
        self.k = 5
        self.lr = 0.1
        self.train = '../data/digitstrain.txt'
        self.valid = '../data/digitsvalid.txt'
        self.test = "../data/digitstest.txt"
        self.n_hidden = 100

In [62]:
if __name__ == "__main__":
    np.seterr(all='raise')

    args = Args()

    train_data = np.genfromtxt(args.train, delimiter=",")
    train_X = train_data[:, :-1]
    train_Y = train_data[:, -1]
    train_X = binary_data(train_X)
    valid_data = np.genfromtxt(args.valid, delimiter=",")
    valid_X = valid_data[:, :-1]
    valid_X = binary_data(valid_X)
    valid_Y = valid_data[:, -1]
    test_data = np.genfromtxt(args.test, delimiter=",")
    test_X = test_data[:, :-1]
    test_X = binary_data(test_X)
    test_Y = test_data[:, -1]
    
    model = RBM(784, 100, 3, lr=0.01, minibatch_size=1)
    
    for i in range(args.max_epoch):
        model.update(train_X)
        print(model.eval(train_X))

19.21485436442035
18.936579750038756
18.721684651205788
18.532417584174706
18.33817104671222
18.190847781814284
18.01289235299193
17.834181516384174
17.65555825634987
17.467283705746958
17.27886120943937
17.06837284962475
16.876782949266328
16.665168428185492
16.450321738236322
16.23983268854855
16.020356215932946
15.818589473387513
15.615268719617752
15.434046191609948
15.283650879302726
15.10022081768923
14.960057488343942
14.808570036205701
14.689271335113018
14.562643561633095
14.46763081120381
14.361738490569493
14.252996772265144
14.182919468761199
14.092223648138
14.0086752773396
13.94546531145332
13.880815243097066
13.808475375729802
13.762205865405203
13.693575368612011
13.641398627084774
13.60978785542856
13.550691469725049
13.514428722730559
13.456999320910908
13.434003752416269
13.389581223572062
13.362574954891434
13.320149354763004
13.288641067349303
13.26879413995269
13.218451979903445
13.190325114865754
13.184364394794631
13.147568351697817
13.13382148588502
13.09251921

10.495151958962404
10.475541628809925
10.477373300859622
10.473078426426676
10.476611118001024
10.459308673422392
10.460087866833481
10.454479828365928
10.452501031913327
10.45139428161134
10.463295759142865
10.451986591620148
10.450362963562188
10.448239024011315
10.436022271390055
10.4407703503814
10.43767821572865
10.42769997132925
10.432406781708792
10.437938042271123
10.413373321772136
10.430285231267389
10.422787747657146
10.432412246073001
10.414584685425988
10.406470274026077
10.409487267157315
10.40592564720554
10.411702476996073
10.403616109999506
10.405123444247007
10.398375339716829
10.397300176043961
10.393406051604476
10.388932049362877
10.37506795659541
10.38088649432343
10.374908818427844
10.377838161905716
10.378383922425929
10.365802633747226
10.367390817985157
10.373954092898511
10.372354284380854
10.354201446810942
10.35841150575838
10.354387676116847
10.343185888673581
10.34920721150377
10.345286334504003
10.35320149099262
10.336155985725785
10.334592725958291
10.3

9.73253657704483
9.737800106531884
9.729773267544466
9.732292027035937
9.737849479082605
9.735892412511122
9.730164721257422
9.723010429629367
9.733389647888561
9.729660959225628
9.725231427431133
9.725083492672539
9.710118277014288
9.720671373077344
9.714410762359083
9.709318416048106
9.717065089809612
9.71776923439562
9.7182433330201
9.708064572218968
9.722847162752446
9.71636830382166
9.71891856190828
9.6920621890837
9.710296068873024
9.711183578528924
9.703405841917402
9.702312481176897
9.70036949213695
9.69106942876061
9.709108244532537
9.697972704746809
9.703632677702313
9.708399393543043
9.693437215788984
9.709460450980858
9.69782033090462
9.689835924263612
9.700490992982877
9.693550842016885
9.694068740634622
9.688122072386072
9.679211526977
9.680249379166291
9.682483965328013
9.688453757207824
9.696272831067368
9.677868471181004
9.68843454673912
9.67968766824521
9.687530564100378
9.68621413597906
9.683165651022085
9.673497106297306
9.691383135578029
9.690657340202199
9.6723140