<a href="https://colab.research.google.com/github/BhagatSurya/wav2vec_librispeech_fine_tuned/blob/main/wav2vec_librispeech_asr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**@Bhagat Surya**

In [None]:
!pip install datasets==1.18.3
!pip install transformers==4.17.0
!pip install jiwer

from datasets import ClassLabel
import random 
import pandas as pd 
from IPython.display import display, HTML

from google.colab import drive
drive.mount('/content/gdrive/')


from datasets import load_dataset

train =load_dataset("librispeech_asr","clean",split="train.100")
test = load_dataset("librispeech_asr","clean",split="test")

train= train.remove_columns(["speaker_id","chapter_id","id"])
test = test.remove_columns(["speaker_id","chapter_id","id"])

def extract_all_character(batch):
  all_text_in_file = " ".join(batch["text"])
  vocabulary = list(set(all_text_in_file))
  return {"vocabulary":[vocabulary],"text":[all_text_in_file]}


voab_list = list(set(train_voab["vocabulary"][0]) | set(test_voab["vocabulary"][0]))

vocabulary_dict = {i:k for k,i in enumerate(voab_list)}

vocabulary_dict["|"] = vocabulary_dict[" "]
del vocabulary_dict[" "]

vocabulary_dict["[UNK]"] = len(vocabulary_dict)
vocabulary_dict["[PAD"] =len(vocabulary_dict)

import json
with open("vocabulary.json","w") as vocabulary_file:
  json.dump(vocabulary_dict,vocabulary_file)

#CTC Token
from transformers import  Wav2Vec2CTCTokenizer
tokenizer = Wav2Vec2CTCTokenizer("./vocabulary.json",unk_token="[UNK]",pad_token="[PAD]",word_delimiter_token="|")

#Feature extraction
from transformers import Wav2Vec2FeatureExtractor
feature_extraction = Wav2Vec2FeatureExtractor(feature_size=1,sampling_rate=16000,padding_value=0.0,do_normalize=True,return_attention_mask=False)

from transformers import Wav2Vec2Processor
processor = Wav2Vec2Processor(tokenizer=tokenizer,feature_extractor=feature_extraction)

dir="/content/gdrive/MyDrive/wav2vec2-large-xlsr-English-librispeech_asr"

processor.save_pretrained(dir)


rand_int = random.randint(0, len(train))
print("Target text:", train[rand_int]["text"])
print("Input array shape:", np.asarray(train[rand_int]["audio"]).shape)
print("Sampling rate:", train[rand_int]["audio"]["sampling_rate"])


def prepare_dataset(batch):
    audio = batch["audio"]

    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["text"]).input_ids
    return batch

train = train.map(prepare_dataset, remove_columns=train.column_names, num_proc=4)
test = test.map(prepare_dataset, remove_columns=test.column_names,num_proc=4)


import torch 
from dataclasses import dataclass,field
from typing import  Any, Dict, List, Optional, Union
@dataclass
class DatacollatorCTCwithPadding():
  processor: Wav2Vec2Processor
  padding: Union[bool,str] = True
  max_length: Optional[int] =  None
  max_length_labels: Optional[int] = None
  pad_to_multiple_of: Optional[int] =  None
  pad_to_multiple_of_lables: Optional[int] = None

  def __call__(self,features:List[Dict[str,Union[List[int],torch.Tensor]]]) -> Dict[str,torch.Tensor]:
    input_features = [{"input_values": feature["input_values"]} for feature in features]
    label_features = [{"input_ids": feature["labels"]} for feature in features]

    batch =  self.processor.pad(
        input_features,
        padding= self.padding,
        max_length = self.max_length,
        pad_to_multiple_of =  self.pad_to_multiple_of,
        return_tensors = "pt"

    )

    with self.processor.as_target_processor():
      labels_batch = self.processor.pad(
          label_features,
          padding = self.padding,
          max_length = self.max_length_labels,
          pad_to_multiple_of = self.pad_to_multiple_of_lables,
          return_tensors="pt"
      )

      lables = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1),-100)

      batch["labels"] = lables
      
      return batch


data_collator = DatacollatorCTCwithPadding(processor=processor, padding=True)


from datasets.load import load_metric
wer_metric = load_metric("wer")



def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}


#model 
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-base", 
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
)



from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir=dir,
  group_by_length=True,
  per_device_train_batch_size=32,
  evaluation_strategy="steps",
  num_train_epochs=30,
  fp16=True,
  gradient_checkpointing=True, 
  save_steps=500,
  eval_steps=500,
  logging_steps=500,
  learning_rate=1e-4,
  weight_decay=0.005,
  warmup_steps=1000,
  save_total_limit=2,
)


from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train,
    eval_dataset=test,
    tokenizer=processor.feature_extractor,
)


trainer.train()

from transformers import AutoModelForCTC, Wav2Vec2Processor

model = Wav2Vec2ForCTC.from_pretrained("/content/gdrive/MyDrive/wav2vec2-large-xlsr-English-librispeech_asr").to("cuda")
processor = Wav2Vec2Processor.from_pretrained("/content/gdrive/MyDrive/wav2vec2-large-xlsr-English-librispeech_asr")

In [None]:
from numpy import number
def show_random_elemnts(dataset,elemnts_number=10):
  assert elemnts_number <= len(dataset),"please enter valid"
  sample =[]
  for _ in range(elemnts_number):
    pick = random.randint(0,len(dataset)-1)
    while pick in sample:
      pick = random.randint(0,len(dataset)-1)
    sample.append(pick)
  
  df = pd.DataFrame(dataset[sample])
  display(HTML(df.to_html()))

show_random_elemnts(train.remove_columns(["file","audio"]))

def map_to_result(batch):
  with torch.no_grad():
    input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0)
    logits = model(input_values).logits

  pred_ids = torch.argmax(logits, dim=-1)
  batch["pred_str"] = processor.batch_decode(pred_ids)[0]
  batch["text"] = processor.decode(batch["labels"], group_tokens=False)
  
  return batch

results = test.map(map_to_result, remove_columns=test.column_names)
print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"])))


show_random_elemnts(results.remove_columns(["speech", "sampling_rate"]))