Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I have some questions about RNNT loss. #3750

Open
girlsending0 opened this issue Feb 26, 2024 · 6 comments
Open

I have some questions about RNNT loss. #3750

girlsending0 opened this issue Feb 26, 2024 · 6 comments

Comments

@girlsending0
Copy link

hello
I would like to ask you a question that may be somewhat trivial.
The shape of logits of RNN T loss is Batch, max_seq_len, max_target_len+1, class.
Why is max_target_len+1 here?
Shouldn't the number of classes be +1 to the size of the total vocab? Because blank is included.
I don't understand at all.
Is there anyone who can help?

https://pytorch.org/audio/main/generated/torchaudio.functional.rnnt_loss.html

@csukuangfj
Copy link
Collaborator

max_target_len+1 is not the vocab size. They are two different things.

You can find my implementation at
https://github.com/csukuangfj/optimized_transducer/blob/master/optimized_transducer/csrc/cpu.cc#L83

@girlsending0
Copy link
Author

@csukuangfj Thank you.

I said that in a misleading way.

What I'm curious about is why target_length +1 needs to be entered as the RNNT loss's 3rd input. Looking at your code, I noticed that you wrote target length+1 because it includes a blank label.

Isn't the blank input already included in n_class? (When setting n_class, I think len(vocab)+1 should be set. Similar to CTC loss.)

I don't quite understand

@csukuangfj
Copy link
Collaborator

You need to differentiate between target length and number of classes.

The transcript of an utterance is converted to tokens. The target length is the number of tokens of the transcript. It is not number of classes. The possible value of a token is in the range [1, num_of_classes-1].

@girlsending0
Copy link
Author

So the number of classes should be len(vocab)?
I understand.
I had misunderstood the mechanism of RNN-Transducer.
Since model will start from a blank label, it should be target_length+1.

@csukuangfj
Copy link
Collaborator

Great to hear it resolves your issue.

@girlsending0
Copy link
Author

@csukuangfj
Thank you for your kindness.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants