In [4]:
import os
import gc
import json
import shutil
import numpy as np
import pandas as pd
import torch
import librosa
import soundfile as sf
import logging
import evaluate
from datetime import datetime, timedelta
from dataclasses import dataclass
from typing import Dict, List, Union, Any

# Hugging Face 라이브러리
from datasets import Dataset, DatasetDict, Audio, Features, Array2D, Sequence, Value, concatenate_datasets
from transformers import (
    WhisperProcessor, 
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    BitsAndBytesConfig,
    TrainerCallback,
    EarlyStoppingCallback,
    get_linear_schedule_with_warmup
)
from peft import LoraConfig, get_peft_model, PeftModelForSeq2SeqLM

# 로깅 설정
logging.getLogger("transformers").setLevel(logging.INFO)

# 환경 설정
output_dir = "./whisper-korean-ft2"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, "cache"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "temp"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "test_samples"), exist_ok=True)

# 데이터셋 설정
VALID_CSV_FILE = 'filtered_data_val.csv'
TRAIN_AUDIO_DIR = 'train'
VALID_AUDIO_DIR = 'valid'
TRAIN_CSV_FILES = ['filtered_data_A.csv', 'filtered_data_B.csv']

# 1. 데이터 콜레이터
@dataclass
class WhisperDataCollator:
    processor: Any

    def __call__(self, features):
        # 배치 크기 확인
        batch_size = len(features)
        
        # 입력 특징 처리
        input_features = [feature["input_features"] for feature in features]
        input_features = torch.tensor(np.array(input_features), dtype=torch.float32)
        
        # 레이블 처리
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(
            label_features,
            return_tensors="pt"
        )
        labels = labels_batch["input_ids"]
        
        # 디코더 입력 ID 생성
        decoder_input_ids = []
        for label_ids in labels:
            decoder_input_ids.append([self.processor.tokenizer.bos_token_id] + label_ids.tolist())
        
        # 패딩 적용
        max_length = max(len(ids) for ids in decoder_input_ids)
        decoder_input_ids = [
            ids + [self.processor.tokenizer.pad_token_id] * (max_length - len(ids))
            for ids in decoder_input_ids
        ]
        decoder_input_ids = torch.tensor(decoder_input_ids)
        
        # 배치 구성
        batch = {
            "input_features": input_features,
            "labels": labels,
            "decoder_input_ids": decoder_input_ids
        }
        
        # 모든 텐서의 배치 크기 확인 및 조정
        for key, tensor in batch.items():
            if tensor.size(0) != batch_size:
                print(f"Warning: {key} batch size mismatch. Expected {batch_size}, got {tensor.size(0)}")
                if tensor.size(0) > batch_size:
                    batch[key] = tensor[:batch_size]
                else:
                    padding = torch.zeros((batch_size - tensor.size(0), *tensor.size()[1:]), 
                                       dtype=tensor.dtype, 
                                       device=tensor.device)
                    batch[key] = torch.cat([tensor, padding], dim=0)
        
        return batch

# 2. 메모리 모니터링 콜백
class MemoryMonitorCallback(TrainerCallback):
    def __init__(self, threshold_gb=20):
        self.threshold_gb = threshold_gb
    
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % 500 == 0:
            memory_allocated = torch.cuda.memory_allocated() / (1024**3)
            if memory_allocated > self.threshold_gb:
                print(f"경고: GPU 메모리 사용량 {memory_allocated:.2f}GB")
                gc.collect()
                torch.cuda.empty_cache()
                print("메모리 최적화 완료")
        return control

# 3. 시간 기반 체크포인트
class TimeCheckpoint(TrainerCallback):
    def __init__(self, interval=30):
        self.interval = timedelta(minutes=interval)
        self.last_save = datetime.now()
        self.model = None
        self.trainer = None
    
    def on_train_begin(self, args, state, control, model=None, trainer=None, **kwargs):
        self.model = model
        self.trainer = trainer
        
    def on_step_end(self, args, state, control, **kwargs):
        if datetime.now() - self.last_save >= self.interval:
            checkpoint_dir = os.path.join(
                args.output_dir, 
                f"checkpoint-time-{datetime.now().strftime('%Y%m%d-%H%M')}"
            )
            self.model.save_pretrained(checkpoint_dir)
            if state is not None:
                state.save_to_json(os.path.join(checkpoint_dir, "trainer_state.json"))
            self.last_save = datetime.now()
            print(f"시간 기반 체크포인트 저장: {checkpoint_dir}")
        return control

# 4. 에러 로거
class ErrorLogger:
    def __init__(self, log_path):
        self.log_path = log_path
        os.makedirs(os.path.dirname(log_path), exist_ok=True)

    def log(self, message):
        with open(self.log_path, "a") as f:
            f.write(f"{datetime.now().isoformat()} - {message}\n")

# 5. 오디오 처리
class AudioProcessor:
    def __init__(self, processor, max_seconds=30):
        self.processor = processor
        self.max_seconds = max_seconds
        
    def process_audio(self, audio):
        """외국인 발화 특성 반영 전처리"""
        try:
            if isinstance(audio, dict) and "array" in audio:
                array = audio["array"]
                sr = audio.get("sampling_rate", 16000)
            else:
                raise ValueError("잘못된 오디오 형식")
            
            array = self.adaptive_normalize(array)
            max_samples = int(self.max_seconds * sr * 1.2)
            if len(array) > max_samples:
                array = array[:max_samples]
            
            feature = self.processor(
                array, 
                sampling_rate=sr, 
                return_tensors="np", 
                truncation=False
            ).input_features[0]
            
            if feature.shape != (80, 3000):
                fixed = np.zeros((80, 3000), dtype=feature.dtype)
                h, w = feature.shape
                fixed[:h, :w] = feature[:min(h,80), :min(w,3000)]
                feature = fixed
            return feature.astype(np.float32)
            
        except Exception as e:
            print(f"오디오 처리 오류: {e}")
            raise

    def adaptive_normalize(self, waveform, target_level=-16.0, frame_length_ms=500, max_gain_db=30.0):
        if len(waveform) == 0:
            return waveform
            
        sample_rate = 16000
        frame_length = int(sample_rate * frame_length_ms / 1000)
        num_frames = max(1, len(waveform) // frame_length)
        normalized = np.zeros_like(waveform)
        
        for i in range(num_frames):
            start = i * frame_length
            end = min(start + frame_length, len(waveform))
            frame = waveform[start:end]
            rms = np.sqrt(np.mean(frame**2))
            
            if rms < 1e-8:
                normalized[start:end] = frame
                continue
                
            current_level = 20 * np.log10(rms) if rms > 0 else -100
            gain_db = target_level - current_level
            gain_db = min(max_gain_db, gain_db)
            gain_linear = 10 ** (gain_db / 20)
            normalized[start:end] = frame * gain_linear
        
        if np.max(np.abs(normalized)) > 0.99:
            normalized = normalized / np.max(np.abs(normalized)) * 0.99
            
        return normalized

    def split_audio(self, array, sr):
        max_samples = int(self.max_seconds * sr)
        return [array[i:i+max_samples] for i in range(0, len(array), max_samples)]

# 6. 데이터셋 로드 함수
def load_dataset_with_fallback():
    try:
        cache_path = os.path.join(output_dir, "processed_dataset")
        if os.path.exists(cache_path) and os.path.isfile(os.path.join(cache_path, "dataset_dict.json")):
            print("캐시된 데이터셋 사용")
            dataset_dict = DatasetDict.load_from_disk(cache_path)
            return dataset_dict, False
            
        print("CSV 파일에서 데이터셋 생성 중...")
        
        train_dfs = []
        for csv_file in TRAIN_CSV_FILES:
            print(f"훈련 CSV 파일 로드 중: {csv_file}")
            df = pd.read_csv(csv_file)
            train_dfs.append(df)
        
        train_df = pd.concat(train_dfs, ignore_index=True)
        print(f"총 {len(train_df)} 개의 훈련 샘플 로드됨")
        
        print(f"검증 CSV 파일 로드 중: {VALID_CSV_FILE}")
        valid_df = pd.read_csv(VALID_CSV_FILE)
        print(f"총 {len(valid_df)} 개의 검증 샘플 로드됨")
        
        train_df['audio'] = train_df['fileName'].apply(lambda fn: os.path.join(TRAIN_AUDIO_DIR, fn))
        valid_df['audio'] = valid_df['fileName'].apply(lambda fn: os.path.join(VALID_AUDIO_DIR, fn))
        
        train_df = train_df[['audio', 'ReadingLabelText']]
        valid_df = valid_df[['audio', 'ReadingLabelText']]
        
        train_df = train_df.rename(columns={'ReadingLabelText': 'transcripts'})
        valid_df = valid_df.rename(columns={'ReadingLabelText': 'transcripts'})
        
        train_df = train_df.dropna()
        valid_df = valid_df.dropna()
        print(f"결측치 제거 후 훈련 {len(train_df)}개, 검증 {len(valid_df)}개 샘플 남음")
        
        train_dataset = Dataset.from_pandas(train_df)
        valid_dataset = Dataset.from_pandas(valid_df)
        
        train_dataset = train_dataset.cast_column('audio', Audio(sampling_rate=16000))
        valid_dataset = valid_dataset.cast_column('audio', Audio(sampling_rate=16000))
        
        dataset_dict = DatasetDict({
            'train': train_dataset,
            'valid': valid_dataset
        })
        
        dataset_dict.save_to_disk(cache_path)
        print(f"데이터셋 처리 완료 및 저장됨: {cache_path}")
        
        return dataset_dict, False
    except Exception as e:
        print(f"데이터셋 로드 실패: {e}")
        raise

# 7. 모델 설정
def setup_model():
    model_id = "openai/whisper-small"
    processor = WhisperProcessor.from_pretrained(
        model_id,
        language="korean",
        task="transcribe"
    )
    
    model = WhisperForConditionalGeneration.from_pretrained(
        model_id,
        torch_dtype=torch.float32
    )
    
    # Whisper 모델의 입력 형식에 맞게 설정
    model.config.use_cache = False
    
    return model, processor

# 8. LoRA 설정
def get_lora_config():
    """LoRA 설정 (외국인 발화 최적화)"""
    return LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
        lora_dropout=0.1,
        bias="none",
        task_type="SEQ_2_SEQ_LM"
    )

# 9. 메트릭 계산
def compute_metrics(processor):
    cer_metric = evaluate.load("cer")
    wer_metric = evaluate.load("wer")
    
    def metrics_fn(pred):
        pred_str = processor.batch_decode(pred.predictions, skip_special_tokens=True)
        label_str = processor.batch_decode(pred.label_ids, skip_special_tokens=True)
        return {
            "cer": cer_metric.compute(predictions=pred_str, references=label_str),
            "wer": wer_metric.compute(predictions=pred_str, references=label_str)
        }
    return metrics_fn

# 10. 추론 함수
def transcribe_audio(model, processor, audio_path):
    try:
        audio, sr = librosa.load(audio_path, sr=16000)
        input_features = processor(
            audio, 
            sampling_rate=sr, 
            return_tensors="pt"
        ).input_features.to(model.device)
        
        gen_kwargs = {
            "max_new_tokens": 256,
            "language": "ko",
            "task": "transcribe",
            "num_beams": 5,
            "temperature": 0.0
        }
        
        with torch.no_grad():
            predicted_ids = model.generate(input_features, **gen_kwargs)
        
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
        return transcription
    except Exception as e:
        print(f"추론 오류: {e}")
        return ""

def transcribe_long_audio(model, processor, audio_path):
    audio, sr = librosa.load(audio_path, sr=16000)
    chunks = AudioProcessor(processor).split_audio(audio, sr)
    
    results = []
    for i, chunk in enumerate(chunks):
        chunk_path = os.path.join(output_dir, "temp", f"temp_{i}.wav")
        sf.write(chunk_path, chunk, sr)
        result = transcribe_audio(model, processor, chunk_path)
        results.append(result)
        os.remove(chunk_path)
    
    return " ".join(results)

# 11. 검증 함수
def validate_model(model_dir):
    """훈련된 모델 단독 검증"""
    try:
        print(f"모델 로드 중: {model_dir}")
        processor = WhisperProcessor.from_pretrained(model_dir)
        model = WhisperForConditionalGeneration.from_pretrained(
            model_dir,
            torch_dtype=torch.float32,
        )
        
        test_dir = os.path.join(output_dir, "test_samples")
        os.makedirs(test_dir, exist_ok=True)
        
        test_files = []
        for file in os.listdir(test_dir):
            if file.endswith('.wav'):
                test_files.append((os.path.join(test_dir, file), file))
        
        if not test_files:
            print(f"경고: 테스트 파일이 없습니다. {test_dir} 디렉토리에 오디오 파일을 추가하세요.")
            print("기본 테스트 파일을 생성합니다...")
            sample_audio = np.zeros(16000)
            sample_path = os.path.join(test_dir, "sample.wav")
            sf.write(sample_path, sample_audio, 16000)
            test_files = [(sample_path, "sample.wav")]
            print(f"기본 테스트 파일 생성됨: {sample_path}")
        
        print("\n===== 모델 테스트 =====")
        print(f"테스트 파일 수: {len(test_files)}")
        
        for file_path, file_name in test_files:
            print(f"\n[테스트] {file_name}")
            start_time = datetime.now()
            result = transcribe_audio(model, processor, file_path)
            elapsed = (datetime.now() - start_time).total_seconds()
            print(f"결과 ({elapsed:.2f}초):")
            print(f"  {result[:100]}..." if len(result) > 100 else result)
        
        print("\n검증 완료!")
        
    except Exception as e:
        print(f"검증 중 오류 발생: {str(e)}")
        import traceback
        traceback.print_exc()

# 12. 메인 훈련 함수
def train():
    """메인 훈련 함수"""
    error_logger = ErrorLogger(os.path.join(output_dir, "error_log.txt"))
    
    try:
        dataset, is_streaming = load_dataset_with_fallback()
        audio_field, text_field = "audio", "transcripts"
        
        model, processor = setup_model()
        
        # 모델을 학습 모드로 설정
        model.train()
        
        # LoRA 설정 및 적용
        lora_config = get_lora_config()
        
        # WhisperPEFTModel로 직접 초기화
        model = WhisperPEFTModel(model, lora_config)
        model.processor = processor
        
        # 학습 가능한 파라미터 설정
        for name, param in model.named_parameters():
            if any(x in name for x in ["lora", "adapter"]):
                param.requires_grad = True
            else:
                param.requires_grad = False
        
        # 학습 가능한 파라미터 출력
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        all_params = sum(p.numel() for p in model.parameters())
        print(f"학습 가능한 파라미터: {trainable_params:,} / 전체 파라미터: {all_params:,}")
        print(f"학습 가능한 파라미터 비율: {100 * trainable_params / all_params:.2f}%")
        
        # 학습 가능한 파라미터 이름 출력
        print("\n학습 가능한 파라미터 목록:")
        for name, param in model.named_parameters():
            if param.requires_grad:
                print(f"- {name}: {param.shape}")
        
        model.print_trainable_parameters()
        
        # 모델이 학습 모드인지 확인
        model.train()
        print("\n모델 학습 모드:", model.training)
        print("학습 가능한 파라미터 수:", sum(p.requires_grad for p in model.parameters()))
        
        audio_processor = AudioProcessor(processor)
        
        def process_batch(batch):
            features = []
            labels = []
            
            for audio, text in zip(batch[audio_field], batch[text_field]):
                try:
                    feature = audio_processor.process_audio(audio)
                    features.append(feature)
                    label = processor.tokenizer(text).input_ids
                    labels.append(label)
                except Exception as e:
                    print(f"배치 처리 오류: {e}")
                    raise
            
            # 배치 크기 확인 및 조정
            if len(features) != len(labels):
                print(f"Warning: Features length ({len(features)}) != Labels length ({len(labels)})")
                min_length = min(len(features), len(labels))
                features = features[:min_length]
                labels = labels[:min_length]
            
            # 입력 특징을 numpy 배열로 변환
            features = np.array(features)
            
            return {
                "input_features": features,
                "labels": labels
            }
        
        print("데이터셋 처리 시작...")
        if is_streaming:
            remove_columns = list(next(iter(dataset["train"])).keys())
            if "__index_level_0__" not in remove_columns:
                remove_columns.append("__index_level_0__")
            processed_dataset = dataset.map(
                process_batch,
                batched=True,
                batch_size=16,
                remove_columns=remove_columns
            )
        else:
            features = Features({
                'input_features': Array2D(shape=(80, 3000), dtype='float32'),
                'labels': Sequence(feature=Value(dtype='int64'))
            })
            
            remove_columns = [col for col in dataset["train"].column_names if col not in ("input_features", "labels")]
            
            # 데이터셋을 더 작은 청크로 나누어 처리
            chunk_size = 1000  # 한 번에 처리할 샘플 수
            train_chunks = []
            valid_chunks = []
            
            print("훈련 데이터셋 처리 중...")
            for i in range(0, len(dataset["train"]), chunk_size):
                chunk = dataset["train"].select(range(i, min(i + chunk_size, len(dataset["train"]))))
                processed_chunk = chunk.map(
                    process_batch,
                    batched=True,
                    batch_size=16,
                    num_proc=1,  # 병렬 처리 비활성화
                    features=features,
                    remove_columns=remove_columns,
                    desc=f"Processing train chunk {i//chunk_size + 1}"
                )
                train_chunks.append(processed_chunk)
            
            print("검증 데이터셋 처리 중...")
            for i in range(0, len(dataset["valid"]), chunk_size):
                chunk = dataset["valid"].select(range(i, min(i + chunk_size, len(dataset["valid"]))))
                processed_chunk = chunk.map(
                    process_batch,
                    batched=True,
                    batch_size=16,
                    num_proc=1,  # 병렬 처리 비활성화
                    features=features,
                    remove_columns=remove_columns,
                    desc=f"Processing valid chunk {i//chunk_size + 1}"
                )
                valid_chunks.append(processed_chunk)
            
            # 청크 합치기
            processed_dataset = DatasetDict({
                "train": concatenate_datasets(train_chunks),
                "valid": concatenate_datasets(valid_chunks)
            })
            
            # 캐시 저장
            processed_dataset.save_to_disk(os.path.join(output_dir, "processed_dataset"))
            print("처리된 데이터셋 저장 완료")
        
        training_args = Seq2SeqTrainingArguments(
            output_dir=output_dir,
            per_device_train_batch_size=16,
            per_device_eval_batch_size=16,
            gradient_accumulation_steps=2,
            learning_rate=5e-5,
            max_steps=4000,
            fp16=True,
            gradient_checkpointing=True,
            optim="adamw_torch",
            report_to=["tensorboard"],
            metric_for_best_model="eval_cer",
            greater_is_better=False,
            logging_strategy="steps",
            save_strategy="steps",
            ddp_find_unused_parameters=False,
            tf32=True,
            warmup_steps=200,
            weight_decay=0.01,
            save_total_limit=3,
            load_best_model_at_end=True,
            evaluation_strategy="steps",
            eval_steps=200,
            save_steps=200,
            dataloader_num_workers=2,  # 4에서 2로 감소
            dataloader_pin_memory=True,
            torch_compile=True,
        )
        
        # 콜백 설정
        callbacks = [
            TimeCheckpoint(),
            MemoryMonitorCallback(threshold_gb=20),
            EarlyStoppingCallback(early_stopping_patience=3)
        ]
        
        trainer = Seq2SeqTrainer(
            model=model,
            args=training_args,
            train_dataset=processed_dataset["train"],
            eval_dataset=processed_dataset["valid"],
            compute_metrics=compute_metrics(processor),
            data_collator=WhisperDataCollator(processor),
            callbacks=callbacks
        )
        
        # 학습률 스케줄러 설정
        num_training_steps = training_args.max_steps
        num_warmup_steps = training_args.warmup_steps
        
        # optimizer가 None이 아닌지 확인
        if trainer.optimizer is None:
            print("Warning: Optimizer is None, initializing default optimizer...")
            trainer.create_optimizer()
        
        scheduler = get_linear_schedule_with_warmup(
            trainer.optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps
        )
        trainer.lr_scheduler = scheduler
        
        # 최종 확인
        print("\n최종 모델 상태:")
        print("학습 모드:", model.training)
        print("학습 가능한 파라미터 수:", sum(p.requires_grad for p in model.parameters()))
        print("Optimizer:", type(trainer.optimizer).__name__)
        print("Learning rate:", trainer.optimizer.param_groups[0]['lr'])
        
        if os.path.exists(output_dir):
            checkpoints = [d for d in os.listdir(output_dir) 
                          if d.startswith("checkpoint-") and os.path.isdir(os.path.join(output_dir, d))]
            if checkpoints:
                numeric_checkpoints = [d for d in checkpoints if d.replace("checkpoint-", "").isdigit()]
                if numeric_checkpoints:
                    latest_checkpoint = sorted(numeric_checkpoints, 
                                              key=lambda x: int(x.replace("checkpoint-", "")))[-1]
                    print(f"체크포인트에서 재개: {latest_checkpoint}")
                    trainer.train(resume_from_checkpoint=os.path.join(output_dir, latest_checkpoint))
                else:
                    print("자동 감지 가능한 체크포인트가 없어 처음부터 시작합니다.")
                    trainer.train()
            else:
                print("체크포인트가 없어 처음부터 시작합니다.")
                trainer.train()
        else:
            print("출력 디렉토리가 없어 처음부터 시작합니다.")
            trainer.train()
        
        final_dir = os.path.join(output_dir, "final_model")
        model = model.merge_and_unload()
        model.save_pretrained(final_dir, safe_serialization=True)
        processor.save_pretrained(final_dir)
        print(f"모델 저장 완료: {final_dir}")

        validate_model(final_dir)

    except Exception as e:
        error_logger.log(f"Critical Error: {str(e)}")
        import traceback
        traceback.print_exc()
        raise

# 메모리 최적화 및 훈련 시작
def optimize_memory():
    gc.collect()
    torch.cuda.empty_cache()
    print("메모리 최적화 완료")

class WhisperPEFTModel(PeftModelForSeq2SeqLM):
    def forward(
        self,
        input_features=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs
    ):
        # 배치 크기 확인
        batch_size = input_features.size(0)
        
        # 모든 입력의 배치 크기 확인 및 조정
        if labels is not None:
            if labels.size(0) != batch_size:
                print(f"Warning: Labels batch size mismatch. Expected {batch_size}, got {labels.size(0)}")
                if labels.size(0) > batch_size:
                    labels = labels[:batch_size]
                else:
                    padding = torch.full((batch_size - labels.size(0), labels.size(1)), 
                                      self.processor.tokenizer.pad_token_id,
                                      dtype=labels.dtype,
                                      device=labels.device)
                    labels = torch.cat([labels, padding], dim=0)
        
        if decoder_input_ids is not None:
            if decoder_input_ids.size(0) != batch_size:
                print(f"Warning: Decoder input IDs batch size mismatch. Expected {batch_size}, got {decoder_input_ids.size(0)}")
                if decoder_input_ids.size(0) > batch_size:
                    decoder_input_ids = decoder_input_ids[:batch_size]
                else:
                    padding = torch.full((batch_size - decoder_input_ids.size(0), decoder_input_ids.size(1)),
                                      self.processor.tokenizer.pad_token_id,
                                      dtype=decoder_input_ids.dtype,
                                      device=decoder_input_ids.device)
                    decoder_input_ids = torch.cat([decoder_input_ids, padding], dim=0)
        
        # PEFT 모델의 forward 메서드 호출
        with self._enable_peft_forward_hooks(**kwargs):
            kwargs = {k: v for k, v in kwargs.items() 
                     if k not in self.special_peft_forward_args 
                     and k not in ["forced_decoder_ids", "use_cache"]}
            
            outputs = self.base_model(
                input_features=input_features,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
                labels=labels,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs,
            )
            
            # 손실 계산 전 배치 크기 확인 및 조정
            if hasattr(outputs, 'logits') and labels is not None:
                logits = outputs.logits
                if logits.size(0) != labels.size(0):
                    print(f"Warning: Logits batch size ({logits.size(0)}) != Labels batch size ({labels.size(0)})")
                    min_size = min(logits.size(0), labels.size(0))
                    logits = logits[:min_size]
                    labels = labels[:min_size]
                    outputs.logits = logits
                    
                    # 손실 재계산
                    if hasattr(outputs, 'loss'):
                        loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.processor.tokenizer.pad_token_id)
                        outputs.loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
            
            return outputs

if __name__ == "__main__":
    optimize_memory()
    train() 

메모리 최적화 완료
CSV 파일에서 데이터셋 생성 중...
훈련 CSV 파일 로드 중: filtered_data_A.csv
훈련 CSV 파일 로드 중: filtered_data_B.csv
총 20623 개의 훈련 샘플 로드됨
검증 CSV 파일 로드 중: filtered_data_val.csv
총 576 개의 검증 샘플 로드됨
결측치 제거 후 훈련 20623개, 검증 576개 샘플 남음


Saving the dataset (16/16 shards): 100%|██████████████████| 20623/20623 [00:52<00:00, 391.17 examples/s]
Saving the dataset (1/1 shards): 100%|████████████████████████| 576/576 [00:00<00:00, 881.07 examples/s]


데이터셋 처리 완료 및 저장됨: ./whisper-korean-ft2/processed_dataset


loading configuration file preprocessor_config.json from cache at /root/.cache/huggingface/hub/models--openai--whisper-small/snapshots/973afd24965f72e36ca33b3055d56a652f456b4d/preprocessor_config.json
Feature extractor WhisperFeatureExtractor {
  "chunk_length": 30,
  "feature_extractor_type": "WhisperFeatureExtractor",
  "feature_size": 80,
  "hop_length": 160,
  "n_fft": 400,
  "n_samples": 480000,
  "nb_max_frames": 3000,
  "padding_side": "right",
  "padding_value": 0.0,
  "processor_class": "WhisperProcessor",
  "return_attention_mask": false,
  "sampling_rate": 16000
}

loading file vocab.json from cache at /root/.cache/huggingface/hub/models--openai--whisper-small/snapshots/973afd24965f72e36ca33b3055d56a652f456b4d/vocab.json
loading file tokenizer.json from cache at /root/.cache/huggingface/hub/models--openai--whisper-small/snapshots/973afd24965f72e36ca33b3055d56a652f456b4d/tokenizer.json
loading file merges.txt from cache at /root/.cache/huggingface/hub/models--openai--whisper-

학습 가능한 파라미터: 3,538,944 / 전체 파라미터: 245,273,856
학습 가능한 파라미터 비율: 1.44%

학습 가능한 파라미터 목록:
- base_model.model.model.encoder.layers.0.self_attn.k_proj.lora_A.default.weight: torch.Size([16, 768])
- base_model.model.model.encoder.layers.0.self_attn.k_proj.lora_B.default.weight: torch.Size([768, 16])
- base_model.model.model.encoder.layers.0.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 768])
- base_model.model.model.encoder.layers.0.self_attn.v_proj.lora_B.default.weight: torch.Size([768, 16])
- base_model.model.model.encoder.layers.0.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 768])
- base_model.model.model.encoder.layers.0.self_attn.q_proj.lora_B.default.weight: torch.Size([768, 16])
- base_model.model.model.encoder.layers.0.self_attn.out_proj.lora_A.default.weight: torch.Size([16, 768])
- base_model.model.model.encoder.layers.0.self_attn.out_proj.lora_B.default.weight: torch.Size([768, 16])
- base_model.model.model.encoder.layers.1.self_attn.k_proj.lora_A.default.weigh

Processing train chunk 1: 100%|██████████████████████████████| 1000/1000 [00:10<00:00, 92.74 examples/s]
Processing train chunk 2: 100%|██████████████████████████████| 1000/1000 [00:14<00:00, 70.58 examples/s]
Processing train chunk 3: 100%|██████████████████████████████| 1000/1000 [00:14<00:00, 68.71 examples/s]
Processing train chunk 4: 100%|█████████████████████████████| 1000/1000 [00:08<00:00, 121.65 examples/s]
Processing train chunk 5: 100%|█████████████████████████████| 1000/1000 [00:08<00:00, 119.28 examples/s]
Processing train chunk 6: 100%|█████████████████████████████| 1000/1000 [00:08<00:00, 124.31 examples/s]
Processing train chunk 7: 100%|█████████████████████████████| 1000/1000 [00:07<00:00, 126.65 examples/s]
Processing train chunk 8: 100%|█████████████████████████████| 1000/1000 [00:08<00:00, 123.58 examples/s]
Processing train chunk 9: 100%|█████████████████████████████| 1000/1000 [00:08<00:00, 111.43 examples/s]
Processing train chunk 10: 100%|███████████████████████

검증 데이터셋 처리 중...


Processing valid chunk 1: 100%|████████████████████████████████| 576/576 [00:06<00:00, 92.97 examples/s]
Saving the dataset (40/40 shards): 100%|██████████████████| 20623/20623 [00:30<00:00, 678.50 examples/s]
Saving the dataset (2/2 shards): 100%|████████████████████████| 576/576 [00:00<00:00, 862.44 examples/s]
PyTorch: setting up devices


처리된 데이터셋 저장 완료


max_steps is given, it will override any value given in num_train_epochs
Using auto half precision backend



최종 모델 상태:
학습 모드: True
학습 가능한 파라미터 수: 288
Optimizer: AdamW
Learning rate: 0.0
체크포인트가 없어 처음부터 시작합니다.


***** Running training *****
  Num examples = 20,623
  Num Epochs = 7
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 2
  Total optimization steps = 4,000
  Number of trainable parameters = 3,538,944
Traceback (most recent call last):
  File "/tmp/ipykernel_9822/3578381076.py", line 635, in train
    trainer.train()
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 2123, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 2481, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 3579, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 3633, in comput

TorchRuntimeError: Failed running call_function <function cross_entropy at 0x7f053cd6c790>(*(FakeTensor(..., device='cuda:0', size=(800, 51865), dtype=torch.float16), FakeTensor(..., device='cuda:0', size=(784,), dtype=torch.int64), None, None, -100, None, 'mean', 0.0), **{}):
Expected input batch_size (800) to match target batch_size (784).

from user code:
   File "/usr/local/lib/python3.8/dist-packages/transformers/models/whisper/modeling_whisper.py", line 1792, in torch_dynamo_resume_in_forward_at_1767
    loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
