-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
__init__.py
61 lines (54 loc) · 1.99 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Copyright (c) 2019 NVIDIA Corporation
import os
from flask import Flask
from ruamel.yaml import YAML
import nemo
import nemo_asr
app = Flask(__name__)
# make sure WORK_DIR exists before calling your service
# in this folder, the service will store received .wav files and constructed
# .json manifests
WORK_DIR = "<PATH_TO_YOUR_WORKDIR>"
MODEL_YAML = "<PATH_TO_YOUR_YAML>"
CHECKPOINT_ENCODER = "<PATH_TO_ENCODER_CHECKPOINT>"
CHECKPOINT_DECODER = "<PATH_TO_DECODER_CHECKPOINT>"
# Set this to True to enable beam search decoder
ENABLE_NGRAM = False
# This is only necessary if ENABLE_NGRAM = True. Otherwise, set to empty string
LM_PATH = "<PATH_TO_KENLM_BINARY>"
# Read model YAML
yaml = YAML(typ="safe")
with open(MODEL_YAML) as f:
jasper_model_definition = yaml.load(f)
labels = jasper_model_definition['labels']
# Instantiate necessary Neural Modules
# Note that data layer is missing from here
neural_factory = nemo.core.NeuralModuleFactory(
placement=nemo.core.DeviceType.GPU,
backend=nemo.core.Backend.PyTorch)
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
factory=neural_factory)
jasper_encoder = nemo_asr.JasperEncoder(
jasper=jasper_model_definition['JasperEncoder']['jasper'],
activation=jasper_model_definition['JasperEncoder']['activation'],
feat_in=jasper_model_definition[
'AudioToMelSpectrogramPreprocessor']['features'])
jasper_encoder.restore_from(CHECKPOINT_ENCODER, local_rank=0)
jasper_decoder = nemo_asr.JasperDecoderForCTC(
feat_in=1024,
num_classes=len(labels))
jasper_decoder.restore_from(CHECKPOINT_DECODER, local_rank=0)
greedy_decoder = nemo_asr.GreedyCTCDecoder()
if ENABLE_NGRAM and os.path.isfile(LM_PATH):
beam_search_with_lm = nemo_asr.BeamSearchDecoderWithLM(
vocab=labels,
beam_width=64,
alpha=2.0,
beta=1.0,
lm_path=LM_PATH,
num_cpus=max(os.cpu_count(), 1))
else:
print("Beam search is not enabled")
from app import routes # noqa
if __name__ == '__main__':
app.run()