## Lecture 3 Onset Detection - Tutorial Solution

In [None]:
import os
import numpy as np
from matplotlib import pyplot as plt
import librosa
from scipy.signal import medfilt as medfilt
from scipy.ndimage import maximum_filter1d as maxfilt
import sys
sys.path.append('../labs/lab2/')  # location of onset.py
import evaluate
import IPython.display as ipd


def rms_odf(filename, windowTime, hopTime):
    """Computes a baseline onset detection function using RMS magnitude.
    Parameters:
        filename: path of the input audio file
        windowTime: window size for analysis (in seconds)
        hopTime: hop size (in seconds)
    Returns:
        A vector containing the root mean square magnitude for each frame of audio
    """
    snd, rate = librosa.load(filename, sr=None)
    hop = round(rate * hopTime)
    # round up to next power of 2
    wlen = int(2 ** np.ceil(np.log2(rate * windowTime)))
    # centre frames: first frame at t=0
    snd = np.concatenate([np.zeros(wlen//2), snd, np.zeros(wlen//2)])
    frameCount = int(np.floor((len(snd) - wlen) / hop + 1))
    odf = np.zeros(frameCount)
    window = np.hamming(wlen)
    for i in range(frameCount):
        start = i * hop
        frame = np.fft.fft(snd[start: start+wlen] * window)
        odf[i] = np.sqrt(np.mean(np.power(np.abs(frame), 2)))
    mx = max(odf)
    if mx > 0:
        odf /= mx
    return odf


def getOnsets(odf, hopTime, wd=3, thr=0):
    """Performs peak-picking, filtering and thresholding on an onset detection
    function to compute the times of each onset.
    Parameters:
        odf: a vector of onset detection function values, one for each frame
        hopTime: time in seconds between frames of the ODF
        wd: width of filter for median and maximum filtering
        thr: threshold (minimum difference between a peak value and the median
            of the ODF for a peak to be accepted as an onset)
    Returns:
        A vector of onset times in seconds
    """
    t = np.arange(len(odf)) * hopTime
    medFiltODF = odf - medfilt(odf, wd)
    maxFiltODF = maxfilt(medFiltODF, wd, mode='nearest', axis=0)
    threshold = [max(i, thr) for i in maxFiltODF]
    peakIndices = np.nonzero(np.greater_equal(medFiltODF, threshold))
    peakTimes = peakIndices[0] * hopTime
    plt.figure(figsize=(14, 3))
    plt.xlabel('Time (s)')
    plt.ylabel('RMS ODF')
    plt.plot(t, medFiltODF, 'c')
    plt.plot(t, threshold, 'y')
    return peakTimes


dataDir = '../labs/lab2/data'   # location of data files
windowSize = 0.04
hopSize = 0.01
filterWidth = 5
threshold = 0.01
for entry in os.scandir(dataDir):
    if entry.name[-4:] == '.wav':
        trueOnsets = np.fromfile(entry.path[:-4] + '.csv', sep = ', ')
        odf = rms_odf(entry.path, windowSize, hopSize)
        estimatedOnsets = getOnsets(odf, hopSize, filterWidth, threshold)
        result = evaluate.evaluate(estimatedOnsets, trueOnsets)
        print('Input file: ' + entry.name)
        print('ODF    C  FP  FN   F    Err')
        print('RMS: {:3.0f} {:3.0f} {:3.0f} {:5.3f} {:5.3f}'.format(*result[0]))
        plt.title('Onsets found (red) vs ground truth (green) for file: ' + entry.name)
        for j in trueOnsets:
            plt.axvline(j, ymin=0.5, color='g')
        for j in estimatedOnsets:
            plt.axvline(j, ymax=0.5, color='k')
        ipd.display(ipd.Audio(entry.path))