-
Notifications
You must be signed in to change notification settings - Fork 241
/
asr_server.py
executable file
·122 lines (98 loc) · 3.7 KB
/
asr_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python3
import json
import os
import sys
import asyncio
import pathlib
import websockets
import concurrent.futures
import logging
from vosk import Model, SpkModel, KaldiRecognizer
def process_chunk(rec, message):
if message == '{"eof" : 1}':
return rec.FinalResult(), True
if message == '{"reset" : 1}':
return rec.FinalResult(), False
elif rec.AcceptWaveform(message):
return rec.Result(), False
else:
return rec.PartialResult(), False
async def recognize(websocket, path):
global model
global spk_model
global args
global pool
loop = asyncio.get_running_loop()
rec = None
phrase_list = None
sample_rate = args.sample_rate
show_words = args.show_words
max_alternatives = args.max_alternatives
logging.info('Connection from %s', websocket.remote_address);
while True:
message = await websocket.recv()
# Load configuration if provided
if isinstance(message, str) and 'config' in message:
jobj = json.loads(message)['config']
logging.info("Config %s", jobj)
if 'phrase_list' in jobj:
phrase_list = jobj['phrase_list']
if 'sample_rate' in jobj:
sample_rate = float(jobj['sample_rate'])
if 'model' in jobj:
model = Model(jobj['model'])
model_changed = True
if 'words' in jobj:
show_words = bool(jobj['words'])
if 'max_alternatives' in jobj:
max_alternatives = int(jobj['max_alternatives'])
continue
# Create the recognizer, word list is temporary disabled since not every model supports it
if not rec or model_changed:
model_changed = False
if phrase_list:
rec = KaldiRecognizer(model, sample_rate, json.dumps(phrase_list, ensure_ascii=False))
else:
rec = KaldiRecognizer(model, sample_rate)
rec.SetWords(show_words)
rec.SetMaxAlternatives(max_alternatives)
if spk_model:
rec.SetSpkModel(spk_model)
response, stop = await loop.run_in_executor(pool, process_chunk, rec, message)
await websocket.send(response)
if stop: break
async def start():
global model
global spk_model
global args
global pool
# Enable loging if needed
#
# logger = logging.getLogger('websockets')
# logger.setLevel(logging.INFO)
# logger.addHandler(logging.StreamHandler())
logging.basicConfig(level=logging.INFO)
args = type('', (), {})()
args.interface = os.environ.get('VOSK_SERVER_INTERFACE', '0.0.0.0')
args.port = int(os.environ.get('VOSK_SERVER_PORT', 2700))
args.model_path = os.environ.get('VOSK_MODEL_PATH', 'model')
args.spk_model_path = os.environ.get('VOSK_SPK_MODEL_PATH')
args.sample_rate = float(os.environ.get('VOSK_SAMPLE_RATE', 8000))
args.max_alternatives = int(os.environ.get('VOSK_ALTERNATIVES', 0))
args.show_words = bool(os.environ.get('VOSK_SHOW_WORDS', True))
if len(sys.argv) > 1:
args.model_path = sys.argv[1]
# Gpu part, uncomment if vosk-api has gpu support
#
# from vosk import GpuInit, GpuInstantiate
# GpuInit()
# def thread_init():
# GpuInstantiate()
# pool = concurrent.futures.ThreadPoolExecutor(initializer=thread_init)
model = Model(args.model_path)
spk_model = SpkModel(args.spk_model_path) if args.spk_model_path else None
pool = concurrent.futures.ThreadPoolExecutor((os.cpu_count() or 1))
async with websockets.serve(recognize, args.interface, args.port):
await asyncio.Future()
if __name__ == '__main__':
asyncio.run(start())