In [None]:
import numpy as np
import torch
from datasets import Dataset, Audio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Trainer, TrainingArguments, TrainerCallback
from dataclasses import dataclass
from typing import Dict, List, Union

# 모델 및 프로세서 불러오기
MODEL_ID = "kresnik/wav2vec2-large-xlsr-korean"
processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(
    MODEL_ID,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id
)
model.freeze_feature_encoder()

In [None]:
# 데이터셋 준비
def prepare_dataset(df):
    dataset = Dataset.from_pandas(df)
    dataset = dataset.cast_column("file_path", Audio(sampling_rate=16000))
    return dataset

train_dataset = prepare_dataset(train_df)
valid_dataset = prepare_dataset(valid_df)

# 데이터 전처리 함수
def prepare_dataset_for_model(batch):
    audio = batch["file_path"]
    array = audio["array"]
    if np.max(np.abs(array)) > 0:
        array = array / np.max(np.abs(array))
    batch["input_values"] = processor(array, sampling_rate=16000).input_values[0]
    with processor.as_target_processor():
        batch["labels"] = processor(batch["normalized_text"]).input_ids
    return batch

train_dataset = train_dataset.map(prepare_dataset_for_model, remove_columns=train_dataset.column_names)
valid_dataset = valid_dataset.map(prepare_dataset_for_model, remove_columns=valid_dataset.column_names)

# 데이터 정렬기
@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        batch = self.processor.pad(input_features, padding=self.padding, return_tensors="pt")
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(label_features, padding=self.padding, return_tensors="pt")
        batch["labels"] = labels_batch["input_ids"]
        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [None]:
# 평가 메트릭 및 콜백
import evaluate
wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer, "cer": cer}

class UnfreezeFeatureEncoderCallback(TrainerCallback):
    def on_epoch_begin(self, args, state, control, **kwargs):
        if state.epoch == 5:
            model = kwargs.get('model', None)
            if model is not None:
                model.wav2vec2.feature_extractor._freeze_parameters = False
                for param in model.wav2vec2.feature_extractor.parameters():
                    param.requires_grad = True
                print("\n특징 추출기(Feature Encoder)가 언프리즈 되었습니다!")
    def on_epoch_end(self, args, state, control, **kwargs):
        torch.cuda.empty_cache()
        print(f"\n에폭 {state.epoch} 완료, GPU 캐시 정리됨")

In [None]:
# 학습 인자 및 Trainer
training_args = TrainingArguments(
    output_dir="./wav2vec2-korean-asr",
    group_by_length=True,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,
    eval_strategy="steps",
    num_train_epochs=30,
    fp16=True,
    save_steps=500,
    eval_steps=500,
    logging_steps=100,
    learning_rate=3e-4,
    weight_decay=0.005,
    warmup_steps=1000,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=processor,
    callbacks=[UnfreezeFeatureEncoderCallback()]
)

In [None]:
# 학습 시작 (필요시 주석 해제)
# trainer.train()