In [None]:
# %pip install jiwer

In [1]:
# Libararies
import os
import re
import csv
import torch
import evaluate
from datasets import load_dataset, Audio
from transformers import Wav2Vec2Processor
from transformers import Wav2Vec2ForCTC
from transformers import TrainingArguments
from transformers import Trainer
import tqdm as notebook_tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
dataset_path = '/home/kangyi/Lyrics-audio-Alignment/dataset/songs_en'
matedata_path = '/home/kangyi/Lyrics-audio-Alignment/dataset/output-en/metadata.csv'
lang = "en-US"

PRE_PROCESSED = True
ORIGINAL_SR = 44100
TARGET_SR = 16000

### Dataset

In [2]:
###############################################################################
# Step 1: Load the CSV dataset
###############################################################################

dataset = load_dataset("csv", data_files=matedata_path)["train"]

###############################################################################
# Step 2: Split into train, validation, and test sets
###############################################################################
# Example: 80% train, 10% validation, 10% test

# First create test split (10% of total)
split_dataset = dataset.train_test_split(test_size=0.1, seed=42)
train_val = split_dataset["train"]
test_dataset = split_dataset["test"]

# Now split the remaining 90% into 80% train and 10% val
split_train_val = train_val.train_test_split(test_size=0.1111, seed=42)  # 0.1111 of 90% ~ 10%
train_dataset = split_train_val["train"]
eval_dataset = split_train_val["test"]

print("Train samples:", len(train_dataset))
print("Validation samples:", len(eval_dataset))
print("Test samples:", len(test_dataset))

In [3]:
###############################################################################
# Step 3: Convert file_name column to Audio feature
###############################################################################
# This will decode and resample the audio at 16kHz on-the-fly
train_dataset = train_dataset.cast_column("file_name", Audio(sampling_rate=16000))
eval_dataset = eval_dataset.cast_column("file_name", Audio(sampling_rate=16000))
test_dataset = test_dataset.cast_column("file_name", Audio(sampling_rate=16000))


In [None]:
###############################################################################
# Step 4: Load a Wav2Vec2 Processor
###############################################################################
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

def normalize_text(text):
    # Lowercase
    if text == None:
        return ''
    text = text.lower()
    # Remove punctuation except for apostrophes needed for words like "don't"
    text = re.sub(r"[^a-zA-Z0-9\s']", "", text)
    text = text.strip()
    return text

In [4]:
###############################################################################
# Step 5: Preprocessing function
###############################################################################
def prepare_batch(batch):
    audio = batch["file_name"]
    text = batch["text"]
    
    # Normalize text if needed
    text = normalize_text(text)

    # Process audio
    inputs = processor(
        audio["array"], 
        sampling_rate=audio["sampling_rate"], 
        return_attention_mask=True,
        padding=True  # Add padding here
    )

    # Encode labels (text)
    with processor.as_target_processor():
        labels = processor.tokenizer(
        text, 
        padding=True,  # Padding for labels
        truncation=True  # Optional: truncate labels if needed
    ).input_ids

    batch["input_values"] = inputs["input_values"][0]
    batch["attention_mask"] = inputs["attention_mask"][0]
    batch["labels"] = labels
    print(batch)
    return batch

# Apply the preprocessing to datasets
train_dataset = train_dataset.map(prepare_batch, remove_columns=train_dataset.column_names)
eval_dataset = eval_dataset.map(prepare_batch, remove_columns=eval_dataset.column_names)
test_dataset = test_dataset.map(prepare_batch, remove_columns=test_dataset.column_names)

  normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
  ret = ret.dtype.type(ret / rcount)
  normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)
  normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
  ret = ret.dtype.type(ret / rcount)
  normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)
  normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
  ret = ret.dtype.type(ret / rcount)
  normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)
  normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_v

### Train

In [7]:
###############################################################################
# Step 6: Load the Pre-trained Model
###############################################################################
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

###############################################################################
# Step 7: Define Metrics (WER)
###############################################################################
wer_metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions.argmax(-1)
    # Decode predictions
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    # Decode labels
    label_ids = pred.label_ids
    # Replace -100 with pad token
    label_ids = [[l for l in label if l != -100] for label in label_ids]
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

###############################################################################
# Step 8: Training Arguments
###############################################################################
training_args = TrainingArguments(
    output_dir="./wav2vec2-finetuned-asr",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    num_train_epochs=5,
    eval_strategy="steps", #evaluation_strategy
    logging_steps=100,
    save_steps=500,
    learning_rate=1e-4,
    warmup_steps=500,
    fp16=True,
    report_to="none"
)

###############################################################################
# Step 9: Initialize Trainer
###############################################################################
from transformers import DataCollatorWithPadding

# Create a data collator with padding
data_collator = DataCollatorWithPadding(
    tokenizer=processor.feature_extractor,  # Use the feature extractor as the tokenizer
    padding=True,  # Ensure padding
    return_tensors="pt",  # Use PyTorch tensors
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,  # Add the data collator
    compute_metrics=compute_metrics
)

###############################################################################
# Step 10: Train the Model
###############################################################################
trainer.train()

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h-lv60-self and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss
1000,4.9762,2.859481
2000,2.7984,2.561285
3000,2.5361,2.298822
4000,2.3667,2.132559
5000,2.2427,2.026144
6000,2.1165,1.931902
7000,2.0072,1.794211
8000,1.9446,1.695302
9000,1.8092,1.655811
10000,1.6702,1.580696




TrainOutput(global_step=12822, training_loss=2.2515898463395434, metrics={'train_runtime': 3560.4181, 'train_samples_per_second': 14.407, 'train_steps_per_second': 3.601, 'total_flos': 5.259924500788864e+18, 'train_loss': 2.2515898463395434, 'epoch': 2.999649081763949})