In [None]:
%%capture
!pip install transformers py3nvml jiwer datasets torchaudio tqdm 

In [3]:
from transformers import Wav2Vec2ForCTC,Wav2Vec2Processor

In [4]:
model = Wav2Vec2ForCTC.from_pretrained('OthmaneJ/distil-wav2vec2')
processor = Wav2Vec2Processor.from_pretrained('OthmaneJ/distil-wav2vec2')


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1592.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=354965335.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=215.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=291.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=505.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=85.0, style=ProgressStyle(description_w…




## Preprocessing pipeline

In [5]:
import torch
import torchaudio
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torch.nn.utils.rnn import pad_sequence

from tqdm import tqdm
from jiwer import wer
import numpy as np

def postprocess_features(feats, sample_rate):
    if feats.dim() == 2: feats = feats.mean(-1)
    assert feats.dim() == 1, feats.dim()
    with torch.no_grad():
        feats = F.layer_norm(feats, feats.shape)
    return feats

def get_feature(batch_sample):
    return postprocess_features(batch_sample[0][0], batch_sample[1])

def get_padding_mask(batch_sample):
    return torch.BoolTensor(batch_sample[0].size(1)).fill_(False)

def get_batch_encoder_input(batch_samples):
    ground_truth = [batch_sample[2] for batch_sample in batch_samples]
    features = [get_feature(batch_sample) for batch_sample in batch_samples]
    features = torch.nn.utils.rnn.pad_sequence(features, batch_first=True, padding_value=0)
    padding_masks = [get_padding_mask(batch_sample) for batch_sample in batch_samples]
    padding_masks = torch.nn.utils.rnn.pad_sequence(padding_masks, batch_first=True, padding_value=True)
    mask = False
    features_only = True
    return features, padding_masks, mask, features_only, ground_truth

## Test dataloader

In [17]:
val_data_path ="/content/sample_data"
test_dataset = torchaudio.datasets.LIBRISPEECH(val_data_path, url='test-clean', download=True)
val_batch_size = 10
num_workers = 2
test_dataloader = DataLoader(test_dataset,
                            batch_size = val_batch_size,
                            shuffle = False,
                            num_workers = num_workers,
                            collate_fn = get_batch_encoder_input,
                                )

HBox(children=(FloatProgress(value=0.0, max=346663984.0), HTML(value='')))




# Beam Search decoder

In [None]:
%cd /content 
!git clone --recursive https://github.com/parlance/ctcdecode.git  
!pip install /content/ctcdecode #takes about 5 minutes on google colab

In [None]:
%%capture
!wget https://www.openslr.org/resources/11/3-gram.pruned.3e-7.arpa.gz
!gunzip '/content/3-gram.pruned.3e-7.arpa.gz'

In [10]:
vocab_dict = processor.tokenizer.get_vocab()
sort_vocab = sorted((value, key) for (key,value) in vocab_dict.items())
vocab = [x[1].replace("|", " ") for x in sort_vocab]

In [11]:
from ctcdecode import CTCBeamDecoder


alpha = 1 # LM Weight
beta = 2 # LM Usage Reward , very important - never set to 0

decoder = CTCBeamDecoder(
    vocab,
    model_path='/content/3-gram.pruned.3e-7.arpa',
    alpha=alpha,
    beta=beta,
    cutoff_top_n=20,
    cutoff_prob=1.0,
    beam_width=64,
    num_processes=2,
    blank_id=0,
    log_probs_input=True,
)

In [15]:
from datasets import load_metric
wer_metric = load_metric('wer')

def inference_pipeline_lm(model, data_loader,decoder = None):

    predictions, ground_truths, wer_list = [], [], []
    itr = 0
    for batch in tqdm(data_loader):
        itr+=1

        with torch.no_grad():
            # tokenize
            input_values = processor.feature_extractor(batch[0], return_tensors="pt", padding="longest",sampling_rate=16000).input_values  
            input_values = input_values.cuda()

            
            # retrieve logits
            logits = model(input_values[0]).logits

            # take argmax and decode
            logits = logits.cpu()
            # return logits
             
            if decoder : 
                # beam search with n-gram language model 
                beam_results, beam_scores, timesteps, out_lens = decoder.decode(logits,)
                beam_results = beam_results[:,0,:] # taking only top beam
                beam_results = beam_results.clip(0,32)
                beam_results[beam_results==32] = 0
                transcription = processor.batch_decode(beam_results,
                                                        skip_special_tokens=True,
                                                        clean_up_tokenization_spaces=False,
                                                        group_tokens=False,
                                                        )
            else :
                # greedy search 
                predicted_ids = torch.argmax(logits, dim=-1)
                transcription = processor.tokenizer.batch_decode(predicted_ids)


            wer_validation = wer_metric.compute(predictions=transcription, references=batch[-1])
            wer_list.append(wer_validation)
        
    return wer_list 


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1947.0, style=ProgressStyle(description…




In [18]:
model.cuda()
model.eval()
wer_score = inference_pipeline_lm(model,test_dataloader,decoder)
print("WER: ",sum(wer_score)/len(wer_score))

  2%|▏         | 4/262 [00:03<03:15,  1.32it/s]

WER:  0.10673644537280902



