In [9]:
import numpy as np
from tqdm import tqdm
import tensorflow as tf
from glob import glob
import os
import argparse
from tensorflow.keras.models import load_model
from clean import downsample_mono, envelope
from sklearn.preprocessing import LabelEncoder
import argparse
import pandas as pd

In [2]:
# Location of tflite model file (float32 or int8 quantized)
model_path = "./model.tflite"

# Processed features (copy from Edge Impulse project)
features = [
  # <COPY FEATURES HERE!>
]
  
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=model_path)

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Allocate tensors
interpreter.allocate_tensors()

# Print the input and output details of the model
print()
print("Input details:")
print(input_details)
print()
print("Output details:")
print(output_details)
print()


Input details:
[{'name': 'stft_1_input', 'index': 0, 'shape': array([    1, 40000,     1], dtype=int32), 'shape_signature': array([   -1, 40000,     1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]

Output details:
[{'name': 'Identity', 'index': 138, 'shape': array([ 1, 44], dtype=int32), 'shape_signature': array([-1, 44], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]



In [7]:
def make_prediction(args):

    # Load TFLite model and allocate tensors.
    interpreter = tf.lite.Interpreter(model_path=args.model_fn)
    interpreter.allocate_tensors()

    # Get input and output tensors information
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    wav_paths = glob('{}/**'.format(args.src_dir), recursive=True)
    wav_paths = sorted([x.replace(os.sep, '/') for x in wav_paths if '.wav' in x])
    classes = sorted(os.listdir(args.src_dir))
    labels = [os.path.split(x)[0].split('/')[-1] for x in wav_paths]
    le = LabelEncoder()
    y_true = le.fit_transform(labels)
    results = []

    for z, wav_fn in tqdm(enumerate(wav_paths), total=len(wav_paths)):
        rate, wav = downsample_mono(wav_fn, args.sr)
        mask, env = envelope(wav, rate, threshold=args.threshold)
        clean_wav = wav[mask]
        step = int(args.sr*args.dt)
        batch = []

        for i in range(0, clean_wav.shape[0], step):
            sample = clean_wav[i:i+step]
            sample = sample.reshape(-1, 1)
            if sample.shape[0] < step:
                tmp = np.zeros(shape=(step, 1), dtype=np.float32)
                tmp[:sample.shape[0],:] = sample.flatten().reshape(-1, 1)
                sample = tmp
            batch.append(sample)

        X_batch = np.array(batch, dtype=np.float32)

        # Use the interpreter to predict
        y_pred = []
        for x in X_batch:
            interpreter.set_tensor(input_details[0]['index'], [x])
            interpreter.invoke()
            output_data = interpreter.get_tensor(output_details[0]['index'])
            print("output_data")
            
            print(output_data)
            y_pred.append(output_data)

        y_mean = np.mean(y_pred, axis=0)
        print("y_mean")
        
        print(y_mean)
        y_pred = np.argmax(y_mean)
        real_class = os.path.dirname(wav_fn).split('/')[-1]
        print('Actual class: {}, Predicted class: {}'.format(real_class, classes[y_pred]))
        results.append(y_mean)

    np.save(os.path.join('logs', args.pred_fn), np.array(results))

In [10]:


parser = argparse.ArgumentParser(description='Audio Classification Training')
parser.add_argument('--model_fn', type=str, default='model.tflite',
                    help='model file to make predictions')
parser.add_argument('--pred_fn', type=str, default='tflite_pred',
                    help='fn to write predictions in logs dir')
parser.add_argument('--src_dir', type=str, default='wavfiles',
                    help='directory containing wavfiles to predict')
parser.add_argument('--dt', type=float, default=2.5,
                    help='time in seconds to sample audio')
parser.add_argument('--sr', type=int, default=16000,
                    help='sample rate of clean audio')
parser.add_argument('--threshold', type=str, default=20,
                    help='threshold magnitude for np.int16 dtype')
args, _ = parser.parse_known_args()

make_prediction(args)

  0%|          | 0/2691 [00:00<?, ?it/s]

: 

: 