Skip to content

Commit

Permalink
Merge pull request #277 from CoEDL/resampling-utility
Browse files Browse the repository at this point in the history
Create resampling utility
  • Loading branch information
harrykeightley committed Jun 7, 2022
2 parents 4a1446a + 954abe0 commit afa94ce
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 71 deletions.
67 changes: 67 additions & 0 deletions elpis/engines/common/utilities/resampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from pathlib import Path
from typing import Any, Dict, Tuple

import numpy as np
import soundfile as sf
import librosa

from werkzeug.datastructures import FileStorage


ORIGINAL_SOUND_FILE_DIRECTORY = Path('/tmp/origial_sound_files/')


def load_audio(file: Path, target_sample_rate: int = None) -> Tuple[np.ndarray, int]:
"""Loads a file and returns the data wtihin.
Parameters:
file (Path): The path of the file to load.
target_sample_rate (int): An optional sample rate with which to load the
audio data.
Returns:
(Tuple<np.ndarray, int>): A tuple containing the numpy array of
the audio data, and the native sample rate of the file.
"""
return librosa.load(file, sr=target_sample_rate)

def resample_audio(file: Path, destination: Path, target_sample_rate: int) -> None:
"""Writes a resampled audio file to the supplied destination, with a supplied
sample rate.
Parameters:
file (Path): The path of the file to resample
destination (Path): The destination at which to create the resampled file
target_sample_rate (int): The target sample rate for the resampled audio.
"""
data, sample_rate = librosa.load(file, sr=None)

# Create temporary directory if it hasn't already been created
ORIGINAL_SOUND_FILE_DIRECTORY.mkdir(parents=True, exist_ok=True)

# Copy audio to the temporary path
original = ORIGINAL_SOUND_FILE_DIRECTORY / file.name
sf.write(original, data, sample_rate)

# Resample and overwrite
sf.write(destination, data, target_sample_rate)


def resample_from_file_storage(file: FileStorage, destination: Path, target_sample_rate: int) -> Dict:
""" Performs audio resampling from a flask request FileStorage file, and
returns some information about the original file.
"""
# Create temporary directory if it hasn't already been created
ORIGINAL_SOUND_FILE_DIRECTORY.mkdir(parents=True, exist_ok=True)

original = ORIGINAL_SOUND_FILE_DIRECTORY / file.filename
with original.open(mode='wb') as fout:
fout.write(file.read())

info = {
'duration': librosa.get_duration(filename=original)
}

resample_audio(original, destination, target_sample_rate)
return info
12 changes: 9 additions & 3 deletions elpis/engines/hft/objects/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,23 +445,29 @@ def prepare_speech(self):
speech = {}
audio_paths = set()
rejected_count = 0

for utt in self.hft_dataset['train']:
audio_paths.add((utt['path'], utt['text'], utt['start_ms'], utt['stop_ms']))

for utt in self.hft_dataset['dev']:
audio_paths.add((utt['path'], utt['text'], utt['start_ms'], utt['stop_ms']))

for utt in self.hft_dataset['test']:
audio_paths.add((utt['path'], utt['text'], utt['start_ms'], utt['stop_ms']))

for path, text, start_ms, stop_ms in audio_paths:
audio_metadata = torchaudio.info(path)

start_frame = int(start_ms * (audio_metadata.sample_rate/1000))
end_frame = int(stop_ms * (audio_metadata.sample_rate/1000))
num_frames = end_frame - start_frame

dur_ms = stop_ms - start_ms
speech_array, sample_rate = torchaudio.load(filepath=path, frame_offset=start_frame, num_frames=num_frames)
# Check that frames exceeds number of characters, wav file is not all zeros, and duration between min, max
if int(audio_metadata.num_frames) >= len(text) \
and speech_array.count_nonzero() \
and float(self.settings['min_duration_s']) < dur_ms/1000 < float(self.settings['max_duration_s']):
if (int(audio_metadata.num_frames) >= len(text)
and speech_array.count_nonzero()
and float(self.settings['min_duration_s']) < dur_ms/1000 < float(self.settings['max_duration_s'])):
# Resample if required
if sample_rate != HFTModel.SAMPLING_RATE:
logger.info(f'Resample from {sample_rate} to {HFTModel.SAMPLING_RATE} | '
Expand Down
45 changes: 10 additions & 35 deletions elpis/engines/hft/objects/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from typing import List, Tuple
from pprint import pprint
from itertools import groupby
import elpis.engines.common.utilities.resampling as resampler

import librosa
import pympi
import soundfile as sf
import torch
from transformers import (
Wav2Vec2ForCTC,
Expand All @@ -18,6 +17,8 @@
from elpis.engines.common.objects.transcription import Transcription as BaseTranscription
from elpis.engines.hft.objects.model import HFTModel

from werkzeug.datastructures import FileStorage


LOAD_AUDIO = 'load_audio'
PROCESS_INPUT = 'process_input'
Expand Down Expand Up @@ -62,8 +63,9 @@ def transcribe(self, on_complete: callable = None) -> None:

# Load audio
self._set_stage(LOAD_AUDIO)
logger.info('==== Load audio ====')
audio_input, sample_rate = self._load_audio(self.audio_file_path)
logger.info('=== Load audio')
audio_input, _ = resampler.load_audio(self.audio_file_path,
target_sample_rate=HFTTranscription.SAMPLING_RATE)
self._set_stage(LOAD_AUDIO, complete=True)

# Pad input values and return pt tensor
Expand Down Expand Up @@ -141,8 +143,8 @@ def _generate_utterances(self,
predicted_ids = predicted_ids[0].tolist()

# Determine original sample rate
original_file = Path(f'/tmp/{self.hash}/original.wav')
_, sample_rate = self._load_audio(original_file)
original_file = resampler.ORIGINAL_SOUND_FILE_DIRECTORY / self.audio_file_path.name
_, sample_rate = resampler.load_audio(original_file)

# Add times to ids
duration_sec = input_values.shape[1] / sample_rate
Expand Down Expand Up @@ -205,39 +207,12 @@ def _save_utterances(self, utterances) -> None:

pympi.Elan.to_eaf(self.elan_path, result)

def _load_audio(self, file: Path) -> Tuple:
audio, sample_rate = librosa.load(file, sr=HFTTranscription.SAMPLING_RATE)
return audio, sample_rate

def prepare_audio(self, audio: Path, on_complete: callable = None):
logger.info(f'==== Prepare audio {audio} {self.audio_file_path} ====')
self._resample_audio_file(audio, self.audio_file_path)
logger.info(f'=== Prepare audio {audio} {self.audio_file_path}')
resampler.resample_from_file_storage(audio, self.audio_file_path, HFTModel.SAMPLING_RATE)
if on_complete is not None:
on_complete()

def _resample_audio_file(self, audio: Path, dest: Path):
"""
Resamples the audio file to be the same sampling rate as the model
was trained with.
Target sampling rate is taken from the HFTModel class.
Parameters:
audio (Path): A path to a soundfile
dest (Path): The destination path at which to write the resampled
audio.
"""
data, sample_rate = self._load_audio(audio)

# Copy to temporary path
temporary_path = Path(f'/tmp/{self.hash}')
temporary_path.mkdir(parents=True, exist_ok=True)
sound_copy = temporary_path.joinpath('original.wav')
sf.write(sound_copy, data, sample_rate)

# Resample and overwrite
sf.write(dest, data, HFTModel.SAMPLING_RATE)

def _set_finished_transcription(self, has_finished: bool) -> None:
self.status = FINISHED if has_finished else UNFINISHED

Expand Down
77 changes: 44 additions & 33 deletions elpis/engines/kaldi/objects/transcription.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from pathlib import Path
from elpis.engines.common.input.resample import resample
from elpis.engines.common.objects.command import run
from elpis.engines.common.objects.transcription import Transcription as BaseTranscription
import subprocess
import elpis.engines.common.utilities.resampling as resampler
from typing import Callable, Dict
import os
import shutil
Expand All @@ -15,8 +14,13 @@
from csv import reader
import codecs

from werkzeug.datastructures import FileStorage


class KaldiTranscription(BaseTranscription):

SAMPLE_RATE = 16_000

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.audio_filename = None
Expand All @@ -31,43 +35,37 @@ def load(cls, base_path: Path):
def _build_spk2utt_file(self, spk_id: str, utt_id: str):
spk2utt_path = self.path.joinpath('spk2utt')
with spk2utt_path.open(mode='w') as fout:
fout.write(f'{spk_id} {utt_id}\n')
fout.write(f'{spk_id} {utt_id}\n')

def _build_utt2spk_file(self, utt_id: str, spk_id: str):
utt2spk_path = self.path.joinpath('utt2spk')
with utt2spk_path.open(mode='w') as fout:
fout.write(f'{utt_id} {spk_id}\n')
fout.write(f'{utt_id} {spk_id}\n')

def _build_segments_file(self, utt_id: str, rec_id: str, start_ms: float, stop_ms: float):
segments_path = self.path.joinpath('segments')
with segments_path.open(mode='w') as fout:
fout.write(f'{utt_id} {rec_id} {start_ms} {stop_ms}\n')
fout.write(f'{utt_id} {rec_id} {start_ms} {stop_ms}\n')

def _build_wav_scp_file(self, rec_id: str, rel_audio_file_path: Path):
wav_scp_path = self.path.joinpath('wav.scp')
with wav_scp_path.open(mode='w') as fout:
fout.write(f'{rec_id} {rel_audio_file_path}\n')
fout.write(f'{rec_id} {rel_audio_file_path}\n')

# Write audio filename to a file so the shell scripts can read it. Maybe safer than setting ENV?
def _build_audio_meta(self, audio_filename: Path):
audio_meta_path = self.path.joinpath('audio_meta.txt')
with audio_meta_path.open(mode='w') as fout:
fout.write(f'{audio_filename}\n')
fout.write(f'{audio_filename}\n')

def _process_audio_file(self, audio):
# TODO: why save to tmp and not just resample to self.path location?
# copy audio to the tmp folder for resampling
def _process_audio_file(self, audio: FileStorage):
print("========= process audio for transcription", self.path)
tmp_path = Path(f'/tmp/{self.hash}')
tmp_path.mkdir(parents=True, exist_ok=True)
tmp_file_path = tmp_path.joinpath(audio.filename)
with tmp_file_path.open(mode='wb') as fout:
fout.write(audio.read())
# resample the audio file
resample(tmp_file_path, self.path.joinpath(audio.filename))
self.audio_filename = audio.filename
self.audio_file_path = self.path.joinpath(self.audio_filename)
self.audio_duration = librosa.get_duration(filename=tmp_file_path)

info = resampler.resample_from_file_storage(
audio, self.audio_file_path, KaldiTranscription.SAMPLE_RATE)
self.audio_duration = info['duration']

# Prepare the files we need for inference, based on the audio we receive
def _generate_inference_files(self):
Expand All @@ -89,7 +87,8 @@ def _generate_inference_files(self):
# Rec id is arbitrary, use anything you like here
rec_id = 'decode'
# Path to the audio, relative to kaldi working dir
rel_audio_file_path = os.path.join('data', 'infer', self.audio_filename)
rel_audio_file_path = os.path.join(
'data', 'infer', self.audio_filename)
# Generate the files
self._build_spk2utt_file(spk_id, utt_id)
self._build_utt2spk_file(utt_id, spk_id)
Expand All @@ -111,7 +110,7 @@ def transcribe(self, on_complete: Callable = None):

print("========= reset exp dir")
# wipe previous exp dir to avoid file_exists errors
exp_path = self.model.path.joinpath('kaldi','exp','tri1_online')
exp_path = self.model.path.joinpath('kaldi', 'exp', 'tri1_online')
if exp_path.exists():
shutil.rmtree(f'{exp_path}')
exp_path.mkdir(parents=True, exist_ok=True)
Expand All @@ -136,12 +135,14 @@ def transcribe(self, on_complete: Callable = None):
"gmm-decode-conf.sh": "transcribing"
}
# Move the relevant templates into the kaldi/data/infer dir.
template_dir_abs_path = Path('/elpis/elpis/engines/kaldi/inference/').joinpath(template_dir_path)
template_dir_abs_path = Path(
'/elpis/elpis/engines/kaldi/inference/').joinpath(template_dir_path)
# Build status for logging
super().build_stage_status(stage_names)

# Provide a helper script for both methods
shutil.copy(Path('/elpis/elpis/engines/kaldi/templates').joinpath('make_split.sh'), f"{local_kaldi_path}")
shutil.copy(Path('/elpis/elpis/engines/kaldi/templates')
.joinpath('make_split.sh'), f"{local_kaldi_path}")
os.chmod(local_kaldi_path.joinpath('make_split.sh'), 0o774)

# Prepare (dump, recreate) main transcription log file
Expand All @@ -154,34 +155,41 @@ def transcribe(self, on_complete: Callable = None):
transcription_log_dir = self.path.joinpath('transcription-logs')
if os.path.exists(transcription_log_dir):
shutil.rmtree(transcription_log_dir)
os.mkdir(transcription_log_dir )
os.mkdir(transcription_log_dir)

stage_count = 0

# Build stage scripts
dir_util.copy_tree(f'{self.path}', f"{kaldi_infer_path}")
file_util.copy_file(f'{self.audio_file_path}', f"{self.model.path.joinpath('kaldi', self.audio_filename)}")
file_util.copy_file(f'{self.audio_file_path}',
f"{self.model.path.joinpath('kaldi', self.audio_filename)}")
# Copy parts of transcription process and chmod
os.makedirs(f"{kaldi_infer_path.joinpath(template_dir_path)}", exist_ok=True)
dir_util.copy_tree(f'{template_dir_abs_path}', f"{kaldi_infer_path.joinpath(template_dir_path)}")
os.makedirs(
f"{kaldi_infer_path.joinpath(template_dir_path)}", exist_ok=True)
dir_util.copy_tree(f'{template_dir_abs_path}',
f"{kaldi_infer_path.joinpath(template_dir_path)}")
stages = os.listdir(kaldi_infer_path.joinpath(template_dir_path))
for file in stages:
os.chmod(kaldi_infer_path.joinpath(template_dir_path).joinpath(file), 0o774)
os.chmod(kaldi_infer_path.joinpath(
template_dir_path).joinpath(file), 0o774)
for stage in sorted(stages):
print(f"Stage {stage} starting")
self.stage_status = (stage, 'in-progress', '')

# Create log file
stage_log_path = self.path.joinpath(os.path.join(transcription_log_dir, f'stage_{stage_count}.log'))
stage_log_path = self.path.joinpath(os.path.join(
transcription_log_dir, f'stage_{stage_count}.log'))
with open(stage_log_path, 'w+') as file:
print('starting log', file=file)
pass

# Run the command, log output. Also redirect Kaldi sterr output to log. These are often not errors :-(
# These scripts must run from the kaldi dir (so set cwd)
try:
script_path = kaldi_infer_path.joinpath(template_dir_path, stage)
stage_process = run(f"sh {script_path} >> {stage_log_path}", cwd=f"{local_kaldi_path}")
script_path = kaldi_infer_path.joinpath(
template_dir_path, stage)
stage_process = run(
f"sh {script_path} >> {stage_log_path}", cwd=f"{local_kaldi_path}")
with open(stage_log_path, 'a+') as file:
print('stdout', stage_process.stdout, file=file)
print('stderr', stage_process.stderr, file=file)
Expand All @@ -206,9 +214,12 @@ def transcribe(self, on_complete: Callable = None):
outfile.write(infile.read())
outfile.write("\n")

file_util.copy_file(f"{kaldi_infer_path.joinpath('one-best-hypothesis.txt')}", f'{self.path}/one-best-hypothesis.txt')
file_util.copy_file(f"{kaldi_infer_path.joinpath('utterance-0.eaf')}", f'{self.path}/{self.hash}.eaf')
file_util.copy_file(f"{kaldi_infer_path.joinpath('ctm_with_conf.ctm')}", f'{self.path}/ctm_with_conf.ctm')
file_util.copy_file(
f"{kaldi_infer_path.joinpath('one-best-hypothesis.txt')}", f'{self.path}/one-best-hypothesis.txt')
file_util.copy_file(
f"{kaldi_infer_path.joinpath('utterance-0.eaf')}", f'{self.path}/{self.hash}.eaf')
file_util.copy_file(
f"{kaldi_infer_path.joinpath('ctm_with_conf.ctm')}", f'{self.path}/ctm_with_conf.ctm')

self.status = "transcribed"

Expand Down

0 comments on commit afa94ce

Please sign in to comment.