In [359]:
import numpy as np      
import matplotlib.pyplot as plt 
import scipy.io.wavfile 
import subprocess
import librosa
import librosa.display
import IPython.display as ipd
import os
import re
import collections

from pathlib import Path, PurePath   
from tqdm.notebook import tqdm

### Utility functions

In [360]:
def convert_mp3_to_wav(audio:str) -> str:  
    """Convert an input MP3 audio track into a WAV file.

    Args:
        audio (str): An input audio track.

    Returns:
        [str]: WAV filename.
    """
    if audio[-3:] == "mp3":
        wav_audio = audio[:-3] + "wav"
        if not Path(wav_audio).exists():
                subprocess.check_output(f"ffmpeg -i {audio} {wav_audio}", shell=True)
        return wav_audio
    
    return audio

def plot_spectrogram_and_picks(track:np.ndarray, sr:int, peaks:np.ndarray, onset_env:np.ndarray) -> None:
    """[summary]

    Args:
        track (np.ndarray): A track.
        sr (int): Sampling rate.
        peaks (np.ndarray): Indices of peaks in the track.
        onset_env (np.ndarray): Vector containing the onset strength envelope.
    """
    times = librosa.frames_to_time(np.arange(len(onset_env)),
                            sr=sr, hop_length=HOP_SIZE)

    plt.figure()
    ax = plt.subplot(2, 1, 2)
    D = librosa.stft(track)
    librosa.display.specshow(librosa.amplitude_to_db(np.abs(D), ref=np.max),
                            y_axis='log', x_axis='time')
    plt.subplot(2, 1, 1, sharex=ax)
    plt.plot(times, onset_env, alpha=0.8, label='Onset strength')
    plt.vlines(times[peaks], 0,
            onset_env.max(), color='r', alpha=0.8,
            label='Selected peaks')
    plt.legend(frameon=True, framealpha=0.8)
    plt.axis('tight')
    plt.tight_layout()
    plt.show()

def load_audio_picks(audio, duration, hop_size):
    """[summary]

    Args:
        audio (string, int, pathlib.Path or file-like object): [description]
        duration (int): [description]
        hop_size (int): 

    Returns:
        tuple: Returns the audio time series (track) and sampling rate (sr), a vector containing the onset strength envelope
        (onset_env), and the indices of peaks in track (peaks).
    """
    try:
        track, sr = librosa.load(audio, duration=duration)
        onset_env = librosa.onset.onset_strength(track, sr=sr, hop_length=hop_size)
        peaks = librosa.util.peak_pick(onset_env, 10, 10, 10, 10, 0.5, 0.5)
    except Error as e:
        print('An error occurred processing ', str(audio))
        print(e)

    return track, sr, onset_env, peaks

### Settings

In [363]:
N_TRACKS = 1413
HOP_SIZE = 512
DURATION = 30
THRESHOLD = 15

In [364]:
data_folder = Path(os.path.expanduser("~/Downloads/mp3s-32k"))
mp3_tracks = data_folder.glob("*/*/*.mp3")
tracks = data_folder.glob("*/*/*.wav")

### Preprocessing

In [365]:
tracks_names = []
for track in tqdm(mp3_tracks, total=N_TRACKS):
    wav_name = convert_mp3_to_wav(str(track))
#print(tracks_names)

  0%|          | 0/1413 [00:00<?, ?it/s]

### Audio signals

### Minhash

In [367]:
class LSH:
    
    def __init__(self, n_rows, n_columns):
        self.matrix = np.zeros((n_rows, n_columns))
        self.i = 0
        self.query_signature_matrix = None
        self.signature_matrix = None
    
    
    def populate_matrix(self, tracks, tracks_names):
        """
        Populates the characteristic matrix by extracting the peaks and then calling the self.add method. 
        A string describing in a tidy fashion the data for each track is also created by using regular 
        expressions and then stored in self.tracks_names.
        """
        
        self.tracks_names = []
        for idx, audio in enumerate(tqdm(tracks, total=N_TRACKS)):
            track, sr, onset_env, peaks = load_audio_picks(audio, DURATION, HOP_SIZE)
            #plot_spectrogram_and_picks(track, sr, peaks, onset_env)
            self.add(peaks)
            track_string = str(audio)
            pattern = "mp3s-32k\/([^\/]+)\/([^\/]+)\/[0-9]{2}-([^\/]+).wav"
            match = re.search(pattern, track_string)
            author, album, title = match.group(1), match.group(2), match.group(3)
            author = author.replace("_", " ").capitalize()
            album = album.replace("_", " ").capitalize()
            title = title.replace("_", " ").capitalize()
            final_string = f"Author: {author}\nAlbum: {album}\nTitle: {title}"
            self.tracks_names.append(final_string)
            assert np.sum(self.matrix[:, idx]) == len(peaks)
    
    
    def add(self, entry):
        """
        Adds each entry to the characteristic matrix, self.matrix[i, j] = 1 if track j presents a peak at position i.
        """
        
        self.matrix[entry, self.i] = 1
        self.i += 1
    
    
    def _get_random_hashing_function(self, scale):
        """
        Obtains a random hashing function of the modular variety.
        """
        
        a = np.random.randint(1, 3*scale)
        b = np.random.randint(1, 3*scale)
        def hash_func(x):
            hashed = (a * x + b) % scale
            return hashed
        
        return hash_func
    
    
    def get_n_hashing_functions(self, n=100):
        """
        Calls _get_random_hashing_function n times and stores the resulting hash functions in self.hash_functions.
        """
        self.hash_functions = []
        
        for i in range(n):
            self.hash_functions.append(self._get_random_hashing_function(self.matrix.shape[0]))
    
    
    def perform_minhash(self, n=100):
        """
        Obtains the signature matrix for each track using minhash. As for current practice, rather than actually
        permuting the matrix n times, we simulated this procedure by creating n hashing functions and calling them
        for each row number of the characteristic matrix whenever a 1 is present.
        """
        # initalize matrix to infinite since we have to find the minimum values across the permutations simulated
        # by the hashing functions
        self.get_n_hashing_functions(n)
        self.signature_matrix = np.ones((len(self.hash_functions), self.matrix.shape[1])) * np.inf
        
        for row_i in range(self.matrix.shape[0]):
            for hash_i, hash_func in enumerate(self.hash_functions):
                 for j in range(self.matrix.shape[1]):
                        element = self.matrix[row_i, j]
                        if element == 1:
                            self.signature_matrix[hash_i, j] = min(self.signature_matrix[hash_i, j], 
                                                                   hash_func(row_i))
                            
    
    def lsh(self, rows_per_band=THRESHOLD, query=False):
        """
        Performs local sensitivity hashing on the signature matrix. Specifically, the seignature matrix is divided 
        row-wise into n bands and a final hashing function is performed on each band column-wise. The resulting 
        indices are then used to store each track information in a specific bucket. Collisions between different
        tracks are more likely if the bands are smaller, i.e. the rows_per_band variable (itself defined by the 
        variable THRESHOLD) is smaller.
        """
        self.rows_per_band = rows_per_band
        signature_matrix = self.signature_matrix
        self.bucket_hash_function = self._get_random_hashing_function(scale=self.matrix.shape[1])
        self.buckets = {}
        n_bands = int(np.ceil(signature_matrix.shape[0] / self.rows_per_band))
        for band_i in range(n_bands):
            start = self.rows_per_band*band_i
            end = start + self.rows_per_band
            band = signature_matrix[start:end]
            for col_i in range(band.shape[1]):
                index = tuple(self.bucket_hash_function(band[:, col_i]))
                if index in self.buckets.keys():
                    self.buckets[index].append(self.tracks_names[col_i])
                else:
                    self.buckets[index] = [self.tracks_names[col_i]]
        
        
    def process_query(self, query_path):
        """
        The query undergoes the same exact treatment as all the other tracks, that is same parameters and hash functions.
        Once we obtain the hashed indices for each band of the query_signature_matrix, we collect the tracks 
        information found at those locations, check which one is more represented and provide that as the final match.
        As an example, if we had previously selected a THRESHOLD value such that we obtain 5 bands for each track, by 
        processing the query we would obtain 5 different buckets that could theoretically store different songs: the result
        that will be chose will be the one present in the most buckets (that is, the result that matched for most bands).
        """
        track, sr, onset_env, peaks = load_audio_picks(query_path, DURATION, HOP_SIZE)

        query_matrix = np.zeros((self.matrix.shape[0]))
        query_matrix[peaks] = 1
        assert np.sum(query_matrix) == len(peaks)
        query_signature_matrix = np.ones((len(self.hash_functions))) * np.inf
        for i in range(query_matrix.shape[0]):
            for hash_i, hash_func in enumerate(self.hash_functions):
                element = query_matrix[i]
                if element == 1:
                    query_signature_matrix[hash_i] = min(hash_func(i), 
                                                         query_signature_matrix[hash_i])

        n_bands = int(np.ceil(self.signature_matrix.shape[0] / self.rows_per_band))
        results = []
        
        for band_i in range(n_bands):
            start = self.rows_per_band*band_i
            end = start + self.rows_per_band
            band = query_signature_matrix[start:end]
            index = tuple(self.bucket_hash_function(band))
            if index in self.buckets.keys():
                results.append(self.buckets[index])
        
        counter = collections.Counter([result for nested_result in results for result in nested_result])
        return counter.most_common()[0][0]

In [368]:
lsh = LSH(1292, N_TRACKS)
lsh.populate_matrix(tracks, tracks_names)

  0%|          | 0/1413 [00:00<?, ?it/s]

In [369]:
lsh.perform_minhash()

In [370]:
lsh.lsh(rows_per_band=THRESHOLD)

In [373]:
def test_queries(folder):
    for query in os.listdir(folder):
        path = folder + "/" + query
        result = lsh.process_query(path)
        print(f"{query.capitalize()}:")
        print(result)
        print("*"*8)

In [374]:
folder = "/Users/alessandro/Downloads/queries"
test_queries(folder)

Track8.wav:
Author: Green day
Album: American idiot
Title: American idiot
********
Track9.wav:
Author: Depeche mode
Album: Some great reward
Title: Somebody
********
Track10.wav:
Author: Steely dan
Album: Katy lied
Title: Black friday
********
Track2.wav:
Author: Queen
Album: The works
Title: I want to break free
********
Track3.wav:
Author: U2
Album: October
Title: October
********
Track1.wav:
Author: Aerosmith
Album: Aerosmith
Title: Dream on
********
Track4.wav:
Author: Beatles
Album: The white album disc 1
Title: Ob-la-di ob-la-da
********
Track5.wav:
Author: Radiohead
Album: Ok computer
Title: Karma police
********
Track7.wav:
Author: Fleetwood mac
Album: Rumours
Title: Go your own way
********
Track6.wav:
Author: Led zeppelin
Album: Led zeppelin ii
Title: Heartbreaker
********
