In [None]:
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, TrainingArguments, Trainer
from datasets import load_from_disk

# 전처리된 데이터셋 로드
dataset = load_from_disk("processed_audio_dataset")

# 모델 및 프로세서 불러오기
model_name = "kresnik/wav2vec2-large-xlsr-korean"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)

# GPU 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)

# 학습 설정
training_args = TrainingArguments(
    output_dir="./wav2vec2-korean-stt",
    group_by_length=True,
    per_device_train_batch_size=8,
    save_strategy="steps",
    save_steps=500,
    num_train_epochs=3,
    fp16=True,
    logging_steps=50,
    remove_unused_columns=False
)

# Trainer 설정
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=processor,
)

# 학습 시작
trainer.train()

# 모델 저장
trainer.save_model("./final-wav2vec2-stt")
