In [1]:
import tensorflow as tf
import numpy as np
import midi_tools
import data_tools

In [2]:
# parameters
piece = './pathetique_2_format0.mid'
n_notes = 128
quantization = 0.05 # time quantization is seconds for reading midi file into chords
go_time = 50 # from what time to learn the fragment
batch_size = 10

In [3]:
data = midi_tools.chordify(piece, quantization=quantization, num_notes=n_notes)

In [4]:
# make sample and batch generators
sampler = data_tools.sample_generator(data, go_time)
batcher = data_tools.batch_generator(data, go_time, batch_size)

In [5]:
# test quantization of piece
midi_tools.test_quantization(piece, quantization=quantization, name='quantization_test.mid')

In [6]:
# test sample
midi_tools.chords2midi(next(sampler), quantization=quantization, name='sample_test.mid')

In [7]:
# network hyperparameters 
n_1 = 128 # number of nodes in first layer 
n_2 = 128 # number of nodes in rnn cell
depth = 2 # number of rnn cells to use

In [8]:
# build graph
with tf.variable_scope('inputs'):
    # shapes are [batch_size, seq_length, n_notes]
    x = tf.placeholder(dtype=tf.float32, shape=[None, None, n_notes], name='x') # input sequence
    y = tf.placeholder(dtype=tf.float32, shape=[None, None, n_notes], name='y') # target sequence

with tf.variable_scope('network'):
    # first layer
    x_1 = tf.contrib.layers.fully_connected(inputs=x, num_outputs=n_1, activation_fn=None)

    # rnn layers
    cell = tf.contrib.rnn.MultiRNNCell(cells=[tf.contrib.rnn.BasicLSTMCell(n_2) for _ in range(depth)])

    # initial state for rnn
    initial_state = cell.zero_state(batch_size=tf.shape(x)[0], dtype=tf.float32) # initial_state

    # unroll the rnn
    x_2, state = tf.nn.dynamic_rnn(cell, x_1, initial_state=initial_state)

    # final layer
    z = tf.contrib.layers.fully_connected(inputs=x_2, num_outputs=n_notes, activation_fn=None)

with tf.variable_scope('predict'):
    p = tf.nn.sigmoid(z, name='p')

with tf.variable_scope('loss'):
    # only compute loss from go_time
    logits = z[:,go_time:,:] 
    labels = y[:,go_time:,:]
    # loss
    loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels), name='loss')

with tf.variable_scope('train'):
    train = tf.train.AdamOptimizer().minimize(loss)

In [None]:
# start session
sess = tf.Session()
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state('tf_checkpoints')

# train
# if checkpoint exists, restore from checkpoint
if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)
    print('model restored')
else:
    sess.run(tf.global_variables_initializer())
    print('model initialized')

model initialized


In [None]:
# train
for i in range(100000):
    batch = next(batcher)
    l, _ = sess.run([loss, train], feed_dict={x:batch[:,:-1,:], y:batch[:,1:,:]})
    if i % 100 == 0:
        saver.save(sess, "tf_checkpoints/model.ckpt")
        print(i, l)

0 0.691269
100 0.0793691
200 0.0931354


In [None]:
# predict
initial_sample = next(sampler)
initial_seq = np.reshape(initial_sample, newshape=[1,-1,n_notes])[:,:go_time+1,:]

# first part of rnn
p_val, state_val = sess.run([p, state], feed_dict={x:initial_seq})
next_chord = np.round(p_val[0,-1,:])

continued_seq = [next_chord]
for _ in range(go_time):
    new_input = np.reshape(next_chord, newshape=[1,1,-1])
    p_val, state_val = sess.run([p, state], feed_dict={x:new_input, initial_state:state_val})
    next_chord = np.round(p_val[0,-1,:])
    continued_seq += [next_chord]

In [None]:
midi_tools.chords2midi(initial_sample, name='initial_sample.mid')
midi_tools.chords2midi(continued_seq, name='prediction.mid')