In [29]:
!rm -r sample_data/

In [None]:
!pip install tf_slim
!pip install scann

In [91]:
import os
import csv
import tensorflow as tf
from typing import Mapping
from itertools import chain

In [None]:
!git clone https://github.com/tensorflow/models.git
%cd models/research/audioset/vggish
!curl -O https://storage.googleapis.com/audioset/vggish_model.ckpt
!curl -O https://storage.googleapis.com/audioset/vggish_pca_params.npz

In [None]:
# Get labels and indexes of youtube noises
!wget http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/balanced_train_segments.csv
!wget http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv
!wget storage.googleapis.com/us_audioset/youtube_corpus/v1/features/features.tar.gz

# Extract dataset
!tar -xf features.tar.gz

In [15]:
def load_labels_map() -> Mapping[int, str]:
  # Build index to label map
  index_label_map = dict()

  class_labels_file = "class_labels_indices.csv"
  with open(class_labels_file) as csv_file:
    csvreader = csv.reader(csv_file)

    # Skip header
    fields = next(csvreader)

    # extracting each data row one by one
    for row in csvreader:
      mid_label_map.update({int(row[0]) : row[2]})   # For example {"8" : "Shout"}

  return mid_label_map

In [None]:
!ls /content/audioset_v1_embeddings/bal_train/a1.tfrecord

In [None]:
# Load audio tfrecords
#!ls /content/audioset_v1_embeddings/bal_train/
for example_str in tf.python_io.tf_record_iterator("/content/audioset_v1_embeddings/bal_train/00.tfrecord"):
    seq_example = tf.train.SequenceExample.FromString(example_str)
    print(seq_example.context.feature['video_id'])

In [176]:
# Takes approximately 45 - 50 seconds to run
# Load audio tfrecords
#!ls /content/audioset_v1_embeddings/bal_train/
files = os.listdir('/content/audioset_v1_embeddings/bal_train/')
video_audio_map = dict()
min_sec_found = 10 # doesnt matter

for filename in files:
  if not filename.endswith(".tfrecord"):
    continue

  for example_str in tf.compat.v1.io.tf_record_iterator(os.path.join("/content/audioset_v1_embeddings/bal_train/", filename)):
    seq_example = tf.train.SequenceExample.FromString(example_str)
    min_sec_found = min(min_sec_found, len(seq_example.feature_lists.feature_list['audio_embedding'].feature))
    if len(seq_example.feature_lists.feature_list['audio_embedding'].feature) >= 5:
      bytes_2d_list = seq_example.feature_lists.feature_list['audio_embedding'].feature[0:5]
      flattened_byte_list = []
      for bytes_list in bytes_2d_list:
        flattened_byte_list.extend(np.frombuffer(bytes_list.bytes_list.value[0], dtype=np.uint8))
      video_audio_map.update({str(seq_example.context.feature['video_id'].bytes_list.value[0], 'utf-8'): flattened_byte_list})
print("SMALLEST MIN_SEC_FOUND: ", min_sec_found)

SMALLEST MIN_SEC_FOUND:  1


In [177]:
dataset = np.empty((0, 640), np.uint8)
index_video_map = dict()
for idx, feature_list_key in enumerate(video_audio_map.keys()):
  index_video_map.update({idx : feature_list_key})
  feature_list = video_audio_map[feature_list_key]
  dataset = np.append(dataset, np.array([feature_list]), axis=0)

In [178]:
# Take approximately 42 seconds
# Build ScaNN index
import scann
dataset = np.array(list(video_audio_map.values()))
num_results = 10 
searcher = scann.scann_ops_pybind.builder(dataset, num_results, "dot_product").tree(
    num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000).score_ah(
    2, anisotropic_quantization_threshold=0.2).reorder(100).build()

In [179]:
queries = np.array(np.ones_like(640, shape=(1, 640)))
neighbors, distances = searcher.search_batched(queries)

In [180]:
for neighbor in neighbors[0]:
  video_id = index_video_map[neighbor]

  print(f"http://youtube.com/watch?v={video_id}")

http://youtube.com/watch?v=ArsKCV3rkc4
http://youtube.com/watch?v=4HSkwF586ro
http://youtube.com/watch?v=QM4qxOYDwHo
http://youtube.com/watch?v=ZaeARmx4m0k
http://youtube.com/watch?v=DRGpwij9No8
http://youtube.com/watch?v=UGtYWC-ddF4
http://youtube.com/watch?v=zFRreJxXDFw
http://youtube.com/watch?v=Dj6vz-bsHXY
http://youtube.com/watch?v=smTo8842-5c
http://youtube.com/watch?v=FOxIDRWTHZc


In [None]:
from __future__ import print_function

import numpy as np
import tensorflow.compat.v1 as tf

import vggish_input
import vggish_params
import vggish_postprocess
import vggish_slim

# Paths to downloaded VGGish files.
checkpoint_path = 'vggish_model.ckpt'
pca_params_path = 'vggish_pca_params.npz'

# Relative tolerance of errors in mean and standard deviation of embeddings.
rel_error = 0.1  # Up to 10%

# Generate a 1 kHz sine wave at 44.1 kHz (we use a high sampling rate
# to test resampling to 16 kHz during feature extraction).
########## REPLACE WITH CODE TO LOAD WAVEFORM FROM USER
num_secs = 4
freq = 1000
sr = 44100
t = np.arange(0, num_secs, 1 / sr)
x = np.sin(2 * np.pi * freq * t)

# Produce a batch of log mel spectrogram examples.
input_batch = vggish_input.waveform_to_examples(x, sr)
np.testing.assert_equal(
    input_batch.shape,
    [num_secs, vggish_params.NUM_FRAMES, vggish_params.NUM_BANDS])

# Define VGGish, load the checkpoint, and run the batch through the model to
# produce embeddings.
with tf.Graph().as_default(), tf.Session() as sess:
  vggish_slim.define_vggish_slim()
  vggish_slim.load_vggish_slim_checkpoint(sess, checkpoint_path)

  features_tensor = sess.graph.get_tensor_by_name(
      vggish_params.INPUT_TENSOR_NAME)
  embedding_tensor = sess.graph.get_tensor_by_name(
      vggish_params.OUTPUT_TENSOR_NAME)
  [embedding_batch] = sess.run([embedding_tensor],
                               feed_dict={features_tensor: input_batch})
  print('Num of embeddings: ', len(embedding_batch))
  print('VGGish embedding: ', embedding_batch[0])