<a href="https://colab.research.google.com/github/MauriVass/ML4IoTCourse/blob/master/Lab5/RESTfulInfKeywordSpottingService.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install cherrypy

In [None]:
# import cherrypy
# import sys

# class HelloWorld:
#     def index(self):
#         return "Hello World!"
#     index.exposed = True
# if __name__ == '__main__':
#    config = {'server.socket_host': '0.0.0.0','server.socket_port' : 18888}
#    cherrypy.config.update(config)
#    cherrypy.quickstart(HelloWorld())

In [None]:
import random
import string
import cherrypy
import json 

class KSInferenceService(object):
  #Required to be accessable online
  exposed=True

  def __init__(self):
    mlp = tf.keras.models.load_model('./moldes/kws_mlp_True')
    cnn = tf.keras.models.load_model('./moldes/kws_cnn_True')
    dscnn = tf.keras.models.load_model('./moldes/kws_dscnn_True')

    self.models = {'mlp':mlp, 'cnn':cnn, 'dscnn':dscnn}
    self.LABELS = ["go", "left", "no", "right", "stop", "up", "yes"]
    self.linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(40,321,16000,20,4000)

  def preprocess(self,audio):
    audio,_ = tf.audio.decode_wav(audio)
    audio = tf.squeeze(audio, axis=1)

    #STFT
    stft = tf.siganl.stft(audio, fram_length=640, frame_step=320, fft_length=640)
    spectrogram = tf.abs(stft)

    #MFCC
    mel_spectrogram = tf.tensordot(spectrogram,
            self.linear_to_mel_weight_matrix, 1)
    log_mel_spectrogram = tf.math.log(mel_spectrogram + 1.e-6)
    mfccs = tf.signal.mfccs_from_log_mel_spectrograms(log_mel_spectrogram)
    mfccs = mfccs[..., :10]

    #Add channel
    mfccs = tf.expand_dims(mfcc, -1)
    #Add batch dimension
    mfccs = tf.expand_dims(mfccs, 0)

    return mfccs

  def PUT(self,*path,**query):
    if(len(path)!=1):
      raise cherrypy.HTTPError(404,f"Use only 1 model: {self.models.keys()}. Used: {path}")
    model = self.models[path[0]]
    #Command not expected
    if(model not in self.models.keys()):
      raise cherrypy.HTTPError(404,f"Model not recognized. Use: {self.models.keys()}. Used: {model}")

    input = cherrypy.request.body.read()
    input = json.loads(input)
    events = input['e']

    audio_string = None
    for e in events:
      if(event['n']=='audio'):
        audio_string = event['vb']
      else:
        raise cherrypy.HTTPError(404,f"No audio file")
    if(audio_string is None):
      raise cherrypy.HTTPError(404,f"Empty audio file")

    audio_bytes = base64.b63.b64encode(audio_string)
    mfccs = self.preprocess(audio_bytes)

    logits = model.predict(mfccs)
    probs = tf.nn.softmax(logits)
    prob = tf.reduce_max.numpy() * 100
    label_ind = tf.argmax(probs, 1).numpy()[0]
    
    label = self.LABELS[label_ind] 
    
    output = {'label':label, 'probability':prob}

    return json.dumps(output)

#'request.dispatch': cherrypy.dispatch.MethodDispatcher() => switch from default URL to HTTP compliant approch
conf = { '/': {	'request.dispatch': cherrypy.dispatch.MethodDispatcher() } 
					}
cherrypy.tree.mount(KSInferenceService(), '/', conf)

cherrypy.config.update({'servet.socket_host':'0.0.0.0'})
cherrypy.config.update({'servet.socket_port':'8888'})

### ### ###
##Start a server in colab
#bind the port 8888 and get a weblink to access
# from google.colab.output import eval_js
# print(eval_js("google.colab.kernel.proxyPort(8888)"))

# #run the script/API in the background
# import subprocess
# subprocess.Popen(["python", "/", "8888"]) 
# cherrypy.quickstart(KSInferenceService())
### ### ###

#cherrypy.engine.start()
#cherrypy.engine.block()