Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lhotse AudioToAudio dataset (supports ref recording and embedding) #8477

Merged
merged 22 commits into from Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
40acfea
Draft for Lhotse AudioToAudio dataset (supports ref recording and emb…
pzelasko Feb 21, 2024
d903fb4
Integrate with speech enhancement models
pzelasko Feb 22, 2024
3894edf
Fix absolute path + write cuts in the output manifest
anteju Mar 5, 2024
99500e3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2024
66d8ec5
Support channel selectors for input, reference, and target recordings
pzelasko Mar 7, 2024
aac3db3
Support on the fly truncation and/or cutting into windows
pzelasko Mar 7, 2024
2fee158
Merge branch 'main' into feature/lhotse-audio-to-audio-dataset
pzelasko Mar 7, 2024
496fc66
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
f2d0bbe
Bump min required lhotse version
pzelasko Mar 7, 2024
864f5a4
Add copyright headers
pzelasko Mar 8, 2024
123dac5
Added unit tests checking lhotse dataloader is matching the existing …
anteju Mar 12, 2024
9b8b458
Fix batch unpacking, test_ds, use nemo logging
anteju Apr 8, 2024
496fb4e
fixed some code scanning issues
anteju Apr 8, 2024
b59dfb5
Fixed a couple CI issues
anteju Apr 9, 2024
f5bf68c
Support NeMo-style resolution of relative paths in native lhotse cuts
pzelasko Apr 11, 2024
470ba9c
Added option to leave original paths or force absolute paths in the c…
anteju Apr 11, 2024
8860979
Fix support for relative path resolution in lhotse arrays
pzelasko Apr 12, 2024
bf26822
Merge branch 'main' into feature/lhotse-audio-to-audio-dataset
pablo-garay Apr 13, 2024
ad4ad82
Merge branch 'main' into feature/lhotse-audio-to-audio-dataset
pzelasko Apr 15, 2024
319a441
Fix unit tests
pzelasko Apr 15, 2024
538e9e4
Merge branch 'main' into feature/lhotse-audio-to-audio-dataset
pzelasko Apr 15, 2024
a1628e0
Merge branch 'main' into feature/lhotse-audio-to-audio-dataset
pzelasko Apr 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/audio_tasks/audio_to_audio_eval.py
Expand Up @@ -73,7 +73,9 @@
from tqdm import tqdm

from nemo.collections.asr.data import audio_to_audio_dataset
from nemo.collections.asr.data.audio_to_audio_lhotse import LhotseAudioToTargetDataset
from nemo.collections.asr.metrics.audio import AudioMetricWrapper
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.collections.common.parts.preprocessing import manifest
from nemo.core.config import hydra_runner
from nemo.utils import logging
Expand Down Expand Up @@ -103,6 +105,11 @@ class AudioEvaluationConfig(process_audio.ProcessConfig):
def get_evaluation_dataloader(config):
"""Prepare a dataloader for evaluation.
"""
if config.get("use_lhotse", False):
return get_lhotse_dataloader_from_config(
config, global_rank=0, world_size=1, dataset=LhotseAudioToTargetDataset()
)

dataset = audio_to_audio_dataset.get_audio_to_target_dataset(config=config)

return torch.utils.data.DataLoader(
Expand Down
2 changes: 1 addition & 1 deletion examples/audio_tasks/speech_enhancement.py
Expand Up @@ -51,7 +51,7 @@ def main(cfg):
trainer.fit(model)

# Run on test data, if available
if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
if hasattr(cfg.model, 'test_ds'):
if trainer.is_global_zero:
# Destroy the current process group and let the trainer initialize it again with a single device.
if torch.distributed.is_initialized():
Expand Down
199 changes: 199 additions & 0 deletions nemo/collections/asr/data/audio_to_audio_lhotse.py
@@ -0,0 +1,199 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from lhotse import AudioSource, CutSet, Recording
from lhotse.array import Array
from lhotse.audio import info
from lhotse.cut import MixedCut
from lhotse.dataset.collation import collate_audio, collate_custom_field
from lhotse.serialization import load_jsonl

from nemo.collections.common.parts.preprocessing.manifest import get_full_path

INPUT_CHANNEL_SELECTOR = "input_channel_selector"
TARGET_CHANNEL_SELECTOR = "target_channel_selector"
REFERENCE_CHANNEL_SELECTOR = "reference_channel_selector"
LHOTSE_TARGET_CHANNEL_SELECTOR = "target_recording_channel_selector"
LHOTSE_REFERENCE_CHANNEL_SELECTOR = "reference_recording_channel_selector"


class LhotseAudioToTargetDataset(torch.utils.data.Dataset):
"""
A dataset for audio-to-audio tasks where the goal is to use
an input signal to recover the corresponding target signal.

.. note:: This is a Lhotse variant of :class:`nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset`.
"""

TARGET_KEY = "target_recording"
REFERENCE_KEY = "reference_recording"
EMBEDDING_KEY = "embedding_vector"

def __getitem__(self, cuts: CutSet) -> dict[str, torch.Tensor]:
src_audio, src_audio_lens = collate_audio(cuts)
ans = {
"input_signal": src_audio,
"input_length": src_audio_lens,
}
if _key_available(cuts, self.TARGET_KEY):
tgt_audio, tgt_audio_lens = collate_audio(cuts, recording_field=self.TARGET_KEY)
ans.update(target_signal=tgt_audio, target_length=tgt_audio_lens)
if _key_available(cuts, self.REFERENCE_KEY):
ref_audio, ref_audio_lens = collate_audio(cuts, recording_field=self.REFERENCE_KEY)
ans.update(reference_signal=ref_audio, reference_length=ref_audio_lens)
if _key_available(cuts, self.EMBEDDING_KEY):
emb = collate_custom_field(cuts, field=self.EMBEDDING_KEY)
ans.update(embedding_signal=emb)
return ans


def _key_available(cuts: CutSet, key: str) -> bool:
for cut in cuts:
if isinstance(cut, MixedCut):
cut = cut._first_non_padding_cut
if cut.custom is not None and key in cut.custom:
continue
else:
return False
return True


def create_recording(path_or_paths: str | list[str]) -> Recording:
if isinstance(path_or_paths, list):
cur_channel_idx = 0
sources = []
infos = []
for p in path_or_paths:
i = info(p)
infos.append(i)
sources.append(
AudioSource(type="file", channels=list(range(cur_channel_idx, cur_channel_idx + i.channels)), source=p)
)
cur_channel_idx += i.channels
assert all(
i.samplerate == infos[0].samplerate for i in infos[1:]
), f"Mismatched sampling rates for individual audio files in: {path_or_paths}"
recording = Recording(
id=p[0],
sources=sources,
sampling_rate=infos[0].samplerate,
num_samples=infos[0].frames,
duration=infos[0].duration,
channel_ids=list(range(0, cur_channel_idx)),
)
else:
recording = Recording.from_file(path_or_paths)
return recording


def create_array(path: str) -> Array:
assert path.endswith(".npy"), f"Currently only conversion of numpy files is supported (got: {path})"
arr = np.load(path)
parent, path = os.path.split(path)
return Array(storage_type="numpy_files", storage_path=parent, storage_key=path, shape=list(arr.shape),)


def convert_manifest_nemo_to_lhotse(
input_manifest: str,
output_manifest: str,
input_key: str = 'input_filepath',
target_key: str = 'target_filepath',
reference_key: str = 'reference_filepath',
embedding_key: str = 'embedding_filepath',
force_absolute_paths: bool = False,
anteju marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Convert an audio-to-audio manifest from NeMo format to Lhotse format.

Args:
input_manifest: Path to the input NeMo manifest.
output_manifest: Path where we'll write the output Lhotse manifest (supported extensions: .jsonl.gz and .jsonl).
input_key: Key of the input recording, mapped to Lhotse's 'Cut.recording'.
target_key: Key of the target recording, mapped to Lhotse's 'Cut.target_recording'.
reference_key: Key of the reference recording, mapped to Lhotse's 'Cut.reference_recording'.
embedding_key: Key of the embedding, mapped to Lhotse's 'Cut.embedding_vector'.
force_absolute_paths: If True, the paths in the output manifest will be absolute.
"""
with CutSet.open_writer(output_manifest) as writer:
for item in load_jsonl(input_manifest):

# Create Lhotse recording and cut object, apply offset and duration slicing if present.
item_input_key = item.pop(input_key)
recording = create_recording(get_full_path(audio_file=item_input_key, manifest_file=input_manifest))
cut = recording.to_cut().truncate(duration=item.pop("duration"), offset=item.pop("offset", 0.0))

if not force_absolute_paths:
# Use the same format for paths as in the original manifest
cut.recording.sources[0].source = item_input_key

if (channels := item.pop(INPUT_CHANNEL_SELECTOR, None)) is not None:
if cut.num_channels == 1:
assert (
len(channels) == 1 and channels[0] == 0
), f"The input recording has only a single channel, but manifest specified {INPUT_CHANNEL_SELECTOR}={channels}"
else:
cut = cut.with_channels(channels)

if target_key in item:
item_target_key = item.pop(target_key)
cut.target_recording = create_recording(
get_full_path(audio_file=item_target_key, manifest_file=input_manifest)
)

if not force_absolute_paths:
# Use the same format for paths as in the original manifest
cut.target_recording.sources[0].source = item_target_key

if (channels := item.pop(TARGET_CHANNEL_SELECTOR, None)) is not None:
if cut.target_recording.num_channels == 1:
assert (
len(channels) == 1 and channels[0] == 0
), f"The target recording has only a single channel, but manifest specified {TARGET_CHANNEL_SELECTOR}={channels}"
else:
cut = cut.with_custom(LHOTSE_TARGET_CHANNEL_SELECTOR, channels)

if reference_key in item:
item_reference_key = item.pop(reference_key)
cut.reference_recording = create_recording(
get_full_path(audio_file=item_reference_key, manifest_file=input_manifest)
)

if not force_absolute_paths:
# Use the same format for paths as in the original manifest
cut.reference_recording.sources[0].source = item_reference_key

if (channels := item.pop(REFERENCE_CHANNEL_SELECTOR, None)) is not None:
if cut.reference_recording.num_channels == 1:
assert (
len(channels) == 1 and channels[0] == 0
), f"The reference recording has only a single channel, but manifest specified {REFERENCE_CHANNEL_SELECTOR}={channels}"
else:
cut = cut.with_custom(LHOTSE_REFERENCE_CHANNEL_SELECTOR, channels)

if embedding_key in item:
item_embedding_key = item.pop(embedding_key)
cut.embedding_vector = create_array(
get_full_path(audio_file=item_embedding_key, manifest_file=input_manifest)
)

if not force_absolute_paths:
# Use the same format for paths as in the original manifest
parent, path = os.path.split(item_embedding_key)
cut.embedding_vector.storage_path = parent

if item:
cut.custom.update(item) # any field that's still left goes to custom fields

writer.write(cut)
25 changes: 23 additions & 2 deletions nemo/collections/asr/models/enhancement_models.py
Expand Up @@ -24,9 +24,11 @@
from tqdm import tqdm

from nemo.collections.asr.data import audio_to_audio_dataset
from nemo.collections.asr.data.audio_to_audio_lhotse import LhotseAudioToTargetDataset
from nemo.collections.asr.data.audio_to_text_dataset import inject_dataloader_value_from_model_config
from nemo.collections.asr.models.audio_to_audio_model import AudioToAudioModel
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType
from nemo.utils import logging
Expand Down Expand Up @@ -198,6 +200,11 @@ def process(

def _setup_dataloader_from_config(self, config: Optional[Dict]):

if config.get("use_lhotse", False):
return get_lhotse_dataloader_from_config(
config, global_rank=self.global_rank, world_size=self.world_size, dataset=LhotseAudioToTargetDataset()
)

is_concat = config.get('is_concat', False)
if is_concat:
raise NotImplementedError('Concat not implemented')
Expand Down Expand Up @@ -398,7 +405,14 @@ def forward(self, input_signal, input_length=None):

# PTL-specific methods
def training_step(self, batch, batch_idx):
input_signal, input_length, target_signal, target_length = batch

if isinstance(batch, dict):
# lhotse batches are dictionaries
input_signal = batch['input_signal']
input_length = batch['input_length']
target_signal = batch['target_signal']
else:
input_signal, input_length, target_signal, _ = batch

# Expand channel dimension, if necessary
# For consistency, the model uses multi-channel format, even if the channel dimension is 1
Expand Down Expand Up @@ -426,7 +440,14 @@ def training_step(self, batch, batch_idx):
return loss

def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):
input_signal, input_length, target_signal, target_length = batch

if isinstance(batch, dict):
# lhotse batches are dictionaries
input_signal = batch['input_signal']
input_length = batch['input_length']
target_signal = batch['target_signal']
else:
input_signal, input_length, target_signal, _ = batch

# Expand channel dimension, if necessary
# For consistency, the model uses multi-channel format, even if the channel dimension is 1
Expand Down
43 changes: 41 additions & 2 deletions nemo/collections/common/data/lhotse/cutset.py
Expand Up @@ -14,13 +14,17 @@

import logging
import warnings
from functools import partial
from itertools import repeat
from pathlib import Path
from typing import Sequence, Tuple

from lhotse import CutSet
from lhotse import CutSet, Features, Recording
from lhotse.array import Array, TemporalArray
from lhotse.cut import Cut, MixedCut, PaddingCut

from nemo.collections.common.data.lhotse.nemo_adapters import LazyNeMoIterator, LazyNeMoTarredIterator
from nemo.collections.common.parts.preprocessing.manifest import get_full_path


def read_cutset_from_config(config) -> Tuple[CutSet, bool]:
Expand Down Expand Up @@ -98,10 +102,45 @@ def read_lhotse_manifest(config, is_tarred: bool) -> CutSet:
cuts = mux(*cutsets, weights=weights, max_open_streams=config.max_open_streams, seed=config.shard_seed)
else:
# Regular Lhotse manifest points to individual audio files (like native NeMo manifest).
cuts = CutSet.from_file(config.cuts_path)
path = config.cuts_path
cuts = CutSet.from_file(path).map(partial(resolve_relative_paths, manifest_path=path))
return cuts


def resolve_relative_paths(cut: Cut, manifest_path: str) -> Cut:
if isinstance(cut, PaddingCut):
return cut

if isinstance(cut, MixedCut):
for track in cut.tracks:
track.cut = resolve_relative_paths(track.cut, manifest_path)
return cut

def resolve_recording(value):
for audio_source in value.sources:
if audio_source.type == "file":
audio_source.source = get_full_path(audio_source.source, manifest_file=manifest_path)

def resolve_array(value):
if isinstance(value, TemporalArray):
value.array.storage_path = get_full_path(value.array.storage_path)
else:
value.storage_path = get_full_path(value.storage_path)
anteju marked this conversation as resolved.
Show resolved Hide resolved

if cut.has_recording:
resolve_recording(cut.recording)
if cut.has_features:
resolve_array(cut.features)
if cut.custom is not None:
for key, value in cut.custom.items():
if isinstance(value, Recording):
resolve_recording(value)
elif isinstance(value, (Array, TemporalArray, Features)):
resolve_array(value)

return cut


def read_nemo_manifest(config, is_tarred: bool) -> CutSet:
common_kwargs = {
"text_field": config.text_field,
Expand Down