In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random
from music21 import *

%run lstm_model.ipynb

def sample(m, tem=1.0):
    m = np.log(m) / tem
    m = np.exp(m) / np.sum(np.exp(m)) - 1e-7

    return np.argmax(np.random.multinomial(1, m, 1))

def generator_X(notes, random_index_seq):
    t = 0
    rand_time = random.randint(0, notes.shape[0] - 16)
    while True:
        X = []
        if t < notes.shape[0] - 16:
            X.append(notes[t:t+16])
        else: 
            t = 0
            continue
        t += 16
#-----------------------------------------
#         if rand_time < notes.shape[0] - 17:
#             X.append(notes[rand_time:rand_time+16])
#         else:
#             rand_time = random.randint(0, notes.shape[0] - 16)
#             X.append(notes[rand_time:rand_time+16])
#         rand_time += 16
#-----------------------------------------
#         if random_index_seq[t] < notes.shape[0] - 16:
#             X.append(notes[random_index_seq[t]:random_index_seq[t]+16])
#         else:
#             X.append(notes[random_index_seq[t]-16:random_index_seq[t]])
#         if t < len(random_index_seq) - 1:
#             t += 1
#         else:
#             t = 0
        X = np.array(X)
        yield (X)
        
def generate_notes(notes, note_type):
    output_indices = []
    random_index_seq = random.sample(range(notes.shape[0]), 300)
    note_dict = dict((index, note) for index, note in enumerate(note_type))
    start_index = 0
    end_index = 0
    for n in note_dict:
        if note_dict[n] == {133}:
            start_index = n
        elif note_dict[n] == {134}:
            end_index = n
    
    for i in range(len(parts_index)):
        model = load_model("models/bach"+str(i)+".h5")
        gen = ((X) for (X) in generator_X(notes, random_index_seq))
        predict_list = model.predict_generator(gen, steps=1000)
        part_notes = []
        markov_chain(predict_list, note_type, i)
            
        for time in range(predict_list.shape[0]):
            for note_index in range(predict_list.shape[1]):
                diversity = np.random.random_sample()
                chosen_note = sample(predict_list[time][note_index], diversity)
                part_notes.append(chosen_note)
                #part_notes.append(np.argmax(predict_list[time][note_index]))
        output_indices.append(part_notes)
        
    output_indices = np.array(output_indices)
    output_chorale = index_note_transform(output_indices, "to_note", len(parts_index))
    midi = chorale_to_midi(output_chorale)
    
    for i in range(len(parts_index)):
        chorale_part = output_chorale[i][np.newaxis, :]
        part_midi = chorale_to_midi(chorale_part)
        output_file = 'Part' + str(i) + '.mid'
        make_midi_file(output_file, part_midi)
    return midi

def create_non_note_list(all_note_list):
    note_dict = dict((index, note) for index, note in enumerate(note_type))
    print(note_dict)
    non_note_list = set()
    non_note = [frozenset([i]) for i in range(133,135)]
    
    for note_i, index in enumerate(all_note_list):
        if note_dict[index] in non_note:
            non_note_list.add(index)
    print(list(non_note_list))
    return list(non_note_list)
def markov_chain(predict_list, note_type, part_id):
    all_note_list = []
    print(predict_list.shape)
    for batch in predict_list:
        for time in batch:
            all_note_list.append(np.argmax(time))        
    non_note_indices = create_non_note_list(all_note_list)
    all_note_list = np.array(all_note_list)
    # Remove the non-note number
    all_note_list = all_note_list[~np.in1d(all_note_list, non_note_indices).reshape(all_note_list.shape)]
    print(all_note_list.shape)
    
    prob_table = np.zeros(shape=(len(note_type),len(note_type)))
    for index, _note in enumerate(all_note_list):
        if index < len(all_note_list) - 1:
            current_note = all_note_list[index]
            next_note = all_note_list[index+1]
            prob_table[current_note, next_note] += 1
    prob_table = prob_table/prob_table.sum(axis=0)
    prob_df = pd.DataFrame(prob_table, index=note_type, columns=note_type)
    
    if part_id == 0:
        print(prob_df)
        test_generate(prob_table, note_type, 800)
        
def generate_index(prob_table, note_type):
    while True:
        index = np.random.randint(len(note_type), size=1)
        if not np.all(np.isnan(prob_table[:, index])):
            return index
        
def test_generate(prob_table, note_type, music_len):
    init_index = generate_index(prob_table, note_type)
    
    index_list = []
    current_index = init_index
    for i in range(music_len):
        index_list.append(current_index[0])
        notes_prob = prob_table[:, current_index] # select the column which is corresponding with the current note
#         print(notes_prob, notes_prob.sum())
        notes_prob = np.squeeze(notes_prob, axis=1) 
        next_index = np.argmax(np.random.multinomial(1, notes_prob, size=1)) # draw a random index from "notes_prob" 
        current_index = [next_index]
#     print(index_list)
    index_list = np.array(index_list)
    index_list = index_list[np.newaxis, :]
    output_chorale = index_note_transform(np.array(index_list), "to_note", 1)
    score = chorale_to_midi(output_chorale)
    
    output_file = "TTTTT.mid"
    make_midi_file(output_file, score)
def chorale_to_midi(chorale):
    score = stream.Score()
#     part_instrument = [instrument.Soprano(), instrument.Alto(), instrument.Tenor(), instrument.Bass()]
    part_instrument = [instrument.Violin(), instrument.Flute(), instrument.Piccolo(), instrument.Bassoon()]
    for i, chorale_part in enumerate(chorale):
        part = stream.Part()
        part.insert(part_instrument[i])
        d = 0
        new_note = note.Rest()
        for sound in chorale_part:
            sound = [n for n in sound]
            if (len(sound) < 2): # note or rest
                if sound[0] == 132:
                    part.append(note.Rest())
                elif sound[0] < 128:
                    if d > 0:
                        new_note.duration = duration.Duration(d / 4)
                        part.append(new_note)
                    d = 1
                    new_note = note.Note(sound[0])
                else:
                    d += 1
            else: # chord
                if d > 0:
                    new_note.duration = duration.Duration(d / 4)
                    part.append(new_note)
                d = 1
                new_note = chord.Chord(sound)
            
        new_note.duration = duration.Duration(d / 4)
        part.append(new_note)
        score.insert(part)
    return score

def make_midi_file(output_file, score):
    mf = midi.translate.music21ObjectToMidiFile(score)
    mf.open(output_file, 'wb')
    mf.write()
    mf.close()
    print("File " + output_file + " written")

[0.2  0.21 0.22 0.23 0.24 0.25 0.26 0.27 0.28 0.29 0.3  0.31 0.32 0.33
 0.34 0.35 0.36 0.37 0.38 0.39 0.4  0.41 0.42 0.43 0.44 0.45 0.46 0.47
 0.48 0.49 0.5  0.51 0.52 0.53 0.54 0.55 0.56 0.57 0.58 0.59 0.6  0.61
 0.62 0.63 0.64 0.65 0.66 0.67 0.68 0.69 0.7  0.71 0.72 0.73 0.74 0.75
 0.76 0.77 0.78 0.79 0.8  0.81 0.82 0.83 0.84 0.85 0.86 0.87 0.88 0.89
 0.9  0.91 0.92 0.93 0.94 0.95 0.96 0.97 0.98]
