Skip to content

Commit

Permalink
Merge pull request #1710 from Honei/deepspeech_server
Browse files Browse the repository at this point in the history
[asr][websocket]fix the ws send bug, cache buffer,  text=doc
  • Loading branch information
zh794390558 committed Apr 16, 2022
2 parents cf1a395 + 3ce4301 commit 0cde9f8
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 38 deletions.
2 changes: 1 addition & 1 deletion demos/speech_recognition/run.sh
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ paddlespeech asr --input ./zh.wav


# asr + punc
paddlespeech asr --input ./zh.wav | paddlespeech text --task punc
paddlespeech asr --input ./zh.wav | paddlespeech text --task punc
2 changes: 1 addition & 1 deletion paddlespeech/s2t/exps/u2/bin/test_wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def run(self):

ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0)

decode_config = self.config.decode
result_transcripts = self.model.decode(
xs,
Expand Down Expand Up @@ -129,6 +128,7 @@ def main(config, args):
args = parser.parse_args()

config = CfgNode(new_allowed=True)

if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/server/conf/ws_application.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8091
port: 8090

# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']
Expand Down
47 changes: 34 additions & 13 deletions paddlespeech/server/tests/asr/online/websocket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
# -*- coding: UTF-8 -*-
import argparse
import asyncio
import codecs
import json
import logging
import os

import numpy as np
import soundfile
Expand All @@ -32,34 +34,30 @@ def __init__(self, url="127.0.0.1", port=8090):
def read_wave(self, wavfile_path: str):
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples)
chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
# chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz

if (x_len - chunk_size) % chunk_stride != 0:
padding_len_x = chunk_stride - (x_len - chunk_size) % chunk_stride
if x_len % chunk_size != 0:
padding_len_x = chunk_size - x_len % chunk_size
else:
padding_len_x = 0

padding = np.zeros((padding_len_x), dtype=samples.dtype)
padded_x = np.concatenate([samples, padding], axis=0)

num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1
assert (x_len + padding_len_x) % chunk_size == 0
num_chunk = (x_len + padding_len_x) / chunk_size
num_chunk = int(num_chunk)

for i in range(0, num_chunk):
start = i * chunk_stride
start = i * chunk_size
end = start + chunk_size
x_chunk = padded_x[start:end]
yield x_chunk

async def run(self, wavfile_path: str):
logging.info("send a message to the server")
# 读取音频
# self.read_wave()
# 发送 websocket 的 handshake 协议头
async with websockets.connect(self.url) as ws:
# server 端已经接收到 handshake 协议头
# 发送开始指令
audio_info = json.dumps(
{
"name": "test.wav",
Expand All @@ -77,8 +75,10 @@ async def run(self, wavfile_path: str):
for chunk_data in self.read_wave(wavfile_path):
await ws.send(chunk_data.tobytes())
msg = await ws.recv()
msg = json.loads(msg)
logging.info("receive msg={}".format(msg))

result = msg
# finished
audio_info = json.dumps(
{
Expand All @@ -91,16 +91,35 @@ async def run(self, wavfile_path: str):
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()
msg = json.loads(msg)
logging.info("receive msg={}".format(msg))

return result


def main(args):
logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8091)
handler = ASRAudioHandler("127.0.0.1", 8090)
loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(args.wavfile))
logging.info("asr websocket client finished")

# support to process single audio file
if args.wavfile and os.path.exists(args.wavfile):
logging.info(f"start to process the wavscp: {args.wavfile}")
result = loop.run_until_complete(handler.run(args.wavfile))
result = result["asr_results"]
logging.info(f"asr websocket client finished : {result}")

# support to process batch audios from wav.scp
if args.wavscp and os.path.exists(args.wavscp):
logging.info(f"start to process the wavscp: {args.wavscp}")
with codecs.open(args.wavscp, 'r', encoding='utf-8') as f,\
codecs.open("result.txt", 'w', encoding='utf-8') as w:
for line in f:
utt_name, utt_path = line.strip().split()
result = loop.run_until_complete(handler.run(utt_path))
result = result["asr_results"]
w.write(f"{utt_name} {result}\n")


if __name__ == "__main__":
Expand All @@ -110,6 +129,8 @@ def main(args):
action="store",
help="wav file path ",
default="./16_audio.wav")
parser.add_argument(
"--wavscp", type=str, default=None, help="The batch audios dict text")
args = parser.parse_args()

main(args)
49 changes: 34 additions & 15 deletions paddlespeech/server/utils/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,38 @@ def __init__(self, bytes, timestamp, duration):

class ChunkBuffer(object):
def __init__(self,
frame_duration_ms=80,
shift_ms=40,
window_n=7,
shift_n=4,
window_ms=20,
shift_ms=10,
sample_rate=16000,
sample_width=2):
self.sample_rate = sample_rate
self.frame_duration_ms = frame_duration_ms
"""audio sample data point buffer
Args:
window_n (int, optional): decode window frame length. Defaults to 7 frame.
shift_n (int, optional): decode shift frame length. Defaults to 4 frame.
window_ms (int, optional): frame length, ms. Defaults to 20 ms.
shift_ms (int, optional): shift length, ms. Defaults to 10 ms.
sample_rate (int, optional): audio sample rate. Defaults to 16000.
sample_width (int, optional): sample point bytes. Defaults to 2 bytes.
"""
self.window_n = window_n
self.shift_n = shift_n
self.window_ms = window_ms
self.shift_ms = shift_ms
self.remained_audio = b''
self.sample_rate = sample_rate
self.sample_width = sample_width # int16 = 2; float32 = 4
self.remained_audio = b''

self.window_sec = float((self.window_n - 1) * self.shift_ms +
self.window_ms) / 1000.0
self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0)

self.window_bytes = int(self.window_sec * self.sample_rate *
self.sample_width)
self.shift_bytes = int(self.shift_sec * self.sample_rate *
self.sample_width)

def frame_generator(self, audio):
"""Generates audio frames from PCM audio data.
Expand All @@ -43,17 +66,13 @@ def frame_generator(self, audio):
audio = self.remained_audio + audio
self.remained_audio = b''

n = int(self.sample_rate * (self.frame_duration_ms / 1000.0) *
self.sample_width)
shift_n = int(self.sample_rate * (self.shift_ms / 1000.0) *
self.sample_width)
offset = 0
timestamp = 0.0
duration = (float(n) / self.sample_rate) / self.sample_width
shift_duration = (float(shift_n) / self.sample_rate) / self.sample_width
while offset + n <= len(audio):
yield Frame(audio[offset:offset + n], timestamp, duration)
timestamp += shift_duration
offset += shift_n

while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], timestamp,
self.window_sec)
timestamp += self.shift_sec
offset += self.shift_bytes

self.remained_audio += audio[offset:]
9 changes: 4 additions & 5 deletions paddlespeech/server/ws/asr_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ async def websocket_endpoint(websocket: WebSocket):
# init buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
chunk_buffer = ChunkBuffer(
window_n=7,
shift_n=4,
window_ms=20,
shift_ms=10,
sample_rate=chunk_buffer_conf['sample_rate'],
sample_width=chunk_buffer_conf['sample_width'])
# init vad
Expand Down Expand Up @@ -75,11 +79,6 @@ async def websocket_endpoint(websocket: WebSocket):
elif "bytes" in message:
message = message["bytes"]

# vad for input bytes audio
vad.add_audio(message)
message = b''.join(f for f in vad.vad_collector()
if f is not None)

engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
asr_results = ""
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/vector/cluster/diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
"""
import argparse
import warnings
from distutils.util import strtobool

import numpy as np
import scipy
import sklearn
from distutils.util import strtobool
from scipy import sparse
from scipy.sparse.csgraph import connected_components
from scipy.sparse.csgraph import laplacian as csgraph_laplacian
Expand Down
2 changes: 1 addition & 1 deletion utils/DER.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
import os
import re
import subprocess
from distutils.util import strtobool

import numpy as np
from distutils.util import strtobool

FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)")
SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+")
Expand Down

0 comments on commit 0cde9f8

Please sign in to comment.