# Neural Network Training (Run on a GPU)

References: 

https://github.com/shubham3121/music-generation-using-rnn 

https://www.hackerearth.com/blog/developers/jazz-music-using-deep-learning/

## Imports

In [5]:
from matplotlib import pyplot as plt
import pickle

from keras.callbacks import ModelCheckpoint
from keras.models import Sequential
from keras.layers import Activation, Dense, LSTM, Dropout, Flatten
from keras.utils import np_utils

## Load Data

In [10]:
with open('network_data', 'rb') as filepath:
    network_input = pickle.load(filepath)
    network_output = pickle.load(filepath)
    n_vocab = pickle.load(filepath)

## Train the Network on the Data

In [2]:
def create_network(network_in, n_vocab): 
    """Create the model architecture"""
    model = Sequential()
    model.add(LSTM(128, input_shape=network_in.shape[1:], return_sequences=True))
    model.add(Dropout(0.2))
    model.add(LSTM(128, return_sequences=True))
    model.add(Flatten())
    model.add(Dense(256))
    model.add(Dropout(0.3))
    model.add(Dense(n_vocab))
    model.add(Activation('softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='adam')

    return model

In [3]:
def train(model, network_input, network_output, epochs): 
    """
    Train the neural network
    """
    # Create checkpoint to save the best model weights.
    filepath = 'weights.best.music3.hdf5'
    checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=0, save_best_only=True)
    
    history = model.fit(network_input, network_output, epochs=epochs, batch_size=32, callbacks=[checkpoint])
    
    return history

In [None]:
#epochs = 200 # 200 caused runtime disconnection
epochs = 200

model = create_network(network_input, n_vocab)
print('Model created')

#return model
'''
TODO: Investigate learning rate?
'''


print('Training in progress')
history = train(model, network_input, network_output, epochs)
print('Training completed')

## Once this is done, move the .hdf5 for offline prediction.
## Accuracy graph:

In [None]:
#history = model1.fit(train_x, train_y,validation_split = 0.1, epochs=50, batch_size=4)
plt.figure()
plt.plot(history.history['loss'])
#plt.plot(history.history['val_acc']) # No validation set, so we can't graph this yet.
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
# plt.legend(['train', 'val'], loc='upper left') # No validation set, so we can't graph this yet.
plt.show()