In [None]:
import numpy as np
import scipy.io.wavfile as wavfile
from scipy.fftpack import fft
import matplotlib.pyplot as plt
import pretty_midi
import pywt
import keras
import tensorflow as tf
from numba import njit

In [None]:
ROUNDING_SPECIFICITY = 4

from enum import Enum

class BaseNote(Enum):
    C = 16.35
    C_SHARP = 17.32
    D = 18.35
    D_SHARP = 19.45
    E = 20.60
    F = 21.83
    F_SHARP = 23.12
    G = 24.50
    G_SHARP = 25.96
    A = 27.50
    A_SHARP = 29.14
    B = 30.87
    
class MIDIPitch(Enum):
    C0 = 12
    C_SHARP0 = 13   
    D0 = 14
    D_SHARP0 = 15
    E0 = 16
    F0 = 17
    F_SHARP0 = 18
    G0 = 19
    G_SHARP0 = 20
    # Everything before 21 is not really in the midi-set, however this enum is also used for determining the note's positioning in sparse arrays, so I want this encompassing more than just the midi set.
    A0 = 21
    A_SHARP0 = 22
    B0 = 23
    C1 = 24
    C_SHARP1 = 25
    D1 = 26
    D_SHARP1 = 27
    E1 = 28
    F1 = 29
    F_SHARP1 = 30
    G1 = 31
    G_SHARP1 = 32
    A1 = 33
    A_SHARP1 = 34
    B1 = 35
    C2 = 36
    C_SHARP2 = 37
    D2 = 38
    D_SHARP2 = 39
    E2 = 40
    F2 = 41
    F_SHARP2 = 42
    G2 = 43
    G_SHARP2 = 44
    A2 = 45
    A_SHARP2 = 46
    B2 = 47
    C3 = 48
    C_SHARP3 = 49
    D3 = 50
    D_SHARP3 = 51
    E3 = 52
    F3 = 53
    F_SHARP3 = 54
    G3 = 55
    G_SHARP3 = 56
    A3 = 57
    A_SHARP3 = 58
    B3 = 59
    C4 = 60
    C_SHARP4 = 61
    D4 = 62
    D_SHARP4 = 63
    E4 = 64
    F4 = 65
    F_SHARP4 = 66
    G4 = 67
    G_SHARP4 = 68
    A4 = 69
    A_SHARP4 = 70
    B4 = 71
    C5 = 72
    C_SHARP5 = 73
    D5 = 74
    D_SHARP5 = 75
    E5 = 76
    F5 = 77
    F_SHARP5 = 78
    G5 = 79
    G_SHARP5 = 80
    A5 = 81
    A_SHARP5 = 82
    B5 = 83
    C6 = 84
    C_SHARP6 = 85
    D6 = 86
    D_SHARP6 = 87
    E6 = 88
    F6 = 89
    F_SHARP6 = 90
    G6 = 91
    G_SHARP6 = 92
    A6 = 93
    A_SHARP6 = 94
    B6 = 95
    C7 = 96
    C_SHARP7 = 97
    D7 = 98
    D_SHARP7 = 99
    E7 = 100
    F7 = 101
    F_SHARP7 = 102
    G7 = 103
    G_SHARP7 = 104
    A7 = 105
    A_SHARP7 = 106
    B7 = 107
    C8 = 108
    C_SHARP8 = 109
    D8 = 110
    D_SHARP8 = 111
    E8 = 112
    F8 = 113
    F_SHARP8 = 114
    G8 = 115
    G_SHARP8 = 116
    A8 = 117
    A_SHARP8 = 118
    B8 = 119
    C9 = 120
    C_SHARP9 = 121
    D9 = 122
    D_SHARP9 = 123
    E9 = 124
    F9 = 125
    F_SHARP9 = 126
    G9 = 127
    
    
    
class Note:
    def __init__(self, note: BaseNote, octave: int, velocity: float = 1.0):
        self.note = note
        self.octave = octave
        self.velocity = velocity
        
    
    @staticmethod
    def from_midi_pitch(midi_pitch: MIDIPitch, velocity: float = 1.0):
        note_name = midi_pitch.name[:-1]
        note_octave = int(midi_pitch.name[-1])
        
        
        return Note(BaseNote[note_name], note_octave, velocity)
        
        
    def midi_name(self):
        return f"{self.note.name}{self.octave}"
    
    def midi_pitch(self):
        return  MIDIPitch[self.midi_name()].value
        
    def __eq__(self, other):
        if isinstance(other, Note):
            return self.note == other.note and self.octave == other.octave
        return False
        
    def __hash__(self):
        return hash((self.note.value, self.octave))
        
    def get_frequency(self):
        return self.note.value * (2 ** self.octave)


class NoteRollNote:
    def __init__(self, note: Note, start_time: float, end_time: float = 0):
        self.note = note
        self.start_time = start_time
        self.end_time = end_time
        
        
    def __eq__(self, other):
        if isinstance(other, NoteRollNote):
            return self.note.__eq__(other.note)
        return False
        
        
    def __hash__(self):
        return hash(self.note.__hash__())
    
    def generate_wave(self, sampling_frequency):
        duration = self.end_time - self.start_time
        t = np.linspace(0, duration, int(duration * sampling_frequency), False)
        wave = NoteUtilities.harmonic(t, self.note.get_frequency(), 1, -np.pi / 2) 
        # * self.note.velocity
        
        return wave 
        

class NoteUtilities:
    def __init__(self):
        self.notes = list(BaseNote)

    @staticmethod
    def enumerate_circle_of_fifths(start_octave=0, end_octave=8):
        for octave in range(start_octave, end_octave + 1):
            for note in BaseNote:
                yield (BaseNote(note.value), note.value * (2 ** octave), octave)
                
    
    @staticmethod
    def get_vocab_size(start_octave=0, end_octave=8):
        return len(MIDIPitch) 
               
                
    frequency_note_map = {}
    @staticmethod
    def get_note_from_frequency(frequency, start_octave= 0, end_octave = 8):
        rounded_frequency = round(frequency)
        if(round(frequency, 2) in NoteUtilities.frequency_note_map):
            return NoteUtilities.frequency_note_map[rounded_frequency]
        
        closest_note: Note = None
        min_difference = float('inf')
        for base_note, note_frequency, octave in NoteUtilities.enumerate_circle_of_fifths(start_octave, end_octave):
            difference = abs(frequency - note_frequency)
            if difference < min_difference:
                min_difference = difference
                closest_note = Note(base_note, octave)
            
            # Break early, since we're now going away from the note frequency
            if difference > min_difference:
                break
            
        NoteUtilities.frequency_note_map[rounded_frequency] = closest_note
        return closest_note
    
    
    @staticmethod
    def note_sets_to_sparse_velocity(note_sets: list[set[Note]]):
        vocab_size = NoteUtilities.get_vocab_size()

        
        sparse_array_set: np.ndarray[np.ndarray[np.float32]] = np.zeros(len(note_sets)).astype(np.ndarray)

        for idx, note_set in enumerate(note_sets):
            # Determine the positions that the note-set occupies in the sparse array. Conveniently, this is the midi pitch (expanded) minus the first midi pitch value.
            sparse_array: np.ndarray[np.float32] = np.zeros(vocab_size)
            for note in note_set:
                index = note.midi_pitch() - MIDIPitch.C0.value
                sparse_array[index] = note.velocity
                
            sparse_array_set[idx] = sparse_array
    
        return sparse_array_set
    
    @staticmethod
    def sparse_velocity_to_note_sets(sparse_array_set):
        note_sets: list[set[Note]] = []
        for sparse_array in sparse_array_set:
            note_set: set[Note] = set()
            for i, velocity in enumerate(sparse_array):
                if velocity > 0:
                    midi_pitch_index = i + MIDIPitch.C0.value
                    midi_pitch = MIDIPitch(midi_pitch_index)
                    note = Note.from_midi_pitch(midi_pitch, velocity)
                    note_set.add(note)
            
            note_sets.append(note_set)
        
        return note_sets
       
    
    @staticmethod
    def get_note_frequency_window_set(
        data: np.ndarray,
        sample_duration: float,
        sampling_frequency: int,
        threshold_intensity: float = 0.1,
    ):
        """
        Creates windows of notes for the input duration in the form [frequency, velocity].
        """

        windows = []
        samples_per_duration = max(int(sample_duration * sampling_frequency), 1)

        for i in range(0, len(data), samples_per_duration):
            window = data[i : i + samples_per_duration]
            fft_spectrum = fft(window)
            frequencies = np.fft.fftfreq(len(fft_spectrum), 1 / sampling_frequency)
            magnitude = np.abs(fft_spectrum)

            sorted_magnitude_indices = np.argsort(magnitude)

            max_magnitude_index = sorted_magnitude_indices[-1]
            max_magnitude = magnitude[max_magnitude_index]
            min_magnitude = threshold_intensity * max_magnitude

            notes_magnitude_indices = np.where(magnitude > min_magnitude)

            note_frequencies = frequencies[notes_magnitude_indices]
            note_frequencies = note_frequencies[note_frequencies > 0]

            note_velocities = magnitude[notes_magnitude_indices]
            note_velocities = note_velocities[frequencies[notes_magnitude_indices] > 0]

            windows.append((note_frequencies, note_velocities))

        return windows

    def get_windowset_notes(noteset):
        
        note_sets: list[set[Note]] = []
        for frequency_list, magnitude_list in noteset:
            notes: set[Note] = set()
            note_sets.append(notes)
            for frequency, magnitude in zip(frequency_list, magnitude_list):
                note = NoteUtilities.get_note_from_frequency(frequency)
                note.velocity = magnitude

                notes.add(note)
                
        return note_sets
    
        
    def get_note_roll(note_sets, window_frame_duration, max_notes: int = 3):
        note_roll: list[NoteRollNote] = []
        
        pending_notes: set[NoteRollNote] = set()
        
        time = 0
        for i, note_set in enumerate(note_sets):
            #  The next note set is ust the next set of notes. For consistency, could be the next not empty set? But that wouldn't reallllly work because of the potential for repeated tones. They would get merged.
            next_note_set: set[Note] = note_sets[i + 1] if i + 1 < len(note_sets) else set()
            
            # LEt's see how the output looks if we smooth notes over.
            # Note: After testing, this does indeed NOT work well. 
            
            # next_note_set: set[Note] = None
            
            # for j in range(i + 1, len(note_sets)):
            #     if len(note_sets[j]) > 0:
            #         next_note_set = note_sets[j]
            #         break
            
            # if next_note_set is None:
            #     next_note_set = set()
            
            # check the pending_notes note set. If the note is in the next note set, dont't do anything.
            # if it is not in the next note set, then remove it from the pending note set and add it to the note roll with end time of now.
            
            to_remove = set()
            for note in pending_notes:
                if note.note in next_note_set:
                    continue
                
                note.end_time = time
                note_roll.append(note)
                to_remove.add(note)
                
            pending_notes.difference_update(to_remove)
            
            note_set = sorted(note_set, key=lambda note: -1 * note.velocity)
            for note in note_set:
                roll_note = NoteRollNote(note, time)
                if roll_note not in pending_notes and len(pending_notes) <= max_notes:
                    pending_notes.add(roll_note)
            
            time += window_frame_duration
            time = round(time, ROUNDING_SPECIFICITY)
            
        note_roll = sorted(note_roll, key=lambda note: (note.start_time, note.end_time))
        return note_roll
    
    def cleanup_note_roll(note_roll, min_duration=0.01):
        cleaned_note_roll = []
        for note in note_roll:
            if note.end_time - note.start_time > min_duration:
                cleaned_note_roll.append(note)
                
        return cleaned_note_roll
    
    def create_wave(note_roll, sampling_frequency):
        length =  int(max(note.end_time for note in note_roll) * sampling_frequency)
        output = np.zeros(length)
        
        # Iterate through each note, generate its wave, calculate its start index, and ADD the wave's values to the existing values at its index.
        for note in note_roll:
            wave = note.generate_wave(sampling_frequency)
            start_index = int(note.start_time * sampling_frequency)
            output[start_index : start_index + len(wave)] += wave
        
    
        return output / np.max(output)
    
    
    def create_midi_notes(note_roll):
        output: np.NDArray[pretty_midi.Note] = np.zeros(len(note_roll)).astype(pretty_midi.Note)
        
        for i, note in enumerate(note_roll):
            midi_note = pretty_midi.Note(
                velocity=note.note.velocity,
                pitch=note.note.midi_pitch(),
                start=note.start_time,
                end=note.end_time,
                )
            
            output[i] = midi_note
        
        return output
    
    def harmonic(t, f1=1, alist=1, philist=0):
        # If alist and φlist are scalar values, convert them to lists of length 1
        if np.isscalar(alist):
            alist = [alist]
            
        common_phase = philist if np.isscalar(philist) else None
        
        frequency_values = []
        if common_phase is not None:
            for i, a in enumerate(alist):
                frequency_values.append(a * NoteUtilities.cosinewave(t, (i + 1)*f1, common_phase))
        else:
            for i, (a, phi) in enumerate(zip(alist, philist)):
                frequency_values.append(a * NoteUtilities.cosinewave(t, (i + 1)*f1, phi))
                
        return np.sum(frequency_values, axis=0)
    
    def cosinewave(time , frequency = 1.0, delay=0.0):
        return np.cos(2 * np.pi * frequency * (time - delay))



In [None]:
# Before I do this on song sections, let's make sure it works on a sine wave!

fs = 44_100
duration = 5

sine_frequency = 5000

n_samples = int(fs * duration)

x = np.linspace(0, duration, n_samples, endpoint=False)
y = np.sin(2 * np.pi * sine_frequency * x)


sinewave_noteset = NoteUtilities.get_note_frequency_window_set(y, 1, fs)

print(sinewave_noteset)

In [None]:
rate, song_data = wavfile.read("./data/LZSTH (1).wav")

mono_data = song_data

if song_data.shape[1] == 2:
    mono_data = np.average(song_data, axis=1)
    
    

time = 0.4 # seconds of audio
sample_at_time = rate * time
plot_fs = 10000
n_samples = int(time * plot_fs)

x = np.linspace(0, sample_at_time, n_samples).astype(np.int32)
y = mono_data[x]


plt.plot(x, y)
plt.show()

In [None]:
# The window duration is the duration of the window in whatever unit the rate references. This is seconds if using the rate reported by scipy.io.wavfile.read.
window_duration = 0.1
# the intensity threshold is the minimum intensity of a frequency present to be considered a note. dropping this to 0.01 results in something that sounds like an awful accordion. 0.95 sounds very blocky.
intensity_threshold = 0.2

# Performance increases as both of these numbers are increased. 
data_note_sets = NoteUtilities.get_note_frequency_window_set(mono_data, window_duration, rate, intensity_threshold)

print(data_note_sets[0])

This works nicely, but the notes are very close in nature and don't give a nice representation of a musical note. Let's define our musical notation using some classes, then transform these arrays into musical notes by rounding their frequencies!

Let's test out our utilities on some sample frequencies!

In [None]:
close_to_c_octave_4 = NoteUtilities.get_note_from_frequency(261)
closer_to_c_sharp_octave_4 = NoteUtilities.get_note_from_frequency(270)
closer_to_a_sharp_octave_2 = NoteUtilities.get_note_from_frequency(116)

print(close_to_c_octave_4.note.name, close_to_c_octave_4.octave)
print(closer_to_c_sharp_octave_4.note.name, close_to_c_octave_4.octave)
print(closer_to_a_sharp_octave_2.note.name, closer_to_a_sharp_octave_2.octave)


In [None]:
print(sinewave_noteset)

note_sets = NoteUtilities.get_windowset_notes(sinewave_noteset)

for note_set in note_sets:
    for note in note_set:
        print(note.note.name, note.octave)

Now, let's try this on the actual song sample!

In [None]:

note_sets = NoteUtilities.get_windowset_notes(data_note_sets)

for i, note_set in enumerate(note_sets[:10]):
    print(F"Set {i}:")
    for note in note_set:
        print(note.note.name, note.octave)

One problem here is we cannot estimate the note duration, but it would seem strange to replay the note over and over again just to keep the note going. So, we need that note to continue playing while it is in subsequent note sets. Ultimately, we need to be able to generate a set of notes, their start time, end times, and velocities.

Each note set's interval is determined by the length of time the window represents. Each note duration is represented by how many frames the note appears in in a row, and the velocity is already known. From this, we can create our list!

In [None]:
# the maximum number of notes in any given section playing simultaneously. Notes still playing are preferred.
max_notes = 3

note_roll = NoteUtilities.get_note_roll(note_sets, window_duration, max_notes)
clean_note_roll = NoteUtilities.cleanup_note_roll(note_roll, min_duration=0.001)
 
for note in clean_note_roll[:10]:
    print(note.start_time, note.end_time, note.note.note.name, note.note.octave, note.note.velocity)    

In [None]:
duration = clean_note_roll[0].end_time - clean_note_roll[0].start_time
t = np.linspace(0, duration, int(duration * rate), False)

wave1 = clean_note_roll[0].generate_wave(rate)
wave2 = clean_note_roll[1].generate_wave(rate)

plt.plot(t, wave1)


duration = clean_note_roll[1].end_time - clean_note_roll[1].start_time
t = np.linspace(0, duration, int(duration * rate), False)
plt.plot(t, wave2)

plt.show()

In [None]:
note_samples = int(rate * 10.0)

wave = NoteUtilities.create_wave(clean_note_roll, rate)

duration =  max(note.end_time for note in clean_note_roll)

n_samples = int(duration * rate)

x = np.linspace(0, duration, n_samples, endpoint=False)

plt.plot(x, wave)

print(len(wave))

Let's listen and see how good the approximation in note-form is!

In [None]:
from IPython.display import Audio

Audio(wave, rate=rate)

saveable  = (wave / np.max(wave))
saveable = (wave * (2 ** 16 - 1)) - 2 ** 15
saveable = saveable.astype(np.int16)

# saveable = saveable.astype(np.int16)
print(np.max(saveable), np.min(saveable))


wavfile.write("output.wav", rate, saveable[:int(len(saveable)/10)])

While generating the tones manually works, there appears to be a lot of noise. Let's try again using pretty_midi to generate the sound!

In [None]:
midi_notes = NoteUtilities.create_midi_notes(clean_note_roll)

cello_program = pretty_midi.instrument_name_to_program('Cello')
cello = pretty_midi.Instrument(program=cello_program)

for note in midi_notes:
    cello.notes.append(note)
    
midi = pretty_midi.PrettyMIDI()
midi.instruments.append(cello)

wav = midi.synthesize(fs=rate)

Audio(wav, rate=rate)

Either way, the generated signal is similarly noisy. This is likely due to clashing notes in the song's generation. Let's try reducing the noise using a few common techniques: moving averages and wavelets.

In [None]:

def moving_average(signal, window_size):
    cumsum = np.cumsum(np.insert(signal, 0, 0)) 
    return (cumsum[window_size:] - cumsum[:-window_size]) / window_size

ma_window_size = 20
ma_smoothed_signal = moving_average(wave, ma_window_size)

print("Moving average smoothed signal:")
Audio(ma_smoothed_signal, rate=rate)


In [None]:

def wavelet_denoising(signal, wavelet, level):
    coeffs = pywt.wavedec(signal, wavelet, level=level)
    sigma = np.median(np.abs(coeffs[-1])) / 0.6745
    uthresh = sigma * np.sqrt(2 * np.log(len(signal)))
    denoised_coeffs = [pywt.threshold(c, value=uthresh, mode='soft') for c in coeffs]
    return pywt.waverec(denoised_coeffs, wavelet)


wavelet = 'db4'
level = 4
wl_denoised_signal = wavelet_denoising(wave, wavelet, level)

print("Wavelet Denoised Signal:")
Audio(wl_denoised_signal, rate=rate)


Well, the moving average sounds less noisey and almost decent! Most importantly, key characteristics of musical nature are present, and this is something our network should learn from!
Let's transform this into a form that is consistent regardless of the number of notes present.

In [None]:
sparse_velocity_set = NoteUtilities.note_sets_to_sparse_velocity(note_sets)
print(sparse_velocity_set[100])


We also need the ability to transform a velocity set back into note sets once we have predicted new velocity sets. 

In [None]:
note_sets_reconstructed = NoteUtilities.sparse_velocity_to_note_sets(sparse_velocity_set)

any_mismatch = False
for original_set, recon_set in zip(note_sets, note_sets_reconstructed):
    for original, recon in zip(original_set, recon_set):
        if original.note != recon.note or original.octave != recon.octave:
            any_mismatch = True
            print("Mismatch detected!")
            break

if not any_mismatch:
    print("No mismatches detected!")
    
    
for note in note_sets[100]:
    print(note.note.name, note.octave, note.velocity)
    
for note in note_sets_reconstructed[100]:
    print(note.note.name, note.octave, note.velocity)
    


Now that we have sets of consistent size, we can break these down into sequences. Each sequence will be used as input to the next sequence with an offset of one, providing input and output sets.

In [None]:

sequence_length = 100


# @njit
def create_sequence_windows(note_sets: list[set[Note]], sequence_length: int):
    length = len(note_sets) - sequence_length
    
    sequences = []
    for i in range(length):
        sequence = np.array([np.array(list(s)) for s in note_sets[i : i + sequence_length + 1]])
        sequences.append(sequence)
        
    return np.array(sequences)


# @njit
def create_training_sets(sequence_windows: list[set[Note]]):
    X = []
    y = []
    
    for i, sequence in enumerate(sequence_windows):
        X.append(np.array(sequence[:-1]))
        y.append(np.array(sequence[-1]))
        
    return np.array(X), np.array(y)

sparse_velo = NoteUtilities.note_sets_to_sparse_velocity(note_sets)
sequence_data = create_sequence_windows(sparse_velo, sequence_length)
X, y = create_training_sets(sequence_data)

print(X.shape, y.shape)
print(X[0][0], y[0])

Now, let's shape a model!

In [None]:
def create_model():
    vocab_size = NoteUtilities.get_vocab_size()
    input_shape = (sequence_length, vocab_size)

    input_layer = keras.layers.Input(input_shape)
    lstm_layer = keras.layers.LSTM(vocab_size, activation="relu")(input_layer)
    lstm_layer = keras.layers.Reshape([vocab_size, 1])(lstm_layer)
    lstm_layer = keras.layers.LSTM(vocab_size)(lstm_layer)
    
    
    #  ReLu is perfect for the output layer because we want the vast majority of notes to have a velocity of 0 and t be inactive in the output.
    output = keras.layers.Dense(vocab_size, activation="relu")(lstm_layer)

    model = keras.Model(input_layer, output)

    return model

model = create_model()
model.summary()

In [None]:
learning_rate = 0.005
optimizer = keras.optimizers.Adam(learning_rate)

def mse_with_positive_pressure(y_true: tf.Tensor, y_pred: tf.Tensor):
    mse = tf.square(y_true - y_pred)
    positive_pressure = 10 * tf.maximum(-y_pred, 0.0)
    
    return tf.reduce_mean(mse + positive_pressure)


loss  = mse_with_positive_pressure


# Test out loss function!
print(y[0].shape)
y_true_test = y[1]
y_pred_test = y[0]

print(mse_with_positive_pressure(y_true_test, y_pred_test))


In [None]:
model.compile(optimizer=optimizer, loss=loss)

# Ensure we can fit our model:
model.fit(X[:10], y[:10], epochs=1, batch_size=1, validation_split=0.2)

print(len(X[10]), len(y[10]))

y_pred_test_out = model.predict(X[9:10])


y_true_test = y[9]
y_pred_test = tf.squeeze(y_pred_test_out)

print(f"Loss: {loss(y_true_test, y_pred_test)}")


In [None]:
callbacks = [
    keras.callbacks.ReduceLROnPlateau(),
    keras.callbacks.ModelCheckpoint("checkpoints/audio_parsing/checkpoint_{epoch}", save_weights_only=True),
    keras.callbacks.EarlyStopping(monitor="loss", patience=5, verbose=1, restore_best_weights=True)
]

epochs = 100

# model.fit(X, y, epochs=epochs, batch_size=100, callbacks=callbacks, validation_split=0.2)
