Skip to content

Commit

Permalink
fix asr import
Browse files Browse the repository at this point in the history
  • Loading branch information
abdeladim-s committed Jun 2, 2023
1 parent 4305278 commit 25fb306
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions easymms/models/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
__copyright__ = "Copyright 2023,"


import os
import atexit
import json
import re
Expand All @@ -22,19 +23,11 @@
from pydub import AudioSegment
from pydub.utils import mediainfo

# fix importing from fairseq.examples
import site
sys.path.append(str(Path(site.getsitepackages()[0]) / 'fairseq'))
try:
from fairseq.examples.speech_recognition.new.infer import hydra_main
except ImportError:
from examples.speech_recognition.new.infer import hydra_main


from easymms import utils
from easymms import utils as easymms_utils
from easymms._logger import set_log_level
from easymms.models.alignment import AlignmentModel
from easymms.constants import CFG, HYPO_WORDS_FILE, MMS_LANGS_FILE
from easymms import constants

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -65,7 +58,7 @@ def __init__(self,
:param log_level: log level
"""
set_log_level(log_level)
self.cfg = CFG.copy()
self.cfg = constants.CFG.copy()
self.model = Path(model)
self.cfg['common_eval']['path'] = str(self.model.resolve())

Expand All @@ -75,6 +68,10 @@ def __init__(self,

self.wer = None

# clone Fairseq
easymms_utils.clone(constants.FAIRSEQ_URL, constants.FAIRSEQ_DIR)
sys.path.append(str(constants.FAIRSEQ_DIR.resolve()))

def _cleanup(self) -> None:
"""
cleans up the temp_dir
Expand Down Expand Up @@ -145,6 +142,12 @@ def transcribe(self,
:return: List of transcription text in the same order as input files
"""
# import
cwd = os.getcwd()
os.chdir(constants.FAIRSEQ_DIR)
from examples.speech_recognition.new.infer import hydra_main
os.chdir(cwd)

processed_files = self._prepare_media_files(media_files)
self._setup_tmp_dir(processed_files)
# edit cfg
Expand All @@ -166,7 +169,7 @@ def transcribe(self,

self.wer = hydra_main(cfg)
# get results: will just read from hypo.word as I don't want to change fairseq repo to get the hypo array
hypo_file = self.tmp_dir_path / HYPO_WORDS_FILE
hypo_file = self.tmp_dir_path / constants.HYPO_WORDS_FILE
res = []
with open(hypo_file) as hw:
hypos = hw.readlines()
Expand All @@ -176,7 +179,7 @@ def transcribe(self,
align_model = AlignmentModel()
for i in range(len(transcripts)):
media_file = processed_files[i][0]
transcript = utils.get_transcript_segments(transcripts[i], timestamps_type, max_segment_len=max_segment_len)
transcript = easymms_utils.get_transcript_segments(transcripts[i], timestamps_type, max_segment_len=max_segment_len)
segments = align_model.align(media_file=media_file,
transcript=transcript,
lang=lang,
Expand All @@ -193,7 +196,7 @@ def get_supported_langs() -> List[str]:
Source <https://dl.fbaipublicfiles.com/mms/misc/language_coverage_mms.html>
:return: list of supported languages
"""
with open(MMS_LANGS_FILE) as f:
with open(constants.MMS_LANGS_FILE) as f:
data = json.load(f)
return [key for key in data if data[key]['ASR']]

Expand Down

0 comments on commit 25fb306

Please sign in to comment.