In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from text_preprocessing import create_songs_for, preprocess_texts, parse_raw_songs
from Midi_preprocessing import preprocess_midi, encode_midi
import nltk
nltk.download('punkt')

In [None]:
import string

import nltk
import numpy as np
import os
import pretty_midi
from keras.utils import to_categorical
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from tqdm import tqdm
root_path = './'


def parse_raw_song(line):
    data = []
    artist_name_index = line.find(',')
    name = line[:artist_name_index].lower().replace("\"", "").strip()
    # The replace between " to empty added beacuse some names contained it, probably error at the data creation
    line_after_name = line[artist_name_index + 1:]
    song_name_index = line_after_name.find(',')
    song_name = line_after_name[:song_name_index].lower().strip()
    lyrics = line_after_name[song_name_index + 1:]
    more_than_one = lyrics.find('&  &  &')  # Indicator for 2 songs in same line, bad dataset :(
    if more_than_one != -1:
        curr_lyrics = lyrics[:more_than_one]
    else:
        curr_lyrics = lyrics
    curr_song = {'name': name, 'song_name': song_name, 'lyrics': curr_lyrics}
    if more_than_one != -1:
        data.append(curr_song)
        more_songs = parse_raw_song(lyrics[more_than_one + len('&  &  &'):])
        for song in more_songs:
            data.append(song)
        return data
    else:
        return [curr_song]


def parse_raw_songs(raw_songs):
    songs = []

    for raw_song in raw_songs:
        curr_songs = parse_raw_song(raw_song)  # Iterate beacuse might contain multiple songs in a raw song line
        for song in curr_songs:
            songs.append(song)

    return songs


def create_songs_for(train=True):
    if train:
        midi_dir = "midi_files"
        lyrics_file = "lyrics_train_set.csv"
    else:
        midi_dir = "midi_files/test"
        lyrics_file = "lyrics_test_set.csv"

    with open(root_path + "/" + lyrics_file, 'r') as raw_lyrics:
        raw_songs = raw_lyrics.read().splitlines()

    songs = parse_raw_songs(raw_songs)

    midi_files = all_midi_files(midi_dir)

    for i, song in enumerate(songs):
        midi_name = song_midi_filename(song)

        matched_midi_files = [midi_file for midi_file in midi_files
                              if midi_name in midi_file.lower()]

        if len(matched_midi_files) != 1:
            print("OH OH", len(matched_midi_files), song)
            continue

        songs[i]['midi_file'] = matched_midi_files[0]

        if songs[i]['lyrics'].find('&,,,,') == -1:
            songs[i]['lyrics'] = songs[i]['lyrics'] + ' EOS'
        else:
            songs[i]['lyrics'] = songs[i]['lyrics'].replace('&,,,,', ' EOS')

        songs[i]['lyrics'] = ' '.join(nltk.word_tokenize(songs[i]['lyrics']))
        songs[i]['lyrics'] = songs[i]['lyrics'].replace('&', 'EOL')

        songs[i]['lyrics'] = remove_brackets(songs[i]['lyrics'])

        splitted_lyrics = [token for token in nltk.word_tokenize(songs[i]['lyrics']) if token not in string.punctuation]
        for j in range(len(splitted_lyrics) - 1):
            if songs[i].__contains__('ngrams'):
                songs[i]['ngrams'].append(splitted_lyrics[j])
            else:
                songs[i]['ngrams'] = [splitted_lyrics[j]]
            if songs[i].__contains__('labels'):
                songs[i]['labels'].append(splitted_lyrics[j + 1])
            else:
                songs[i]['labels'] = [splitted_lyrics[j + 1]]

        
#         pad songs[i]['lyrics'] to 400
#         ngrams from songs[i]['lyrics'] size2
#         for ngram: ngram = ngram[:-1], label = ngram[-1]
        songs[i]['ngrams'] = np.array(songs[i]['ngrams'])
        songs[i]['labels'] = np.array(songs[i]['labels'])

    return songs

def create_ngram_set(input_list, ngram_value=2):
    """
    Extract a set of n-grams from a list of integers.
    >>> create_ngram_set([1, 4, 9, 4, 1, 4], ngram_value=2)
    {(4, 9), (4, 1), (1, 4), (9, 4)}
    >>> create_ngram_set([1, 4, 9, 4, 1, 4], ngram_value=3)
    [(1, 4, 9), (4, 9, 4), (9, 4, 1), (4, 1, 4)]
    """
    return set(zip(*[input_list[i:] for i in range(ngram_value)]))
  
def remove_brackets(lyrics):
    while lyrics.find('(') != -1:
        open_bracket = lyrics.find('(')
        close_bracket = lyrics.find(')')
        lyrics = lyrics[:open_bracket] + lyrics[close_bracket + 1:]
    return lyrics


def song_midi_filename(song):
    return song['name'].replace(' ', '_') + "_-_" + song['song_name'].replace(' ', '_')


def all_midi_files(midi_dir):
    midi_path = os.path.join(root_path, midi_dir)

    return [os.path.join(midi_path, path) for path in os.listdir(midi_path)
            if '.mid' in path or '.midi' in path]


def create_data_for(train=True):
    songs = create_songs_for(train)
    
    X = np.array([song['ngrams'] for song in songs])
    y = np.array([song['labels'] for song in songs])
    
    
    return X, y


def init_tokenizer(text):
    tokenizer = Tokenizer(filters='')
    tokenizer.fit_on_texts([text])
    return tokenizer
  





In [None]:
midi_dir = "midi_files"
lyrics_file = "lyrics_train_set.csv"
with open(lyrics_file, 'r') as raw_lyrics:
    raw_songs = raw_lyrics.read().splitlines()

    songs = parse_raw_songs(raw_songs)

    midi_files = all_midi_files(midi_dir)

    sequences = list()
    seq_counts = []
    for i, song in enumerate(songs):
        seq_count = 0
        midi_name = song_midi_filename(song)

        matched_midi_files = [midi_file for midi_file in midi_files
                              if midi_name in midi_file.lower()]

        if len(matched_midi_files) != 1:
            print("OH OH", len(matched_midi_files), song)
            continue

        songs[i]['midi_file'] = matched_midi_files[0]

        if songs[i]['lyrics'].find('&,,,,') == -1:
            songs[i]['lyrics'] = songs[i]['lyrics'] + ' EOS'
        else:
            songs[i]['lyrics'] = songs[i]['lyrics'].replace('&,,,,', ' EOS')

        songs[i]['lyrics'] = ' '.join(nltk.word_tokenize(songs[i]['lyrics']))
        songs[i]['lyrics'] = songs[i]['lyrics'].replace('&', 'EOL')

        songs[i]['lyrics'] = remove_brackets(songs[i]['lyrics'])

        splitted_lyrics = [token for token in nltk.word_tokenize(songs[i]['lyrics']) if token not in string.punctuation]
        
  
        n = 4
        splitted_lyrics = (['<s>'] * (n-1)) + splitted_lyrics
    
        for j in range(0, len(splitted_lyrics) - n):
            sequence = splitted_lyrics[j:j+n]
            sequences.append(sequence)
            seq_count += 1
        
        seq_counts.append(seq_count)
            
            
all_songs_words = ' '.join(np.hstack(sequences).flatten()) + ' eos'
tokenizer = init_tokenizer(all_songs_words)
word_index = tokenizer.word_index


sequences = [np.array(tokenizer.texts_to_sequences(seq)).flatten() for seq in sequences]

sequences = pad_sequences(sequences, maxlen=n)

# split into input and output elements
sequences = np.array(sequences)
X, y = sequences[:,:-1],sequences[:,-1]
y = to_categorical(y, num_classes=len(word_index))

embeddings_index = {}
f = open('glove.6B.300d.txt', encoding="utf8")
EMBEDDING_DIM = 300
for line in f:
    values = line.split()
    word = values[0]
    coefs = np.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs
f.close()

print('Found %s word vectors.' % len(embeddings_index))

embedding_matrix = np.zeros((len(word_index) + 1, EMBEDDING_DIM))
unk_words = []
for word, i in word_index.items():
    embedding_vector = embeddings_index.get(word)
    if embedding_vector is not None:
        # words not found in embedding index will be all-zeros.
        embedding_matrix[i] = embedding_vector
    else:
        unk_words.append(word)

print("Found", len(unk_words) + 1, "Unknown words")
print(unk_words)

In [None]:
import pretty_midi
import os
from os.path import isfile, join
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
%matplotlib inline

files = create_songs_for(train=True)

onlyfiles = [file['midi_file'] for file in files]

encoded_midis = encode_midi(onlyfiles)

In [None]:
from gensim.models.doc2vec import Doc2Vec, TaggedDocument, FAST_VERSION
from nltk.tokenize import word_tokenize
from collections import OrderedDict
import multiprocessing

cores = multiprocessing.cpu_count()

tagged_data = [TaggedDocument(words=_d, tags=[str(i)]) for i, _d in enumerate(encoded_midis)]

assert FAST_VERSION > -1, "This will be painfully slow otherwise"

doc2vec_model = Doc2Vec(dm=1, vector_size=100, window=10, negative=5, hs=0, min_count=2, sample=0, 
                        epochs=20, workers=cores, alpha=0.05)



doc2vec_model.build_vocab(tagged_data)
print("%s vocabulary scanned & state initialized" % doc2vec_model)

%time doc2vec_model.train(tagged_data, total_examples=len(tagged_data), epochs=doc2vec_model.epochs)


In [None]:
midi_train = []
for song_index, doc in tqdm(enumerate(encoded_midis)):
    embedding = doc2vec_model.infer_vector(doc)
    midi_train.extend([embedding] * seq_counts[song_index])

In [None]:
import tensorflow as tf
import tensorflow.keras.backend as K
def perplexity(y_true, y_pred):
    cross_entropy = K.categorical_crossentropy(y_true, y_pred)
    perplexity = K.pow(2.0, cross_entropy)

    return perplexity

# def categorical_crossentropy(y_true, y_pred):
#   return tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_true, logits=y_pred)

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import LSTM, Embedding, Dense, Input, Concatenate, Dropout, Masking, BatchNormalization, LayerNormalization
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.callbacks import TensorBoard
from time import time
# import keras.backend as K
# K.clear_session()

text_in = Input(shape=(n-1,))
# masking = Masking()(text_in)
embedding = Embedding(input_dim=len(word_index) + 1, output_dim=EMBEDDING_DIM, weights=[embedding_matrix], mask_zero=True, trainable=False)(text_in)
text_norm = BatchNormalization()(embedding)
text_drop = Dropout(0.3)(text_norm)
lstm1 = LSTM(units=128)(text_drop)
lstm_norm = LayerNormalization()(lstm1)
# lstm2 = LSTM(units=128, return_sequences=True)(lstm1)
# lstm3 = LSTM(units=128)(lstm2)
midi_in = Input(shape=(100,))
midi_norm = BatchNormalization()(midi_in)
concat = Concatenate()([lstm_norm, midi_norm])
dropout = Dropout(0.1)(concat)
dense = Dense(units=len(word_index), activation='softmax')(dropout)

tb_callback = TensorBoard(log_dir='./logs/{}'.format(time()), batch_size=128, write_graph=True)

nn = Model(inputs=[text_in, midi_in], outputs=[dense])
print(nn.summary())

In [None]:
from tensorflow.keras.optimizers import Adam

optimizer = Adam()

nn.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=[perplexity])

history = nn.fit([X,np.array(midi_train)], y, batch_size=128,
          epochs=10, validation_split=0.2, callbacks=[tb_callback])

In [None]:
# generate a sequence from the model
def generate_seq(model, song_embedding, tokenizer, max_length, seed_text, n_words):
    result = [seed_text]
    
    # generate a fixed number of words
    for _ in range(n_words):
        # encode the text as integer
        encoded = tokenizer.texts_to_sequences([result])
        encoded = pad_sequences(encoded, maxlen=max_length, padding='pre')
        # predict a word in the vocabulary
        probs = model.predict([encoded, song_embedding])
        predicted_id = np.where(probs[0] == np.random.choice(probs[0], p=probs[0]))[0][0]
        # map predicted word index to word
        out_word = tokenizer.index_word[predicted_id]
        # append to input
        result.append(out_word)
    return ' '.join(result)

In [None]:
test_files = create_songs_for(train=False)

test_onlyfiles = [file['midi_file'] for file in test_files]
print(test_onlyfiles)
encoded_testmidis = encode_midi(test_onlyfiles)
print(encoded_testmidis)
test_vectors = []

for midi in encoded_testmidis:
    test_vectors.append(doc2vec_model.infer_vector(midi))


In [None]:
for index, song in enumerate(test_vectors):
    for word in ['hello', 'beautiful', 'world']:
        print(test_onlyfiles[index], word, ':')
        print(generate_seq(nn, song.reshape((1,100)), tokenizer, n-1, word, 150))