In [None]:
import sys

# This for managing relative imports from nb
if '..' not in sys.path: sys.path.append('..')
    
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from scipy.io import wavfile

import warnings
warnings.filterwarnings("ignore") # To supress WavFileWarning: Chunk (non-data) not understood, skipping it.

In [None]:
DEFAULT_BPM = 120
TIME_WINDOW_SEC = 2.0

In [None]:
def generate_dataset(wav_filename, time_window=TIME_WINDOW_SEC):
#{
    rate, data = wavfile.read(wav_filename)    
    audio_sec = data.shape[0] / rate
    
    print(f"Contents of audio file: {wav_filename}:")
    print("  Length of data:", audio_sec, "seconds")
    print("  Sample rate:", rate, "Hz")
    print("  Number of channels:", data.shape[1])
    print("  Samples per channel:", data.shape[0])
    
    # Zero-pad end of data to assure all sample are length TIME_WINDOW_SEC
    p = round((time_window - ((data.shape[0] / rate) % time_window)) * rate)
    data = np.concatenate((data, np.zeros((p, 2))))

    audio_sec = data.shape[0] / rate
    wav_step = int(time_window * rate)
    
    print(f"\nContents after padding for {time_window} sec window discritization:")
    print("  Length of data:", audio_sec, "seconds")
    print("  Samples per channel:", data.shape[0])
    
    # Note: wav data is in stereo, but for now, just using the first channel
    spectral = [signal.spectrogram(data[i:i+wav_step, 0], fs=1000, noverlap=128, nfft=256) \
                for i in range(0, data.shape[0], wav_step)]
    
    # Plot spectogram of a sample for good measure
    end_pt = int(len(spectral) - 1) * wav_step
    mid_pt = int(len(spectral) / 2) * wav_step
    
    plt.figure(1)
    plt.title(f"Spectogram of sample {int(len(spectral) / 2)} length {time_window} sec ")
    plt.specgram(data[mid_pt:mid_pt+wav_step, 0], NFFT=256, Fs=1000, noverlap=128)
#     plt.specgram(data[0:0+wav_step, 0], NFFT=256, Fs=1000, noverlap=128)
#     plt.specgram(data[end_pt:end_pt+wav_step, 0], NFFT=256, Fs=1000, noverlap=128)
    
    f, t, pxx = spectral[0]
    n_freq = pxx.shape[0]                 # Number of frequencies per sample
    Tx = pxx.shape[1]                     # Number of time steps per sample
    m = len(spectral)                     # Number of TIME_WINDOW_SEC length samples
       
    # Write into network input form
    X = np.zeros((m, Tx, n_freq))
    for m_ind in range(m): X[m_ind,:,:] = np.transpose(spectral[m_ind][2])
    
    print("\nDerived spectral dataset:")
    print("  X dataset shape:", X.shape)
    print("  Number of frequencies n_freq:", n_freq)
    print("  Number of time steps per window Tx:", Tx)
    print("  Number of windowed samples m:", m, "\n")
    
    return X
#}

In [None]:
import mido
from mido import MidiFile

In [None]:
def tick2bin(tick, ppqn, Ty, time_window):
#{ 
    # Assumes 120 BPM = 500000 microseconds per beat (quarter note)
    return int((Ty/time_window) * mido.tick2second(tick, ppqn, tempo=500000))
#}

def bin2tick(time_bin, ppqn, Ty, time_window):
#{ 
    # Assumes 120 BPM = 500000 microseconds per beat (quarter note)
    return int(mido.second2tick(time_bin * (time_window/Ty), ppqn, tempo=500000))
#}

def note2index(note) : return (note - 21)
def index2note(index): return (index + 21)

In [None]:
def reshape2samples(contiguous_data, time_step, dim_tone):
#{
    """
    Cuts a contiguous, "full song" of data
    into a sequence of time_step samples
    
    time_step = likely Tx (5511) or Ty (1375)
    dim_tone = likely n_freq (101) or n_notes (88)
    """

    # Somewhat tricky little reshape maneuver
    T1 = np.transpose(contiguous_data)
    return np.reshape(T1,  (-1, time_step, dim_tone))
    #return np.transpose(R, axes=(0,2,1))
#}

# Just a quick demonstration of this fcn
# A = np.array([[1,2,3,4,5,6,-6, -5,-4,-3,-2,-1],
#               [7,8,9,10,11,12,-12,-11,-10,-9,-8,-7], 
#               [13,14,15,16,17,18,-18,-17,-16,-15,-14,-13]])

# print("A_orig =\n", A, '\n\nA_reshaped =\n', reshape2samples(A, 3, 3))

In [None]:
def format_label_data(midi_file, Ty, time_window, Y_full):
#{
    """
    Parses midi_file data and writes ground-truth
    note labels into a time scaled data-block Y
    
    midi_file = the single-song MIDI to parse
    Y = the time-scaled block of ground-truth labels
    
    """

    on_notes = {}
    accrued_ticks = 0
    ppqn = midi_file.ticks_per_beat
    
    for message in midi_file.tracks[0]:
    #{
        accrued_ticks += message.time
        if message.type == 'note_on':
        #{
            # Non-zero velocity --> note has been struck
            # Zero velocity --> note has been released

            if message.velocity > 0: 
                on_notes[message.note] = accrued_ticks

            else:

                # Scale tick-window to spectral-bin-window
                start = tick2bin(on_notes[message.note], ppqn, Ty, time_window)
                end   = tick2bin(accrued_ticks, ppqn, Ty, time_window)
                
                # Valid MIDI-note codes are: (21-108)
                # Scale them to output node index: (0-88)
                Y_full[note2index(message.note), start:end] = 1
        #}
    #}
    return reshape2samples(Y_full, Ty, n_tones)
#}

In [None]:
def generate_label_data(midi_filename, m, n_tones, Ty, time_window=TIME_WINDOW_SEC):
#{
    mid = MidiFile(midi_filename)
    tempo = mido.bpm2tempo(DEFAULT_BPM)
    
    assert len(mid.tracks) == 1, f"MIDI file contains invalid number of tracks: {len(mid.tracks)}"
 
    accrued_ticks = 0
    for message in mid.tracks[0]: accrued_ticks += message.time
    seconds = mido.tick2second(accrued_ticks, mid.ticks_per_beat, tempo)
    
    print("Output paramaters defined as:")
    print("  Number of output nodes n_tones:", n_tones)
    print("  Number of time steps per output window Ty:", Ty, "\n")

    print("Contents of MIDI file:", midi_filename)
    print("  Tempo:", DEFAULT_BPM, "BPM or", tempo, "micros/beat")
    print("  Number of ticks per beat (PPQN):", mid.ticks_per_beat)
    print("  Number of messages:", len(mid.tracks[0]))
    print("  Number of seconds of messages:", seconds)

    Y_zfull = np.zeros((n_tones, m * Ty))
    Y = format_label_data(mid, Ty, time_window, Y_zfull)                  
    
    # Plot lables
    mid_pt = int(m / 2)
    plt.figure(2)
    plt.title(f"Label data sample {mid_pt} length {time_window} sec ")
    plt.pcolormesh(np.transpose(Y[mid_pt,:,:]))

    print(f"\nDiscretized MIDI lables into {time_window} sec window samples:")
    print("  Y label dataset shape:", Y.shape, "\n")
        
    return Y
#}

In [None]:
# Generate TRAINING data dataset
X_train = generate_dataset('../data/audio/88_Key_Ascending_Chromatic_Scale.wav', TIME_WINDOW_SEC)

# Network inputs
m = X_train.shape[0]                  # Number of TIME_WINDOW_SEC length samples
Tx = X_train.shape[1]                 # Number of time steps per sample
n_freq = X_train.shape[2]             # Number of frequencies per sample

# Network ouput (defined by choice)
n_tones = 88                          # Number of output nodes == 88 on/off piano keys
Ty = min(Tx, 1375)                    # Number of time steps per TIME_WINDOW_SEC length output sample

# Generate matching label dataset
Y_train = generate_label_data('../data/midi_cleaned/88_Key_Ascending_Chromatic_Scale.mid', m, n_tones, Ty)

In [None]:
# Generate VALIDATION datasets
X_val = generate_dataset('../data/audio/88_Key_Descending_Chromatic_Scale.wav', TIME_WINDOW_SEC)
Y_val = generate_label_data('../data/midi_cleaned/88_Key_Descending_Chromatic_Scale.mid', m, n_tones, Ty)

print("Validating on:", X_val.shape[0], "samples")
assert (Tx == X_val.shape[1]) and (n_freq == X_val.shape[2]), \
    f"Validation data shape {X_val.shape} does not match training {X_train.shape}"

In [None]:
import keras
from keras.layers import Input, Conv1D, BatchNormalization, Activation, Dropout, GRU, TimeDistributed, Dense

In [None]:
# Create a Keras model

def generate_model(Tx, n_freq, Ty, n_tones):
#{
    X_input = Input(shape=(Tx, n_freq))

    # Step 1: CONV layer (≈4 lines)
    # X = Conv1D(196, kernel_size=15, strides=4)(X_conv_in) # CONV1D
    # X = BatchNormalization()(X)                           # Batch normalization
    # X = Activation('relu')(X)                             # ReLu activation
    # X = Dropout(0.2)(X)                                   # dropout (use 0.2)

    # Step 2: First GRU Layer (≈4 lines)
    X = GRU(units = 128, return_sequences = True)(X_input) # GRU (use 128 units and return the sequences)
    X = Dropout(0.2)(X)                                     # dropout (use 0.2)
    X = BatchNormalization()(X)                             # Batch normalization

    # Step 3: Second GRU Layer (≈4 lines)
    X = GRU(units = 128, return_sequences = True)(X)    # GRU (use 128 units and return the sequences)
    X = Dropout(0.2)(X)                                 # dropout (use 0.2)
    X = BatchNormalization()(X)                         # Batch normalization
    X = Dropout(0.2)(X)                                 # dropout (use 0.2)

    # Step 4: Time-distributed dense layer (≈1 line)
    X = TimeDistributed(Dense(n_tones, activation = "sigmoid"))(X) # time distributed  (sigmoid)

    model = keras.models.Model(inputs = X_input, outputs = X)
    model.summary()

    return model
#}

In [None]:
# Generate model and configure for training
model = generate_model(Tx, n_freq, Ty, n_tones)
opt = keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, decay=0.01)
early_stop = [keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)]
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['categorical_accuracy'])

In [None]:
# Train model
history = model.fit(X_train, Y_train, batch_size = m, validation_data=(X_val, Y_val), callbacks=early_stop, epochs=1000)

In [None]:
model.save('../data/h5/music_model_19_02_08.h5') 

# List all data in history
print(history.history.keys())

# Plot with respect to accuracy
plt.figure(1)
plt.plot(history.history['categorical_accuracy'])
plt.plot(history.history['val_categorical_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['train', 'validate'], loc='upper left')

# Plot with respect to loss
plt.figure(2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train', 'validate'], loc='upper left')

In [None]:
import mido

FLIP_TOLERANCE = 0.5

def output_midi_file(nn_output, filename, ppqn, Ty):
#{
    midi_file_out = mido.MidiFile(ticks_per_beat=ppqn)
    track_out = mido.MidiTrack()
    
    last_delta = 0
    accrued_steps = 0
    for sample in range(nn_output.shape[0]):
    #{
        step_flips = []
        for step in range(Ty-1):
        #{
            accrued_steps += 1
            total_ticks = bin2tick(accrued_steps, ppqn, Ty, TIME_WINDOW_SEC)
            
            # Look for note indices that flip on/off between steps
            delta = nn_output[sample, step+1, :] - nn_output[sample, step, :]
            on_idx = np.where(delta > FLIP_TOLERANCE)[0]
            off_idx = np.where(delta < -FLIP_TOLERANCE)[0]
            
            for on in on_idx:
                step_flips.append(mido.Message('note_on', note=index2note(on), velocity=64))
                
            for off in off_idx:
                step_flips.append(mido.Message('note_on', note=index2note(off), velocity=0))
            
            if len(step_flips) > 0:
            #{
                # Only first message/step can have non-zero relative time      
                step_flips[0].time = total_ticks - last_delta
                last_delta = total_ticks
                track_out.extend(step_flips)
            #}
        #}
    #}
    
    track_out.append(mido.MetaMessage('end_of_track'))
    midi_file_out.tracks.append(track_out)     
    midi_file_out.save(filename)
    
    accrued_ticks = 0
    tempo = mido.bpm2tempo(DEFAULT_BPM)
    for message in midi_file_out.tracks[0]: accrued_ticks += message.time
    seconds = mido.tick2second(accrued_ticks, midi_file_out.ticks_per_beat, tempo)
    
    print("Saved to MIDI file:", filename)
    print("  Number of ticks per beat (PPQN):", ppqn)
    print("  Number of messages:", len(midi_file_out.tracks[0]))
    print("  Number of seconds of messages:", seconds)
#}

In [None]:
output = model.predict(X_train)
ppqn = MidiFile('../data/midi_cleaned/88_Key_Ascending_Chromatic_Scale.mid').ticks_per_beat
output_midi_file(output, '../data/midi_output/88_Key_Ascending_Chromatic_Scale.mid', ppqn, Ty)

# output1 = model.predict(X_train[0:25,:,:])
# output2 = model.predict(X_train[25:,:,:])

# ppqn1 = MidiFile('data/midi_cleaned/Cry_me_a_river_simple.mid').ticks_per_beat
# ppqn2 = MidiFile('data/midi_cleaned/A_fine_romance_simple.mid').ticks_per_beat

# output_midi_file(output1, 'data/midi_output/Cry_me_a_river_output.mid', ppqn1)
# output_midi_file(output1, 'data/midi_output/A_fine_romance_output.mid', ppqn2)