In [None]:
# 데이터셋 스트리밍 로딩 및 통합
from datasets import load_dataset, Audio

# 두 데이터셋을 스트리밍으로 불러와 interleave로 합침
#ds1 = load_dataset("jwh1449/AIhub_foreign_dataset", split="train", streaming=True)
#ds2 = load_dataset("jwh1449/AIhub_foreign_dataset3", split="train", streaming=True)
#dataset = interleave_datasets([ds1, ds2])

dataset = load_dataset("jwh1449/AIhub_foreign_dataset3", split="train", streaming=True)

# 오디오 컬럼이 'audio'일 경우 샘플링레이트 통일
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

In [None]:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from peft import LoraConfig, get_peft_model

model_id = "ghost613/whisper-large-v3-turbo-korean"
device = "cuda" if torch.cuda.is_available() else "cpu"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id,
    torch_dtype=torch.float32,  
    low_cpu_mem_usage=True,
    use_safetensors=True
).to(device)

lora_config = LoraConfig(
    r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

processor = AutoProcessor.from_pretrained(model_id)

2025-05-11 10:40:11.802530: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-11 10:40:11.846400: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
# 데이터셋 전처리 및 Whisper 입력 변환
def prepare_features(example):
    audio = example["audio"]
    text = example.get("text") or example.get("transcripts") or example.get("label")
    inputs = processor(
        audio["array"],
        sampling_rate=audio["sampling_rate"],
        return_tensors="pt"
    )
    # 텍스트 토크나이즈 및 -100 패딩
    labels = processor.tokenizer(text, return_tensors="pt").input_ids[0]
    labels = torch.where(labels == processor.tokenizer.pad_token_id, -100, labels)
    example["input_features"] = inputs.input_features[0].to(torch.float32)
    example["labels"] = labels
    return example

from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    input_features = torch.stack([item["input_features"] for item in batch])
    labels = [item["labels"] for item in batch]
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=-100)
    return {"input_features": input_features, "labels": labels_padded}

def batch_iterator(dataset, batch_size=8):
    batch = []
    for example in dataset:
        batch.append(prepare_features(example))
        if len(batch) == batch_size:
            yield collate_fn(batch)
            batch = []
    if batch:
        yield collate_fn(batch)

In [None]:
# 학습 루프 (PEFT/LoRA, RTX 3090 기준)
from transformers import AdamW, get_linear_schedule_with_warmup
import os

save_dir = "./model_saved"
os.makedirs(save_dir, exist_ok=True)

model.train()
optimizer = AdamW(model.parameters(), lr=1e-5)
num_training_steps = 10000  
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=500, num_training_steps=num_training_steps
)

# 1. 데이터셋 샘플 확인
try:
    first = next(iter(dataset))
    print("First sample:", first)
except Exception as e:
    print("Error reading first sample:", e)

# 2. 배치 생성 확인
for i, batch in enumerate(batch_iterator(dataset, batch_size=8)):
    print(f"Batch {i} loaded", flush=True)
    if i >= 2:
        break

# 3. 학습 루프 print에 flush 추가
try:
    for step, batch in enumerate(batch_iterator(dataset, batch_size=8)):
        print(f"Step {step} | Batch loaded", flush=True)
        input_features = batch["input_features"].to(device)
        labels = batch["labels"].to(device)
        print("Calculating outputs...", flush=True)
        outputs = model(input_features=input_features, labels=labels)
        print("Calculating loss...", flush=True)
        loss = outputs.loss
        print(f"Step {step} | Loss: {loss.item():.4f}", flush=True)
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        if step % 100 == 0:
            print(f"Step {step} | Loss: {loss.item():.4f}", flush=True)
        if step % 200 == 0 and step > 0:
            # LoRA 어댑터 저장 (Hugging Face 표준 포맷)
            checkpoint_dir = f"{save_dir}/lora_step{step}"
            model.save_pretrained(checkpoint_dir)
            processor.save_pretrained(checkpoint_dir)
            print(f"Checkpoint saved at step {step}: {checkpoint_dir}", flush=True)
        if step >= num_training_steps:
            break
except Exception as e:
    print("Training error:", e)


In [None]:
# 저장
model.save_pretrained("./whisper-ktf-ver1")
processor.save_pretrained("./whisper-ktf-ver1")


In [None]:
# 테스트(추론 및 WER/CER 평가)
import evaluate
wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")

def transcribe_and_evaluate(test_dataset, num_samples=100):
    model.eval()
    preds, refs = [], []
    for i, example in enumerate(test_dataset):
        if i >= num_samples:
            break
        audio = example["audio"]
        text = example.get("text") or example.get("transcripts") or example.get("label")
        inputs = processor(
            audio["array"],
            sampling_rate=audio["sampling_rate"],
            return_tensors="pt"
        ).to(device)
        with torch.no_grad():
            generated_ids = model.generate(inputs.input_features)
        pred = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        preds.append(pred)
        refs.append(text)
        if i % 10 == 0:
            print(f"Sample {i}:")
            print("  GT:", text)
            print("  STT:", pred)
    wer = wer_metric.compute(predictions=preds, references=refs)
    cer = cer_metric.compute(predictions=preds, references=refs)
    print(f"\n[TEST] WER: {wer*100:.2f}% | CER: {cer*100:.2f}%")

# 예시: validation split 100개로 테스트
test_ds = load_dataset("jwh1449/AIhub_foreign_dataset", split="validation", streaming=True)
test_ds = test_ds.cast_column("audio", Audio(sampling_rate=16000))
transcribe_and_evaluate(test_ds, num_samples=100)