In [1]:
import sys
import io
import json
from flask import Flask, jsonify, request
from flask_cors import CORS, cross_origin
sys.path.append('../datasets/')
sys.path.append('../models/')

from prepare_data import prepareData
from prepare_sequences import getSequences, classes18
from prepare_individuals import getIndividuals

import soundfile as sf
import torch
import librosa
import numpy as np

app = Flask(__name__)
cors = CORS(app)


sample_rate = 22050          # recordings are in 96 kHz, 24 bit depth, 1:10 TE (mic sr 960 kHz), 22050 Hz = 44100 Hz TE


def prepare(audio_bytes, expanded):
    # load bytes as signal
    tmp = io.BytesIO(audio_bytes)
    data, sr = sf.read(tmp, dtype='float32')
    data = data.T
    if expanded == "false":
        y = librosa.resample(data, sr, sample_rate*10)
    else:
        y = librosa.resample(data, sr, sample_rate)
    S_db = prepareData(y) # filter, spectrogram, denoise
    return S_db


def get_prediction(audio_bytes, selected_model, expanded):
    # select model
    if selected_model == "BAT-1: 18 european bats":
        classes = classes18
        patch_len = 44               # = 250ms ~ 25ms
        patch_skip = 22              # = 150ms ~ 15ms
        seq_len = 60                 # = 500ms with ~ 5 calls
        seq_skip = 15
        resize = (88, 44)
        model = torch.jit.load('../models/bat_1.pt')
    elif selected_model == "ResNet-50: 18 european bats":
        classes = classes18
        patch_len = 44               # = 250ms ~ 25ms
        patch_skip = 22              # = 150ms ~ 15ms
        resize = None
        model = torch.jit.load('../models/baseline.pt')
    else:
        return np.zeros(len(classes)).tolist(), classes
        
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model, device_ids=[0, 1])
    model.to(device)
    model.eval()
    
    S_db = prepare(audio_bytes, expanded)
        
    if selected_model == "BAT-1: 18 european bats":
        sequences = np.asarray(getSequences(S_db, patch_len, patch_skip, seq_len, seq_skip, resize))
        tensor = torch.Tensor(sequences).to(device)
        outputs = model(tensor)
        outputs = torch.nn.functional.softmax(outputs.mean(dim=0), dim=0)
        return outputs.tolist(), classes
    elif selected_model == "ResNet-50: 18 european bats":
        call_nocall_model = torch.jit.load('../models/call_nocall.pt')
        call_nocall_model.to(device)
        call_nocall_model.eval()
        inds = np.asarray(getIndividuals(S_db, patch_len, resize=None, threshold=0,
                                         ml=True, model=call_nocall_model, device=device))
        if inds.shape[0] > 0:
            tensor = torch.Tensor(np.expand_dims(inds[:64], axis=1)).to(device)
            outputs = model(tensor)
            outputs = torch.nn.functional.softmax(outputs.mean(dim=0), dim=0)
            return outputs.tolist(), classes
    
    return np.zeros(len(classes)).tolist(), classes

@app.route('/predict', methods=['POST'])
@cross_origin()
def predict():
    if request.method == 'POST':
        selected_model = request.form['model']
        file = request.files['file']
        expanded = request.form['expanded']
        
        audio_bytes = file.read()
        prediction, classes = get_prediction(audio_bytes, selected_model, expanded)
        return jsonify({'prediction': prediction, 'classes': list(classes)})


if __name__ == '__main__':
    app.run(host='0.0.0.0', port=8888)

 * Serving Flask app '__main__' (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:8888
 * Running on http://134.60.40.108:8888 (Press CTRL+C to quit)
80.130.173.47 - - [21/Apr/2022 23:43:37] "OPTIONS /predict HTTP/1.1" 200 -
80.130.173.47 - - [21/Apr/2022 23:44:15] "OPTIONS /predict HTTP/1.1" 200 -
80.130.173.47 - - [21/Apr/2022 23:44:56] "POST /predict HTTP/1.1" 200 -
80.130.173.47 - - [21/Apr/2022 23:45:21] "OPTIONS /predict HTTP/1.1" 200 -
80.130.173.47 - - [21/Apr/2022 23:45:59] "POST /predict HTTP/1.1" 200 -
80.130.173.47 - - [21/Apr/2022 23:49:58] "OPTIONS /predict HTTP/1.1" 200 -
80.130.173.47 - - [21/Apr/2022 23:50:31] "POST /predict HTTP/1.1" 200 -
80.130.173.47 - - [21/Apr/2022 23:52:40] "OPTIONS /predict HTTP/1.1" 200 -
80.130.173.47 - - [21/Apr/2022 23:52:47] "POST /predict HTTP/1.1" 200 -
80.130.173.47 - - [21/Apr/2022 23:53:32] "OPTIONS /predict HTTP/1.1" 200 -
80.130.173.47 - - [21/Apr/2022 23:53:37] "POST /predict HTTP/1.1" 200 -
80.130.173.47 - - [21/Apr/2022 23:54:53] "OPTIONS /pr