In [7]:
import torchaudio
from datasets import load_dataset
import evaluate
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, TrainingArguments, Trainer
import torch
import numpy as np
# import os

In [8]:
# Load Google Speech Commands Dataset
dataset = load_dataset("speech_commands", "v0.02", split="train+test", trust_remote_code=True)
dataset = dataset.train_test_split(test_size=0.2)  # Split into train and validation

# Use a smaller subset of the dataset
train_subset = dataset["train"].select(range(100))  # Select 100 samples for training
validation_subset = dataset["test"].select(range(20))  # Select 20 samples for validation

In [None]:
''' dataset_root = os.path.expanduser("~/.cache/huggingface/datasets/speech_commands/v0.02/0.2.0/ba3d9a6cf49aa1313c51abe16b59203451482ccb9fee6d23c94fecabf3e206da")

def speech_file_to_array_fn(batch):
    # Join the dataset root with the relative file path
    file_path = os.path.join(dataset_root, batch["file"])
    
    # Load the audio file
    speech_array, sampling_rate = torchaudio.load(file_path)
    
    # Resample to 16kHz
    resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
    batch["speech"] = resampler(speech_array).squeeze().numpy()
    batch["target_text"] = batch["label"]
    return batch '''


# Preprocessing Function
def speech_file_to_array_fn(batch):
    # speech_array, sampling_rate = torchaudio.load(batch["file"])
    # resampler = torchaudio.transforms.Resample(sampling_rate, 16000)
    # batch["speech"] = resampler(speech_array).squeeze().numpy()
    batch["target_text"] = batch["label"]  # Use the label as the target text
    batch['audio']['array']
    return batch

train_subset = train_subset.map(speech_file_to_array_fn)
validation_subset = validation_subset.map(speech_file_to_array_fn) 

'''def speech_file_to_array_fn(batch):
    try:
        speech_array, sampling_rate = torchaudio.load(batch["file"])
        resampler = torchaudio.transforms.Resample(sampling_rate, 16000)
        batch["speech"] = resampler(speech_array).squeeze().numpy()
    except Exception as e:
        print(f"Error loading file {batch['file']}: {e}")
        batch["speech"] = None  # Assign None for unusable files
    return batch

train_subset = train_subset.filter(lambda x: x["speech"] is not None)
validation_subset = validation_subset.filter(lambda x: x["speech"] is not None)

train_subset = train_subset.map(speech_file_to_array_fn)
validation_subset = validation_subset.map(speech_file_to_array_fn)'''


Map:   0%|          | 0/100 [00:00<?, ? examples/s]

LibsndfileError: Error opening 'yes/1cb788bc_nohash_0.wav': System error.

In [None]:
# Load Pre-trained Processor and Model
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-base-960h",
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.1,
    mask_time_prob=0.05,
    layerdrop=0.1,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
)

In [None]:
# Tokenize Inputs
def prepare_dataset(batch):
    batch["input_values"] = processor(batch["speech"], sampling_rate=16000).input_values[0]
    batch["labels"] = processor.tokenizer(batch["target_text"]).input_ids
    return batch

train_subset = train_subset.map(prepare_dataset, remove_columns=train_subset.column_names)
validation_subset = validation_subset.map(prepare_dataset, remove_columns=validation_subset.column_names)

In [None]:
# Data Collator
def data_collator(batch):
    input_values = torch.tensor([item["input_values"] for item in batch], dtype=torch.float32)
    labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
    return {"input_values": input_values, "labels": labels}


In [None]:
# Define Metrics
wer_metric = evaluate.load("wer")
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids)
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

In [None]:
# Training Arguments
training_args = TrainingArguments(
    output_dir="./wav2vec2-finetuned",
    evaluation_strategy="steps",
    per_device_train_batch_size=16,
    gradient_accumulation_steps=2,
    per_device_eval_batch_size=16,
    learning_rate=3e-4,
    warmup_steps=500,
    max_steps=4000,
    save_steps=500,
    eval_steps=500,
    logging_steps=500,
    save_total_limit=2,
    #fp16=True,
    report_to="tensorboard",
)

In [None]:
# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_subset,
    eval_dataset=validation_subset,
    tokenizer=processor.feature_extractor,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [None]:
# Start Training
trainer.train()

# Save Model
trainer.save_model("./wav2vec2-finetuned")


In [11]:
train_subset

Dataset({
    features: ['file', 'audio', 'label', 'is_unknown', 'speaker_id', 'utterance_id'],
    num_rows: 100
})

In [12]:
train_subset[0]

{'file': 'yes/1cb788bc_nohash_0.wav',
 'audio': {'path': 'yes/1cb788bc_nohash_0.wav',
  'array': array([0.03964233, 0.03878784, 0.03799438, ..., 0.04742432, 0.04766846,
         0.04782104]),
  'sampling_rate': 16000},
 'label': 0,
 'is_unknown': False,
 'speaker_id': '1cb788bc',
 'utterance_id': 0}

In [13]:
train_subset[50]

{'file': 'left/b308773d_nohash_2.wav',
 'audio': {'path': 'left/b308773d_nohash_2.wav',
  'array': array([0.00021362, 0.00021362, 0.00021362, ..., 0.0015564 , 0.00201416,
         0.00238037]),
  'sampling_rate': 16000},
 'label': 4,
 'is_unknown': False,
 'speaker_id': 'b308773d',
 'utterance_id': 2}