Skip to content

Commit

Permalink
Fix unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Żelasko <petezor@gmail.com>
  • Loading branch information
pzelasko committed Apr 15, 2024
1 parent ad4ad82 commit 319a441
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions nemo/collections/asr/data/audio_to_audio_lhotse.py
Expand Up @@ -88,7 +88,7 @@ def create_recording(path_or_paths: str | list[str]) -> Recording:
i.samplerate == infos[0].samplerate for i in infos[1:]
), f"Mismatched sampling rates for individual audio files in: {path_or_paths}"
recording = Recording(
id=p[0],
id=path_or_paths[0],
sources=sources,
sampling_rate=infos[0].samplerate,
num_samples=infos[0].frames,
Expand Down Expand Up @@ -136,9 +136,7 @@ def convert_manifest_nemo_to_lhotse(
recording = create_recording(get_full_path(audio_file=item_input_key, manifest_file=input_manifest))
cut = recording.to_cut().truncate(duration=item.pop("duration"), offset=item.pop("offset", 0.0))

if not force_absolute_paths:
# Use the same format for paths as in the original manifest
cut.recording.sources[0].source = item_input_key
_as_relative(cut.recording, item_input_key, enabled=not force_absolute_paths)

if (channels := item.pop(INPUT_CHANNEL_SELECTOR, None)) is not None:
if cut.num_channels == 1:
Expand All @@ -154,9 +152,7 @@ def convert_manifest_nemo_to_lhotse(
get_full_path(audio_file=item_target_key, manifest_file=input_manifest)
)

if not force_absolute_paths:
# Use the same format for paths as in the original manifest
cut.target_recording.sources[0].source = item_target_key
_as_relative(cut.target_recording, item_target_key, enabled=not force_absolute_paths)

if (channels := item.pop(TARGET_CHANNEL_SELECTOR, None)) is not None:
if cut.target_recording.num_channels == 1:
Expand All @@ -172,9 +168,7 @@ def convert_manifest_nemo_to_lhotse(
get_full_path(audio_file=item_reference_key, manifest_file=input_manifest)
)

if not force_absolute_paths:
# Use the same format for paths as in the original manifest
cut.reference_recording.sources[0].source = item_reference_key
_as_relative(cut.reference_recording, item_target_key, enabled=not force_absolute_paths)

if (channels := item.pop(REFERENCE_CHANNEL_SELECTOR, None)) is not None:
if cut.reference_recording.num_channels == 1:
Expand All @@ -192,10 +186,22 @@ def convert_manifest_nemo_to_lhotse(

if not force_absolute_paths:
# Use the same format for paths as in the original manifest
parent, path = os.path.split(item_embedding_key)
cut.embedding_vector.storage_path = parent
cut.embedding_vector.storage_path = ""
cut.embedding_vector.storage_key = item_embedding_key

if item:
cut.custom.update(item) # any field that's still left goes to custom fields

writer.write(cut)


def _as_relative(recording: Recording, paths: list[str] | str, enabled: bool) -> None:
if not enabled:
return
if isinstance(paths, str):
paths = [paths]
assert len(recording.sources) == len(
paths
), f"Mismatched number of sources for lhotse Recording and the override list. Got {recording=} and {paths=}"
for source, path in zip(recording.sources, paths):
source.source = path

0 comments on commit 319a441

Please sign in to comment.