In [1]:
import tensorflow as tf
import numpy as np
from tqdm import tqdm
import CRBM
import input_manipulation

"""
	This file stores the code for initializing the weights of the RNN-RBM. We initialize the parameters of the RBMs by
	training them directly on the data with CD-k. We initialize the parameters of the RNN with small weights.
"""

num_epochs = 100 #The number of epochs to train the CRBM
lr = 0.01 #The learning rate for the CRBM


num_conv_filters = 8
conv_strides = 2
span = 123
num_timesteps = 32

size_conv_filters = 4
hidden_width = np.floor((num_timesteps-size_conv_filters)/conv_strides) + 1
n_hidden_recurrent = 100


In [12]:
songs = input_manipulation.get_songs('Game_Music_Midi')

x  = tf.placeholder(tf.float32, [num_timesteps, span], name="x") #The placeholder variable that holds our data

#Testing
batch_size = tf.placeholder(tf.int64, [1], name="batch_size")
#parameters of CRBM
W   = tf.Variable(tf.random_normal([size_conv_filters, span, 1, num_conv_filters], 0.01), name="W") #The weight matrix of the RBM
bh  = tf.Variable(tf.zeros([hidden_width,num_conv_filters], tf.float32), name="bh") #The RNN -> RBM hidden bias vector
bv  = tf.Variable(tf.zeros([num_timesteps, span], tf.float32), name="bv")#The RNN -> RBM visible bias vector

#parameters related to RNN
Wuh = tf.Variable(tf.random_normal([n_hidden_recurrent, int(hidden_width*num_conv_filters)], 0.0001), name="Wuh")  #The RNN -> RBM hidden weight matrix
Wuv = tf.Variable(tf.random_normal([n_hidden_recurrent, int(num_timesteps*span)], 0.0001), name="Wuv") #The RNN -> RBM visible weight matrix
Wvu = tf.Variable(tf.random_normal([int(num_timesteps*span), n_hidden_recurrent], 0.0001), name="Wvu") #The data -> RNN weight matrix
Wuu = tf.Variable(tf.random_normal([n_hidden_recurrent, n_hidden_recurrent], 0.0001), name="Wuu") #The RNN hidden unit weight matrix
bu  = tf.Variable(tf.zeros([1, n_hidden_recurrent],  tf.float32), name="bu")   #The RNN hidden unit bias vector
u0  = tf.Variable(tf.zeros([1, n_hidden_recurrent], tf.float32), name="u0") #The initial state of the RNN

#The RBM bias vectors. These matrices will get populated during rnn-rbm training and generation
BH_t = tf.Variable(tf.ones([hidden_width,num_conv_filters],  tf.float32), name="BH_t")
BV_t = tf.Variable(tf.ones([num_timesteps, span],  tf.float32), name="BV_t")

#Build the RBM optimization
saver = tf.train.Saver()

100%|██████████| 45/45 [00:01<00:00, 32.49it/s]


In [13]:
song = songs[0][0,:,:]
song.shape

(32, 123)

In [14]:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

array([[0., 0., 0., ..., 1., 0., 0.],
       [1., 1., 1., ..., 0., 0., 0.],
       [0., 1., 0., ..., 1., 1., 0.],
       ...,
       [0., 0., 1., ..., 1., 0., 1.],
       [0., 0., 1., ..., 1., 1., 1.],
       [1., 0., 0., ..., 0., 0., 1.]], dtype=float32)

In [24]:
gs = CRBM.gibbs_sample(x, W, bv, bh, 10)
x_sample = sess.run(gs, feed_dict={x:song})

In [25]:
x_sample

array([[0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 1., 0., ..., 0., 1., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 1.],
       [0., 1., 1., ..., 0., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)

In [29]:
sum(sum(song-x_sample))

-2052.0

In [35]:
h_sample = CRBM.crbm_inference(x_sample, W, bh)
h = CRBM.crbm_inference(song, W, bh)

In [36]:
sess.run(h_sample)
sess.run(h)

array([[ 0.71566916,  0.18752742, -0.34768224, -0.34198058,  1.2602948 ,
         0.4239071 ,  0.4831197 ,  0.83092785],
       [ 0.71566916,  0.18752742, -0.34768224, -0.34198058,  1.2602948 ,
         0.4239071 ,  0.4831197 ,  0.83092785],
       [ 0.71566916,  0.18752742, -0.34768224, -0.34198058,  1.2602948 ,
         0.4239071 ,  0.4831197 ,  0.83092785],
       [ 3.0173056 ,  2.3149362 ,  0.97044605,  1.5812275 , -1.7198399 ,
        -0.57958686, -0.58893   ,  2.6540585 ],
       [ 5.1004767 ,  1.4820268 ,  0.07543936,  0.64307356,  6.463602  ,
        -1.6956797 ,  0.6888623 ,  0.04951668],
       [ 2.7088323 ,  1.2094274 ,  2.693478  , -2.2750654 , -1.028224  ,
        -4.0044336 , -1.0212129 , -5.4219894 ],
       [ 4.1201043 , -0.6346062 , -0.85693854,  0.04538471,  1.8206699 ,
         1.648764  , -1.9534055 ,  2.4111166 ],
       [ 2.2835813 ,  4.3705726 , -1.2843188 ,  1.5207521 ,  2.7914424 ,
        -0.35044503, -0.01205558, -0.886302  ],
       [-0.56505626,  1.4670893 

In [38]:
fc = CRBM.free_energy(x, h, W, bv, bh) - CRBM.free_energy(x_sample, h_sample, W, bv, bh)

In [54]:
W_ = tf.multiply(-lr,tf.gradients(fc,W))
W_ = sess.run(tf.reshape(W_,W.shape))

In [55]:
W_

array([[[[-2.9999999e-02, -2.9999999e-02, -2.3373613e-02, ...,
          -2.9999999e-02, -2.9999999e-02, -2.9999999e-02]],

        [[-1.3000000e-01, -1.3000000e-01, -9.3374066e-02, ...,
          -1.3000000e-01, -1.3000000e-01, -1.3000000e-01]],

        [[-4.9999997e-02, -4.9999997e-02, -3.9999999e-02, ...,
          -4.9999997e-02, -4.9999997e-02, -4.9999997e-02]],

        ...,

        [[-4.9999997e-02, -4.9999997e-02, -4.9999997e-02, ...,
          -4.9999997e-02, -4.9999997e-02, -4.9999997e-02]],

        [[-1.1000000e-01, -1.1000000e-01, -8.3374061e-02, ...,
          -1.1000000e-01, -1.1000000e-01, -1.1000000e-01]],

        [[-1.2000000e-01, -1.2000000e-01, -8.3374061e-02, ...,
          -1.2000000e-01, -1.2000000e-01, -1.2000000e-01]]],


       [[[-3.9999999e-02, -3.9999999e-02, -3.3373613e-02, ...,
          -3.9999999e-02, -3.9999999e-02, -3.9999999e-02]],

        [[-9.9999998e-03, -9.9999998e-03, -4.4654871e-07, ...,
          -9.9999988e-03, -9.9999998e-03, -9.9999998e

In [47]:
size_bt = tf.cast(tf.shape(song[1])[0],tf.float32)

In [48]:
size_bt

<tf.Tensor 'Cast:0' shape=() dtype=float32>

In [50]:
sess.run(size_bt)

123.0

In [51]:
W_ = tf.Print(W_,[W_])

In [52]:
W_

<tf.Tensor 'Print:0' shape=(4, 123, 1, 8) dtype=float32>

In [53]:
sess.run(W_)

array([[[[-2.9999999e-02, -2.9999999e-02, -2.3373613e-02, ...,
          -2.9999999e-02, -2.9999999e-02, -2.9999999e-02]],

        [[-1.3000000e-01, -1.3000000e-01, -9.3374066e-02, ...,
          -1.3000000e-01, -1.3000000e-01, -1.3000000e-01]],

        [[-4.9999997e-02, -4.9999997e-02, -3.9999999e-02, ...,
          -4.9999997e-02, -4.9999997e-02, -4.9999997e-02]],

        ...,

        [[-4.9999997e-02, -4.9999997e-02, -4.9999997e-02, ...,
          -4.9999997e-02, -4.9999997e-02, -4.9999997e-02]],

        [[-1.1000000e-01, -1.1000000e-01, -8.3374061e-02, ...,
          -1.1000000e-01, -1.1000000e-01, -1.1000000e-01]],

        [[-1.2000000e-01, -1.2000000e-01, -8.3374061e-02, ...,
          -1.2000000e-01, -1.2000000e-01, -1.2000000e-01]]],


       [[[-3.9999999e-02, -3.9999999e-02, -3.3373613e-02, ...,
          -3.9999999e-02, -3.9999999e-02, -3.9999999e-02]],

        [[-9.9999998e-03, -9.9999998e-03, -4.4654871e-07, ...,
          -9.9999988e-03, -9.9999998e-03, -9.9999998e