In [None]:
import os
import tensorflow as tf
import keras
import numpy as np
import scipy.io.wavfile as wavfile
from scipy.fftpack import fft
import pretty_midi
import keras
import re


ROUNDING_SPECIFICITY = 4

from enum import Enum

class BaseNote(Enum):
    C = 16.35
    C_SHARP = 17.32
    D = 18.35
    D_SHARP = 19.45
    E = 20.60
    F = 21.83
    F_SHARP = 23.12
    G = 24.50
    G_SHARP = 25.96
    A = 27.50
    A_SHARP = 29.14
    B = 30.87
    
class MIDIPitch(Enum):
    C0 = 12
    C_SHARP0 = 13   
    D0 = 14
    D_SHARP0 = 15
    E0 = 16
    F0 = 17
    F_SHARP0 = 18
    G0 = 19
    G_SHARP0 = 20
    # Everything before 21 is not really in the midi-set, however this enum is also used for determining the note's positioning in sparse arrays, so I want this encompassing more than just the midi set.
    A0 = 21
    A_SHARP0 = 22
    B0 = 23
    C1 = 24
    C_SHARP1 = 25
    D1 = 26
    D_SHARP1 = 27
    E1 = 28
    F1 = 29
    F_SHARP1 = 30
    G1 = 31
    G_SHARP1 = 32
    A1 = 33
    A_SHARP1 = 34
    B1 = 35
    C2 = 36
    C_SHARP2 = 37
    D2 = 38
    D_SHARP2 = 39
    E2 = 40
    F2 = 41
    F_SHARP2 = 42
    G2 = 43
    G_SHARP2 = 44
    A2 = 45
    A_SHARP2 = 46
    B2 = 47
    C3 = 48
    C_SHARP3 = 49
    D3 = 50
    D_SHARP3 = 51
    E3 = 52
    F3 = 53
    F_SHARP3 = 54
    G3 = 55
    G_SHARP3 = 56
    A3 = 57
    A_SHARP3 = 58
    B3 = 59
    C4 = 60
    C_SHARP4 = 61
    D4 = 62
    D_SHARP4 = 63
    E4 = 64
    F4 = 65
    F_SHARP4 = 66
    G4 = 67
    G_SHARP4 = 68
    A4 = 69
    A_SHARP4 = 70
    B4 = 71
    C5 = 72
    C_SHARP5 = 73
    D5 = 74
    D_SHARP5 = 75
    E5 = 76
    F5 = 77
    F_SHARP5 = 78
    G5 = 79
    G_SHARP5 = 80
    A5 = 81
    A_SHARP5 = 82
    B5 = 83
    C6 = 84
    C_SHARP6 = 85
    D6 = 86
    D_SHARP6 = 87
    E6 = 88
    F6 = 89
    F_SHARP6 = 90
    G6 = 91
    G_SHARP6 = 92
    A6 = 93
    A_SHARP6 = 94
    B6 = 95
    C7 = 96
    C_SHARP7 = 97
    D7 = 98
    D_SHARP7 = 99
    E7 = 100
    F7 = 101
    F_SHARP7 = 102
    G7 = 103
    G_SHARP7 = 104
    A7 = 105
    A_SHARP7 = 106
    B7 = 107
    C8 = 108
    C_SHARP8 = 109
    D8 = 110
    D_SHARP8 = 111
    E8 = 112
    F8 = 113
    F_SHARP8 = 114
    G8 = 115
    G_SHARP8 = 116
    A8 = 117
    A_SHARP8 = 118
    B8 = 119
    C9 = 120
    C_SHARP9 = 121
    D9 = 122
    D_SHARP9 = 123
    E9 = 124
    F9 = 125
    F_SHARP9 = 126
    G9 = 127
    
    
    
class Note:
    def __init__(self, note: BaseNote, octave: int, velocity: float = 1.0):
        self.note = note
        self.octave = octave
        self.velocity = velocity
        
    
    @staticmethod
    def from_midi_pitch(midi_pitch: MIDIPitch, velocity: float = 1.0):
        note_name = midi_pitch.name[:-1]
        note_octave = int(midi_pitch.name[-1])
        
        
        return Note(BaseNote[note_name], note_octave, velocity)
        
        
    def midi_name(self):
        return f"{self.note.name}{self.octave}"
    
    def midi_pitch(self):
        return  MIDIPitch[self.midi_name()].value
        
    def __eq__(self, other):
        if isinstance(other, Note):
            return self.note == other.note and self.octave == other.octave
        return False
        
    def __hash__(self):
        return hash((self.note.value, self.octave))
        
    def get_frequency(self):
        return self.note.value * (2 ** self.octave)


class NoteRollNote:
    def __init__(self, note: Note, start_time: float, end_time: float = 0):
        self.note = note
        self.start_time = start_time
        self.end_time = end_time
        
        
    def __eq__(self, other):
        if isinstance(other, NoteRollNote):
            return self.note.__eq__(other.note)
        return False
        
        
    def __hash__(self):
        return hash(self.note.__hash__())
    
    def generate_wave(self, sampling_frequency):
        duration = self.end_time - self.start_time
        t = np.linspace(0, duration, int(duration * sampling_frequency), False)
        
        wave = NoteUtilities.harmonic(t, self.note.get_frequency(), 1, -np.pi / 2) 
        # * self.note.velocity # This has a tendancy to really drown out sounds.
        
        return wave 
        

class NoteUtilities:
    def __init__(self):
        self.notes = list(BaseNote)

    @staticmethod
    def enumerate_circle_of_fifths(start_octave=0, end_octave=8):
        for octave in range(start_octave, end_octave + 1):
            for note in BaseNote:
                yield (BaseNote(note.value), note.value * (2 ** octave), octave)
                
    
    @staticmethod
    def get_vocab_size(start_octave=0, end_octave=8):
        return len(MIDIPitch) 
               
                
    frequency_note_map = {}
    @staticmethod
    def get_note_from_frequency(frequency, start_octave= 0, end_octave = 8):
        rounded_frequency = round(frequency)
        if(round(frequency, 2) in NoteUtilities.frequency_note_map):
            return NoteUtilities.frequency_note_map[rounded_frequency]
        
        closest_note: Note = None
        min_difference = float('inf')
        for base_note, note_frequency, octave in NoteUtilities.enumerate_circle_of_fifths(start_octave, end_octave):
            difference = abs(frequency - note_frequency)
            if difference < min_difference:
                min_difference = difference
                closest_note = Note(base_note, octave)
            
            # Break early, since we're now going away from the note frequency
            if difference > min_difference:
                break
            
        NoteUtilities.frequency_note_map[rounded_frequency] = closest_note
        return closest_note
    
    
    @staticmethod
    def note_sets_to_sparse_velocity(note_sets: list[set[Note]]):
        vocab_size = NoteUtilities.get_vocab_size()

        
        sparse_array_set: np.ndarray[np.ndarray[np.float32]] = np.zeros(len(note_sets)).astype(np.ndarray)

        for idx, note_set in enumerate(note_sets):
            # Determine the positions that the note-set occupies in the sparse array. Conveniently, this is the midi pitch (expanded) minus the first midi pitch value.
            sparse_array: np.ndarray[np.float32] = np.zeros(vocab_size)
            for note in note_set:
                index = note.midi_pitch() - MIDIPitch.C0.value
                sparse_array[index] = note.velocity
                
            sparse_array_set[idx] = sparse_array
    
        return sparse_array_set
    
    @staticmethod
    def sparse_velocity_to_note_sets(sparse_array_set):
        note_sets: list[set[Note]] = []
        for sparse_array in sparse_array_set:
            note_set: set[Note] = set()
            for i, velocity in enumerate(sparse_array):
                if velocity > 0:
                    midi_pitch_index = i + MIDIPitch.C0.value
                    midi_pitch = MIDIPitch(midi_pitch_index)
                    note = Note.from_midi_pitch(midi_pitch, velocity)
                    note_set.add(note)
            
            note_sets.append(note_set)
        
        return note_sets
       
    
    @staticmethod
    def get_note_frequency_window_set(
        data: np.ndarray,
        sample_duration: float,
        sampling_frequency: int,
        threshold_intensity: float = 0.1,
    ):
        """
        Creates windows of notes for the input duration in the form [frequency, velocity].
        """

        windows = []
        samples_per_duration = max(int(sample_duration * sampling_frequency), 1)

        for i in range(0, len(data), samples_per_duration):
            window = data[i : i + samples_per_duration]
            fft_spectrum = fft(window)
            frequencies = np.fft.fftfreq(len(fft_spectrum), 1 / sampling_frequency)
            magnitude = np.abs(fft_spectrum)

            sorted_magnitude_indices = np.argsort(magnitude)

            max_magnitude_index = sorted_magnitude_indices[-1]
            max_magnitude = magnitude[max_magnitude_index]
            min_magnitude = threshold_intensity * max_magnitude

            notes_magnitude_indices = np.where(magnitude > min_magnitude)

            note_frequencies = frequencies[notes_magnitude_indices]
            note_frequencies = note_frequencies[note_frequencies > 0]

            note_velocities = magnitude[notes_magnitude_indices]
            note_velocities = note_velocities[frequencies[notes_magnitude_indices] > 0]

            windows.append((note_frequencies, note_velocities))

        return windows

    def get_windowset_notes(noteset):
        
        note_sets: list[set[Note]] = []
        for frequency_list, magnitude_list in noteset:
            notes: set[Note] = set()
            note_sets.append(notes)
            for frequency, magnitude in zip(frequency_list, magnitude_list):
                note = NoteUtilities.get_note_from_frequency(frequency)
                note.velocity = magnitude

                notes.add(note)
                
        return note_sets
    
        
    def get_note_roll(note_sets, window_frame_duration, max_notes: int = 3):
        note_roll: list[NoteRollNote] = []
        
        pending_notes: set[NoteRollNote] = set()
        
        time = 0
        for i, note_set in enumerate(note_sets):
            note_set = sorted(note_set, key=lambda note: -1 * note.velocity)
            for note in note_set:
                roll_note = NoteRollNote(note, time)
                roll_note.end_time = time + window_frame_duration
                if len(pending_notes) <= max_notes:
                    note_roll.append(roll_note)
                #     pending_notes.add(roll_note)
            
            time += window_frame_duration
            time = round(time, ROUNDING_SPECIFICITY)
            
        note_roll = sorted(note_roll, key=lambda note: (note.start_time, note.end_time))
        return note_roll
    
    def cleanup_note_roll(note_roll, min_duration=0.01):
        cleaned_note_roll = []
        for note in note_roll:
            if note.end_time - note.start_time > min_duration:
                cleaned_note_roll.append(note)
                
        return cleaned_note_roll
    
    def create_wave(note_roll, sampling_frequency):
        length =  int(max(note.end_time for note in note_roll) * sampling_frequency)
        output = np.zeros(length)
        
        # Iterate through each note, generate its wave, calculate its start index, and ADD the wave's values to the existing values at its index.
        for note in note_roll:
            wave = note.generate_wave(sampling_frequency)
            start_index = int(note.start_time * sampling_frequency)
            output[start_index : start_index + len(wave)] += wave
        
    
        return output / np.max(output)
    
    
    def create_midi_notes(note_roll):
        output: np.NDArray[pretty_midi.Note] = np.zeros(len(note_roll)).astype(pretty_midi.Note)
        
        for i, note in enumerate(note_roll):
            midi_note = pretty_midi.Note(
                velocity=note.note.velocity,
                pitch=note.note.midi_pitch(),
                start=note.start_time,
                end=note.end_time,
                )
            
            output[i] = midi_note
        
        return output
    
    def harmonic(t, f1=1, alist=1, philist=0):
        # If alist and φlist are scalar values, convert them to lists of length 1
        if np.isscalar(alist):
            alist = [alist]
            
        common_phase = philist if np.isscalar(philist) else None
        
        frequency_values = []
        if common_phase is not None:
            for i, a in enumerate(alist):
                frequency_values.append(a * NoteUtilities.cosinewave(t, (i + 1)*f1, common_phase))
        else:
            for i, (a, phi) in enumerate(zip(alist, philist)):
                frequency_values.append(a * NoteUtilities.cosinewave(t, (i + 1)*f1, phi))
                
        return np.sum(frequency_values, axis=0)
    
    def cosinewave(time , frequency = 1.0, delay=0.0):
        return np.cos(2 * np.pi * frequency * (time - delay))



# @njit
def create_sequence_windows(note_sets: list[set[Note]], sequence_length: int):
    length = len(note_sets) - sequence_length
    
    sequences = []
    for i in range(length):
        sequence = np.array([np.array(list(s)) for s in note_sets[i : i + sequence_length + 1]])
        sequences.append(sequence)
        
    return np.array(sequences)


# @njit
def create_training_sets(sequence_windows: list[set[Note]]) -> tuple[np.ndarray, np.ndarray]:
    X = []
    y = []
    
    max = 0
    
    for i, sequence in enumerate(sequence_windows):
        seq_max = np.max(sequence)
        
        if(seq_max > max):
            max = seq_max
        
        X.append(np.array(sequence[:-1]))
        y.append(np.array(sequence[-1]))
        
        
    X = X / max
    y = y / max
        
    return np.array(X), np.array(y)



def create_model(lstm_neurons: int = 256, sequence_length: int = 100):
    vocab_size = NoteUtilities.get_vocab_size()
    input_shape = (sequence_length, vocab_size)

    input_layer = keras.layers.Input(input_shape)
    lstm_layer = keras.layers.LSTM(lstm_neurons, activation="relu", )(input_layer)
    lstm_layer = keras.layers.Reshape([lstm_neurons, 1])(lstm_layer)
    lstm_layer = keras.layers.LSTM(lstm_neurons)(lstm_layer)
    
    
    #  ReLu is perfect for the output layer because we want the vast majority of notes to have a velocity of 0 and t be inactive in the output.
    output = keras.layers.Dense(vocab_size, activation="relu")(lstm_layer)

    model = keras.Model(input_layer, output)

    return model



def mse_with_positive_pressure(y_true: tf.Tensor, y_pred: tf.Tensor):
    mse = tf.square(y_true - y_pred)
    positive_pressure = 10 * tf.maximum(-y_pred, 0.0)
    
    return tf.reduce_mean(mse + positive_pressure)



def parse_song(song_file: str):
    rate, song_data = wavfile.read(song_file)

    mono_data = song_data

    if song_data.shape[1] == 2:
        mono_data = np.average(song_data, axis=1)

    data_note_sets = NoteUtilities.get_note_frequency_window_set(mono_data, training_parameters["window_duration"], rate, training_parameters["intensity_threshold"])
    note_sets = NoteUtilities.get_windowset_notes(data_note_sets)
    sparse_velo = NoteUtilities.note_sets_to_sparse_velocity(note_sets)
    sequence_data = create_sequence_windows(sparse_velo, training_parameters["sequence_length"])
    return create_training_sets(sequence_data)


def parse_song_directory(directory: str):
    for root, dirs, files in os.walk(directory):
        Xs = []
        ys = []

        for file in files:
            full_name = os.path.abspath(os.path.join(root, file))
            X, y = parse_song(full_name)
            Xs.append(X)
            ys.append(y)
        
        X = np.concatenate(Xs, axis=0)
        y = np.concatenate(ys, axis=0)
        
        return X, y



training_parameters = {
    "learning_rate": 0.0001,
    "batch_size": 100,
    "window_duration": 0.1,
    "intensity_threshold": 0.2,
    "lstm_neurons": 256,
    "sequence_length": 100,
    "epochs": 1000
}

checkpoint_path = "checkpoints/FFT_Waveforms_with_LSTM/checkpoint_{epoch}"
directory = "./data/Wavs"



X, y = parse_song_directory(directory)

X = tf.constant(X)
y = tf.constant(y)

# Shuffling may seem odd because we're specifically trying to learn the sequential relationship, but we also want to be able to control this with the sequence length parameter.
dataset = tf.data.Dataset.from_tensor_slices((X, y))
shuffled_dataset = dataset.shuffle(buffer_size=len(X), reshuffle_each_iteration=True)
shuffled_X, shuffled_y = next(iter(shuffled_dataset.batch(len(X))))

# Use numpy arrays for training - this enables the use of a validation split.
shuffled_X = shuffled_X.numpy()
shuffled_y = shuffled_y.numpy()


callbacks = [
    keras.callbacks.ReduceLROnPlateau(monitor="val_loss"),
    keras.callbacks.ModelCheckpoint(checkpoint_path, save_weights_only=True),
    keras.callbacks.EarlyStopping(monitor="loss", patience=5, verbose=1, restore_best_weights=True)
]


model = create_model(training_parameters["lstm_neurons"])

optimizer = keras.optimizers.Adam(training_parameters["learning_rate"])
loss = mse_with_positive_pressure
model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])
model.fit(shuffled_X, shuffled_y, epochs=training_parameters["epochs"], batch_size=training_parameters["batch_size"], callbacks=callbacks, validation_split=0.2)


In [None]:

import os
import numpy as np
from scipy.io import wavfile

checkpoint_dir = os.path.dirname(checkpoint_path)
def get_last_checkpoint():
    pattern = r'checkpoint_(\d+)\.'
    files = os.listdir(checkpoint_dir)
    checkpoints = [int(re.match(pattern, file).group(1)) if re.match(pattern, file) else -1 for file in files if file.startswith("checkpoint")]
    return max(checkpoints)

checkpoint_path = f"{checkpoint_dir}/checkpoint_{get_last_checkpoint()}"

model = create_model()
model.load_weights(checkpoint_path)


def generate_song(song_duration: float, seed_note_sets: list[set[Note]]):
    windows_to_generate = int(song_duration / window_duration)
    
    input_sequence = np.array([np.vstack(NoteUtilities.note_sets_to_sparse_velocity(seed_note_sets))])
    
    generated = []
    
    for _ in range(windows_to_generate):
        predicted_note_set = model.predict(input_sequence)
        generated.append(predicted_note_set)
        input_sequence = np.concatenate((input_sequence, [predicted_note_set]), axis=1)[:, 1:]
        
    generated_stack = np.vstack(generated)
    generated_note_set = NoteUtilities.sparse_velocity_to_note_sets(generated_stack)   
    return generated_note_set

def generate_base_seed_set(sequence_length: int):
    seed_note_sets = []
    
    for _ in range(sequence_length):
        note_set = set()
        note_set.add(Note(BaseNote.C, 4))
        seed_note_sets.append(note_set)
            
    
    return seed_note_sets


def moving_average(signal, window_size):
    cumsum = np.cumsum(np.insert(signal, 0, 0)) 
    return (cumsum[window_size:] - cumsum[:-window_size]) / window_size



def save_wav(audio, rate, filename):
    audio = (audio / np.max(audio))
    saveable = (audio * (2 ** 16 - 1)) - 2 ** 15
    saveable = saveable.astype(np.int16)
    wavfile.write(filename, rate, saveable)

    
seconds_of_audio = 120
window_duration = 0.1
max_notes = 10
rate = 44_100
seed_note_sets = generate_base_seed_set(100)

song = generate_song(seconds_of_audio, seed_note_sets)
note_roll = NoteUtilities.get_note_roll(song, window_duration, max_notes)
clean_note_roll = NoteUtilities.cleanup_note_roll(note_roll, min_duration=0)
wave = NoteUtilities.create_wave(clean_note_roll, rate)


ma_window_size = 20
ma_smoothed_signal = moving_average(wave, ma_window_size)
save_wav(wave, rate, "./predictions/FFT_Waveforms_with_LSTM/predicted_song.wav")
