Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lhotse AudioToAudio dataset (supports ref recording and embedding) #8477

Merged
merged 22 commits into from Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
40acfea
Draft for Lhotse AudioToAudio dataset (supports ref recording and emb…
pzelasko Feb 21, 2024
d903fb4
Integrate with speech enhancement models
pzelasko Feb 22, 2024
3894edf
Fix absolute path + write cuts in the output manifest
anteju Mar 5, 2024
99500e3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2024
66d8ec5
Support channel selectors for input, reference, and target recordings
pzelasko Mar 7, 2024
aac3db3
Support on the fly truncation and/or cutting into windows
pzelasko Mar 7, 2024
2fee158
Merge branch 'main' into feature/lhotse-audio-to-audio-dataset
pzelasko Mar 7, 2024
496fc66
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
f2d0bbe
Bump min required lhotse version
pzelasko Mar 7, 2024
864f5a4
Add copyright headers
pzelasko Mar 8, 2024
123dac5
Added unit tests checking lhotse dataloader is matching the existing …
anteju Mar 12, 2024
9b8b458
Fix batch unpacking, test_ds, use nemo logging
anteju Apr 8, 2024
496fb4e
fixed some code scanning issues
anteju Apr 8, 2024
b59dfb5
Fixed a couple CI issues
anteju Apr 9, 2024
f5bf68c
Support NeMo-style resolution of relative paths in native lhotse cuts
pzelasko Apr 11, 2024
470ba9c
Added option to leave original paths or force absolute paths in the c…
anteju Apr 11, 2024
8860979
Fix support for relative path resolution in lhotse arrays
pzelasko Apr 12, 2024
bf26822
Merge branch 'main' into feature/lhotse-audio-to-audio-dataset
pablo-garay Apr 13, 2024
ad4ad82
Merge branch 'main' into feature/lhotse-audio-to-audio-dataset
pzelasko Apr 15, 2024
319a441
Fix unit tests
pzelasko Apr 15, 2024
538e9e4
Merge branch 'main' into feature/lhotse-audio-to-audio-dataset
pzelasko Apr 15, 2024
a1628e0
Merge branch 'main' into feature/lhotse-audio-to-audio-dataset
pzelasko Apr 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/audio_tasks/audio_to_audio_eval.py
Expand Up @@ -73,7 +73,9 @@
from tqdm import tqdm

from nemo.collections.asr.data import audio_to_audio_dataset
from nemo.collections.asr.data.audio_to_audio_lhotse import LhotseAudioToTargetDataset
from nemo.collections.asr.metrics.audio import AudioMetricWrapper
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.collections.common.parts.preprocessing import manifest
from nemo.core.config import hydra_runner
from nemo.utils import logging
Expand Down Expand Up @@ -103,6 +105,11 @@ class AudioEvaluationConfig(process_audio.ProcessConfig):
def get_evaluation_dataloader(config):
"""Prepare a dataloader for evaluation.
"""
if config.get("use_lhotse", False):
return get_lhotse_dataloader_from_config(
config, global_rank=0, world_size=1, dataset=LhotseAudioToTargetDataset()
)

dataset = audio_to_audio_dataset.get_audio_to_target_dataset(config=config)

return torch.utils.data.DataLoader(
Expand Down
2 changes: 1 addition & 1 deletion examples/audio_tasks/speech_enhancement.py
Expand Up @@ -51,7 +51,7 @@ def main(cfg):
trainer.fit(model)

# Run on test data, if available
if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
if hasattr(cfg.model, 'test_ds'):
if trainer.is_global_zero:
# Destroy the current process group and let the trainer initialize it again with a single device.
if torch.distributed.is_initialized():
Expand Down
163 changes: 163 additions & 0 deletions nemo/collections/asr/data/audio_to_audio_lhotse.py
@@ -0,0 +1,163 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from lhotse import AudioSource, CutSet, Recording
from lhotse.array import Array
from lhotse.audio import info
from lhotse.cut import MixedCut
from lhotse.dataset.collation import collate_audio, collate_custom_field
from lhotse.serialization import load_jsonl

from nemo.collections.common.parts.preprocessing.manifest import get_full_path

INPUT_CHANNEL_SELECTOR = "input_channel_selector"
TARGET_CHANNEL_SELECTOR = "target_channel_selector"
REFERENCE_CHANNEL_SELECTOR = "reference_channel_selector"
LHOTSE_TARGET_CHANNEL_SELECTOR = "target_recording_channel_selector"
LHOTSE_REFERENCE_CHANNEL_SELECTOR = "reference_recording_channel_selector"


class LhotseAudioToTargetDataset(torch.utils.data.Dataset):
"""
A dataset for audio-to-audio tasks where the goal is to use
an input signal to recover the corresponding target signal.

.. note:: This is a Lhotse variant of :class:`nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset`.
"""

TARGET_KEY = "target_recording"
REFERENCE_KEY = "reference_recording"
EMBEDDING_KEY = "embedding_vector"

def __getitem__(self, cuts: CutSet) -> dict[str, torch.Tensor]:
src_audio, src_audio_lens = collate_audio(cuts)
ans = {
"input_signal": src_audio,
"input_length": src_audio_lens,
}
if _key_available(cuts, self.TARGET_KEY):
tgt_audio, tgt_audio_lens = collate_audio(cuts, recording_field=self.TARGET_KEY)
ans.update(target_signal=tgt_audio, target_length=tgt_audio_lens)
if _key_available(cuts, self.REFERENCE_KEY):
ref_audio, ref_audio_lens = collate_audio(cuts, recording_field=self.REFERENCE_KEY)
ans.update(reference_signal=ref_audio, reference_length=ref_audio_lens)
if _key_available(cuts, self.EMBEDDING_KEY):
emb = collate_custom_field(cuts, field=self.EMBEDDING_KEY)
ans.update(embedding_signal=emb)
return ans


def _key_available(cuts: CutSet, key: str) -> bool:
for cut in cuts:
if isinstance(cut, MixedCut):
cut = cut._first_non_padding_cut
if cut.custom is not None and key in cut.custom:
continue
else:
return False
return True


def create_recording(path_or_paths: str | list[str]) -> Recording:
if isinstance(path_or_paths, list):
cur_channel_idx = 0
sources = []
infos = []
for p in path_or_paths:
i = info(p)
infos.append(i)
sources.append(
AudioSource(type="file", channels=list(range(cur_channel_idx, cur_channel_idx + i.channels)), source=p)
)
cur_channel_idx += i.channels
assert all(
i.samplerate == infos[0].samplerate for i in infos[1:]
), f"Mismatched sampling rates for individual audio files in: {path_or_paths}"
recording = Recording(
id=p[0],
sources=sources,
sampling_rate=infos[0].samplerate,
num_samples=infos[0].frames,
duration=infos[0].duration,
channel_ids=list(range(0, cur_channel_idx)),
)
else:
recording = Recording.from_file(path_or_paths)
return recording


def create_array(path: str) -> Array:
assert path.endswith(".npy"), f"Currently only conversion of numpy files is supported (got: {path})"
arr = np.load(path)
parent, path = os.path.split(path)
return Array(storage_type="numpy_files", storage_path=parent, storage_key=path, shape=list(arr.shape),)


def convert_manifest_nemo_to_lhotse(
input_manifest: str,
output_manifest: str,
input_key: str = 'input_filepath',
target_key: str = 'target_filepath',
reference_key: str = 'reference_filepath',
embedding_key: str = 'embedding_filepath',
):
with CutSet.open_writer(output_manifest) as writer:
for item in load_jsonl(input_manifest):

# Create Lhotse recording and cut object, apply offset and duration slicing if present.
recording = create_recording(get_full_path(audio_file=item.pop(input_key), manifest_file=input_manifest))
anteju marked this conversation as resolved.
Show resolved Hide resolved
cut = recording.to_cut().truncate(duration=item.pop("duration"), offset=item.pop("offset", 0.0))

if (channels := item.pop(INPUT_CHANNEL_SELECTOR, None)) is not None:
if cut.num_channels == 1:
assert (
len(channels) == 1 and channels[0] == 0
), f"The input recording has only a single channel, but manifest specified {INPUT_CHANNEL_SELECTOR}={channels}"
else:
cut = cut.with_channels(channels)

if target_key in item:
cut.target_recording = create_recording(
get_full_path(audio_file=item.pop(target_key), manifest_file=input_manifest)
)
if (channels := item.pop(TARGET_CHANNEL_SELECTOR, None)) is not None:
if cut.target_recording.num_channels == 1:
assert (
len(channels) == 1 and channels[0] == 0
), f"The target recording has only a single channel, but manifest specified {TARGET_CHANNEL_SELECTOR}={channels}"
else:
cut = cut.with_custom(LHOTSE_TARGET_CHANNEL_SELECTOR, channels)

if reference_key in item:
cut.reference_recording = create_recording(
get_full_path(audio_file=item.pop(reference_key), manifest_file=input_manifest)
)
if (channels := item.pop(REFERENCE_CHANNEL_SELECTOR, None)) is not None:
if cut.reference_recording.num_channels == 1:
assert (
len(channels) == 1 and channels[0] == 0
), f"The reference recording has only a single channel, but manifest specified {REFERENCE_CHANNEL_SELECTOR}={channels}"
else:
cut = cut.with_custom(LHOTSE_REFERENCE_CHANNEL_SELECTOR, channels)

if embedding_key in item:
cut.embedding_vector = create_array(
get_full_path(audio_file=item.pop(embedding_key), manifest_file=input_manifest)
)

if item:
cut.custom.update(item) # any field that's still left goes to custom fields

writer.write(cut)
27 changes: 25 additions & 2 deletions nemo/collections/asr/models/enhancement_models.py
Expand Up @@ -24,9 +24,11 @@
from tqdm import tqdm

from nemo.collections.asr.data import audio_to_audio_dataset
from nemo.collections.asr.data.audio_to_audio_lhotse import LhotseAudioToTargetDataset
from nemo.collections.asr.data.audio_to_text_dataset import inject_dataloader_value_from_model_config
from nemo.collections.asr.models.audio_to_audio_model import AudioToAudioModel
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType
from nemo.utils import logging
Expand Down Expand Up @@ -198,6 +200,11 @@

def _setup_dataloader_from_config(self, config: Optional[Dict]):

if config.get("use_lhotse", False):
return get_lhotse_dataloader_from_config(
config, global_rank=self.global_rank, world_size=self.world_size, dataset=LhotseAudioToTargetDataset()
)

is_concat = config.get('is_concat', False)
if is_concat:
raise NotImplementedError('Concat not implemented')
Expand Down Expand Up @@ -398,7 +405,15 @@

# PTL-specific methods
def training_step(self, batch, batch_idx):
input_signal, input_length, target_signal, target_length = batch

if isinstance(batch, dict):
# lhotse batches are dictionaries
input_signal = batch['input_signal']
input_length = batch['input_length']
target_signal = batch['target_signal']
target_length = batch['target_length']
Fixed Show fixed Hide fixed
else:
input_signal, input_length, target_signal, target_length = batch

# Expand channel dimension, if necessary
# For consistency, the model uses multi-channel format, even if the channel dimension is 1
Expand Down Expand Up @@ -426,7 +441,15 @@
return loss

def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):
input_signal, input_length, target_signal, target_length = batch

if isinstance(batch, dict):
# lhotse batches are dictionaries
input_signal = batch['input_signal']
input_length = batch['input_length']
target_signal = batch['target_signal']
target_length = batch['target_length']
Fixed Show fixed Hide fixed
else:
input_signal, input_length, target_signal, target_length = batch

# Expand channel dimension, if necessary
# For consistency, the model uses multi-channel format, even if the channel dimension is 1
Expand Down
32 changes: 28 additions & 4 deletions nemo/collections/common/data/lhotse/dataloader.py
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import warnings
from dataclasses import dataclass
from functools import partial
Expand All @@ -33,6 +32,7 @@
from omegaconf import DictConfig, OmegaConf

from nemo.collections.common.data.lhotse.cutset import read_cutset_from_config
from nemo.utils import logging
anteju marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
Expand Down Expand Up @@ -92,6 +92,15 @@ class LhotseDataLoadingConfig:
concatenate_duration_factor: float = 1.0
concatenate_merge_supervisions: bool = True
db_norm: Optional[float] = -25.0 # from CodeSwitchingDataset
# d. On-the-fly cut truncation or window slicing
# I) truncate: select one chunk of a fixed duration for each cut
truncate_duration: Optional[float] = None # set this to enable
truncate_offset_type: str = "random" # "random" | "start" (fixed) | "end" (fixed, counted back)
# II) cut_into_windows: convert each cut to smaller cut using a sliding window (define hop for overlapping windows)
cut_into_windows_duration: Optional[float] = None # set this to enable
cut_into_windows_hop: Optional[float] = None
# III) common options
keep_excessive_supervisions: bool = True # when a cut is truncated in the middle of a supervision, should we keep them.

# 5. Other Lhotse options.
text_field: str = "text" # key to read the transcript from
Expand Down Expand Up @@ -128,9 +137,6 @@ def get_lhotse_dataloader_from_config(
# Resample as a safeguard; it's a no-op when SR is already OK
cuts = cuts.resample(config.sample_rate)

# Duration filtering, same as native NeMo dataloaders.
cuts = cuts.filter(DurationFilter(config.min_duration, config.max_duration))

# Expands cuts if multiple translations are provided.
cuts = CutSet(LazyFlattener(cuts.map(_flatten_alt_text)))

Expand All @@ -149,6 +155,24 @@ def get_lhotse_dataloader_from_config(
if config.perturb_speed:
cuts = CutSet.mux(cuts, cuts.perturb_speed(0.9), cuts.perturb_speed(1.1),)

# 2.d: truncation/slicing
if config.truncate_duration is not None:
cuts = cuts.truncate(
max_duration=config.truncate_duration,
offset_type=config.truncate_offset_type,
keep_excessive_supervisions=config.keep_excessive_supervisions,
)
if config.cut_into_windows_duration is not None:
cuts = cuts.cut_into_windows(
duration=config.cut_into_windows_duration,
hop=config.cut_into_windows_hop,
keep_excessive_supervisions=config.keep_excessive_supervisions,
)

# Duration filtering, same as native NeMo dataloaders.
# We can filter after the augmentations because they are applied only when calling load_audio().
cuts = cuts.filter(DurationFilter(config.min_duration, config.max_duration))

# 3. The sampler.
if config.use_bucketing:
# Bucketing. Some differences from NeMo's native bucketing:
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements_asr.txt
Expand Up @@ -5,7 +5,7 @@ ipywidgets
jiwer
kaldi-python-io
kaldiio
lhotse>=1.20.0
lhotse>=1.22.0
librosa>=0.10.0
marshmallow
matplotlib
Expand Down