In [None]:
# 1. 라이브러리 설치
!pip install -q transformers datasets peft bitsandbytes accelerate jiwer librosa soundfile

import os
import json
import shutil
import gc
import torch
import librosa
import numpy as np
from dataclasses import dataclass
from typing import Any
from datasets import Dataset
from google.colab import drive

from transformers import (
    WhisperFeatureExtractor,
    WhisperTokenizer,
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    BitsAndBytesConfig,
    EarlyStoppingCallback
)
from peft import (
    prepare_model_for_kbit_training,
    LoraConfig,
    get_peft_model,
    PeftModel
)

# =========================================================
# Config 설정
# =========================================================
class Config:
    # 데이터셋 경로
    ZIP_PATH = "" #경로 지정 필요
    DATA_ROOT = "" #경로 지정 필요

    # 학습된 LoRA 어댑터 저장 경로
    LORA_OUTPUT_DIR = "" #경로 지정 필요

    # 최종 병합된 모델 저장 경로
    MERGED_OUTPUT_DIR = "" #경로 지정 필요

    # 모델 ID
    MODEL_ID = "openai/whisper-large-v3"
    LANGUAGE = "ko"
    TASK = "transcribe"

# =========================================================
# GPU Check
# =========================================================
print("GPU 정보:", torch.cuda.get_device_name(0))
total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
print(f"총 VRAM: {total_vram:.2f} GB")

if total_vram <= 16:
    print("T4 (16GB) 환경 감지 -> 4bit QLoRA 설정 적용")

# =========================================================
# Step 1: 데이터셋 준비
# =========================================================
print("Step 1: 데이터셋 준비 중...")
if not os.path.exists('/content/drive'): drive.mount('/content/drive')

if os.path.exists(Config.DATA_ROOT): shutil.rmtree(Config.DATA_ROOT)
shutil.unpack_archive(Config.ZIP_PATH, Config.DATA_ROOT)

with open(f"{Config.DATA_ROOT}/metadata.json", 'r', encoding='utf-8') as f:
    metadata = json.load(f)

dataset = Dataset.from_dict({
    "audio": [os.path.join(Config.DATA_ROOT, "audio", i['file_name']) for i in metadata],
    "sentence": [i['text'] for i in metadata]
})

processor = WhisperProcessor.from_pretrained(Config.MODEL_ID, language=Config.LANGUAGE, task=Config.TASK)
feature_extractor = processor.feature_extractor
tokenizer = processor.tokenizer

def prepare_dataset(batch):
    try:
        audio, _ = librosa.load(batch["audio"], sr=16000)
    except Exception as e:
        print(f"Error loading {batch['audio']}: {e}")
        audio = np.zeros(16000)

    batch["input_features"] = feature_extractor(audio, sampling_rate=16000).input_features[0]
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

dataset = dataset.map(prepare_dataset, num_proc=1)
dataset = dataset.train_test_split(test_size=0.1)

print(f"데이터 준비 완료: Train {len(dataset['train'])}개 / Test {len(dataset['test'])}개")

@dataclass
class DataCollator:
    processor: Any
    def __call__(self, features):
        input_features = [{"input_features": f["input_features"]} for f in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        label_features = [{"input_ids": f["labels"]} for f in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch

# =========================================================
# Step 2: 학습 (Training)
# =========================================================
print("Step 2: 모델 로딩 및 학습 시작 (4bit QLoRA)")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True
)

model = WhisperForConditionalGeneration.from_pretrained(
    Config.MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto"
)

model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none"
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

os.environ["WANDB_DISABLED"] = "true"

training_args = Seq2SeqTrainingArguments(
    output_dir=Config.LORA_OUTPUT_DIR,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=1e-4,
    warmup_steps=50,
    max_steps=500,  # 필요에 따라 조절
    fp16=True,
    eval_strategy="steps",
    per_device_eval_batch_size=2,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=50,
    eval_steps=50,
    logging_steps=25,
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    greater_is_better=False,
    remove_unused_columns=False
)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    data_collator=DataCollator(processor=processor),
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
)

trainer.train()

# 학습 결과(LoRA 어댑터) 저장
print(f"LoRA 어댑터 저장 중: {Config.LORA_OUTPUT_DIR}")
model.save_pretrained(Config.LORA_OUTPUT_DIR)
processor.save_pretrained(Config.LORA_OUTPUT_DIR)
tokenizer.save_pretrained(Config.LORA_OUTPUT_DIR)

print("학습 완료.")

# =========================================================
# Step 3: 메모리 정리 (Merge 전 필수)
# =========================================================
print("Step 3: Merge를 위해 메모리 정리 중...")

# 학습에 사용된 모델 객체 삭제
del model
del trainer
torch.cuda.empty_cache()
gc.collect()

print(f"VRAM 확보 완료. 현재 메모리: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")

# =========================================================
# Step 4: 모델 병합 (Merge LoRA + Base)
# =========================================================
print("Step 4: 모델 병합 (Base Model + LoRA Adapter)")

# 4-1. Base Model 로딩 (FP16)
print("Base Whisper 모델 로딩 중 (FP16)...")
base_model = WhisperForConditionalGeneration.from_pretrained(
    Config.MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto"
)

# 4-2. LoRA 어댑터 로딩
print(f"LoRA 어댑터 로딩 중: {Config.LORA_OUTPUT_DIR}")
lora_model = PeftModel.from_pretrained(
    base_model,
    Config.LORA_OUTPUT_DIR
)

# 4-3. Merge 진행
print("Merge 진행 중...")
merged_model = lora_model.merge_and_unload()

# 4-4. 최종 모델 저장
print(f"Merge된 모델 저장 중: {Config.MERGED_OUTPUT_DIR}")
merged_model.save_pretrained(
    Config.MERGED_OUTPUT_DIR,
    safe_serialization=True
)

# 토크나이저와 설정 파일도 함께 저장 (독립 사용을 위해 필수)
processor.save_pretrained(Config.MERGED_OUTPUT_DIR)
tokenizer.save_pretrained(Config.MERGED_OUTPUT_DIR)

print(f"모든 작업 완료, 병합된 모델 경로: {Config.MERGED_OUTPUT_DIR}")