## Импорт основных библиотек и вспомогательных модулей

In [1]:
import pandas as pd
import torch
from transformers import (
    Wav2Vec2Processor,
    Wav2Vec2ForCTC,
    Wav2Vec2FeatureExtractor,
    TrainingArguments,
    Trainer
)
import asr_utils
import asr_inference

## Определение путей к файлам словаря и датасета, загрузка датасета

In [2]:
vocab, dsat = "./vocab.json", "./asr_dataset.csv"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
df = pd.read_csv(dsat)
df['audio_path'] = df['audio_path'].str.replace('\\', '/')
vocab_dict = asr_utils.create_vocab(df, vocab)

## Инициализация модели и процессора, токенизация, настройка архитектуры

In [3]:
model_name = "facebook/wav2vec2-large-960h"

tokenizer = asr_utils.get_tokenizer(vocab)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

model = Wav2Vec2ForCTC.from_pretrained(model_name)
model.config.vocab_size = len(tokenizer)
model.lm_head = torch.nn.Linear(model.config.hidden_size, len(tokenizer))
model.config.ctc_loss_reduction = "mean"
model.config.pad_token_id = processor.tokenizer.pad_token_id

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Подготовка функции преобразования для датасета

In [4]:
def prepare_dataset(batch):
    audio = asr_utils.load_audio(batch["audio_path"])
    processed = processor(audio, text=batch["text"], sampling_rate=16000)
    batch["input_values"] = processed.input_values[0]
    batch["labels"] = processed.labels
    
    return batch

train_dataset, eval_dataset, data_collator = asr_utils.prepare_split(df, prepare_dataset, processor)

Map:   0%|          | 0/885 [00:00<?, ? examples/s]

## Параметры обучения и Trainer (раскомментируйте, чтобы запустить обучение моделей)

In [None]:
# dir_to_save_checkpoints = "./trained/wav2vec2-960h-chukchi-finetuned"
# training_args = TrainingArguments(
#     output_dir=dir_to_save_checkpoints,
#     group_by_length=True,
#     per_device_train_batch_size=12,
#     per_device_eval_batch_size=12,
#     gradient_accumulation_steps=1,
#     eval_strategy="steps",
#     num_train_epochs=30,
#     fp16=True,
#     gradient_checkpointing=True,
#     save_steps=100,
#     eval_steps=100,
#     logging_steps=25,
#     learning_rate=1e-4,
#     warmup_steps=200,
#     weight_decay=0.01,
#     lr_scheduler_type="cosine_with_restarts",
#     save_total_limit=3,
#     dataloader_num_workers=2,
#     dataloader_pin_memory=True,
#     load_best_model_at_end=True,
#     metric_for_best_model="eval_loss",
#     greater_is_better=False,
#     report_to=None,
#     max_steps=4000,
# )

# trainer = Trainer(
#     model=model,
#     data_collator=data_collator,
#     args=training_args,
#     compute_metrics=asr_utils.create_compute_metrics(processor),
#     train_dataset=train_dataset,
#     eval_dataset=eval_dataset,
#     processing_class=processor.feature_extractor,
# )

## Папка назначения для финальной сохраненной модели

In [6]:
dir_to_save_best = "./trained/wav2vec2-960h-chukchi-final"

## Обучение и сохранение результатов (раскомментируйте, чтобы запустить обучение моделей)

In [None]:
# trainer.train()

# trainer.save_model(dir_to_save_best)
# processor.save_pretrained(dir_to_save_best)

## Оценка модели на валидационной выборке (рассчёт WER/CER)

In [None]:
path_to_save_results = "./results/wav2vec2-results.txt"
asr_inference.evaluate_model(eval_dataset, dir_to_save_best, path_to_save_results)

WER: 1.006
CER: 2.192
