# Audio Search Engine
Developed during Pinnacle 2021 Olympic Hackathon competition.

This notebook is the backend for the INSERT_PROJECT_NAME.

Authors: Megan Bui, Sam Vanderlinda, Abduselam Shaltu

## Setup
Installs required libraries, loads imports, downloads dataset/models, and defines required functions.

In [None]:
!pip install tf_slim
!pip install scann
!pip install flask_ngrok
!pip install flask_cors
!pip install pydub

In [None]:
# Get labels and indexes of youtube noises.
!wget http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/balanced_train_segments.csv
!wget http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv
!wget storage.googleapis.com/us_audioset/youtube_corpus/v1/features/features.tar.gz

# Extract dataset.
!tar -xf features.tar.gz

In [None]:
# Clone audio encoder model from Google Storage.
!git clone https://github.com/tensorflow/models.git
%cd models/research/audioset/vggish
!curl -O https://storage.googleapis.com/audioset/vggish_model.ckpt
!curl -O https://storage.googleapis.com/audioset/vggish_pca_params.npz

In [7]:
import os
import sys
import csv
import scann
import numpy as np
import tensorflow as tf
from typing import Mapping
from itertools import chain
from flask_ngrok import run_with_ngrok
from __future__ import print_function
from pydub import AudioSegment

# Flask server related imports.
from flask import Flask, json, request, stream_with_context, jsonify
from flask_cors import CORS, cross_origin

# Audio encoder imports.
import vggish_input
import vggish_params
import vggish_postprocess
import vggish_slim

In [18]:
# Build dict with video id mapped to audio features.
# Takes approximately 45 - 50 seconds to run.
def load_video_audio_map():
  files = os.listdir('/content/audioset_v1_embeddings/bal_train/')
  video_audio_map = dict()

  for filename in files:
    # Ignore non-tfrecord files.
    if not filename.endswith(".tfrecord"):
      continue

    for example_str in tf.compat.v1.io.tf_record_iterator(os.path.join("/content/audioset_v1_embeddings/bal_train/", filename)):
      seq_example = tf.train.SequenceExample.FromString(example_str)
      if len(seq_example.feature_lists.feature_list['audio_embedding'].feature) >= 5:
        bytes_2d_list = seq_example.feature_lists.feature_list['audio_embedding'].feature[0:5]
        flattened_byte_list = []
        for bytes_list in bytes_2d_list:
          flattened_byte_list.extend(np.frombuffer(bytes_list.bytes_list.value[0], dtype=np.uint8))
        video_audio_map.update({str(seq_example.context.feature['video_id'].bytes_list.value[0], 'utf-8'): flattened_byte_list})

  return video_audio_map

# Build dict with video id mapped to start time.
def load_video_start_map() -> Mapping[str, int]:
  video_start_map = dict()

  segments_file = "/content/balanced_train_segments.csv"
  with open(segments_file) as csv_file:
    csvreader = csv.reader(csv_file)

    # Skip first and second row which are just stats.
    next(csvreader) 
    next(csvreader)

    # Skip header.
    fields = next(csvreader)

    # Extract each video data row by row.
    for row in csvreader:
      video_start_map.update({row[0] : int(float(row[1]))})  # {"QM4qxOYDwHo" : 430}

  return video_start_map

# Build dict mapping index to video id.
def load_index_video_map(video_audio_map):
  index_video_map = dict()
  for idx, feature_list_key in enumerate(video_audio_map.keys()):
    index_video_map.update({idx : feature_list_key})
  
  return index_video_map

# Build a ScaNN index for later ANN lookups. Takes approximately 42 seconds.
def load_search_engine(video_audio_map, num_results):
  dataset = np.array(list(video_audio_map.values()))
  search_engine = scann.scann_ops_pybind.builder(dataset, num_results, "dot_product").tree(
      num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000).score_ah(
      2, anisotropic_quantization_threshold=0.2).reorder(100).build()
  return search_engine

# Convert mp3 file to wav and ensure the saved wav file is 5 seconds long.
def convert_mp3_to_wav(mp3_file_path):
  base_file_path = mp3_file_path[0 : mp3_file_path.index(".mp3")]
  wav_file_path = f"{base_file_path}.wav"
  sound = AudioSegment.from_mp3(mp3_file_path)

  # pydub does things in milliseconds.
  five_seconds = 5 * 1000       

  # Add silence if duration not long enough.
  if sound.duration_seconds < 5:
    sound += AudioSegment.silent(duration=five_seconds)

  sound = sound[0 : five_seconds]   # First 5 seconds.
  sound.export(wav_file_path, 'wav')

  return wav_file_path

In [19]:
# Encode an audio file in wav format.
# Returns an embedding in the form  
def encode_audio(wav_file, sr=None):
  # Paths to downloaded VGGish files.
  checkpoint_path = 'vggish_model.ckpt'
  pca_params_path = 'vggish_pca_params.npz'

  # Relative tolerance of errors in mean and standard deviation of embeddings.
  rel_error = 0.1  # Up to 10%

  # Produce a batch of log mel spectrogram examples.
  if sr:
    input_batch = vggish_input.waveform_to_examples(wav_file, sr)
  else:
    input_batch = vggish_input.wavfile_to_examples(wav_file)

  # Define VGGish, load the checkpoint, and run the batch through the model to
  # produce embeddings.
  with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess:
    vggish_slim.define_vggish_slim()
    vggish_slim.load_vggish_slim_checkpoint(sess, checkpoint_path)

    features_tensor = sess.graph.get_tensor_by_name(
        vggish_params.INPUT_TENSOR_NAME)
    embedding_tensor = sess.graph.get_tensor_by_name(
        vggish_params.OUTPUT_TENSOR_NAME)
    [embedding_batch] = sess.run([embedding_tensor],
                                feed_dict={features_tensor: input_batch})
    
  pproc = vggish_postprocess.Postprocessor(pca_params_path)
  postprocessed_batch = pproc.postprocess(embedding_batch)

  embedding = []
  for postprocessed_embedding in postprocessed_batch:
    embedding.extend(postprocessed_embedding)

  return np.array(embedding, ndmin=2), postprocessed_batch,

## Flask Server

In [20]:
# Load all server required variables needed for recommendation.
num_results = 10
mp3_file_path = '/content/recording.mp3'

video_audio_map = load_video_audio_map()
video_start_map = load_video_start_map()
index_video_map = load_index_video_map(video_audio_map)

search_engine = load_search_engine(video_audio_map, num_results)

In [None]:
# Server details.
api = Flask(__name__)
cors = CORS(api)
api.config['CORS_HEADERS'] = 'Content-Type'
run_with_ngrok(api)   

# Endpoint needed for recommendation for audio snippets.
@api.route('/recommend', methods=['GET', 'POST'])
@cross_origin(origin='*',headers=['Content-Type'])
def recommend():
  files = request.files
  files['rawAudioData'].save(mp3_file_path)
  wav_file_path = convert_mp3_to_wav(mp3_file_path)

  # Encode audio into embedding with shape (1, 640).
  flat_audio_embedding, batched_embeddings = encode_audio(wav_file_path)

  # Perform approximate nearest neighbor(ANN) search. 
  neighbors, distances = search_engine.search_batched(flat_audio_embedding)

  # Extract video ids and start time from approximate nearest neighbor search.
  videos = []
  for neighbor in neighbors[0]:
    video_id = index_video_map[neighbor]
    start_time_seconds = video_start_map[video_id]
    videos.append([video_id, start_time_seconds])

  # Cleanup
  os.remove(mp3_file_path)
  os.remove(wav_file_path)

  return {"videos" : videos}

if __name__ == '__main__':
    api.run()

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)


 * Running on http://ba4f-104-199-122-141.ngrok.io
 * Traffic stats available on http://127.0.0.1:4040


127.0.0.1 - - [19/Sep/2021 04:25:46] "[31m[1mGET /recommend HTTP/1.1[0m" 400 -


INFO:tensorflow:Restoring parameters from vggish_model.ckpt


127.0.0.1 - - [19/Sep/2021 04:26:08] "[37mPOST /recommend HTTP/1.1[0m" 200 -


INFO:tensorflow:Restoring parameters from vggish_model.ckpt


127.0.0.1 - - [19/Sep/2021 04:27:10] "[37mPOST /recommend HTTP/1.1[0m" 200 -


INFO:tensorflow:Restoring parameters from vggish_model.ckpt


127.0.0.1 - - [19/Sep/2021 04:29:58] "[37mPOST /recommend HTTP/1.1[0m" 200 -


INFO:tensorflow:Restoring parameters from vggish_model.ckpt


127.0.0.1 - - [19/Sep/2021 04:35:35] "[37mPOST /recommend HTTP/1.1[0m" 200 -


INFO:tensorflow:Restoring parameters from vggish_model.ckpt


127.0.0.1 - - [19/Sep/2021 04:38:42] "[37mPOST /recommend HTTP/1.1[0m" 200 -


INFO:tensorflow:Restoring parameters from vggish_model.ckpt


127.0.0.1 - - [19/Sep/2021 04:39:16] "[37mPOST /recommend HTTP/1.1[0m" 200 -


INFO:tensorflow:Restoring parameters from vggish_model.ckpt


127.0.0.1 - - [19/Sep/2021 04:39:33] "[37mPOST /recommend HTTP/1.1[0m" 200 -


INFO:tensorflow:Restoring parameters from vggish_model.ckpt


127.0.0.1 - - [19/Sep/2021 04:42:59] "[37mPOST /recommend HTTP/1.1[0m" 200 -


INFO:tensorflow:Restoring parameters from vggish_model.ckpt


127.0.0.1 - - [19/Sep/2021 04:43:49] "[37mPOST /recommend HTTP/1.1[0m" 200 -


INFO:tensorflow:Restoring parameters from vggish_model.ckpt


127.0.0.1 - - [19/Sep/2021 04:44:04] "[37mPOST /recommend HTTP/1.1[0m" 200 -


INFO:tensorflow:Restoring parameters from vggish_model.ckpt


127.0.0.1 - - [19/Sep/2021 04:44:13] "[37mPOST /recommend HTTP/1.1[0m" 200 -


INFO:tensorflow:Restoring parameters from vggish_model.ckpt


127.0.0.1 - - [19/Sep/2021 04:44:25] "[37mPOST /recommend HTTP/1.1[0m" 200 -
