In [None]:
from transformers import Wav2Vec2ForCTC, AutoProcessor
import torch
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
import time
import librosa

In [None]:
model_id = "facebook/mms-1b-all"

processor = AutoProcessor.from_pretrained(model_id)
model = Wav2Vec2ForCTC.from_pretrained(model_id)

In [None]:
model_id_lid = "facebook/mms-lid-126"

processor_lid = AutoFeatureExtractor.from_pretrained(model_id_lid)
model_lid = Wav2Vec2ForSequenceClassification.from_pretrained(model_id_lid)

In [None]:
def transcribe(audio):
    audio = librosa.load(audio, sr=16_000, mono=True)[0]
    inputs = processor(audio, sampling_rate=16_000,return_tensors="pt")
    with torch.no_grad():
        tr_start_time = time.time()
        outputs = model(**inputs).logits
        tr_end_time = time.time()
    ids = torch.argmax(outputs, dim=-1)[0]
    transcription = processor.decode(ids)
    return transcription,(tr_end_time-tr_start_time)

In [None]:
def detect_language(audio):
    audio = librosa.load(audio, sr=16_000, mono=True)[0]
    inputs_lid = processor_lid(audio, sampling_rate=16_000, return_tensors="pt")
    with torch.no_grad():
        start_time_lid = time.time()
        outputs_lid = model_lid(**inputs_lid).logits
        end_time = time.time()
#     print(end_time-start_time," sec")
    lang_id = torch.argmax(outputs_lid, dim=-1)[0].item()
    detected_lang = model_lid.config.id2label[lang_id]
    return detected_lang, (end_time_lid-start_time_lid)

In [None]:
def transcribe_lang(audio,lang):
    audio = librosa.load(audio, sr=16_000, mono=True)[0]
    processor.tokenizer.set_target_lang(lang)
    model.load_adapter(lang)
    print(lang)
    inputs = processor(audio, sampling_rate=16_000,return_tensors="pt")
    with torch.no_grad():
        tr_start_time = time.time()
        outputs = model(**inputs).logits
        tr_end_time = time.time()
    ids = torch.argmax(outputs, dim=-1)[0]
    transcription = processor.decode(ids)
    return transcription,(tr_end_time-tr_start_time)

In [None]:
import gradio as gr
from asr import transcribe,detect_language,transcribe_lang

demo = gr.Interface(transcribe,
                   gr.Audio(source="microphone", type="filepath", label="Use mic"),
                   outputs=["text","text"])
demo2 = gr.Interface(detect_language,
                   gr.Audio(source="microphone", type="filepath", label="Use mic"),
                   outputs=["text","text"])
demo3 = gr.Interface(transcribe_lang,
                   inputs=[gr.Audio(source="microphone", type="filepath", label="Use mic"),"text"],
                   outputs=["text","text"])

tabbed_interface = gr.TabbedInterface([demo,demo2,demo3],["Transcribe by auto detecting language","Detect language","Transcribe by providing language"])

with gr.Blocks() as asr:
    tabbed_interface.render()
asr.queue(concurrency_count=3)
asr.launch()