In [None]:
from datasets import Dataset
import numpy as np
import torch
from lsync.config import TARGET_SR
import librosa
from transformers import Wav2Vec2Processor, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer

import evaluate

metrics = evaluate.load("wer")
MODEL_ID = "patrickvonplaten/wavlm-libri-clean-100h-base-plus"
MODEL_ID = "facebook/wav2vec2-large-960h-lv60-self"
tokenizer = Wav2Vec2CTCTokenizer("vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

### Dataset

In [2]:
def audio_processing(fp):
    y, sr = librosa.load(fp, sr=TARGET_SR)
    yt, _ = librosa.effects.trim(y, top_db=30)
    return yt


def data_preprocess(data):
    audio = audio_processing(data['path'])
    result = {}
    result['input_values'] = processor(
        audio, sampling_rate=TARGET_SR).input_values[0]
    with processor.as_target_processor():
        result["labels"] = processor(data["text"]).input_ids
    return result


def prepare_dataset(batch):
    audio = batch["path"]

    # batched output is "un-batched" to ensure mapping is correct
    batch["input_values"] = processor(
        audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    batch["labels"] = processor(batch["text"]).input_ids
    return batch

def get_datasets():
    dataset_csv_path = "/home/kangyi/Lyrics-audio-Alignment/dataset/output-en/metadata new.csv"
    dataset = Dataset.from_csv(dataset_csv_path)
    dataset = dataset.filter(lambda x: x["text"] != "")
    dataset.cleanup_cache_files()
    dataset = dataset.train_test_split(test_size=0.05, seed=41)
    # smaller_dataset = dataset['test'].train_test_split(test_size=0.05)
    train_dataset = dataset['train']
    test_dataset = dataset['test']

    train_dataset = train_dataset.map(
        data_preprocess,
        num_proc=8
    )
    test_dataset = test_dataset.map(
        data_preprocess,
        num_proc=8
    )
    return (train_dataset, test_dataset)

In [3]:
# Define a target mask_length (ensure this is smaller than your shortest valid sequence length)
MASK_LENGTH = 10  # Adjust this value as needed

def filter_short_sequences(batch):
    """
    Filter out audio sequences that are shorter than a specified mask length.
    """
    audio = audio_processing(batch["path"])  # Load audio
    input_values = processor(audio, sampling_rate=TARGET_SR).input_values[0]  # Process audio to get input values
    sequence_length = len(input_values)  # Get the sequence length

    # Keep sequences longer than the specified mask length
    return sequence_length > MASK_LENGTH

def get_datasets():
    dataset_csv_path = "/home/kangyi/Lyrics-audio-Alignment/dataset/output-en/metadata new.csv"
    dataset = Dataset.from_csv(dataset_csv_path)

    # Remove empty text entries
    dataset = dataset.filter(lambda x: x["text"] != "")

    # Clean up cache files
    dataset.cleanup_cache_files()

    # Split into train and test sets
    dataset = dataset.train_test_split(test_size=0.05, seed=41)
    train_dataset = dataset["train"]
    test_dataset = dataset["test"]

    # Preprocess datasets
    train_dataset = train_dataset.map(
        data_preprocess,
        num_proc=8
    )
    test_dataset = test_dataset.map(
        data_preprocess,
        num_proc=8
    )

    # Filter out short sequences
    train_dataset = train_dataset.filter(filter_short_sequences)
    test_dataset = test_dataset.filter(filter_short_sequences)

    return train_dataset, test_dataset


In [None]:
train_dataset, test_dataset = get_datasets()

In [5]:
# Get sequence lengths
# sequence_lengths = []

# for batch in train_dataset:
#     sequence_lengths.append(len(batch["input_values"]))

# print(f"Minimum sequence length: {min(sequence_lengths)}")

### Train

In [6]:
from dataclasses import dataclass
from typing import Union,Optional, List, Dict
@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": feature["input_values"]}
                          for feature in features]
        label_features = [{"input_ids": feature["labels"]}
                          for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch


def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    wer = metrics.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

In [None]:
# Train
from transformers import Wav2Vec2Processor
from transformers import Wav2Vec2ForCTC
from transformers import TrainingArguments
from transformers import Trainer
# Train
data_collator = DataCollatorCTCWithPadding(
    processor=processor, padding=True)

model = Wav2Vec2ForCTC.from_pretrained(
    MODEL_ID,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
)
model.freeze_feature_encoder()
torch.cuda.empty_cache()
training_args = TrainingArguments(
    output_dir="model960",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=2,
    num_train_epochs=3,
    fp16=True,
    save_steps=5000,
    eval_steps=1000,
    logging_steps=1000,
    learning_rate=5e-4,
    weight_decay=0.0001,
    warmup_steps=1000,
    save_total_limit=8,
    eval_strategy="steps",
    gradient_checkpointing=True
    # load_best_model_at_end=True,
    # metric_for_best_model="eval_wer"
)

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    processing_class=processor.feature_extractor,
)

trainer.train()