In [10]:
import sys
import io
import json
from flask import Flask, jsonify, request
from flask_cors import CORS, cross_origin
sys.path.append('../datasets/')
sys.path.append('../models/')

from prepare_data import prepareData
from prepare_sequences import slideWindow, classes18

import soundfile as sf
import torch
import librosa

app = Flask(__name__)
cors = CORS(app)


def prepare(audio_bytes, sample_rate, patch_len, patch_skip, resize):
    # load bytes as signal
    tmp = io.BytesIO(audio_bytes)
    data, sr = sf.read(tmp, dtype='float32')
    data = data.T
    y = librosa.resample(data, sr, sample_rate)
    
    S_db = prepareData(y) # filter, spectrogram, denoise
    seq = slideWindow(S_db, patch_len, patch_skip, resize)[:-1] # last one is not full
    return seq


def get_prediction(audio_bytes, selected_model):
    # select model
    if selected_model == "BAT-1: 18 european bats":
        classes = classes18
        sample_rate = 22050          # recordings are in 96 kHz, 24 bit depth, 1:10 TE (mic sr 960 kHz), 22050 Hz = 44100 Hz TE
        patch_len = 44               # = 250ms ~ 25ms
        patch_skip = 22              # = 150ms ~ 15ms
        resize = (44, 44)
        model = torch.jit.load('../models/bat_1.pt')
        
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model, device_ids=[0, 1])
    model.to(device)
    model.eval()
    
    sequence = prepare(audio_bytes, sample_rate, patch_len, patch_skip, resize)[:60]
    tensor = torch.Tensor([sequence]).to(device)
    outputs = model(tensor)
    return outputs[0].tolist(), classes


@app.route('/predict', methods=['POST'])
@cross_origin()
def predict():
    if request.method == 'POST':
        selected_model = request.form['model']
        file = request.files['file']
        audio_bytes = file.read()
        prediction, classes = get_prediction(audio_bytes, selected_model)
        return jsonify({'prediction': prediction, 'classes': list(classes)})


if __name__ == '__main__':
    app.run(host='0.0.0.0', port=8888)

 * Serving Flask app '__main__' (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:8888
 * Running on http://134.60.40.108:8888 (Press CTRL+C to quit)
93.184.191.83 - - [21/Apr/2022 14:04:35] "POST /predict HTTP/1.1" 200 -
