In this notebook, we want to use a pre-trained ASR model form huggingface to get transcripts of each chunks that are segmented and stored before.

In [None]:
!pip install git+https://github.com/huggingface/transformers.git
!pip install torchaudio

In [None]:
import json
import torch
import numpy as np
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torchaudio
import torchaudio.transforms as transforms
import os

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = Wav2Vec2Processor.from_pretrained("m3hrdadfi/wav2vec2-large-xlsr-persian")
model = Wav2Vec2ForCTC.from_pretrained("m3hrdadfi/wav2vec2-large-xlsr-persian").eval().to(device)

In [None]:
prefix_url = 'https://traffic.libsyn.com/secure/radiomarz/'

with open('./persian_dataset_chunks.json', 'r') as inFile:
  dataset = json.load(inFile)

utterances =  list(dataset.keys())

for utterance in utterances:

    if not os.path.exists(utterance):
        !wget '{prefix_url}{utterance}'

    waveform, sr = torchaudio.load(utterance)

    sample_rate = 16000
    resample_transform = transforms.Resample(sr,sample_rate)
    waveform = resample_transform(waveform).squeeze().numpy()
    
    for i, chunk in enumerate(dataset[utterance]):
        
        start_time = chunk["start_time"]
        end_time = chunk["end_time"]
        

        #Skip if transcript already obtained
        if 'transcription' in chunk.keys():
            continue 
            
        start_sample = int(start_time * sample_rate)
        end_sample = int(end_time * sample_rate)
        audio_segment = waveform[start_sample:end_sample]

        input_values = processor(audio_segment, sampling_rate=sample_rate, return_tensors="pt").input_values
        input_values = input_values.to(device)

        with torch.no_grad():
            logits = model(input_values).logits

        predicted_ids = np.argmax(logits.cpu().detach().numpy(), axis=-1)
        transcription = processor.decode(predicted_ids[0])
        
        dataset[utterance][i]["transcription"] = transcription


In [None]:
with open("persian_dataset.json", "w", encoding='utf-8') as inFile:
  json.dump(dataset,inFile,indent=1,ensure_ascii=False)