In [None]:
import os
import mido
import tensorflow as tf
import tkinter as tk
from tkinter import filedialog
import numpy as np
from utils.preprocess.quantize_note_timings import quantize_note_timings
from utils.preprocess.normalize_velocities import normalize_velocities
from utils.preprocess.filter_unnecessary_data import filter_unnecessary_data
from utils.train.functions import split_train_validation_data, preprocess_data_for_training
from utils.predict.functions import preprocess_data_for_prediction, postprocess_predictions_to_midi
from utils.impure.functions import process_directory
from music21 import *

In [None]:
# Read and parse MIDI file using mido
def parse_midi_file(midi_file_path):
    midi_data = mido.MidiFile(midi_file_path)
    return midi_data


In [None]:
# Preprocess MIDI data before feeding it to the machine learning model
def preprocess_midi_data(midi_data):
    midi_data = quantize_note_timings(midi_data)
    midi_data = normalize_velocities(midi_data)
    midi_data = filter_unnecessary_data(midi_data)
    return midi_data

In [None]:
# Load a list of MIDI files for training and validation
def load_midi_files(file_directory):
    midi_files = []
    for root, dirs, files in os.walk(file_directory):
        for file in files:
            if file.endswith(".mid") or file.endswith(".midi"):
                midi_files.append(os.path.join(root, file))

    return midi_files

In [None]:
def train_lstm_model(train_sequences, train_labels, validation_sequences, validation_labels, input_shape, num_classes):

    print("Train sequences shape: ", train_sequences.shape)
    print("Train labels shape: ", train_labels.shape)
    print("Validation sequences shape: ", validation_sequences.shape)
    print("Validation labels shape: ", validation_labels.shape)
    print("Input shape: ", input_shape)
    print("Number of classes: ", num_classes)

    # Define the LSTM model architecture
    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=input_shape),
        tf.keras.layers.LSTM(units=128, return_sequences=True),
        tf.keras.layers.LSTM(units=128),
        tf.keras.layers.Dense(units=1, activation='sigmoid')  # Change the number of units to 1 and the activation to 'sigmoid'
    ])

    # Compile the model
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])  # Change the loss to 'binary_crossentropy'


    # Fit the model to the data
    history = model.fit(train_sequences, train_labels, epochs=20, validation_data=(validation_sequences, validation_labels))

    return model, history

In [None]:
def train_machine_learning_model():
    # Load and preprocess my MIDI data
    pure_file_directory = "./adl-piano-midi/Pop/Karaoke"
    impure_file_directory = "./adl-piano-midi-impure/Pop/Karaoke"
    
    pure_midi_files = load_midi_files(pure_file_directory)
    impure_midi_files = load_midi_files(impure_file_directory)
    
    print("Loaded {} pure MIDI files".format(len(pure_midi_files)))
    print("Loaded {} impure MIDI files".format(len(impure_midi_files)))

    preprocessed_pure_midi_data = [preprocess_midi_data(parse_midi_file(file)) for file in pure_midi_files]
    preprocessed_impure_midi_data = [preprocess_midi_data(parse_midi_file(file)) for file in impure_midi_files]
    
    print("Preprocessed {} pure MIDI files".format(len(preprocessed_pure_midi_data)))
    print("Preprocessed {} impure MIDI files".format(len(preprocessed_impure_midi_data)))

    # Split the preprocessed MIDI data into training and validation sets
    train_impure, validation_impure = split_train_validation_data(preprocessed_impure_midi_data)
    train_pure, validation_pure = split_train_validation_data(preprocessed_pure_midi_data)
    
    print("Split {} MIDI files into {} training files and {} validation files".format(len(preprocessed_impure_midi_data), len(train_impure), len(validation_impure)))

    # Further preprocess the MIDI data to create input sequences and corresponding labels for training
    train_sequences_impure, train_labels_impure = preprocess_data_for_training(train_impure, 0)
    train_sequences_pure, train_labels_pure = preprocess_data_for_training(train_pure, 1)
    train_sequences = np.concatenate((train_sequences_impure, train_sequences_pure))
    train_labels = np.concatenate((train_labels_impure, train_labels_pure))

    print("train_sequences shape: ", train_sequences.shape)
    print("train_labels shape: ", train_labels.shape)


    validation_sequences_impure, validation_labels_impure = preprocess_data_for_training(validation_impure, 0)
    validation_sequences_pure, validation_labels_pure = preprocess_data_for_training(validation_pure, 1)
    validation_sequences = np.concatenate((validation_sequences_impure, validation_sequences_pure))
    validation_labels = np.concatenate((validation_labels_impure, validation_labels_pure))
    
    print("Created {} training sequences and {} validation sequences".format(len(train_sequences), len(validation_sequences)))

    # Determine input_shape and num_classes based on preprocessed data
    input_shape = train_sequences.shape[1:] # Use the shape of the sequences for input_shape
    num_classes = np.max(train_labels) + 1 # Determine the number of unique classes in train_labels

    # Train the LSTM model
    model, history = train_lstm_model(train_sequences, train_labels, validation_sequences, validation_labels, input_shape, num_classes)
    return model, history

In [None]:
def preprocess_data_for_prediction(midi_data, sequence_length=32):
    midi_events = []
    # Extract MIDI events from the MidiFile object
    for track in midi_data.tracks:
        for event in track:
            if hasattr(event, 'note'):
                midi_events.append([event.note, event.time])
            else:
                midi_events.append([0, event.time]) # Replace with appropriate handling

    sequences = []
    for i in range(len(midi_events) - sequence_length):
        sequence = midi_events[i:i+sequence_length]
        sequences.append(sequence)

    return np.array(sequences)

def postprocess_predictions_to_midi(predictions, midi_data, threshold=0.000005, output_file_path="output.mid"):
    mid = mido.MidiFile()
    track = mido.MidiTrack()
    mid.tracks.append(track)

    # Flatten the list of events
    events = [event for track in midi_data.tracks for event in track if hasattr(event, 'note')]

    # Ensure the number of events matches the number of predictions
    events = events[:len(predictions)]

    # Filter the events based on the predictions
    filtered_events = [event for i, event in enumerate(events) if predictions[i] >= threshold]

    # Append the filtered events to the track, along with a 'note_off' event for each 'note_on' event
    for event in filtered_events:
        if event.time >= 0:
            track.append(event)
            if event.type == 'note_on':
                note_off = mido.Message('note_off', note=event.note, velocity=64, time=event.time)
                track.append(note_off)

    mid.save(output_file_path)
    return mid


def filter_midi_data(midi_data, model):
    # Preprocess the MIDI data to get individual notes
    input_notes = preprocess_data_for_prediction(midi_data)

    print("Created {} input notes".format(len(input_notes)))
    print("Input notes shape: ", input_notes.shape)

    # Predict whether each note is pure or impure
    note_predictions = model.predict(input_notes)

    print("Note predictions shape: ", note_predictions.shape)
    print("Note predictions: ", note_predictions)

    # Postprocess the note predictions to create a new MIDI file
    clean_midi_data = postprocess_predictions_to_midi(note_predictions, midi_data)

    return clean_midi_data


In [None]:
def convert_midi_to_musicxml(midi_data, midi_file_path, output_file_path):
    # Save the MidiFile object to a file
    midi_data.save(midi_file_path)

    # Convert MIDI data to music representation using music21
    music_rep = converter.parse(midi_file_path)

    # Write the music representation to a MusicXML file
    music_rep.write('musicxml', fp=output_file_path)

In [None]:
if __name__ == "__main__":
    model = train_machine_learning_model()
    print("Model trained")
    #Make an impure version of the midi file directory and save it to adl-piano-midi-impure
    # process_directory("./adl-piano-midi")

In [None]:
root = tk.Tk()
root.withdraw()  # we don't want a full GUI, so keep the root window from appearing
midi_file_path = filedialog.askopenfilename()  # show an "Open" dialog box and return the path to the selected file
midi_data = parse_midi_file(midi_file_path)
preprocess_midi_data(midi_data)

In [None]:
print("Filtering MIDI data...")
clean_midi = filter_midi_data(midi_data, model[0])
print("Filtered MIDI data")
print("Clean MIDI data: ", clean_midi)

In [None]:
print("Converting MIDI to MusicXML...")
midi_file_path = "./resultMIDI/" + os.path.splitext(os.path.basename(midi_file_path))[0] + ".mid"
output_file_path_XML = "./resultXML/" + os.path.splitext(os.path.basename(midi_file_path))[0] + ".xml"
convert_midi_to_musicxml(clean_midi, midi_file_path, output_file_path_XML)
print("Exported MusicXML to ", output_file_path_XML)