In [1]:
import math
import copy
import numpy as np
import pickle
from RBM import *
import matplotlib.pyplot as plt

In [2]:
class DBN:
    def __init__(self, n_v, layers, k=1):
        """
        Initialization
        Args:
            n_v: the visible layer dimension
            layers: a list, the dimension of each hidden layer, e.g,, [500, 784]
            k: the number of gibbs sampling steps
        """
        self.n_v = n_v
        self.layer_n = layers
        self.k = k
        
        self.layer = [RBM(self.layer_n[i+1], self.layer_n[i], k, lr=0.01) for i in range(len(self.layer_n) - 1)]
        self.layer.reverse()
        
    def train(self, X):
        for layer in self.layer:
            for i in range(args.max_epoch):
                layer.update(X)
            X = layer.sample_h(X)[0]
            
    def generate(self, n_sample=100, k=1000):
        """
        Sample generation
        Args:
            n_sample: number of samples needed to be generated
        """
        v = np.random.choice([0, 1], size=(100, self.layer_n[-1]))
        for layer in self.layer[:-1]:
            v = layer.sample_h(v)[0]
        _, _, _, v, _, _ = self.layer[-1].gibbs_k(v, k)
        
        g_layer = copy.deepcopy(self.layer[:-1])
        g_layer.reverse()
        for layer in g_layer:
            v = layer.sample_v(v)[0]
        return v   

In [3]:
class Args():
    def __init__(self):
        self.max_epoch = 1000
        self.k = 5
        self.lr = 0.1
        self.train = '../data/digitstrain.txt'
        self.valid = '../data/digitsvalid.txt'
        self.test = "../data/digitstest.txt"
        self.n_hidden = 100

In [4]:
if __name__ == "__main__":
    np.seterr(all='raise')

    args = Args()

    train_data = np.genfromtxt(args.train, delimiter=",")
    train_X = train_data[:, :-1]
    train_Y = train_data[:, -1]
    train_X = binary_data(train_X)
    valid_data = np.genfromtxt(args.valid, delimiter=",")
    valid_X = valid_data[:, :-1]
    valid_X = binary_data(valid_X)
    valid_Y = valid_data[:, -1]
    test_data = np.genfromtxt(args.test, delimiter=",")
    test_X = test_data[:, :-1]
    test_X = binary_data(test_X)
    test_Y = test_data[:, -1]
    
    model = DBN(784, [100, 100, 100, 784], k=args.k)
    
    model.train(train_X)
    kkk = model.generate(train_X)