In [1]:
import pandas as pd
import torch
from transformers import (
    Wav2Vec2Processor,
    Wav2Vec2ForCTC,
    Wav2Vec2FeatureExtractor,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
import asr_utils
import asr_inference

In [None]:
vocab, dsat = "./data/vocab.json", "./data/asr_dataset.csv"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
df = pd.read_csv(dsat)
df['audio_path'] = df['audio_path'].str.replace('\\', '/')
vocab_dict = asr_utils.create_vocab(df, vocab)

In [3]:
model_name = "facebook/mms-1b-all"

tokenizer = asr_utils.get_tokenizer(vocab)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

model = Wav2Vec2ForCTC.from_pretrained(model_name)

model.wav2vec2.feature_extractor._freeze_parameters()

for i, layer in enumerate(model.wav2vec2.encoder.layers):
    if i < 30:
        for param in layer.parameters():
            param.requires_grad = False

model.config.vocab_size = len(tokenizer)
model.lm_head = torch.nn.Linear(model.config.hidden_size, len(tokenizer))
model.config.ctc_loss_reduction = "mean"
model.config.pad_token_id = processor.tokenizer.pad_token_id

In [4]:
def prepare_dataset(batch):
    audio = asr_utils.load_audio(batch["audio_path"])
    processed = processor(audio, text=batch["text"], sampling_rate=16000)
    batch["input_values"] = processed.input_values[0]
    batch["labels"] = processed.labels
    
    return batch

train_dataset, eval_dataset, data_collator = asr_utils.prepare_split(df, prepare_dataset, processor)

Map:   0%|          | 0/885 [00:00<?, ? examples/s]

In [5]:
# dir_to_save_checkpoints = "./trained/mms-1b-chukchi-frozen-finetuned"
# training_args = TrainingArguments(
#     output_dir=dir_to_save_checkpoints,
#     group_by_length=True,
#     per_device_train_batch_size=4,
#     per_device_eval_batch_size=8,
#     gradient_accumulation_steps=4,
#     eval_strategy="steps",
#     num_train_epochs=30,
#     fp16=True,
#     gradient_checkpointing=True,
#     save_steps=150,
#     eval_steps=150,
#     logging_steps=25,
#     learning_rate=5e-5,
#     weight_decay=0.01,
#     warmup_steps=200,
#     lr_scheduler_type="cosine",
#     save_total_limit=5,
#     dataloader_num_workers=2,
#     dataloader_pin_memory=True,
#     load_best_model_at_end=True,
#     metric_for_best_model="wer",
#     greater_is_better=False,
#     report_to=None,
#     remove_unused_columns=False,
#     max_steps=10000,
# )

# trainer = Trainer(
#     model=model,
#     data_collator=data_collator,
#     args=training_args,
#     compute_metrics=asr_utils.create_compute_metrics(processor),
#     train_dataset=train_dataset,
#     eval_dataset=eval_dataset,
#     processing_class=processor.feature_extractor,
#     callbacks=[EarlyStoppingCallback(early_stopping_patience=10, early_stopping_threshold=0.001)]
# )

# print(f"{100*sum(p.numel() for p in model.parameters() if p.requires_grad)/sum(p.numel() for p in model.parameters()):.1f}% trainable params")

In [6]:
dir_to_save_best = "./trained/mms-1b-chukchi-final"

In [7]:
# trainer.train()

# trainer.save_model(dir_to_save_best)
# processor.save_pretrained(dir_to_save_best)

In [9]:
path_to_save_results = "./results/mms-1b-results.txt"
asr_inference.evaluate_model(eval_dataset, dir_to_save_best, path_to_save_results)

WER: 0.841
CER: 0.190
