# 2-2. LSTM

<img src="./img/lstm.png" alt="lstm" width="500" align="left"/>

<img src="./img/song_lstm.png" alt="song_lstm" width="500" align="left"/>

In [None]:
import os
import tensorflow as tf
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.utils import np_utils
import numpy as np

In [None]:
CKPT_DIR = "../generated_output/LSTM"

In [None]:
if not os.path.exists(CKPT_DIR):
    os.makedirs(CKPT_DIR)

In [None]:
np.random.seed(5)

In [None]:
seq = ['g8', 'e8', 'e4', 'f8', 'd8', 'd4', 'c8', 'd8', 'e8', 'f8', 'g8', 'g8', 'g4',
       'g8', 'e8', 'e8', 'e8', 'f8', 'd8', 'd4', 'c8', 'e8', 'g8', 'g8', 'e8', 'e8', 'e4',
       'd8', 'd8', 'd8', 'd8', 'd8', 'e8', 'f4', 'e8', 'e8', 'e8', 'e8', 'e8', 'f8', 'g4',
       'g8', 'e8', 'e4', 'f8', 'd8', 'd4', 'c8', 'e8', 'g8', 'g8', 'e8', 'e8', 'e4']

In [None]:
note2idx = {'c4':0, 'd4':1, 'e4':2, 'f4':3, 'g4':4, 'a4':5, 'b4':6,
            'c8':7, 'd8':8, 'e8':9, 'f8':10, 'g8':11, 'a8':12, 'b8':13}

idx2note = {0:'c4', 1:'d4', 2:'e4', 3:'f4', 4:'g4', 5:'a4', 6:'b4',
            7:'c8', 8:'d8', 9:'e8', 10:'f8', 11:'g8', 12:'a8', 13:'b8'}

In [None]:
def seq2dataset(seq, window_size):
    dataset = []
    for i in range(len(seq)-window_size):
        subset = seq[i:(i+window_size+1)]
        dataset.append([note2idx[item] for item in subset])
    return np.array(dataset)

In [None]:
dataset = seq2dataset(seq, window_size = 4)

In [None]:
x_train = dataset[:,0:4]
y_train = dataset[:,4]
max_idx_value = 13
x_train = x_train / float(max_idx_value)
x_train = x_train.reshape([-1, 4, 1])
y_train = np_utils.to_categorical(y_train)
one_hot_vec_size = y_train.shape[1]

In [None]:
model = Sequential()
model.add(LSTM(128, batch_input_shape = (1, 4, 1), stateful=True))
model.add(Dense(one_hot_vec_size, activation='softmax'))

In [None]:
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

In [None]:
class LossHistory(keras.callbacks.Callback):
    def init(self):
        self.losses = []
        
    def on_epoch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))

In [None]:
history = LossHistory()
history.init()
num_epochs = 2000

for epoch_idx in range(num_epochs):
    print ('epochs : ' + str(epoch_idx) )
    model.fit(
        x_train, y_train, epochs=1, batch_size=1, verbose=2, shuffle=False, callbacks=[history])
    model.reset_states()

In [None]:
scores = model.evaluate(x_train, y_train, batch_size=1)
print("%s: %.2f%%" %(model.metrics_names[1], scores[1]*100))
model.reset_states()

In [None]:
pred_count = 50
note_onestep = ['g8', 'e8', 'e4', 'f8']
pred_out = model.predict(x_train, batch_size=1)
for i in range(pred_count):
    idx = np.argmax(pred_out[i])
    note_onestep.append(idx2note[idx])
model.reset_states()

In [None]:
seq_in = ['g8', 'e8', 'e4', 'f8']
note_fullsong = seq_in
seq_in = [note2idx[it] / float(max_idx_value) for it in seq_in]
for i in range(pred_count):
    sample_in = np.array(seq_in)
    sample_in = np.reshape(sample_in, (1, 4, 1))
    pred_out = model.predict(sample_in)
    idx = np.argmax(pred_out)
    note_fullsong.append(idx2note[idx])
    seq_in.append(idx / float(max_idx_value))
    seq_in.pop(0)
model.reset_states()

In [None]:
print("one step prediction : ", note_onestep)
print("full song prediction : ", note_fullsong)

In [None]:
import music21 as m21
from writeMIDI import writeMIDI

def note2midi(notes, num):
    n = []
    global start
    start = 0
    for i in range(len(notes)):
        timing = int(8/int(notes[i][1]))
        if timing == 1: n.append((notes[i][0]+'5',start+i,1,120))
        else :
            n.append((notes[i][0]+'5',start+i,1*timing,120))
            start += 1
        if not os.path.exists(CKPT_DIR+'/Midi'):
            os.makedirs(CKPT_DIR+'/Midi')
        writeMIDI('C','piano', 130, n, (CKPT_DIR+'/Midi/LSTM_result_%d' % num))
        
    return print("MLP result_%d export complete!" % num)


note2midi(note_onestep, 1)
note2midi(note_fullsong, 2)

[MIDI 재생 사이트 링크](https://onlinesequencer.net/import)