In [69]:
import glob
import os
import random
from typing import Union, List

import librosa
import numpy as np
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer,  Wav2Vec2FeatureExtractor
import yaml
import json
from tqdm import tqdm
import pickle
import datasets

In [2]:
data_path = "asr_data.json"
with open(data_path, "r") as f:
    asr_data = json.load(f)
random.shuffle(asr_data)
global train_data, valid_data
train_data = asr_data[:int(len(asr_data)*0.9)]
valid_data = asr_data[int(len(asr_data)*0.9):]

train = {}
train["file"]=[]
train["audio"]=[]
train["text"]=[]
for pair_data in tqdm(train_data):
    file, trans = pair_data
    train["file"].append(file)
    audio, _ = librosa.load(file, sr = 16000)
    train["audio"].append(audio)
    train["text"].append(trans)
train_set = datasets.Dataset.from_dict(train)

test = {}
test["file"]=[]
test["audio"]=[]
test["text"]=[]
for pair_data in tqdm(valid_data):
    file, trans = pair_data
    test["file"].append(file)
    audio, _ = librosa.load(file, sr = 16000)
    test["audio"].append(audio)
    test["text"].append(trans)
test_set = datasets.Dataset.from_dict(test)
food_dataset = {}
food_dataset["train"] = train_set
food_dataset["test"] = test_set

food_dataset = datasets.DatasetDict(food_dataset)
food_dataset.save_to_disk("food_dataset")

In [41]:
food_dataset = datasets.load_from_disk("food_dataset")



In [43]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

In [42]:
import re
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'

def remove_special_characters(batch):
    batch["text"] = re.sub(chars_to_ignore_regex, '', batch["text"]).lower() + " "
    return batch
food_dataset = food_dataset.map(remove_special_characters)

  0%|          | 0/3628 [00:00<?, ?ex/s]

  0%|          | 0/404 [00:00<?, ?ex/s]

In [45]:
show_random_elements(food_dataset["train"].remove_columns(["audio", "file"]))

Unnamed: 0,text
0,there is one egg in a blue background
1,two bananas in a white background
2,there are two kiwi fruits in a brown background
3,there are two cucumbers in a white background
4,there are two carrots in a brown background
5,three eggplants in a blue background
6,there is one banana in a blue background
7,there are two strawberries in a white background
8,there are two lemons in a white background
9,there are three white radishes in a blue background


In [46]:
def extract_all_chars(batch):
  all_text = " ".join(batch["text"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

In [52]:
vocabs = food_dataset.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=food_dataset.column_names["train"])

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [53]:
vocab_list = list(set(vocabs["train"]["vocab"][0]) | set(vocabs["test"]["vocab"][0]))

In [59]:
vocab_dict = {v: k for k, v in enumerate(vocab_list)}

In [61]:
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

In [63]:
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)

25

In [64]:
import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In [67]:
tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

In [70]:
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)

In [100]:
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [101]:
def prepare_dataset(batch):
    audio = batch["audio"]

    # batched output is "un-batched" to ensure mapping is correct
    batch["input_values"] = processor(audio, sampling_rate=16000).input_values[0]
    batch["input_length"] = len(batch["input_values"])
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["text"]).input_ids
    return batch

In [103]:
food_dataset = food_dataset.map(prepare_dataset, remove_columns=food_dataset.column_names["train"], num_proc=4)

        

#1:   0%|          | 0/907 [00:00<?, ?ex/s]

#3:   0%|          | 0/907 [00:00<?, ?ex/s]

#2:   0%|          | 0/907 [00:00<?, ?ex/s]

#0:   0%|          | 0/907 [00:00<?, ?ex/s]

        

#2:   0%|          | 0/101 [00:00<?, ?ex/s]

#0:   0%|          | 0/101 [00:00<?, ?ex/s]

#1:   0%|          | 0/101 [00:00<?, ?ex/s]

#3:   0%|          | 0/101 [00:00<?, ?ex/s]

In [104]:
max_input_length_in_sec = 6.0
food_dataset["train"] = food_dataset["train"].filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])

  0%|          | 0/4 [00:00<?, ?ba/s]

In [105]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [106]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [108]:
wer_metric = datasets.load_metric("wer")

  """Entry point for launching an IPython kernel.


In [109]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [110]:
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-base",
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
)

  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForCTC: ['project_q.weight', 'quantizer.weight_proj.bias', 'project_hid.weight', 'project_hid.bias', 'quantizer.weight_proj.weight', 'project_q.bias', 'quantizer.codevectors']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['lm_head.weight', 'lm

In [111]:
model.freeze_feature_encoder()

AttributeError: 'Wav2Vec2ForCTC' object has no attribute 'freeze_feature_encoder'

In [121]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./asr_model",
  group_by_length=True,
  per_device_train_batch_size=32,
  evaluation_strategy="steps",
  num_train_epochs=30,
  fp16=True,
  gradient_checkpointing=True,
  save_steps=500,
  eval_steps=500,
  logging_steps=500,
  learning_rate=1e-4,
  weight_decay=0.005,
  warmup_steps=1000,
  save_total_limit=2,
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [122]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=food_dataset["train"],
    eval_dataset=food_dataset["test"],
    tokenizer=processor.feature_extractor,
)

Using amp half precision backend


In [123]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.
***** Running training *****
  Num examples = 3625
  Num Epochs = 30
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 1
  Total optimization steps = 3420


Step,Training Loss,Validation Loss,Wer
500,1.5999,0.332024,0.29304
1000,0.2055,0.069532,0.182418
1500,0.0924,0.034597,0.165934
2000,0.0523,0.030913,0.159341
2500,0.0335,0.018203,0.153114
3000,0.022,0.014778,0.147985


The following columns in the evaluation set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.
***** Running Evaluation *****
  Num examples = 404
  Batch size = 8
Saving model checkpoint to ./asr_model/checkpoint-500
Configuration saved in ./asr_model/checkpoint-500/config.json
Model weights saved in ./asr_model/checkpoint-500/pytorch_model.bin
Configuration saved in ./asr_model/checkpoint-500/preprocessor_config.json
The following columns in the evaluation set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.
***** Running Evaluation *****
  Num examples = 404
  Batch size = 8
Saving model checkpoint to ./asr_model/checkpoint-1000
Configuration saved in ./asr_model/checkpoint-1000/config.json
Model weights saved in ./asr_model/checkpoint-1000/pytorch_model.bin
Configuration saved in ./asr_model/checkpoint-1000/preprocessor_config.json
The following columns in the evaluation set  don

TrainOutput(global_step=3420, training_loss=0.29487965971405744, metrics={'train_runtime': 2493.3666, 'train_samples_per_second': 43.616, 'train_steps_per_second': 1.372, 'total_flos': 3.621757166136115e+18, 'train_loss': 0.29487965971405744, 'epoch': 30.0})