-
Notifications
You must be signed in to change notification settings - Fork 0
/
dbn_example.py
115 lines (103 loc) · 4.44 KB
/
dbn_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Example of usage: Deep Belief Network
#
# Author: Alejandro Pozas-Kerstjens
# Requires: numpy for numerics
# pytorch as ML framework
# matplotlib for plots
# imageio for output export
# Last modified: Jun, 2018
import imageio
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
from DBN import DBN
from optimizers import SGD
from samplers import PersistentContrastiveDivergence
from torchvision import datasets
#------------------------------------------------------------------------------
# Parameter choices
#------------------------------------------------------------------------------
hidden_layers = [30, 30] # Number of nodes on each hidden layer
pretrain_lr = 1e-2 # Learning rate for pre-training
weight_decay = 1e-4 # Weight decay for pre-training
momentum = 0.95 # Momentum for pre-training
pretrain_epochs = 20 # Pre-training epochs
k = 5 # Steps of contrastive divergence in pre-training
finetune_lr = 1e-4 # Learning rate for fine-tuning
finetune_epochs = 30 # Finetuning epochs
batch_size = 20 # Batch size
gpu = False # Use GPU
continuous_out = True # Whether we want continuous outputs or not
sample_copies = 5 # Number of samples taken from the hidden
# representation of each datapoint
#------------------------------------------------------------------------------
# Data preparation
#------------------------------------------------------------------------------
device = torch.device('cuda' if gpu else 'cpu')
data = datasets.MNIST('mnist',
train=True,
download=True).train_data.type(torch.float)
test = datasets.MNIST('mnist',
train=False).test_data.type(torch.float)
data = (data.view((-1, 784)) / 255).to(device)
test = (test.view((-1, 784)) / 255).to(device)
vis = len(data[0])
# -----------------------------------------------------------------------------
# Construct DBN
# -----------------------------------------------------------------------------
pre_trained = os.path.isfile('DBN.h5')
sampler = PersistentContrastiveDivergence(k=k, hidden_activations=True)
optimizer = SGD(learning_rate=pretrain_lr,
momentum=momentum,
weight_decay=weight_decay)
dbn = DBN(n_visible=vis,
hidden_layer_sizes=hidden_layers,
sample_copies=sample_copies,
sampler=sampler,
optimizer=optimizer,
continuous_output=continuous_out,
device=device)
if pre_trained:
dbn.load_model('DBN.h5')
# -----------------------------------------------------------------------------
# Training
# -----------------------------------------------------------------------------
if not pre_trained:
dbn.pretrain(input_data=data,
epochs=pretrain_epochs,
batch_size=batch_size,
test=test)
dbn.finetune(input_data=data,
lr=finetune_lr,
epochs=finetune_epochs,
batch_size=batch_size)
dbn.save_model('DBN.h5')
# -----------------------------------------------------------------------------
# Plotting
# -----------------------------------------------------------------------------
print('#########################################')
print('# Generating samples #')
print('#########################################')
top_RBM = dbn.gen_layers[-1]
plt.figure(figsize=(20, 10))
zero = torch.zeros(25, len(top_RBM.vbias)).to(device)
images = [np.zeros((5 * 28, 5 * 28))]
for i in range(200):
sampler.continuous_output = False
zero = sampler.get_h_from_v(zero, top_RBM.W, top_RBM.hbias)
zero = sampler.get_v_from_h(zero, top_RBM.W, top_RBM.vbias)
sample = zero
for gen_layer in reversed(dbn.gen_layers[1:-1]):
sample = sampler.get_v_from_h(sample, gen_layer.W, gen_layer.vbias)
sampler.continuous_output = continuous_out
sample = sampler.get_v_from_h(sample,
dbn.gen_layers[0].W,
dbn.gen_layers[0].vbias)
datas = sample.data.cpu().numpy().reshape((25, 28, 28))
image = np.zeros((5 * 28, 5 * 28))
for k in range(5):
for l in range(5):
image[28*k:28*(k+1), 28*l:28*(l+1)] = datas[k + 5*l, :, :]
images.append(image)
imageio.mimsave('DBN_sample.gif', images, duration=0.1)