In [3]:
import os
import pandas as pd
import torch
import torchaudio
import numpy as np
import re
import json
import logging
import gc
from pathlib import Path
from typing import Dict, List, Optional
from datetime import datetime
import warnings
warnings.filterwarnings("ignore")

from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from torch.utils.data import Dataset
from datasets import Dataset as HFDataset
from jiwer import wer, cer

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('whisper_finetuning.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

class KoreanWhisperDataPreprocessor:
    def __init__(self, audio_dir: str, csv_files: List[str], output_dir: str, feature_dir: str, chunk_size: int = 1000):
        self.audio_dir = Path(audio_dir)
        self.csv_files = csv_files
        self.output_dir = Path(output_dir)
        self.feature_dir = Path(feature_dir)
        self.output_dir.mkdir(exist_ok=True)
        self.feature_dir.mkdir(exist_ok=True)
        self.processor = WhisperProcessor.from_pretrained("openai/whisper-small")
        self.chunk_size = chunk_size

    def clean_text(self, text: str) -> str:
        text = re.sub(r'\([^)]*\)', '', text)
        text = re.sub(r'[",\'\"]', '', text)
        text = re.sub(r'[^\w\s.!?ㄱ-ㅎㅏ-ㅣ가-힣]', '', text)
        text = re.sub(r'\s+', ' ', text)
        return text.strip()

    def load_and_merge_data(self) -> pd.DataFrame:
        dataframes = []
        for csv_file in self.csv_files:
            df = pd.read_csv(csv_file)
            dataframes.append(df)
            logger.info(f"Loaded {len(df)} samples from {csv_file}")
        merged_df = pd.concat(dataframes, ignore_index=True)
        logger.info(f"Total samples after merging: {len(merged_df)}")
        return merged_df

    def validate_audio_files(self, df: pd.DataFrame) -> pd.DataFrame:
        valid_indices = []
        for idx, row in df.iterrows():
            audio_path = self.audio_dir / row['fileName']
            if audio_path.exists():
                valid_indices.append(idx)
            else:
                logger.warning(f"Audio file not found: {audio_path}")
        valid_df = df.loc[valid_indices].reset_index(drop=True)
        logger.info(f"Valid audio files: {len(valid_df)}/{len(df)}")
        return valid_df

    def process_and_save_feature(self, audio_path: Path, feature_path: Path) -> bool:
        try:
            waveform, sample_rate = torchaudio.load(audio_path)
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)
            if sample_rate != 16000:
                resampler = torchaudio.transforms.Resample(sample_rate, 16000)
                waveform = resampler(waveform)
            audio_array = waveform.squeeze().numpy().astype(np.float32)
            max_length = 30 * 16000
            if len(audio_array) > max_length:
                audio_array = audio_array[:max_length]
                logger.warning(f"Audio truncated: {audio_path}")
            # (1, 80, N) → (80, N)
            features = self.processor.feature_extractor(audio_array, sampling_rate=16000).input_features.squeeze(0)
            np.save(feature_path, features)
            del waveform, audio_array, features
            gc.collect()
            return True
        except Exception as e:
            logger.error(f"Error processing audio {audio_path}: {e}")
            return False

    def create_dataset(self) -> (HFDataset, HFDataset):
        train_meta_path = self.output_dir / "train_metadata.json"
        val_meta_path = self.output_dir / "val_metadata.json"

        if train_meta_path.exists() and val_meta_path.exists():
            logger.info("Loading existing metadata files...")
            with open(train_meta_path, 'r', encoding='utf-8') as f:
                train_meta = json.load(f)
            with open(val_meta_path, 'r', encoding='utf-8') as f:
                val_meta = json.load(f)
        else:
            logger.info("Creating new metadata files and extracting features...")
            df = self.load_and_merge_data()
            df = self.validate_audio_files(df)
            df['cleaned_text'] = df['Reading'].apply(self.clean_text)
            df = df[df['cleaned_text'].str.len() > 0].reset_index(drop=True)
            train_meta = []
            val_meta = []
            meta_chunks = []
            for idx, row in df.iterrows():
                if idx % self.chunk_size == 0 and idx > 0:
                    logger.info(f"Processed {idx}/{len(df)} samples, writing chunk to disk and clearing memory...")
                    meta_chunks.append((train_meta.copy(), val_meta.copy()))
                    train_meta.clear()
                    val_meta.clear()
                    gc.collect()
                audio_path = self.audio_dir / row['fileName']
                feature_path = self.feature_dir / (row['fileName'] + ".npy")
                if not feature_path.exists():
                    success = self.process_and_save_feature(audio_path, feature_path)
                    if not success:
                        continue
                meta = {
                    "feature_path": str(feature_path),
                    "text": row['cleaned_text'],
                    "file_name": row['fileName'],
                    "duration": row['recordTime']
                }
                if idx < int(len(df) * 0.8):
                    train_meta.append(meta)
                else:
                    val_meta.append(meta)
            meta_chunks.append((train_meta, val_meta))
            all_train_meta = []
            all_val_meta = []
            for t, v in meta_chunks:
                all_train_meta.extend(t)
                all_val_meta.extend(v)
            with open(train_meta_path, 'w', encoding='utf-8') as f:
                json.dump(all_train_meta, f, ensure_ascii=False, indent=2)
            with open(val_meta_path, 'w', encoding='utf-8') as f:
                json.dump(all_val_meta, f, ensure_ascii=False, indent=2)
            logger.info(f"Saved metadata - Train: {len(all_train_meta)}, Val: {len(all_val_meta)}")
            train_meta = all_train_meta
            val_meta = all_val_meta

        train_dataset = HFDataset.from_list(train_meta)
        val_dataset = HFDataset.from_list(val_meta)
        return train_dataset, val_dataset

class WhisperDataset(Dataset):
    def __init__(self, dataset: HFDataset, processor: WhisperProcessor, max_length: int = 448):
        self.dataset = dataset
        self.processor = processor
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        features = np.load(item["feature_path"])  # (80, N)
        assert features.shape[0] == 80, f"Invalid feature shape: {features.shape}"
        features = torch.tensor(features, dtype=torch.float)  # [80, N]
        labels = self.processor.tokenizer(
            item["text"],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.max_length
        ).input_ids.squeeze()
        labels[labels == self.processor.tokenizer.pad_token_id] = -100
        return {
            "input_features": features,  # [80, N]
            "labels": labels
        }
        
def whisper_collate_fn(batch):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 입력 패딩 (CPU에서 생성)
    max_len = max(f["input_features"].shape[1] for f in batch)
    batch_size = len(batch)
    input_tensor = torch.zeros((batch_size, 80, max_len), dtype=torch.float32)  # device 지정 X → CPU
    
    for i, item in enumerate(batch):
        feat = item["input_features"].cpu()  # CPU로 이동
        input_tensor[i, :, :feat.shape[1]] = feat
    
    # 라벨 패딩 (CPU에서 생성)
    max_label_len = max(item["labels"].size(0) for item in batch)
    label_tensor = torch.full((batch_size, max_label_len), -100, dtype=torch.long)
    
    for i, item in enumerate(batch):
        l = item["labels"].cpu()  # CPU로 이동
        label_tensor[i, :l.size(0)] = l
    
    return {
        "input_features": input_tensor.to('cuda'),  # CPU 텐서
        "labels": label_tensor.to('cuda')
    }


class WhisperMetrics:
    def __init__(self, processor: WhisperProcessor):
        self.processor = processor

    def compute_metrics(self, eval_pred):
        predictions, labels = eval_pred
        predictions = np.where(predictions != -100, predictions, self.processor.tokenizer.pad_token_id)
        decoded_preds = self.processor.tokenizer.batch_decode(predictions, skip_special_tokens=True)
        labels = np.where(labels != -100, labels, self.processor.tokenizer.pad_token_id)
        decoded_labels = self.processor.tokenizer.batch_decode(labels, skip_special_tokens=True)
        wer_score = wer(decoded_labels, decoded_preds)
        cer_score = cer(decoded_labels, decoded_preds)
        return {
            "wer": wer_score,
            "cer": cer_score
        }

class CustomTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.log_history = []

    def log(self, logs: Dict[str, float]) -> None:
        super().log(logs)
        self.log_history.append({
            "step": self.state.global_step,
            "epoch": self.state.epoch,
            "timestamp": datetime.now().isoformat(),
            **logs
        })
        if "loss" in logs:
            logger.info(f"Step {self.state.global_step}: Loss = {logs['loss']:.4f}")
        if "eval_loss" in logs:
            logger.info(f"Validation - Loss: {logs['eval_loss']:.4f}, "
                        f"WER: {logs.get('eval_wer', 0):.4f}, "
                        f"CER: {logs.get('eval_cer', 0):.4f}")

    def save_log_history(self, output_dir: str):
        log_path = Path(output_dir) / "training_logs.json"
        with open(log_path, 'w', encoding='utf-8') as f:
            json.dump(self.log_history, f, ensure_ascii=False, indent=2)
        logger.info(f"Training logs saved to {log_path}")

def train_whisper_korean():
    audio_dir = "train"
    csv_files = ["filtered_data_A.csv", "filtered_data_B.csv"]
    output_dir = "preprocessed_whisper"
    feature_dir = "preprocessed_whisper/features"
    model_output_dir = "whisper-small-korean-finetuned"
    logger.info("Starting Whisper Korean fine-tuning...")

    preprocessor = KoreanWhisperDataPreprocessor(audio_dir, csv_files, output_dir, feature_dir, chunk_size=500)
    train_dataset, val_dataset = preprocessor.create_dataset()

    model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
    model.config._attn_implementation = "eager"

    processor = WhisperProcessor.from_pretrained("openai/whisper-small")
    model.config.forced_decoder_ids = None
    model.config.suppress_tokens = []

    train_torch_dataset = WhisperDataset(train_dataset, processor)
    val_torch_dataset = WhisperDataset(val_dataset, processor)

    metrics = WhisperMetrics(processor)

    training_args = TrainingArguments(
        output_dir=model_output_dir,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=2,
        learning_rate=1e-5,
        warmup_steps=500,
        max_steps=5000,
        gradient_checkpointing=True,
        fp16=False,
        eval_strategy="steps",
        eval_steps=250,
        save_strategy="steps",
        save_steps=500,
        logging_steps=100,
        load_best_model_at_end=True,
        metric_for_best_model="eval_cer",
        greater_is_better=False,
        save_total_limit=3,
        remove_unused_columns=False, 
        dataloader_pin_memory=False, 
        dataloader_num_workers=0,
        report_to=None,
        run_name=f"whisper-korean-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    )

    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=train_torch_dataset,
        eval_dataset=val_torch_dataset,
        tokenizer=processor.feature_extractor,
        compute_metrics=metrics.compute_metrics,
        data_collator=whisper_collate_fn,  # 중요!
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )

    checkpoint_dir = None
    if os.path.exists(model_output_dir):
        checkpoints = [d for d in os.listdir(model_output_dir) if d.startswith("checkpoint-")]
        if checkpoints:
            latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("-")[1]))
            checkpoint_dir = os.path.join(model_output_dir, latest_checkpoint)
            logger.info(f"Resuming from checkpoint: {checkpoint_dir}")

    try:
        logger.info("Starting training...")
        trainer.train(resume_from_checkpoint=checkpoint_dir)
        trainer.save_model()
        processor.save_pretrained(model_output_dir)
        trainer.save_log_history(model_output_dir)
        logger.info("Running final evaluation...")
        eval_results = trainer.evaluate()
        logger.info("Training completed successfully!")
        logger.info(f"Final evaluation results: {eval_results}")
        config_for_upload = {
            "model_type": "whisper",
            "task": "automatic-speech-recognition",
            "language": "korean",
            "dataset_size": len(train_dataset) + len(val_dataset),
            "training_steps": training_args.max_steps,
            "final_cer": eval_results.get("eval_cer", 0),
            "final_wer": eval_results.get("eval_wer", 0)
        }
        with open(f"{model_output_dir}/training_info.json", 'w', encoding='utf-8') as f:
            json.dump(config_for_upload, f, ensure_ascii=False, indent=2)
        logger.info(f"Model ready for HuggingFace upload at: {model_output_dir}")
    except Exception as e:
        logger.error(f"Training failed: {e}")
        raise

if __name__ == "__main__":
    # 멀티프로세싱 시작 방식 설정 (CUDA 사용 시 필수)
    torch.multiprocessing.set_start_method('spawn', force=True)
    
    if torch.cuda.is_available():
        logger.info(f"Using GPU: {torch.cuda.get_device_name()}")
        logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    else:
        logger.warning("GPU not available, using CPU")
    train_whisper_korean()


2025-06-14 00:46:57,853 - INFO - Using GPU: NVIDIA GeForce RTX 3090
2025-06-14 00:46:57,854 - INFO - GPU Memory: 25.4 GB
2025-06-14 00:46:57,854 - INFO - Starting Whisper Korean fine-tuning...
2025-06-14 00:46:59,073 - INFO - Loading existing metadata files...
2025-06-14 00:47:01,055 - INFO - model, processor loaded... 
max_steps is given, it will override any value given in num_train_epochs
2025-06-14 00:47:01,364 - INFO - Starting training...


Step,Training Loss
100,2.1648
200,0.7581
300,0.5202
400,0.213
500,0.162
600,0.13
700,0.1096
800,0.097
900,0.0811
1000,0.0735


2025-06-14 00:55:58,095 - INFO - Step 100: Loss = 2.1648
2025-06-14 01:05:36,697 - INFO - Step 200: Loss = 0.7581
2025-06-14 01:15:28,264 - INFO - Step 300: Loss = 0.5202
2025-06-14 01:25:29,612 - INFO - Step 400: Loss = 0.2130
2025-06-14 01:35:18,793 - INFO - Step 500: Loss = 0.1620
2025-06-14 01:45:42,893 - INFO - Step 600: Loss = 0.1300
2025-06-14 01:56:33,585 - INFO - Step 700: Loss = 0.1096
2025-06-14 02:07:16,559 - INFO - Step 800: Loss = 0.0970
2025-06-14 02:17:44,410 - INFO - Step 900: Loss = 0.0811
2025-06-14 02:27:54,074 - INFO - Step 1000: Loss = 0.0735
2025-06-14 02:37:57,285 - INFO - Step 1100: Loss = 0.0677
2025-06-14 02:48:01,930 - INFO - Step 1200: Loss = 0.0545
2025-06-14 02:58:25,987 - INFO - Step 1300: Loss = 0.0309
2025-06-14 03:08:51,549 - INFO - Step 1400: Loss = 0.0323
2025-06-14 03:19:23,027 - INFO - Step 1500: Loss = 0.0270
2025-06-14 03:29:08,600 - INFO - Step 1600: Loss = 0.0294
2025-06-14 03:38:53,034 - INFO - Step 1700: Loss = 0.0273
2025-06-14 03:48:35,154

2025-06-14 05:00:50,322 - ERROR - Training failed: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.



RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
import torch
import json
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import Dataset
from jiwer import wer, cer
import logging
from tqdm import tqdm
import gc
from google.colab import files
import time
import os

# 로깅 설정
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

class WhisperDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, processor, max_length=448):
        self.dataset = dataset
        self.processor = processor
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        features = np.load(item["feature_path"])
        assert features.shape[0] == 80, f"Invalid feature shape: {features.shape}"
        features = torch.tensor(features, dtype=torch.float)
        labels = self.processor.tokenizer(
            item["text"],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.max_length
        ).input_ids.squeeze()
        labels[labels == self.processor.tokenizer.pad_token_id] = -100
        return {
            "input_features": features,
            "labels": labels,
            "text": item["text"],
            "file_name": item["file_name"]
        }

def evaluate_single_model(model_name, val_metadata_path, features_dir, num_samples=5, batch_size=4):
    # GPU 사용 가능 여부 확인
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Using device: {device}")

    # GPU 메모리 확인
    if device == "cuda":
        logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
        logger.info(f"Available GPU Memory: {torch.cuda.memory_allocated() / 1e9:.1f} GB")

    # Validation 데이터 로드
    logger.info("Loading validation metadata...")
    with open(val_metadata_path, 'r', encoding='utf-8') as f:
        val_metadata = json.load(f)

    # feature_path 수정
    for item in val_metadata:
        # 파일 이름만 추출
        file_name = item["feature_path"].split("/")[-1]
        # 새로운 경로로 수정
        item["feature_path"] = f"/content/drive/MyDrive/preprocessed_whisper/features/{file_name}"

    val_dataset = Dataset.from_list(val_metadata)
    logger.info(f"Loaded {len(val_dataset)} validation samples")

    start_time = time.time()

    # 모델과 프로세서 로드
    logger.info(f"Loading model and processor: {model_name}")
    processor = WhisperProcessor.from_pretrained(model_name)
    model = WhisperForConditionalGeneration.from_pretrained(model_name)

    model = model.to(device)
    model.eval()
    logger.info("Model loaded successfully")

    # 데이터셋 준비
    val_torch_dataset = WhisperDataset(val_dataset, processor)

    # 평가 결과 저장
    all_predictions = []
    all_references = []
    sample_results = []

    # 배치 단위로 평가 진행
    logger.info("Starting evaluation...")
    with torch.no_grad():
        pbar = tqdm(range(0, len(val_torch_dataset), batch_size),
                   desc=f"Evaluating {model_name}",
                   ncols=100)

        for idx in pbar:
            batch_items = [val_torch_dataset[i] for i in range(idx, min(idx + batch_size, len(val_torch_dataset)))]

            # 배치 데이터 준비
            input_features = torch.stack([item["input_features"] for item in batch_items]).to(device)

            # 예측
            predicted_ids = model.generate(input_features)
            transcriptions = processor.batch_decode(predicted_ids, skip_special_tokens=True)

            # 결과 저장
            for i, (transcription, item) in enumerate(zip(transcriptions, batch_items)):
                all_predictions.append(transcription)
                all_references.append(item["text"])

                if idx + i < num_samples:
                    sample_results.append({
                        "file_name": item["file_name"],
                        "reference": item["text"],
                        "prediction": transcription
                    })

            # 진행 상황 업데이트
            pbar.set_postfix({
                'processed': f"{min(idx + batch_size, len(val_torch_dataset))}/{len(val_torch_dataset)}",
                'memory': f"{torch.cuda.memory_allocated() / 1e9:.1f}GB"
            })

            # 메모리 정리
            del input_features, predicted_ids
            torch.cuda.empty_cache()
            gc.collect()

    # 메트릭 계산
    logger.info("Calculating metrics...")
    wer_score = wer(all_references, all_predictions)
    cer_score = cer(all_references, all_predictions)

    results = {
        "model_name": model_name,
        "wer": wer_score,
        "cer": cer_score,
        "samples": sample_results,
        "duration": time.time() - start_time
    }

    logger.info(f"\nEvaluation completed in {results['duration']:.2f} seconds")
    logger.info(f"WER: {wer_score:.4f}")
    logger.info(f"CER: {cer_score:.4f}")

    # 샘플 결과 출력
    logger.info("\nSample Results:")
    for i, sample in enumerate(sample_results, 1):
        logger.info(f"\nSample {i}:")
        logger.info(f"File: {sample['file_name']}")
        logger.info(f"Reference: {sample['reference']}")
        logger.info(f"Prediction: {sample['prediction']}")

    # 메모리 정리
    del model
    torch.cuda.empty_cache()
    gc.collect()

    return results

def compare_results(results_dir):
    """여러 모델의 평가 결과를 비교"""
    all_results = {}
    for result_file in os.listdir(results_dir):
        if result_file.endswith('_results.json'):
            with open(os.path.join(results_dir, result_file), 'r', encoding='utf-8') as f:
                results = json.load(f)
                model_name = results['model_name']
                all_results[model_name] = {
                    'wer': results['wer'],
                    'cer': results['cer'],
                    'duration': results['duration']
                }

    # 결과 비교 출력
    logger.info("\nModel Comparison Summary:")
    logger.info("="*70)
    logger.info(f"{'Model Name':<30} {'WER':<10} {'CER':<10} {'Duration (s)':<15}")
    logger.info("-"*70)
    for model_name, metrics in all_results.items():
        logger.info(f"{model_name:<30} {metrics['wer']:<10.4f} {metrics['cer']:<10.4f} {metrics['duration']:<15.2f}")
    logger.info("="*70)

def main():

    # GPU 메모리 최적화
    logger.info("Optimizing GPU memory...")
    torch.cuda.empty_cache()
    gc.collect()

    # 평가할 모델들
    model_names = [
        "urewui/ktf",
        "openai/whisper-small"
    ]


    # 결과 저장 디렉토리 생성
    results_dir = "evaluation_results"
    os.makedirs(results_dir, exist_ok=True)

    # 각 모델별로 개별 평가
    for model_name in model_names:
        logger.info(f"\n{'='*50}")
        logger.info(f"Starting evaluation for {model_name}")
        logger.info(f"{'='*50}")

        # 모델 평가
        results = evaluate_single_model(
            model_name=model_name,
            val_metadata_path="val_metadata.json",
            features_dir="preprocessed_whisper/features",
            num_samples=5,
            batch_size=2
        )

        # 결과 저장
        result_file = os.path.join(results_dir, f"{model_name.replace('/', '_')}_results.json")
        with open(result_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=2)

        # 결과 파일 다운로드
        files.download(result_file)

    # 모든 결과 비교
    compare_results(results_dir)

    logger.info("All evaluations completed!")

if __name__ == "__main__":
    main()

In [None]:
import os
import torch
import numpy as np
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import json

# val 데이터셋 메타데이터 경로
val_dataset_path = "preprocessed_whisper/val_metadata.json"

# 체크포인트 디렉토리 및 리스트
checkpoint_dir = "whisper-small-korean-finetuned"
checkpoints = ["checkpoint-500", "checkpoint-1000", "checkpoint-1500"]

# val 메타데이터 로드
with open(val_dataset_path, 'r', encoding='utf-8') as f:
    val_meta = json.load(f)

# WhisperProcessor 로드
processor = WhisperProcessor.from_pretrained(checkpoint_dir)

def transcribe_checkpoint(checkpoint_path, val_meta, processor, device="cuda" if torch.cuda.is_available() else "cpu", max_samples=None):
    model = WhisperForConditionalGeneration.from_pretrained(checkpoint_path).to(device)
    model.eval()

    transcriptions = []

    for i, item in enumerate(val_meta):
        if max_samples is not None and i >= max_samples:
            break
        feature_path = item["feature_path"]
        text = item["text"]
        file_name = item["file_name"]

        features = np.load(feature_path)  # (80, N)
        features = torch.tensor(features, dtype=torch.float).unsqueeze(0).to(device)  # [1, 80, N]

        with torch.no_grad():
            predicted_ids = model.generate(features)
        transcription = processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0]

        transcriptions.append({
            "file_name": file_name,
            "reference": text,
            "transcription": transcription
        })

        if (i + 1) % 10 == 0:
            print(f"Processed {i+1} samples...")

    return transcriptions

all_transcriptions = {}
for ckpt in checkpoints:
    ckpt_path = os.path.join(checkpoint_dir, ckpt)
    if os.path.exists(ckpt_path):
        print(f"\nTranscribing with {ckpt} ...")
        transcriptions = transcribe_checkpoint(ckpt_path, val_meta, processor)
        all_transcriptions[ckpt] = transcriptions
    else:
        print(f"Checkpoint {ckpt} not found.")
        all_transcriptions[ckpt] = None

# 결과 저장
with open("transcription_results.json", "w", encoding="utf-8") as f:
    json.dump(all_transcriptions, f, ensure_ascii=False, indent=2)

print("\n=== Transcription Results ===")
for ckpt, results in all_transcriptions.items():
    print(f"\n--- Checkpoint: {ckpt} ---")
    for i, item in enumerate(results[:5]):  # 예시로 5개만 출력
        print(f"File: {item['file_name']}")
        print(f"Reference: {item['reference']}")
        print(f"Transcription: {item['transcription']}\n")
