In [3]:
from keras_self_attention import SeqSelfAttention

In [4]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import TimeDistributed
from tensorflow.keras.layers import LSTM
from tensorflow.keras.layers import Bidirectional
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Flatten
import pickle

from preprocessing import prepare_notes

##need to upgrade tensorflow to use SeqSelfAttention
#from tensorflow.keras.layers import SeqSelfAttention

# First simple model to get MVP

In [6]:
lstm_input, lstm_output = prepare_notes()

In [5]:
pickle_file = open("data/notes", "rb")
notes = pickle.load(pickle_file)
  

In [49]:
VOCAB = (len(set(notes)))

In [54]:
def add_weights(lstm_input):
    '''
    Build and compile the model
    
    lstm_input: lstm_input.shape[1] = number of steps, lstm_input.shape[2] = number features needed for
    Bi-directional
    
    model_output: number of categories to output for classification
    
    returns: compiled model
    '''
    
    model = Sequential()
    model.add(Bidirectional(LSTM(512,input_shape=(lstm_input.shape[1], lstm_input.shape[2]), return_sequences=True)))
    model.add(SeqSelfAttention(attention_width=15, attention_activation='sigmoid'))
    model.add(Dropout(0.3))
    
#     model.add(Bidirectional(LSTM(512,input_shape=(lstm_input.shape[1], lstm_input.shape[2]), return_sequences=True)))
#     model.add(SeqSelfAttention(attention_width=15, attention_activation='sigmoid'))
#     model.add(Dropout(0.3))
    
#     model.add(Bidirectional(LSTM(512,input_shape=(lstm_input.shape[1], lstm_input.shape[2]), return_sequences=True)))
#     model.add(SeqSelfAttention(attention_width=15, attention_activation='sigmoid'))
#     model.add(Dropout(0.3))
    
    model.add(Flatten())
    model.add(Dense(VOCAB))
    model.add(Activation('softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
    model.save_weights('first_train.h5')
    
    return model



In [55]:
def fit_model(model, lstm_input, lstm_output):
    '''
    fit model and save checkpoint
    
    model: lstm model from function build_model()
    lstm_input: input to the model- output from function prepare_notes() 
    lstm_output: targets for model- output from function prepare_notes()
    
    returns: None
    '''
    
    # checkpoint
    filepath="checkpoint/weights-{epoch:02d}-{val_accuracy:.2f}.hdf5"
    checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min')
    callbacks_list = [checkpoint]
    #only 10 epochs to start and see if the model works
    model.fit(lstm_input, lstm_output, epochs=1, batch_size=64, callbacks=callbacks_list)
    print(model.summary())

# Train the model on the data

In [56]:
def train(vocab):
    '''
    calls the prepare notes function
    calls the build model function 
    calls the fit_model function 
    '''
    lstm_input, lstm_output = prepare_notes()
    model = build_model(lstm_input)
    fit_model(model, lstm_input, lstm_output)
    
    

# Train the model

In [57]:
train(VOCAB)

68473
68473
  14/1070 [..............................] - ETA: 7:58:17 - loss: 7.5254

KeyboardInterrupt: 

In [51]:
#print(lstm_input[1])
print(lstm_input.shape)

(68473, 100, 1)


In [22]:
lstm_output.shape

(68473, 2424)

In [23]:
len(lstm_input)

68473

In [41]:
print(len(set(notes)))

2424


In [7]:
# Get all pitch names
pitchnames = sorted(set(item for item in notes))
# Get all pitch names
n_vocab = len(set(notes))
    
#print(pitchnames)
print(n_vocab)
#print(notes)
print(len(notes))

2424
68573


In [11]:
pitchnames

['0.1.3.7.81.0',
 '0.1.5.80.25',
 '0.1.5.80.5',
 '0.1.5.81.0',
 '0.1.50.75',
 '0.1.60.5',
 '0.1.62.5',
 '0.10.0',
 '0.2.3.5.70.5',
 '0.2.4.70.5',
 '0.2.4.71.0',
 '0.2.40.5',
 '0.2.5.70.5',
 '0.2.50.5',
 '0.2.52/3',
 '0.2.60.25',
 '0.2.60.5',
 '0.2.61/3',
 '0.2.62/3',
 '0.2.64/3',
 '0.2.70.5',
 '0.2.70.75',
 '0.2.71.0',
 '0.2.71/3',
 '0.20.0',
 '0.20.25',
 '0.20.5',
 '0.20.75',
 '0.21.0',
 '0.21/3',
 '0.22/3',
 '0.3.50.5',
 '0.3.51/3',
 '0.3.52/3',
 '0.3.6.91/3',
 '0.3.6.94.0',
 '0.3.60.5',
 '0.3.62/3',
 '0.3.70.25',
 '0.3.70.5',
 '0.3.70.75',
 '0.3.71.0',
 '0.3.71.5',
 '0.3.71/3',
 '0.30.25',
 '0.30.5',
 '0.30.75',
 '0.31.0',
 '0.31.5',
 '0.31.75',
 '0.31/3',
 '0.33.0',
 '0.34/3',
 '0.4.52.0',
 '0.4.60.5',
 '0.4.60.75',
 '0.4.70.25',
 '0.4.70.5',
 '0.4.70.75',
 '0.4.71.0',
 '0.4.71.75',
 '0.4.71/3',
 '0.4.72.0',
 '0.4.72.5',
 '0.4.72.75',
 '0.4.72/3',
 '0.4.74.0',
 '0.4.80.25',
 '0.40.25',
 '0.40.5',
 '0.40.75',
 '0.41.0',
 '0.41.5',
 '0.41.75',
 '0.41/3',
 '0.410/3',
 '0.42.0',
 '0.42

In [9]:
n_vocab

2424