In [1]:
import sounddevice as sd
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import wavfile

from scipy import signal
import os

from utils.preproc import create_spectrogram_from_wav_file
from core import load_audio, write_audio
import uuid
import subprocess
import pickle
import json

In [2]:
def predict(audio_path = './test_audio_folder/', output_path = './output/', config_path = 'config.json', segemt_sec = 5, FS = 44100, plot_spect = False):
    with open(config_path) as config_buffer:    
        config = json.load(config_buffer)
    labels_str = config['model']['labels']
    audio_paths = []
    
    if os.path.isdir(audio_path):
        for inp_file in os.listdir(audio_path):
            audio_paths += [audio_path + inp_file]
    else:
        audio_paths += [audio_path] 
        
    audio_paths = [inp_file for inp_file in audio_paths if (inp_file[-4:]  == '.wav')]
    
    dir_name = uuid.uuid4().hex
    os.mkdir(output_path + dir_name)
    dir_name = output_path + dir_name + '/'
    
    tracks = []
    
    for inp_audio in audio_paths:
        track = {}
        data, fs = load_audio(inp_audio, FS)
        track['audio'] = data
        track['fs'] = fs
        track['name'] = inp_audio.split('/')[-1]
        nfft = int(fs*0.005)
        noverlap = int(fs*0.0025)
        
        full_segments = (data.shape[0] /  fs) // segemt_sec
        if (data.shape[0] /  fs) % segemt_sec != 0:            
            data_pad = np.random.rand(int(fs * (full_segments + 1) * segemt_sec))*0.00005
            data_pad[:data.shape[0]] = data + 0.00005
            data_pad = data_pad.reshape((int(full_segments + 1), -1))
        else:
            data_pad = data.reshape((int(full_segments), -1))
        
        segments = []
        for i, segment in enumerate(data_pad):
            frequencies, times, spectrogram = signal.spectrogram(segment, fs, nfft = nfft, noverlap = noverlap, nperseg=nfft)
            
            filename = inp_audio.split('/')[-1] + '_seg_' + str(i+1) +'.jpg'
            
            fig=plt.figure(figsize=((13.22, 13.57)))
            ax=fig.add_subplot(1,1,1)
            plt.axis('off')
            plt.pcolormesh(times, frequencies, np.log(spectrogram), figure = fig)
            extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
            name = dir_name + filename
            segments.append(name)
            plt.savefig(name, bbox_inches=extent, dpi = 100)
            
            if plot_spect:        
                plt.show()
                
            fig.clear()
            plt.close(fig)
            
        track['segments'] = segments
        tracks.append(track)
        
    subprocess.call('python3 predict.py -c config.json -i {} -o {}'.format(dir_name, dir_name), shell=True)
    
    for track in tracks:
        plt.figure(figsize=(20,10))
        plt.title(track['name'])
        plt.plot(np.linspace(0, track['audio'].shape[0] / track['fs'], track['audio'].shape[0]), track['audio'])
        
        for segment in track['segments']:
            with open(segment+'.pkl', 'rb') as f:
                boxes = pickle.load(f)
            for box in boxes:
                score = box['score']
                if score > 0.1:
                    nmb_seg = int(name.split('_')[-1].split('.')[0])
                    tmin = (box['xmin'] + 1024 * nmb_seg) / 1024 * 5
                    tmax = (box['xmax'] + 1024 * nmb_seg) / 1024 * 5
                    label = box['label']
                    plt.plot(np.linspace(0, track['audio'].shape[0] / track['fs'], track['audio'].shape[0])[int(tmin*track['fs']):int(tmax*track['fs'])],
                             track['audio'][int(tmin*track['fs']):int(tmax*track['fs'])], label = labels_str[int(label)], lw = 5)
        plt.legend()
        plt.xlabel('time (s)')