In [1]:
import numpy as np
import matplotlib.pyplot as plt
import midi
from tqdm import tqdm
import glob
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops

%matplotlib inline

In [2]:
lowerBound = 24
upperBound = 102
bound_diff = upperBound-lowerBound
note_range = bound_diff

num_timesteps = 15 #This is the number of timesteps that we will create at a time
n_visible = 2*note_range*num_timesteps #This is the size of the visible layer. 
n_hidden = 50 #This is the size of the hidden layer

batch_size = 100 #The number of training examples that we are going to send through the RBM at a time. 
lr = tf.constant(0.005, tf.float32) #The learning rate of our model

x  = tf.placeholder(tf.float32, [None, n_visible], name="x") #The placeholder variable that holds our data
W  = tf.Variable(tf.random_normal([n_visible, n_hidden], 0.01), name="W") #The weight matrix that stores the edge weights
bh = tf.Variable(tf.zeros([1, n_hidden],  tf.float32, name="bh")) #The bias vector for the hidden layer
bv = tf.Variable(tf.zeros([1, n_visible],  tf.float32, name="bv")) #The bias vector for the visible layer

In [3]:


def midiToNoteStateMatrix(midifile, squash=True, bound_diff=bound_diff):
    pattern = midi.read_midifile(midifile)

    timeleft = [track[0].tick for track in pattern]

    posns = [0 for track in pattern]

    statematrix = []
    time = 0

    state = [[0,0] for x in range(bound_diff)]
    statematrix.append(state)
    condition = True
    while condition:
        if time % (pattern.resolution / 4) == (pattern.resolution / 8):
            # Crossed a note boundary. Create a new state, defaulting to holding notes
            oldstate = state
            state = [[oldstate[x][0],0] for x in range(bound_diff)]
            statematrix.append(state)
        for i in range(len(timeleft)): #For each track
            if not condition:
                break
            while timeleft[i] == 0:
                track = pattern[i]
                pos = posns[i]

                evt = track[pos]
                if isinstance(evt, midi.NoteEvent):
                    if (evt.pitch < lowerBound) or (evt.pitch >= upperBound):
                        pass
                        # print "Note {} at time {} out of bounds (ignoring)".format(evt.pitch, time)
                    else:
                        if isinstance(evt, midi.NoteOffEvent) or evt.velocity == 0:
                            state[evt.pitch-lowerBound] = [0, 0]
                        else:
                            state[evt.pitch-lowerBound] = [1, 1]
                elif isinstance(evt, midi.TimeSignatureEvent):
                    if evt.numerator not in (2, 4):
                        # We don't want to worry about non-4 time signatures. Bail early!
                        # print "Found time signature event {}. Bailing!".format(evt)
                        out =  statematrix
                        condition = False
                        break
                try:
                    timeleft[i] = track[pos + 1].tick
                    posns[i] += 1
                except IndexError:
                    timeleft[i] = None

            if timeleft[i] is not None:
                timeleft[i] -= 1

        if all(t is None for t in timeleft):
            break

        time += 1

    S = np.array(statematrix)
    statematrix = np.hstack((S[:, :, 0], S[:, :, 1]))
    statematrix = np.asarray(statematrix).tolist()
    return statematrix

def noteStateMatrixToMidi(statematrix, name="example", bound_diff=bound_diff):
    statematrix = np.array(statematrix)
    if not len(statematrix.shape) == 3:
        statematrix = np.dstack((statematrix[:, :bound_diff], statematrix[:, bound_diff:]))
    statematrix = np.asarray(statematrix)
    pattern = midi.Pattern()
    track = midi.Track()
    pattern.append(track)
    
    tickscale = 55
    
    lastcmdtime = 0
    prevstate = [[0,0] for x in range(bound_diff)]
    for time, state in enumerate(statematrix + [prevstate[:]]):  
        offNotes = []
        onNotes = []
        for i in range(bound_diff):
            n = state[i]
            p = prevstate[i]
            if p[0] == 1:
                if n[0] == 0:
                    offNotes.append(i)
                elif n[1] == 1:
                    offNotes.append(i)
                    onNotes.append(i)
            elif n[0] == 1:
                onNotes.append(i)
        for note in offNotes:
            track.append(midi.NoteOffEvent(tick=(time-lastcmdtime)*tickscale, pitch=note+lowerBound))
            lastcmdtime = time
        for note in onNotes:
            track.append(midi.NoteOnEvent(tick=(time-lastcmdtime)*tickscale, velocity=40, pitch=note+lowerBound))
            lastcmdtime = time
            
        prevstate = state
    
    eot = midi.EndOfTrackEvent(tick=1)
    track.append(eot)

    midi.write_midifile("{}.mid".format(name), pattern)

def get_songs():
    files = glob.glob('../dataset/*.mid')
    songs = []
    for f in tqdm(files):
        try:
            song = np.array(midiToNoteStateMatrix(f))
            if np.array(song).shape[0] > 50:
                songs.append(song)
        except Exception as e:
            raise e           
    return songs

In [8]:
def sample(probs):
    return tf.floor(probs + tf.random_uniform(tf.shape(probs), 0, 1))


def gibbs_step(count, k, xk):
    probs = tf.sigmoid(tf.matmul(xk, W) + bh)
    hk = sample(probs)

    xk = sample(tf.sigmoid(tf.matmul(hk, tf.transpose(W)) + bv))
    return count+1, k, xk


def generate_x_sample(k):
    ct = tf.constant(0) #counter
    [_, _, x_sample] = control_flow_ops.while_loop(
        lambda count,
        num_iter,
        *args: count < num_iter,
        gibbs_step,
        [ct, tf.constant(k), x],
        parallel_iterations=10,
        back_prop=False)

    x_sample = tf.stop_gradient(x_sample) 
    return x_sample

num_epochs = 100

def generate_rnn():
    #TODO: Revisit this
    x_sample = generate_x_sample(1) 
    h = sample(tf.sigmoid(tf.matmul(x, W) + bh)) 
    h_sample = sample(tf.sigmoid(tf.matmul(x_sample, W) + bh)) 

    size_bt = tf.cast(tf.shape(x)[0], tf.float32)
    W_adder  = tf.multiply(lr/size_bt, tf.subtract(tf.matmul(tf.transpose(x), h), tf.matmul(tf.transpose(x_sample), h_sample)))
    bv_adder = tf.multiply(lr/size_bt, tf.reduce_sum(tf.subtract(x, x_sample), 0, True))
    bh_adder = tf.multiply(lr/size_bt, tf.reduce_sum(tf.subtract(h, h_sample), 0, True))
    updt = [W.assign_add(W_adder), bv.assign_add(bv_adder), bh.assign_add(bh_adder)]
    
def save_samples(sample):
    print('OUTPUT: ')
    print(sample.shape)
    print(sample[0].shape)
    for i in range(sample.shape[0]):
        if any(sample[i,:]):
            S = np.reshape(sample[i,:], (num_timesteps, 2*note_range))
            noteStateMatrixToMidi(S, "../output/{}_generated_chord_{}".format(num_epochs, i))
    
def main():
    from datetime import datetime

    def sample(probs):
        return tf.floor(probs + tf.random_uniform(tf.shape(probs), 0, 1))

    songs = get_songs()
    print("songs count", len(songs))
    songs  = [np.array(s) for s in songs]

    x_sample = generate_x_sample(1) 
    h = sample(tf.sigmoid(tf.matmul(x, W) + bh)) 
    h_sample = sample(tf.sigmoid(tf.matmul(x_sample, W) + bh)) 

    size_bt = tf.cast(tf.shape(x)[0], tf.float32)
    W_adder  = tf.multiply(lr/size_bt, tf.subtract(tf.matmul(tf.transpose(x), h), tf.matmul(tf.transpose(x_sample), h_sample)))
    bv_adder = tf.multiply(lr/size_bt, tf.reduce_sum(tf.subtract(x, x_sample), 0, True))
    bh_adder = tf.multiply(lr/size_bt, tf.reduce_sum(tf.subtract(h, h_sample), 0, True))
    updt = [W.assign_add(W_adder), bv.assign_add(bv_adder), bh.assign_add(bh_adder)]
    
    start = datetime.now()

    with tf.Session() as session:

        init = tf.initialize_all_variables()
        session.run(init)

        epoch_index = 0
        for epoch in tqdm(range(num_epochs)):
            epoch_index += 1

            for song in songs:
                song = np.array(song)
                ind = int(np.floor(song.shape[0]/num_timesteps)*num_timesteps)
                song = song[:ind]
                song = np.reshape(
                    song,
                    list(map(int,[song.shape[0]/num_timesteps, song.shape[1]*num_timesteps]))
                )
                print(song.shape)
                print(song[0:batch_size].shape)
                for i in range(1, len(song), batch_size): 
                    trainX = song[i:i+batch_size]
                    session.run(
                        updt,
                        feed_dict={x: trainX}
                    )



            if epoch_index % 100 == 0:
                print('epoch: {}'.format(epoch_index))

        end = datetime.now()
        print('time: ', end - start)

        sample = generate_x_sample(1).eval(session=session, feed_dict={x: np.zeros((10, n_visible))})

        save_samples(sample)

main()

100%|██████████| 2/2 [00:00<00:00,  2.07it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

songs count 2
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)


  6%|▌         | 6/100 [00:00<00:03, 24.49it/s]

(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)


 12%|█▏        | 12/100 [00:00<00:03, 24.19it/s]

(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)


 14%|█▍        | 14/100 [00:00<00:03, 22.55it/s]

(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)


 20%|██        | 20/100 [00:00<00:03, 23.13it/s]

(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)


 23%|██▎       | 23/100 [00:01<00:03, 22.97it/s]

(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)


 29%|██▉       | 29/100 [00:01<00:03, 22.31it/s]

(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)


 32%|███▏      | 32/100 [00:01<00:03, 21.81it/s]

(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)


 35%|███▌      | 35/100 [00:01<00:03, 20.18it/s]

(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)


 37%|███▋      | 37/100 [00:01<00:03, 19.52it/s]

(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)


 41%|████      | 41/100 [00:02<00:03, 19.21it/s]

(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)


 45%|████▌     | 45/100 [00:02<00:02, 18.41it/s]

(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)


 47%|████▋     | 47/100 [00:02<00:02, 18.21it/s]

(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)


 49%|████▉     | 49/100 [00:02<00:02, 17.95it/s]

(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)


 53%|█████▎    | 53/100 [00:03<00:02, 17.21it/s]

(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)


 57%|█████▋    | 57/100 [00:03<00:02, 17.00it/s]

(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)


 60%|██████    | 60/100 [00:03<00:02, 17.26it/s]

(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)


 67%|██████▋   | 67/100 [00:03<00:01, 17.96it/s]

(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)


 74%|███████▍  | 74/100 [00:03<00:01, 18.62it/s]

(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)


 77%|███████▋  | 77/100 [00:04<00:01, 18.66it/s]

(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)


 83%|████████▎ | 83/100 [00:04<00:00, 18.95it/s]

(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)


 90%|█████████ | 90/100 [00:04<00:00, 19.43it/s]

(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)


 94%|█████████▍| 94/100 [00:04<00:00, 19.77it/s]

(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)


100%|██████████| 100/100 [00:04<00:00, 20.03it/s]


(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
(93, 2340)
(93, 2340)
(101, 2340)
(100, 2340)
epoch: 100
time:  0:00:05.025480
OUTPUT: 
(10, 2340)
(2340,)
