In [3]:
from datasets import load_dataset, load_metric

timit = load_dataset("timit_asr")

print(timit)

Reusing dataset timit_asr (/home/slzhang/.cache/huggingface/datasets/timit_asr/clean/2.0.1/66d672b7070257726bc9e76100d06d3e342aa6c37aed742803302c293540e96d)


DatasetDict({
    train: Dataset({
        features: ['file', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id'],
        num_rows: 4620
    })
    test: Dataset({
        features: ['file', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id'],
        num_rows: 1680
    })
})


In [4]:
timit = timit.remove_columns(["phonetic_detail", "word_detail", "dialect_region", "id", "sentence_type", "speaker_id"])

from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

show_random_elements(timit["train"].remove_columns(["file"]))

Unnamed: 0,text
0,Movies never have enough villains.
1,We've got plenty of time to think about that.
2,Leave me your address.
3,"But briefly, the topping configuration must be examined for its inferences."
4,Naturally no woman can ever completely monopolize the sexual initiative.
5,Guerrillas were racing toward him.
6,She always jokes about too much garlic in his food.
7,Don't ask me to carry an oily rag like that.
8,"They were both walking towards each other, unhurried."
9,She had your dark suit in greasy wash water all year.


In [5]:
import re
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'

def remove_special_characters(batch):
    batch["text"] = re.sub(chars_to_ignore_regex, '', batch["text"]).lower()
    return batch

timit = timit.map(remove_special_characters)

Loading cached processed dataset at /home/slzhang/.cache/huggingface/datasets/timit_asr/clean/2.0.1/66d672b7070257726bc9e76100d06d3e342aa6c37aed742803302c293540e96d/cache-b88dbc8ee61360f1.arrow
Loading cached processed dataset at /home/slzhang/.cache/huggingface/datasets/timit_asr/clean/2.0.1/66d672b7070257726bc9e76100d06d3e342aa6c37aed742803302c293540e96d/cache-957c31b128545f2d.arrow


In [6]:
show_random_elements(timit["train"].remove_columns(["file"]))

Unnamed: 0,text
0,get a calico cat to keep
1,the irate actor stomped away idiotically
2,it's impossible to deal with bureaucracy
3,the news agency hired a great journalist
4,quince seed gum is the main ingredient in wavesetting lotions
5,birthday parties have cupcakes and ice cream
6,she had your dark suit in greasy wash water all year
7,we got drenched from the uninterrupted rain
8,don't ask me to carry an oily rag like that
9,they find deep pessimism in them


In [7]:
def extract_all_chars(batch):
  all_text = " ".join(batch["text"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

vocabs = timit.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=timit.column_names["train"])

vocab_list = list(set(vocabs["train"]["vocab"][0]) | set(vocabs["test"]["vocab"][0]))

vocab_dict = {v: k for k, v in enumerate(vocab_list)}

vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
print(len(vocab_dict))

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

30


In [8]:
import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

In [9]:
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Processor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [10]:
import soundfile as sf

def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = sf.read(batch["file"])
    batch["speech"] = speech_array
    batch["sampling_rate"] = sampling_rate
    batch["target_text"] = batch["text"]
    return batch

timit = timit.map(speech_file_to_array_fn, remove_columns=timit.column_names["train"], num_proc=4)

Loading cached processed dataset at /home/slzhang/.cache/huggingface/datasets/timit_asr/clean/2.0.1/66d672b7070257726bc9e76100d06d3e342aa6c37aed742803302c293540e96d/cache-74d3b8e6dc352907.arrow
Loading cached processed dataset at /home/slzhang/.cache/huggingface/datasets/timit_asr/clean/2.0.1/66d672b7070257726bc9e76100d06d3e342aa6c37aed742803302c293540e96d/cache-6c7d2671723cd6e0.arrow
Loading cached processed dataset at /home/slzhang/.cache/huggingface/datasets/timit_asr/clean/2.0.1/66d672b7070257726bc9e76100d06d3e342aa6c37aed742803302c293540e96d/cache-4e158e336671bd11.arrow
Loading cached processed dataset at /home/slzhang/.cache/huggingface/datasets/timit_asr/clean/2.0.1/66d672b7070257726bc9e76100d06d3e342aa6c37aed742803302c293540e96d/cache-cd59151092b0bcee.arrow
Loading cached processed dataset at /home/slzhang/.cache/huggingface/datasets/timit_asr/clean/2.0.1/66d672b7070257726bc9e76100d06d3e342aa6c37aed742803302c293540e96d/cache-2a192eaa3342b134.arrow
Loading cached processed datas

In [11]:
import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(timit["train"]))

ipd.Audio(data=np.asarray(timit["train"][rand_int]["speech"]), autoplay=True, rate=16000)

In [12]:
rand_int = random.randint(0, len(timit["train"]))

print("Target text:", timit["train"][rand_int]["target_text"])
print("Input array shape:", np.asarray(timit["train"][rand_int]["speech"]).shape)
print("Sampling rate:", timit["train"][rand_int]["sampling_rate"])

Target text: tim takes sheila to see movies twice a week
Input array shape: (53248,)
Sampling rate: 16000


In [13]:
def prepare_dataset(batch):
    # check that all files have the correct sampling rate
    assert (
        len(set(batch["sampling_rate"])) == 1
    ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."

    batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["target_text"]).input_ids
    return batch

timit_prepared = timit.map(prepare_dataset, remove_columns=timit.column_names["train"], batch_size=8, num_proc=4, batched=True)


  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)


In [14]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        max_length (:obj:`int`, `optional`):
            Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
        max_length_labels (:obj:`int`, `optional`):
            Maximum length of the ``labels`` returned list and optionally padding length (see above).
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch


In [15]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
wer_metric = load_metric("wer")

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [16]:
from transformers import Wav2Vec2Processor, Wav2Vec2CTCTokenizer
from transformers.modeling_outputs import CausalLMOutput
from wav2vec2_model import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2Model, Wav2Vec2PreTrainedModel, Wav2Vec2Config, Wav2Vec2Encoder, Wav2Vec2EncoderLayer
from wav2vec2_model import Wav2Vec2RawEncoder, Wav2Vec2_with_exit
import torch.nn as nn
import copy


In [23]:
# model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-base", 
    gradient_checkpointing=True, 
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
)

model_CTC = Wav2Vec2_with_exit.from_pretrained(
    "facebook/wav2vec2-base", 
    gradient_checkpointing=True, 
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
)
# model_CTC = Wav2Vec2_with_exit.from_pretrained("facebook/wav2vec2-base-960h")
model_CTC.add_exit(7)
model_CTC.freeze_feature_extractor()

loading configuration file https://huggingface.co/facebook/wav2vec2-base/resolve/main/config.json from cache at /home/slzhang/.cache/huggingface/transformers/c7746642f045322fd01afa31271dd490e677ea11999e68660a92619ec7c892b4.02212753c42f07ecd65bbe35175ac4866badb735f9dae5bf2ae455c57db4dbb7
Model config Wav2Vec2Config {
  "activation_dropout": 0.0,
  "apply_spec_augment": true,
  "architectures": [
    "Wav2Vec2ForPreTraining"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 1,
  "codevector_dim": 256,
  "contrastive_logits_temperature": 0.1,
  "conv_bias": false,
  "conv_dim": [
    512,
    512,
    512,
    512,
    512,
    512,
    512
  ],
  "conv_kernel": [
    10,
    3,
    3,
    3,
    3,
    2,
    2
  ],
  "conv_stride": [
    5,
    2,
    2,
    2,
    2,
    2,
    2
  ],
  "ctc_loss_reduction": "mean",
  "ctc_zero_infinity": false,
  "diversity_loss_weight": 0.1,
  "do_stable_layer_norm": false,
  "eos_token_id": 2,
  "feat_extract_activation": "gelu",
  "feat_extract_no

In [24]:
# model.freeze_feature_extractor()

from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
  # output_dir="/content/gdrive/MyDrive/wav2vec2-base-timit-demo",
  output_dir="./wav2vec2-base-timit-demo",
  group_by_length=True,
  per_device_train_batch_size=8,
  evaluation_strategy="steps",
  num_train_epochs=30,
  fp16=True,
  save_steps=500,
  eval_steps=50,
  logging_steps=50,
  learning_rate=1e-4,
  weight_decay=0.005,
  warmup_steps=1000,
  save_total_limit=2,
)

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=timit_prepared["train"],
    eval_dataset=timit_prepared["test"],
    tokenizer=processor.feature_extractor,
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
Using the `WAND_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using amp fp16 backend


In [25]:
import os
os.environ["WANDB_DISABLED"] = "true"
trainer.train()

***** Running training *****
  Num examples = 4620
  Num Epochs = 30
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 1
  Total optimization steps = 4350


Step,Training Loss,Validation Loss,Wer
50,8.648,14.203657,1.0


***** Running Evaluation *****
  Num examples = 1680
  Batch size = 32


KeyboardInterrupt: 