-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Availability of OWSM-CTC #5683
Comments
Thanks a lot. @pyf98, can you make a PR for OWSM CTC and add the information to our webpage? |
Hi @wgb14 , thanks for your interest in our work! As Shinji shared, I have prepared that webpage to collect OWSM related papers and models. However, we sometimes cannot update it immediately due to some anonymity considerations in certain venues. For OWSM-CTC, it will take some time to merge it into the master branch. But I do have the model and code public now:
An example script to run short-form ASR/ST: import soundfile as sf
import numpy as np
import librosa
import kaldiio
from espnet2.bin.s2t_inference_ctc import Speech2TextGreedySearch
s2t = Speech2TextGreedySearch.from_pretrained(
"pyf98/owsm_ctc_v3.1_1B",
device="cuda",
generate_interctc_outputs=False,
lang_sym='<eng>',
task_sym='<asr>',
)
speech, rate = sf.read(
"xxx.wav"
)
speech = librosa.util.fix_length(speech, size=(16000 * 30))
res = s2t(speech)[0]
print(res) An example script to run long-form ASR: import soundfile as sf
import torch
from espnet2.bin.s2t_inference_ctc import Speech2TextGreedySearch
if __name__ == "__main__":
context_len_in_secs = 4 # left and right context when doing buffered inference
batch_size = 32 # depends on the GPU memory
s2t = Speech2TextGreedySearch.from_pretrained(
"pyf98/owsm_ctc_v3.1_1B",
device='cuda' if torch.cuda.is_available() else 'cpu',
generate_interctc_outputs=False,
lang_sym='<eng>',
task_sym='<asr>',
)
speech, rate = sf.read(
"xxx.wav"
)
text = s2t.decode_long_batched_buffered(
speech,
batch_size=batch_size,
context_len_in_secs=context_len_in_secs,
frames_per_sec=12.5, # 80ms shift, model-dependent, don't change
)
print(text) |
Thanks for your prompt responses! these help a lot. |
Are there any fine-tuned examples or scripts, or are there plans to release any in the future? |
Hi @teinhonglo , thanks for your question! The fine-tuning would be similar to the normal setup in ESPnet (if you are familiar with ESPnet). |
Hi espnet team,
Thank you for your amazing work on OWSM, this greatly helps the open-source community. Truly grateful for your efforts.
I assume this repo is a proper place to discuss OWSM related stuff, and really would like to know if you are planning to release OWSM-CTC recipes and models here as well.
By the way, I'm curious about where I can keep myself updated on OWSM. For now, I'm keeping an eye on issues and PRs in this repo, and papers from your lab page. I even have to get the latest OWSM-CTC paper from google scholar.
thanks in advance
The text was updated successfully, but these errors were encountered: