In [1]:
## Import packages
import os
import csv
import numpy as np
import pandas as pd
from scipy.io import wavfile

import tensorflow as tf

import vggish_input
import vggish_postprocess
import vggish_params
import vggish_slim

  from ._conv import register_converters as _register_converters


In [2]:
# wav_file = "wav_files/car_18.wav"
wav_file = "elephants/EFAF2011A008.wav"
#wav_file = "gunshots/shot9mm.45.wav"

pca_params = "yt8m/vggish_pca_params.npz"
checkpoint = "yt8m/vggish_model.ckpt"

labels_csv = "csv_files/class_labels_indices.csv"

In [3]:
## Ik verdeel het parsen van een wav-file naar embeddings in stappen:
# Stap 1a: lezen van wav-file, input is array met samples die db aanduiden. Ook sample rate (per sec) wordt gelezen
# Stap 1b: Bij 2d array (stereo, ipv mono) bereken gemiddelde, daarna normaliseren (delen door 32.768)
# Stap 2: Bepaal examples in vorm [batch size, num frames, num bands].
    # Hierbij worden voor verschillende batches (omdat alles tegelijk niet in 1x in NN kan),
    # een log mel spectrogram gemaakt (in vorm [num_frames, num_bands])
# Stap 3: Bepaal features: nu wordt de embedding laag gemaakt (PCA-components, discreet maken etc)
    # Hiervoor worden model-parameters opgehaald die eerder zijn opgeslagen
# Stap 4: Maken van predictions

In [4]:
## Stap 1 en 2
## This function reads the wav file and converts the samples into np arrays of [batch size, num frames, num bands]
examples_batch = vggish_input.wavfile_to_examples(wav_file)
print(examples_batch.shape)

(448, 96, 64)


In [None]:
## Read csv-file with labels
class_map = {}
with open(labels_csv) as f:
    next(f)  # skip header
    reader = csv.reader(f)
    for row in reader:
        class_map[int(row[0])] = row[2]

In [None]:
## Stap 3
with tf.Graph().as_default(), tf.Session() as sess:
    # Define the model: load the checkpoint and locate input and output tensors
    # Input: [batch_size, num_frames, num_bands] 
    # where [num_frames, num_bands] represents log-mel-scale spectrogram
    # Output: embeddings
    vggish_slim.define_vggish_slim(training=False)
    vggish_slim.load_vggish_slim_checkpoint(sess, checkpoint)
    
    pca_params = np.load(vggish_params.VGGISH_PCA_PARAMS)
    pca_matrix = pca_params[vggish_params.PCA_EIGEN_VECTORS_NAME]
    pca_means = pca_params[vggish_params.PCA_MEANS_NAME].reshape(-1, 1)
    
    features_tensor = sess.graph.get_tensor_by_name(
        vggish_params.VGGISH_INPUT_TENSOR_NAME)
    embedding_tensor = sess.graph.get_tensor_by_name(
        vggish_params.VGGISH_OUTPUT_TENSOR_NAME)
    vggish_slim.load_youtube_model(sess, vggish_params.YOUTUBE_CHECKPOINT_FILE)
    
    # Run inference and postprocessing
    [embedding_batch] = sess.run([embedding_tensor],
                                 feed_dict={features_tensor: examples_batch})
    
    postprocessed_batch = np.dot(
            pca_matrix, (embedding_batch.T - pca_means)
        ).T
    #print(postprocessed_batch)
    
    num_frames = np.minimum(postprocessed_batch.shape[0], vggish_params.MAX_FRAMES)
    data = vggish_postprocess.resize(postprocessed_batch, 0, vggish_params.MAX_FRAMES)
    data = np.expand_dims(data, 0)
    num_frames = np.expand_dims(num_frames, 0)
    
    input_tensor = sess.graph.get_collection("input_batch_raw")[0]
    num_frames_tensor = sess.graph.get_collection("num_frames")[0]
    predictions_tensor = sess.graph.get_collection("predictions")[0]
    
    ## Stap 4
    predictions_val, = sess.run(
        [predictions_tensor],
        feed_dict={
            input_tensor: data,
            num_frames_tensor: num_frames
        })

INFO:tensorflow:Restoring parameters from yt8m/vggish_model.ckpt
INFO:tensorflow:Restoring parameters from models/youtube_model.ckpt


In [None]:
## Filter predictions (give top 20 where p>0.1)
count = vggish_params.PREDICTIONS_COUNT_LIMIT
hit = 0
top_indices = np.argpartition(predictions_val[0], -count)[-count:]
line = ((class_map[i], float(predictions_val[0][i])) for i in top_indices if predictions_val[0][i] > hit)
predictions = sorted(line, key=lambda p: -p[1])

In [None]:
print(predictions)