Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dataloader as input to audio for transcription #9201

Merged
merged 5 commits into from
May 17, 2024

Conversation

titu1994
Copy link
Collaborator

What does this PR do ?

Enables the use of a pre-constructed data loader as input to the model.transcribe() function.
This allows for a fastpath to ignore all manifest and tensor handling to the user, only executing the model forward and later steps.

Collection: [ASR]

Changelog

  • Allows the user to provide a DataLoader object, which overrides internal computation of manifest processing or dataset construction
  • Assumes implicit faith in user provided input - the user is now responsible for formatting and providing all arguments to match up with the ASR model's forward arguments if user chooses to provide a dataloader.

Usage

from nemo.collections.asr.data.audio_to_text import _speech_collate_fn

model = ASRModel.from_pretrained("stt_en_conformer_ctc_small")

# Load audio file
import soundfile as SF

audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav")
audio, sr = sf.read(audio_file, dtype='float32')

audio_file2 = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an152-mwhw-b.wav")
audio2, sr = sf.read(audio_file2, dtype='float32')

# Create a dummy dataset to hold the tensor values
class DummyDataset(Dataset):
    def __init__(self, audio_tensors: List[str], config: Dict = None):
        self.audio_tensors = audio_tensors
        self.config = config

    def __getitem__(self, index):
        data = self.audio_tensors[index]
        samples = torch.tensor(data)
        # Calculate seq length
        seq_len = torch.tensor(samples.shape[0], dtype=torch.long)

        # Dummy text tokens
        text_tokens = torch.tensor([0], dtype=torch.long)
        text_tokens_len = torch.tensor(1, dtype=torch.long)

        # Ensure to provide output tokens that can be consumed by an ASR's forward function
        return (samples, seq_len, text_tokens, text_tokens_len)

    def __len__(self):
        return len(self.audio_tensors)

# Wrap the dataset into a data loader with proper collate function
dataset = DummyDataset([audio, audio2])
collate_fn = lambda x: _speech_collate_fn(x, pad_id=0)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=0, collate_fn=collate_fn)

# DataLoader as input to audio
outputs = model.transcribe(dataloader, batch_size=1)

assert len(outputs) == 2
assert isinstance(outputs[0], str)
assert isinstance(outputs[1], str)

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

@github-actions github-actions bot added the ASR label May 15, 2024
@titu1994 titu1994 requested review from galv and pzelasko May 15, 2024 07:04
@titu1994
Copy link
Collaborator Author

Fyi @nithinraok - this should enable fastpath execution for HF too

@titu1994
Copy link
Collaborator Author

@galv @pzelasko actual change is just these 5 lines, everything else is just black formatter and a unittest at the bottom - https://github.com/NVIDIA/NeMo/pull/9201/files#diff-04e10f8fb8f7afddb360c7ea0ff9c831613c754085d2fd06961bb80a9f932a25R372-R376

audio2, sr = sf.read(audio_file2, dtype='float32')

dataset = DummyDataset([audio, audio2])
collate_fn = lambda x: _speech_collate_fn(x, pad_id=0)

Check notice

Code scanning / CodeQL

Returning tuples with varying lengths Note

TestTranscriptionMixin.test_transcribe_dataloader.lambda returns
tuple of size 4
and
tuple of size 5
.
pzelasko
pzelasko previously approved these changes May 15, 2024
Copy link
Collaborator

@pzelasko pzelasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@titu1994 @nithinraok @galv I'd like to point out that if you're going to use a TensorDataset, make it an IterableDataset that yields collated mini-batches, rather than relying on DataLoader's batch_size and collate_fn to do it for you. This way you'll amortize the overhead of collation as it will happen in the background process. This change will matter for super-high RTFx models.

Copy link
Collaborator

@nithinraok nithinraok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor comments. For huggingface, we need to depend on Iterable dataset I believe.

@@ -403,7 +403,8 @@ def transcribe(
"""
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
Args:
audio: (a list) of paths to audio files. \
audio: (a single or list) of paths to audio files or a np.ndarray audio array.
Can also be a dataloader object that provides values that can be consumed by the model.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add torch.dataloader to audio types in the func signature.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -364,7 +365,8 @@ def transcribe(
Generate class labels for provided audio files. Use this method for debugging and prototyping.

Args:
audio: (a single or list) of paths to audio files or a np.ndarray audio sample. \
audio: (a single or list) of paths to audio files or a np.ndarray audio array.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add torch.dataloader to audio types in the func signature.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@galv
Copy link
Collaborator

galv commented May 15, 2024

@nithinraok can you investigate to see whether this change will fix your performance degredation issues when using transcribe() within the huggingface open asr leaderboard RTFx measurements?

@nithinraok
Copy link
Collaborator

@galv yeah, I am thinking to use iterable dataset (looks like streaming is the option for huggingface datasets to get iterable dataset) and run the HF evals.

titu1994 and others added 4 commits May 15, 2024 15:26
Signed-off-by: smajumdar <titu1994@gmail.com>
Signed-off-by: titu1994 <titu1994@users.noreply.github.com>
Signed-off-by: smajumdar <titu1994@gmail.com>
Signed-off-by: smajumdar <titu1994@gmail.com>
Signed-off-by: titu1994 <titu1994@users.noreply.github.com>
@titu1994 titu1994 merged commit 67401ed into main May 17, 2024
133 checks passed
@titu1994 titu1994 deleted the transcribe_dataloader branch May 17, 2024 16:18
titu1994 added a commit that referenced this pull request May 17, 2024
* Support dataloader as input to `audio` for transcription

Signed-off-by: smajumdar <titu1994@gmail.com>

* Apply isort and black reformatting

Signed-off-by: titu1994 <titu1994@users.noreply.github.com>

* Support dataloader as input to `audio` for transcription

Signed-off-by: smajumdar <titu1994@gmail.com>

* Update transcribe signatures

Signed-off-by: smajumdar <titu1994@gmail.com>

* Apply isort and black reformatting

Signed-off-by: titu1994 <titu1994@users.noreply.github.com>

---------

Signed-off-by: smajumdar <titu1994@gmail.com>
Signed-off-by: titu1994 <titu1994@users.noreply.github.com>
(cherry picked from commit 67401ed)
titu1994 added a commit that referenced this pull request May 17, 2024
* Support dataloader as input to `audio` for transcription

Signed-off-by: smajumdar <titu1994@gmail.com>

* Apply isort and black reformatting

Signed-off-by: titu1994 <titu1994@users.noreply.github.com>

* Support dataloader as input to `audio` for transcription

Signed-off-by: smajumdar <titu1994@gmail.com>

* Update transcribe signatures

Signed-off-by: smajumdar <titu1994@gmail.com>

* Apply isort and black reformatting

Signed-off-by: titu1994 <titu1994@users.noreply.github.com>

---------

Signed-off-by: smajumdar <titu1994@gmail.com>
Signed-off-by: titu1994 <titu1994@users.noreply.github.com>
(cherry picked from commit 67401ed)
BoxiangW pushed a commit to BoxiangW/NeMo that referenced this pull request Jun 5, 2024
* Support dataloader as input to `audio` for transcription

Signed-off-by: smajumdar <titu1994@gmail.com>

* Apply isort and black reformatting

Signed-off-by: titu1994 <titu1994@users.noreply.github.com>

* Support dataloader as input to `audio` for transcription

Signed-off-by: smajumdar <titu1994@gmail.com>

* Update transcribe signatures

Signed-off-by: smajumdar <titu1994@gmail.com>

* Apply isort and black reformatting

Signed-off-by: titu1994 <titu1994@users.noreply.github.com>

---------

Signed-off-by: smajumdar <titu1994@gmail.com>
Signed-off-by: titu1994 <titu1994@users.noreply.github.com>
Signed-off-by: Boxiang Wang <boxiangw@nvidia.com>
janekl pushed a commit that referenced this pull request Jun 12, 2024
* Support dataloader as input to `audio` for transcription

Signed-off-by: smajumdar <titu1994@gmail.com>

* Apply isort and black reformatting

Signed-off-by: titu1994 <titu1994@users.noreply.github.com>

* Support dataloader as input to `audio` for transcription

Signed-off-by: smajumdar <titu1994@gmail.com>

* Update transcribe signatures

Signed-off-by: smajumdar <titu1994@gmail.com>

* Apply isort and black reformatting

Signed-off-by: titu1994 <titu1994@users.noreply.github.com>

---------

Signed-off-by: smajumdar <titu1994@gmail.com>
Signed-off-by: titu1994 <titu1994@users.noreply.github.com>
Signed-off-by: Jan Lasek <janek.lasek@gmail.com>
rohitrango pushed a commit to rohitrango/NeMo that referenced this pull request Jun 25, 2024
* Support dataloader as input to `audio` for transcription

Signed-off-by: smajumdar <titu1994@gmail.com>

* Apply isort and black reformatting

Signed-off-by: titu1994 <titu1994@users.noreply.github.com>

* Support dataloader as input to `audio` for transcription

Signed-off-by: smajumdar <titu1994@gmail.com>

* Update transcribe signatures

Signed-off-by: smajumdar <titu1994@gmail.com>

* Apply isort and black reformatting

Signed-off-by: titu1994 <titu1994@users.noreply.github.com>

---------

Signed-off-by: smajumdar <titu1994@gmail.com>
Signed-off-by: titu1994 <titu1994@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants