# Training notebook for Huggingface XLSR-Wav2Vec2 Model
## (https://huggingface.co/transformers/model_doc/wav2vec2.html)
## Steps
1. Data preprocessing and preparation
2. Dataset class and dataloader
3. Model preparation
4. training Loop
5. model evaluation with and without language model
### Note old script based on (https://huggingface.co/blog/fine-tune-xlsr-wav2vec2)

# Step 1 Dataprocessing
- Data is assumed to be in csv format with columns (path, sentence)
- we have from different sources (listed below)
- some preprocessing has already been done for data from different sources

In [1]:
from datasets import load_dataset, load_metric
import pandas as pd

#Set all sources of data
commonvoice = "data/commonvoice/train.csv"
singlespeaker = "data/singlespeaker/train.csv"
speechcollector = "data/speechcollector/train.csv"
voxpopuli = "data/fi/train.csv"
#eduskunta_1 = "data/eduskunnanpuheet/uudetpuheet/dev-eval/train.csv"
#eduskunta_2 = "data/eduskunnanpuheet/uudetpuheet/2008-2016set/train.csv"

test1 = "data/commonvoice/test.csv"
test2 = "data/eduskunnanpuheet/uudetpuheet/dev-eval/test.csv"

train_df = pd.concat([pd.read_csv(commonvoice), pd.read_csv(singlespeaker), pd.read_csv(speechcollector), pd.read_csv(voxpopuli)])
test_df = pd.concat([pd.read_csv(test1), pd.read_csv(test2)])
test_small_df = pd.read_csv(test1)

print(f"Training set contains {len(train_df)} Samples")
print(f"test set contains {len(test_df)} Samples")
train_df.head()

Training set contains 15920 Samples
test set contains 1976 Samples


Unnamed: 0,path,sentence
0,/home/sampo/.cache/huggingface/datasets/downlo...,Mitä nyt tekisimme?
1,/home/sampo/.cache/huggingface/datasets/downlo...,Äänestämme tämän vuoksi toisin kuin maataloude...
2,/home/sampo/.cache/huggingface/datasets/downlo...,"Rupeatko remmiin, vai et?"
3,/home/sampo/.cache/huggingface/datasets/downlo...,Äänestin näin ollen mietinnön puolesta.
4,/home/sampo/.cache/huggingface/datasets/downlo...,"Kiitos, että tulitte ja opetitte meille viisau..."


# Remove these specific characters and lower case transcriptions

In [2]:
import random
import pandas as pd
from IPython.display import display, HTML
import re
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\...\…\–\é]'

def custom_remove_special_characters(sent):
    sent = re.sub(chars_to_ignore_regex, '', sent).lower() + " "
    return sent

train_df['sentence'] = train_df['sentence'].apply(custom_remove_special_characters)
test_df['sentence'] = test_df['sentence'].apply(custom_remove_special_characters)
test_small_df['sentence'] = test_small_df['sentence'].apply(custom_remove_special_characters)

train_df.head()

Unnamed: 0,path,sentence
0,/home/sampo/.cache/huggingface/datasets/downlo...,mitä nyt tekisimme
1,/home/sampo/.cache/huggingface/datasets/downlo...,äänestämme tämän vuoksi toisin kuin maataloude...
2,/home/sampo/.cache/huggingface/datasets/downlo...,rupeatko remmiin vai et
3,/home/sampo/.cache/huggingface/datasets/downlo...,äänestin näin ollen mietinnön puolesta
4,/home/sampo/.cache/huggingface/datasets/downlo...,kiitos että tulitte ja opetitte meille viisaut...


# Create vocabulary of characters in the dataset
- (if there are characters you dont want revise the regex in the previous step)

In [3]:
import itertools

def get_chars(df):
    return set(itertools.chain(*[list(x) for x in df['sentence'].values]))

vocab_list = list(get_chars(train_df).union(get_chars(test_df)))
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
print(vocab_dict)

{'ä': 0, ' ': 1, 'u': 2, 'v': 3, 'c': 4, 'm': 5, 'p': 6, 'r': 7, 'e': 8, 'k': 9, 'z': 10, 'a': 11, 'b': 12, 't': 13, 'q': 14, 'd': 15, 'w': 16, 's': 17, 'g': 18, 'x': 19, 'l': 20, 'i': 21, 'ö': 22, 'f': 23, 'j': 24, 'å': 25, 'n': 26, 'o': 27, 'y': 28, 'h': 29}


# Add special tokens into the vocab and save

In [4]:
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)

#for key in vocab_dict.keys():
#    if key != "[PAD]":
#        vocab_dict[key] +=1

import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

# Create Hugginface Processor from vocab
- Notice that voxpopuli model assumes clips are sampled at 16000Hz
- used for preprocess, encode and decode inputs

In [5]:
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor

tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

# PyTorch Dataset class and Dataloader
- default mode loads and resamples audio files on the fly to save RAM
- loading on the fly does not slow training much
- training samples are sorted according to transcription length to reduce infinities on CTC Loss
- If you dont sort the samples remember to change model flag ctc_zero_infinity to True
- collate function handles padding and batching
- audio files are very memory intensive peak VRAM comsumtion with batch_size = 4 is 18GB

In [11]:
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchaudio
import librosa

def resample(audio, source_sr, target_sr = 16000):
    audio = librosa.resample(np.asarray(audio), source_sr, target_sr)
    return audio


class CTCDataset(Dataset):
    """
    Dataset class used for Speech recognition with ctc loss
    enables precomputing data as arrays or transforming on the fly
    if dataset does not fit into ram
    """
    def __init__(self, dataframe, processor, mode="otf"):
        
        self.data = dataframe
        self.data.sort_values(by="sentence", key=lambda x: x.str.len(), inplace=True, ascending=False)
        self.processor = processor
        self.mode = mode
        if mode!="otf":
            raise NotImplemented
    
    def _processaudio(self, path):
        data, sr = torchaudio.load(path)
        data = data[0].numpy()
        data = resample(data, sr, 16000)
        
        return data
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):     
        if self.mode == 'otf':
            sent = self.data.iloc[idx, 1]
            data = self._processaudio(self.data.iloc[idx, 0])
            return data, sent
        
    def _precompute(self):
        pass
    
    def reorder_df(self):
        pass
        
    
def collate_fn_otf_train(batch):
    """
    collate function used for training and loading audio data on the fly
    """
    
    lists = list(zip(*batch))
    inputs = processor(lists[0], sampling_rate=16_000, return_tensors="pt", padding=True, pad_to_multiple_of=8)
    with processor.as_target_processor():
        labels = processor(lists[1], padding=True, return_tensors="pt", pad_to_multiple_of=8).input_ids
    return inputs.input_values, inputs.attention_mask, labels

def collate_fn_otf(batch):
    """
    collate function used for training and loading audio data on the fly
    """
    
    lists = list(zip(*batch))
    inputs = processor(lists[0], sampling_rate=16_000, return_tensors="pt", padding=True)
    with processor.as_target_processor():
        labels = processor(lists[1], padding=True, return_tensors="pt").input_ids
    return inputs.input_values, inputs.attention_mask, labels



trainset = CTCDataset(train_df, processor)
testset = CTCDataset(test_df, processor)
testset_small = CTCDataset(test_small_df, processor)

trainloader = DataLoader(trainset, batch_size = 4, collate_fn = collate_fn_otf_train, num_workers=8)
testloader = DataLoader(testset, batch_size=1, collate_fn = collate_fn_otf, num_workers=4)
testloader_small = DataLoader(testset_small, batch_size=1, collate_fn = collate_fn_otf, num_workers=4)

# Load pretrained model from huggingface
- currently using voxpopuli (https://github.com/facebookresearch/voxpopuli)

In [7]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-100k-voxpopuli",
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.1,
    gradient_checkpointing=True,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
    ctc_zero_infinity=False
)

#Freeze the weights of the pretrained feature extractor
model.freeze_feature_extractor()

Some weights of the model checkpoint at facebook/wav2vec2-large-100k-voxpopuli were not used when initializing Wav2Vec2ForCTC: ['quantizer.codevectors', 'quantizer.weight_proj.weight', 'quantizer.weight_proj.bias', 'project_q.weight', 'project_q.bias', 'project_hid.weight', 'project_hid.bias']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-100k-voxpopuli and are newly initialized: ['lm_head.weight', 'lm_head.bias']
You should probably TRAIN this model on a down-stream task t

# see documentation of model to set parameters

In [8]:
help(Wav2Vec2ForCTC)

Help on class Wav2Vec2ForCTC in module transformers.models.wav2vec2.modeling_wav2vec2:

class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel)
 |  Wav2Vec2ForCTC(config)
 |  
 |  Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). 
 |  Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
 |  <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
 |  
 |  This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
 |  methods the library implements for all its model (such as downloading or saving etc.).
 |  
 |  This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use
 |  it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
 |  behavior.
 |  
 |  Parameters:
 |      config

# Define utility functions for training

In [9]:
from datasets import load_metric
from tqdm.notebook import tqdm


def decode_output(logits):
    pred_ids = torch.argmax(logits, dim=-1)
    pred = processor.batch_decode(pred_ids)
    return pred

@torch.no_grad()
def evaluation_func(model, dataloader, ref_sentences ,use_amp, device="cuda"):
    """
    return tuple (loss, wer)
    """
    wer = load_metric("wer")
    
    model.eval()
    preds_amp = []
    preds = []
    losses_amp = []
    losses = []
        
    for batch in tqdm(dataloader):
        inputs, masks, labels = batch
        
        output = model(inputs.to(device), masks.to(device), labels = labels.to(device))
        loss = output.loss.item()
        
        with torch.cuda.amp.autocast(enabled=use_amp):
            output_amp = model(inputs.to(device), masks.to(device), labels = labels.to(device))
            loss_amp = output_amp.loss.item()
            
        losses.append(loss)
        losses_amp.append(loss_amp)
        
        logits_amp = output_amp.logits
        logits = output.logits
        
        pred = decode_output(logits)
        pred_amp = decode_output(logits_amp)
        
        preds.extend(pred)
        preds_amp.extend(pred_amp)
        
    return sum(losses)/len(losses), sum(losses_amp)/len(losses_amp),wer.compute(predictions=preds, references=ref_sentences) ,wer.compute(predictions=preds_amp, references=ref_sentences) 
    
def checkpoint_func(model, save_dir):
    model.save_pretrained(save_directory=save_dir)
    return

def calculate_wer(preds, references):
    wer = load_metric("wer")
    return 100*wer.compute(predictions=preds, references=references)

def calculate_cer(preds, references):
    cer = load_metric("cer")
    return 100*cer.compute(predictions=preds, references=references)
    
#evaluation_func(model, testloader, testset.data.sentence)
#checkpoint_func(model, "testi/")

# set up parameters and training loop
- losses, step_size and WER logged into tensorboard
- note WER evaluation wer is obtained without language model
- one epoch takes considerably less than tqdm estimates at first because samples are sorted from longest to shortest
# Parameters to set
- use_amp (use mixed precision training)
- num_epochs
- lr (initial learning rate)
- step_interval (how many batches accumulated before gradient step)
- eval_interval (how often is model evaluated)
- save_dir (directoriy to save model)
- device (use cuda or no)

In [9]:
import transformers
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

save_dir = "test_run1"
device = "cuda"
losses = []
training_losses = []

model.to(device)

use_amp  = True
num_epochs = 40
lr = 0.00025
step_interval = 2
eval_interval = (len(trainloader)-1)/4.0
#eval_interval = 10
steps = 0

#setup optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
#scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=200, num_training_steps=len(trainloader)*num_epochs/step_interval)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 4, gamma=0.5)


scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

best_wer = 1.0
eval_losses = []
eval_wers = []
eval_step = 0

print("starting training loop")
for epoch in range(num_epochs):
    print(f"starting epoch: {epoch+1}")
    model.train()
    losses = []
    i = 0
    for batch in tqdm(trainloader):
        
        inputs, masks, labels = batch
        with torch.cuda.amp.autocast(enabled=use_amp):
            output = model(inputs.to(device), masks.to(device), labels=labels.to(device))
            loss = output.loss/step_interval
        
        scaler.scale(loss).backward()
        #accumulate gradients for step_interval batches
        if (i+1)%step_interval == 0:
            #optimizer.step()
            scaler.step(optimizer)
            scaler.update()
            #scheduler.step()
            optimizer.zero_grad()
            steps+=1
        losses.append(output.loss.item())
        #if i%30==0:
        #    print(output.loss.item())
        #evaluate model and save best WER
        if (i+1)%eval_interval == 0:
            eval_loss,eval_loss_amp, wer, wer_amp = evaluation_func(model, testloader, testset.data.sentence, use_amp, device)
            
            writer.add_scalar('eval/loss', eval_loss, eval_step)
            writer.add_scalar('eval/wer', wer, eval_step)
            writer.add_scalar('eval/loss_amp', eval_loss_amp, eval_step)
            writer.add_scalar('eval/wer_amp', wer_amp, eval_step)
            writer.add_scalar('lr', scheduler.get_last_lr()[0], eval_step)
            eval_losses.append(eval_loss)
            eval_wers.append(wer)
            eval_step +=1
            if wer < best_wer:
                #save model with best test WER
                checkpoint_func(model, save_dir)
                best_wer = wer
            else:
                scheduler.step()
            model.train()
            
        i+=1
    #end of epoch
    epoch_loss = sum(losses)/len(losses)
    training_losses.append(epoch_loss)
    writer.add_scalar('train/epoch_loss', epoch_loss, epoch)
#check for final improvements
eval_loss,eval_loss_amp, wer, wer_amp = evaluation_func(model, testloader, testset.data.sentence, use_amp, device)
if wer < best_wer:
    best_wer = wer
    checkpoint_func(model, save_dir)
print(f"training finished, best WER: {best_wer}")
writer.close()

starting training loop
starting epoch: 1


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 2


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 3


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 4


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 5


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 6


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 7


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 8


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 9


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 10


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 11


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 12


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 13


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 14


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 15


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 16


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 17


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 18


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 19


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 20


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 21


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 22


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 23


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 24


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 25


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 26


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 27


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 28


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 29


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 30


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 31


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 32


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 33


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 34


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 35


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 36


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 37


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 38


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 39


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))



starting epoch: 40


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))





HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))


training finished, best WER: 0.2428591480103867


# Evaluation on models
- SpeechRecognizer class takes directory of model files as input 
- evaluated with greedy decoding and with CTCBeamsearch using a Ngram language model

In [7]:
from SpeechRecognizer import SpeechRecognizer, CTCDecoder

recognizer = SpeechRecognizer("best_model/")
labels, blank = recognizer.get_labels()
lm_path = "best_model/model2.bin"
decoder = CTCDecoder(labels, lm_path=lm_path, alpha=1.5, beta=0.8, blank_id=blank, beam_width=256, cutoff_top_n=15)

Initializing Decoder
Decoder ready


In [12]:
wer = load_metric("wer")
cer = load_metric("cer")
device = "cuda"

preds = []
LMpreds = []

with torch.no_grad():
    for batch in tqdm(testloader):
        inputs, masks, labels = batch
        logits = recognizer.model(inputs.to(device), masks.to(device), ).logits
        
        #calculate no lm preds
        pred_ids = torch.argmax(logits, dim=-1)
        pred = recognizer.processor.batch_decode(pred_ids)
        preds.append(pred[0])
        
        probs = logits.softmax(dim=2).cpu()
        text = decoder.decode(probs)
        LMpreds.append(text)
    
    lmwer = calculate_wer(LMpreds, testset.data.sentence)
    wer_c = calculate_wer(preds, testset.data.sentence)
    lmcer = calculate_cer(LMpreds, testset.data.sentence)
    cer_c = calculate_cer(preds, testset.data.sentence)
    print("stats for full testset")
    print(f"lm stats:{lmwer} {lmcer},      no-lm stats:{wer_c} {cer_c}")
    
    #calculate stats for small test set
    LMpreds = []
    preds = []
    for batch in tqdm(testloader_small):
        inputs, masks, labels = batch
        logits = recognizer.model(inputs.to(device), masks.to(device)).logits

        #calculate no lm preds
        pred_ids = torch.argmax(logits, dim=-1)
        pred = recognizer.processor.batch_decode(pred_ids)
        preds.append(pred[0])

        probs = logits.softmax(dim=2).cpu()
        text = decoder.decode(probs)
        LMpreds.append(text)
    
    lmwer = calculate_wer(LMpreds, testset_small.data.sentence)
    wer_c = calculate_wer(preds, testset_small.data.sentence)
    lmcer = calculate_cer(LMpreds, testset_small.data.sentence)
    cer_c = calculate_cer(preds, testset_small.data.sentence)
    print("stats for small testset")
    print(f"lm stats:{lmwer} {lmcer},      no-lm stats:{wer_c} {cer_c}")

Couldn't find file locally at cer/cer.py, or remotely at https://raw.githubusercontent.com/huggingface/datasets/1.5.0/metrics/cer/cer.py.
The file was picked from the master branch on github instead at https://raw.githubusercontent.com/huggingface/datasets/master/metrics/cer/cer.py.


HBox(children=(FloatProgress(value=0.0, max=1976.0), HTML(value='')))




Couldn't find file locally at cer/cer.py, or remotely at https://raw.githubusercontent.com/huggingface/datasets/1.5.0/metrics/cer/cer.py.
The file was picked from the master branch on github instead at https://raw.githubusercontent.com/huggingface/datasets/master/metrics/cer/cer.py.
Couldn't find file locally at cer/cer.py, or remotely at https://raw.githubusercontent.com/huggingface/datasets/1.5.0/metrics/cer/cer.py.
The file was picked from the master branch on github instead at https://raw.githubusercontent.com/huggingface/datasets/master/metrics/cer/cer.py.


stats for full testset
lm stats:17.752122955996914 6.321402692746593,      no-lm stats:22.27524738578146 5.96821292049759


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




Couldn't find file locally at cer/cer.py, or remotely at https://raw.githubusercontent.com/huggingface/datasets/1.5.0/metrics/cer/cer.py.
The file was picked from the master branch on github instead at https://raw.githubusercontent.com/huggingface/datasets/master/metrics/cer/cer.py.
Couldn't find file locally at cer/cer.py, or remotely at https://raw.githubusercontent.com/huggingface/datasets/1.5.0/metrics/cer/cer.py.
The file was picked from the master branch on github instead at https://raw.githubusercontent.com/huggingface/datasets/master/metrics/cer/cer.py.


stats for small testset
lm stats:8.896525391370751 1.9373555588336528,      no-lm stats:15.54028255059183 2.6995132025372475
