Skip to content

Commit

Permalink
Small binary RBM example added
Browse files Browse the repository at this point in the history
  • Loading branch information
jan authored and jan committed Apr 20, 2017
1 parent 7497bc6 commit 5536700
Showing 1 changed file with 120 additions and 0 deletions.
120 changes: 120 additions & 0 deletions examples/small_binary_RBM_MNIST.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
''' Example using a small BB-RBMs on the MNIST handwritten digit database.
:Version:
1.1.0
:Date:
20.04.2017
:Author:
Jan Melchior
:Contact:
JanMelchior@gmx.de
:License:
Copyright (C) 2017 Jan Melchior
This file is part of the Python library PyDeep.
PyDeep is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
'''
import numpy as numx
import pydeep.rbm.model as model
import pydeep.rbm.trainer as trainer
import pydeep.rbm.estimator as estimator

import pydeep.misc.io as io
import pydeep.misc.visualization as vis
import pydeep.misc.measuring as mea

# Set random seed (optional)
numx.random.seed(42)

# Input and hidden dimensionality
v1 = v2 = 28
h1 = h2 = 4

# Load data , get it from 'deeplearning.net/data/mnist/mnist.pkl.gz'
train_data = io.load_mnist("../../data/mnist.pkl.gz", True)[0]

# Training paramters
batch_size = 100
epochs = 39

# Create trainer and model
rbm = model.BinaryBinaryRBM(number_visibles=v1 * v2,
number_hiddens=h1 * h2,
data=train_data)
trainer = trainer.PCD(rbm, batch_size)

# Measuring time
measurer = mea.Stopwatch()

# Train model
print('Training')
print('Epoch\t\tRecon. Error\tLog likelihood \tExpected End-Time')
for epoch in range(1, epochs + 1):

# Shuffle training samples (optional)
train_data = numx.random.permutation(train_data)

# Loop over all batches
for b in range(0, train_data.shape[0], batch_size):
batch = train_data[b:b + batch_size, :]
trainer.train(data=batch, epsilon=0.05)

# Calculate Log-Likelihood, reconstruction error and expected end time every 10th epoch
if epoch % 10 == 0:
logZ = estimator.partition_function_factorize_h(rbm)
ll = numx.mean(estimator.log_likelihood_v(rbm, logZ, train_data))
re = numx.mean(estimator.reconstruction_error(rbm, train_data))
print('{}\t\t{:.4f}\t\t\t{:.4f}\t\t\t{}'.format(epoch, re, ll, measurer.get_expected_end_time(epoch, epochs)))
else:
print(epoch)

measurer.end()

# Print end/training time
print("End-time: \t{}".format(measurer.get_end_time()))
print("Training time:\t{}".format(measurer.get_interval()))

# Calculate true partition function
logZ = estimator.partition_function_factorize_h(rbm, batchsize_exponent=h1, status=False)
print("True Partition: {} (LL: {})".format(logZ, numx.mean(estimator.log_likelihood_v(rbm, logZ, train_data))))

# Approximate partition function by AIS (tends to overestimate)
logZ_approx_ = estimator.annealed_importance_sampling(rbm)[0]
print(
"AIS Partition: {} (LL: {})".format(logZ_approx_, numx.mean(estimator.log_likelihood_v(rbm, logZ_approx_, train_data))))

# Approximate partition function by reverse AIS (tends to underestimate)
logZ_approx_up = estimator.reverse_annealed_importance_sampling(rbm, data=train_data)[0]
print("reverse AIS Partition: {} (LL: {})".format(logZ_approx_up, numx.mean(
estimator.log_likelihood_v(rbm, logZ_approx_up, train_data))))

# Reorder RBM features by average activity decreasingly
reordered_rbm = vis.reorder_filter_by_hidden_activation(rbm, train_data)

# Display RBM parameters
vis.imshow_standard_rbm_parameters(reordered_rbm, v1, v2, h1, h2)

# Sample some steps and show results
samples = vis.generate_samples(rbm, train_data[0:30], 30, 1, v1, v2, False, None)
vis.imshow_matrix(samples, 'Samples')

# Display results
vis.show()

0 comments on commit 5536700

Please sign in to comment.