In [None]:
!pip install asteroid

In [None]:
!pip install torchcodec

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
print("구글 드라이브에서 코랩 로컬 디스크로 데이터 복사 중...")
!mkdir -p /content/realdata
!cp -r /content/drive/MyDrive/dprnn_project/train /content/realdata/
!cp -r /content/drive/MyDrive/dprnn_project/val /content/realdata/
print("데이터 복사 완료.")

In [None]:
# -----------------------------------------------------------------
# 지하철 안내방송 STT 최적화 학습 코드 (Complete Version)
# - 타깃: Clean (무잔향, 0%)
# - 입력: 강한 잔향 (100%) + 소음
# - 전처리: 100Hz High-pass Filter (저음 웅웅거림 제거)
# - 손실함수: Multi-Resolution STFT + SI-SNR + STOI
# -----------------------------------------------------------------
import os
import random
import numpy as np
import scipy.signal
import torch
import torchaudio
import torchaudio.transforms as T
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from asteroid.models import BaseModel
from torch_stoi import NegSTOILoss
import pyroomacoustics as pra
import traceback

# -----------------------------------------------------------------
# [PART 1] 지하철 특화 RIR 생성 함수들
# -----------------------------------------------------------------

def generate_subway_car_rir(sr=16000):
    """ 지하철 객차 RIR """
    try:
        length = float(np.random.uniform(15, 20))
        width = float(np.random.uniform(2.8, 3.2))
        height = float(np.random.uniform(2.3, 2.5))
        room_dim = [length, width, height]

        e_absorption = float(np.random.uniform(0.15, 0.25))

        room = pra.ShoeBox(
            room_dim, fs=int(sr), materials=pra.Material(e_absorption), max_order=3
        )

        source_pos = np.array([
            np.random.uniform(2, length - 2),
            np.random.uniform(0.5, width - 0.5),
            np.random.uniform(2.0, height - 0.1)
        ], dtype=float).reshape(3, 1)

        mic_pos = np.array([
            np.random.uniform(2, length - 2),
            np.random.uniform(0.5, width - 0.5),
            np.random.uniform(1.2, 1.7)
        ], dtype=float).reshape(3, 1)

        room.add_source(source_pos[:, 0].tolist())
        mic_array = pra.MicrophoneArray(mic_pos, room.fs)
        room.add_microphone_array(mic_array)

        room.compute_rir()
        rir = room.rir[0][0]
        return np.array(rir, dtype=float)

    except Exception as e:
        print(f"generate_subway_car_rir error: {e}")
        rir = np.zeros(256, dtype=float)
        rir[0] = 1.0
        return rir

def generate_subway_platform_rir(sr=16000):
    """ 지하철 플랫폼 RIR """
    try:
        length = float(np.random.uniform(100, 150))
        width = float(np.random.uniform(10, 15))
        height = float(np.random.uniform(3.5, 4.5))
        room_dim = [length, width, height]

        e_absorption = float(np.random.uniform(0.05, 0.12))

        room = pra.ShoeBox(
            room_dim, fs=int(sr), materials=pra.Material(e_absorption), max_order=5
        )

        source_pos = np.array([
            np.random.uniform(10, length - 10),
            np.random.uniform(2, width - 2),
            np.random.uniform(3.0, height - 0.1)
        ], dtype=float).reshape(3, 1)

        mic_pos = np.array([
            np.random.uniform(10, length - 10),
            np.random.uniform(2, width - 2),
            np.random.uniform(1.4, 1.8)
        ], dtype=float).reshape(3, 1)

        room.add_source(source_pos[:, 0].tolist())
        mic_array = pra.MicrophoneArray(mic_pos, room.fs)
        room.add_microphone_array(mic_array)

        room.compute_rir()
        rir = room.rir[0][0]
        return np.array(rir, dtype=float)

    except Exception as e:
        print(f"generate_subway_platform_rir error: {e}")
        rir = np.zeros(256, dtype=float)
        rir[0] = 1.0
        return rir

def generate_subway_corridor_rir(sr=16000):
    """ 지하철 통로 RIR """
    try:
        length = float(np.random.uniform(30, 50))
        width = float(np.random.uniform(3, 5))
        height = float(np.random.uniform(2.5, 3.0))
        room_dim = [length, width, height]

        e_absorption = float(np.random.uniform(0.10, 0.18))

        room = pra.ShoeBox(
            room_dim, fs=int(sr), materials=pra.Material(e_absorption), max_order=4
        )

        source_pos = np.array([
            np.random.uniform(5, length - 5),
            np.random.uniform(0.5, width - 0.5),
            np.random.uniform(2.0, height - 0.1)
        ], dtype=float).reshape(3, 1)

        mic_pos = np.array([
            np.random.uniform(5, length - 5),
            np.random.uniform(0.5, width - 0.5),
            np.random.uniform(1.4, 1.7)
        ], dtype=float).reshape(3, 1)

        room.add_source(source_pos[:, 0].tolist())
        mic_array = pra.MicrophoneArray(mic_pos, room.fs)
        room.add_microphone_array(mic_array)

        room.compute_rir()
        rir = room.rir[0][0]
        return np.array(rir, dtype=float)

    except Exception as e:
        print(f"generate_subway_corridor_rir error: {e}")
        rir = np.zeros(256, dtype=float)
        rir[0] = 1.0
        return rir

def generate_random_subway_rir(sr=16000):
    """ 다양한 지하철 환경 RIR 랜덤 선택 """
    rir_funcs = [generate_subway_car_rir, generate_subway_platform_rir, generate_subway_corridor_rir]
    selected_func = random.choice(rir_funcs)
    return selected_func(sr)

def apply_rir_split(waveform, rir, sr=16000, early_ms=50):
    """
    [핵심 수정 함수]
    RIR을 적용하되, 시간 지연(Delay) 문제를 해결하기 위해
    Input용(Full Reverb)과 Target용(Early Reverb)을 나누어 반환

    Args:
        waveform: 원본 음성 (Clean)
        rir: 적용할 RIR
        sr: 샘플링 레이트
        early_ms: 타깃에 남길 초기 반사음 길이 (ms). 50ms로 설정.
                  (이 값이 0이면 완전 Clean이지만, 약간의 공간감을 남겨야 학습이 더 잘됨)

    Returns:
        out_full: 잔향이 100% 적용된 오디오 (Input용)
        out_early: 초기 반사음만 적용되어 타이밍이 정렬된 오디오 (Target용)
    """
    try:
        # 1. 텐서/넘파이 변환
        if isinstance(waveform, torch.Tensor):
            wav_np = waveform.detach().cpu().squeeze().numpy().astype(np.float32)
        else:
            wav_np = np.array(waveform).squeeze().astype(np.float32)

        # 2. RIR 에너지 정규화
        rir_energy = np.sqrt(np.sum(rir**2) + 1e-8)
        rir = rir / rir_energy

        # 3. Early RIR (Target용) 생성 로직
        # RIR에서 가장 에너지가 큰 지점(Direct Sound)을 찾습니다.
        peak_idx = np.argmax(np.abs(rir))

        # 피크부터 early_ms(50ms) 만큼의 샘플 수 계산
        early_samples = int(sr * (early_ms / 1000.0))

        # Early RIR 생성 (꼬리 자르기)
        rir_early = np.zeros_like(rir)
        end_idx = min(len(rir), peak_idx + early_samples)
        rir_early[:end_idx] = rir[:end_idx] # 피크 포함 앞부분만 복사

        # 4. Convolution (Full vs Early)
        # mode='full'로 해야 시작점(t=0)이 왜곡되지 않습니다.
        mix_full = scipy.signal.fftconvolve(wav_np, rir, mode='full')
        mix_early = scipy.signal.fftconvolve(wav_np, rir_early, mode='full')

        # 5. 원본 길이로 자르기
        src_len = len(wav_np)
        mix_full = mix_full[:src_len]
        mix_early = mix_early[:src_len]

        # 6. Clipping (소리 깨짐 방지)
        mix_full = np.clip(mix_full, -1.0, 1.0)
        mix_early = np.clip(mix_early, -1.0, 1.0)

        # 7. 텐서 변환
        out_full = torch.from_numpy(mix_full).float().unsqueeze(0)
        out_early = torch.from_numpy(mix_early).float().unsqueeze(0)

        return out_full, out_early

    except Exception as e:
        print(f"apply_rir_split error: {e}")
        # 에러 발생 시 원본을 그대로 반환 (차원 맞춰서)
        dummy = waveform if isinstance(waveform, torch.Tensor) else torch.tensor(waveform).float()
        if dummy.dim() == 1: dummy = dummy.unsqueeze(0)
        return dummy, dummy

# -----------------------------------------------------------------
# [PART 2] Babble Noise 생성
# -----------------------------------------------------------------

def create_babble_noise(speech_files, target_len, sr=16000, num_speakers_range=(3, 7)):
    if not speech_files:
        return torch.zeros(1, target_len)

    num_speakers = random.randint(*num_speakers_range)
    if len(speech_files) < num_speakers:
        selected_files = random.choices(speech_files, k=num_speakers)
    else:
        selected_files = random.sample(speech_files, num_speakers)

    mixed_babble = torch.zeros(1, target_len)
    speaker_count = 0

    for file_path in selected_files:
        try:
            wav, file_sr = torchaudio.load(file_path)
            if file_sr != sr:
                resampler = T.Resample(file_sr, sr)
                wav = resampler(wav)
            if wav.dim() == 2 and wav.size(0) > 1:
                wav = wav.mean(dim=0, keepdim=True)

            if wav.size(1) > target_len:
                start = random.randint(0, wav.size(1) - target_len)
                chunk = wav[:, start : start + target_len]
            else:
                chunk = F.pad(wav, (0, max(0, target_len - wav.size(1))))

            speed_change = random.uniform(0.9, 1.1)
            if abs(speed_change - 1.0) > 1e-6:
                new_len = max(1, int(target_len / speed_change))
                chunk = F.interpolate(chunk.unsqueeze(0), size=new_len, mode='linear', align_corners=False).squeeze(0)
                if chunk.size(1) > target_len:
                    chunk = chunk[:, :target_len]
                else:
                    chunk = F.pad(chunk, (0, target_len - chunk.size(1)))

            chunk_rms = torch.sqrt(torch.mean(chunk**2) + 1e-8)
            chunk = chunk / (chunk_rms + 1e-8)

            gain = random.uniform(0.5, 1.0)
            mixed_babble += (chunk * gain)
            speaker_count += 1

        except Exception as e:
            continue

    if speaker_count > 0:
        mixed_babble = mixed_babble / speaker_count

    max_val = torch.max(torch.abs(mixed_babble))
    if max_val > 1e-6:
        mixed_babble = mixed_babble / (max_val + 1e-8) * 0.9

    return mixed_babble

# -----------------------------------------------------------------
# [PART 3] 손실 함수 (Multi-Resolution STFT + SI-SNR + STOI)
# -----------------------------------------------------------------

SAMPLE_RATE = 16000

# 1. Multi-Resolution STFT Loss Class
class MultiResolutionSTFTLoss(nn.Module):
    def __init__(self, fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 60], win_lengths=[600, 1200, 240]):
        super().__init__()
        self.fft_sizes = fft_sizes
        self.hop_sizes = hop_sizes
        self.win_lengths = win_lengths

    def stft(self, x, fft_size, hop_size, win_length):
        window = torch.hann_window(win_length).to(x.device)
        return torch.stft(x, n_fft=fft_size, hop_length=hop_size, win_length=win_length,
                          window=window, return_complex=True)

    def forward(self, x, y):
        # x: estimate, y: target
        loss = 0.0
        for fs, hs, wl in zip(self.fft_sizes, self.hop_sizes, self.win_lengths):
            x_stft = self.stft(x, fs, hs, wl)
            y_stft = self.stft(y, fs, hs, wl)

            x_mag = torch.abs(x_stft) + 1e-7
            y_mag = torch.abs(y_stft) + 1e-7

            # Spectral Convergence Loss
            sc_loss = torch.norm(y_mag - x_mag, p="fro") / (torch.norm(y_mag, p="fro") + 1e-7)
            # Log Magnitude Loss
            mag_loss = F.l1_loss(torch.log(y_mag), torch.log(x_mag))

            loss += sc_loss + mag_loss
        return loss / len(self.fft_sizes)

# 손실함수 인스턴스 (Main Loop에서 device로 보냄)
mr_stft_loss_func = MultiResolutionSTFTLoss()
stoi_loss_func = NegSTOILoss(sample_rate=SAMPLE_RATE)

def si_snr(estimate, target, epsilon=1e-8):
    if estimate.dim() == 1:
        estimate = estimate.unsqueeze(0)
    if target.dim() == 1:
        target = target.unsqueeze(0)

    estimate_z = estimate - torch.mean(estimate, dim=-1, keepdim=True)
    target_z = target - torch.mean(target, dim=-1, keepdim=True)

    dot = torch.sum(estimate_z * target_z, dim=-1, keepdim=True)
    target_norm_sq = torch.sum(target_z**2, dim=-1, keepdim=True)

    target_scaled = (dot / (target_norm_sq + epsilon)) * target_z
    noise = estimate_z - target_scaled

    snr_sq = torch.sum(target_scaled**2, dim=-1) / (torch.sum(noise**2, dim=-1) + epsilon)
    si_snr_score = 10 * torch.log10(snr_sq + epsilon)

    return si_snr_score.squeeze()

def combined_loss_stt_optimized(estimate, target, epsilon=1e-8):
    """
    STT 최적화 손실함수
    - STFT Loss: 스펙트럼(발음) 복원 (가장 중요)
    - SI-SNR: 파형 정렬
    - STOI: 명료도 향상
    """
    # 1. Spectral Loss (Magnitude)
    loss_freq = mr_stft_loss_func(estimate, target)

    # 2. SI-SNR Loss
    snr = si_snr(estimate, target, epsilon)
    loss_sisnr = -torch.mean(snr)

    # 3. STOI Loss
    loss_stoi = torch.mean(stoi_loss_func(estimate, target))

    # 4. Target이 무음인지 확인 (무음 구간은 L1 Loss로 억제)
    target_energy = torch.sum(target**2, dim=-1)
    is_silent_mask = (target_energy < 1e-6)

    # 가중치: STFT(0.5)로 질감을 잡고, STOI(0.3)로 명료도, SI-SNR(0.2)로 파형 보정
    loss_speech = (0.5 * loss_freq) + (0.2 * loss_sisnr) + (0.3 * loss_stoi)

    loss_l1 = F.l1_loss(estimate, target)

    final_loss = torch.where(
        is_silent_mask,
        loss_l1,
        loss_speech
    )

    return torch.mean(final_loss)

# -----------------------------------------------------------------
# [PART 4] Dataset (High-pass Filter + Clean Target)
# -----------------------------------------------------------------

class SubwaySTTDataset(Dataset):
    def __init__(
        self,
        announcement_dir,
        noise_dir,
        speech_dir,
        sample_rate=16000,
        chunk_seconds=10.0,
        snr_range=(-5.0, 5.0),
        epoch_len=2000,
        noise_only_ratio=0.1,
        target_rir_strength=0.0,
        input_rir_strength=1.0,
        is_validation=False
    ):
        self.announcement_files = self._get_files(announcement_dir)
        self.noise_files = self._get_files(noise_dir)
        self.speech_files = self._get_files(speech_dir) if speech_dir else []
        self.has_speech = len(self.speech_files) > 0

        self.sample_rate = int(sample_rate)
        self.chunk_seconds = float(chunk_seconds)
        self.chunk_samples = int(self.chunk_seconds * self.sample_rate)
        self.snr_range = snr_range
        self.epoch_len = int(epoch_len)
        self.noise_only_ratio = float(noise_only_ratio)
        self.target_rir_strength = float(target_rir_strength)
        self.input_rir_strength = float(input_rir_strength)
        self.is_validation = is_validation

        # High-pass Filter (100Hz 이하 제거 - 지하철 웅웅거림 삭제)
        # 로딩 속도를 위해 _load_wav_chunk에서 직접 biquad 적용

        assert len(self.noise_files) > 0, f"Noise 파일 없음: {noise_dir}"
        assert len(self.announcement_files) > 0, f"Announcement 파일 없음: {announcement_dir}"

    def _get_files(self, path):
        if path and os.path.exists(path):
            return [os.path.join(path, f) for f in os.listdir(path) if f.lower().endswith((".mp3", ".wav"))]
        return []

    def __len__(self):
        return self.epoch_len

    def _load_wav_chunk(self, path, length):
        try:
            wav, file_sr = torchaudio.load(path)

            if file_sr != self.sample_rate:
                wav = T.Resample(file_sr, self.sample_rate)(wav)

            if wav.dim() == 2 and wav.size(0) > 1:
                wav = wav.mean(dim=0, keepdim=True)

            # 로드 시점에 High-pass Filter 적용 (100Hz)
            # STT 인식률을 위해 불필요한 저음 노이즈 제거
            wav = torchaudio.functional.highpass_biquad(wav, self.sample_rate, cutoff_freq=100)

            if wav.size(1) > length:
                start = random.randint(0, wav.size(1) - length)
                wav = wav[:, start:start + length]
            elif wav.size(1) < length:
                wav = F.pad(wav, (0, length - wav.size(1)))

            return wav

        except Exception as e:
            print(f"_load_wav_chunk failed for {path}: {e}")
            return torch.zeros(1, length)

    def _scale_mix(self, clean, noise, snr_db):
        rms_c = torch.sqrt(torch.mean(clean ** 2) + 1e-8)
        rms_n = torch.sqrt(torch.mean(noise ** 2) + 1e-8)
        scale = (rms_c / (10 ** (snr_db / 20))) / (rms_n + 1e-8)
        return noise * scale

    def __getitem__(self, idx):
        if self.is_validation:
            seed = idx
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)

        is_announcement = random.random() > self.noise_only_ratio

        # 1. 안내방송 로드 (HPF 적용됨)
        if is_announcement:
            ann_path = random.choice(self.announcement_files)
            clean_ann = self._load_wav_chunk(ann_path, self.chunk_samples)

            # RIR 생성
            shared_rir = generate_random_subway_rir(self.sample_rate)

            # apply_rir_split 사용
            # Input: Full Reverb (잔향 100%)
            # Target: Early Reverb (초기반사음만 남김 -> 타이밍 정렬됨 + 깨끗함)
            input_ann, target_chunk = apply_rir_split(
                clean_ann, shared_rir, sr=self.sample_rate, early_ms=50
            )

            target_rms = torch.sqrt(torch.mean(target_chunk ** 2)).squeeze()
        else:
            target_chunk = torch.zeros(1, self.chunk_samples)
            input_ann = torch.zeros_like(target_chunk)
            target_rms = torch.tensor(0.0)

        # 2. 소음 준비 (HPF 적용됨)
        r_case = random.random()
        env_path = random.choice(self.noise_files)
        noise_env = self._load_wav_chunk(env_path, self.chunk_samples)

        noise_speech = torch.zeros_like(noise_env)
        if self.has_speech:
            noise_speech = create_babble_noise(self.speech_files, self.chunk_samples, self.sample_rate)

            # Babble Noise도 apply_rir_split으로 처리하되, Full(noise_input)만 사용
            noise_rir = generate_random_subway_rir(self.sample_rate)
            noise_speech, _ = apply_rir_split(noise_speech, noise_rir, sr=self.sample_rate)

        if r_case < 0.4:
            final_noise = noise_env
        elif r_case < 0.9:
            final_noise = noise_env + (noise_speech * random.uniform(0.5, 1.2))
        else:
            final_noise = noise_speech

        # 3. 믹싱
        if is_announcement:
            snr_db = random.uniform(*self.snr_range)
            noise_scaled = self._scale_mix(input_ann, final_noise, snr_db)
            mixture = input_ann + noise_scaled
        else:
            rms_noise = torch.sqrt(torch.mean(final_noise ** 2) + 1e-8)
            if rms_noise > 0.0:
                mixture = final_noise * (random.uniform(0.5, 1.0) / rms_noise)
            else:
                mixture = final_noise

        # 4. Normalization (Peak)
        max_val = torch.max(torch.abs(mixture))
        if max_val > 0.99:
            scale = 0.99 / max_val
            mixture = mixture * scale
            target_chunk = target_chunk * scale
            target_rms = target_rms * scale

        # 5. RMS Norm
        eps = 1e-8
        mixture_rms = torch.sqrt(torch.mean(mixture**2) + eps)
        target_rms_val = torch.sqrt(torch.mean(target_chunk**2) + eps)

        mixture = mixture / mixture_rms
        target_chunk = target_chunk / target_rms_val
        target_rms = target_rms / (target_rms_val + eps)

        return mixture.squeeze(0), target_chunk.squeeze(0), target_rms

# -----------------------------------------------------------------
# [PART 5] 학습 및 검증 Main Loop
# -----------------------------------------------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

BASE_DRIVE_PATH = "/content/realdata"
MODEL_SAVE_PATH = "/content/drive/MyDrive/dprnn_project/realcheckpoints_stt_opt"
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

BATCH_SIZE = 4
LEARNING_RATE = 5e-5
EPOCHS = 100
TRAIN_EPOCH_LEN = 2000
VAL_EPOCH_LEN = 200
LOAD_CHECKPOINT = True

def main_train():
    print(f"사용 장치: {device}")
    print("타깃: Clean (무잔향 0%)")
    print("입력: 강한 잔향 (100%) + 소음 + HPF Filter(100Hz)")
    print("손실함수: STFT(0.5) + SI-SNR(0.2) + STOI(0.3)")

    # Dataset 설정: target_rir_strength=0.0 (Clean)
    train_dataset = SubwaySTTDataset(
        f"{BASE_DRIVE_PATH}/train/announcement_clean",
        f"{BASE_DRIVE_PATH}/train/noise_environment",
        f"{BASE_DRIVE_PATH}/train/noise_speech",
        epoch_len=TRAIN_EPOCH_LEN,
        noise_only_ratio=0.1,
        target_rir_strength=0.0,  # 중요: Clean Target
        input_rir_strength=1.0
    )
    val_dataset = SubwaySTTDataset(
        f"{BASE_DRIVE_PATH}/val/announcement_clean",
        f"{BASE_DRIVE_PATH}/val/noise_environment",
        f"{BASE_DRIVE_PATH}/val/noise_speech",
        epoch_len=VAL_EPOCH_LEN,
        noise_only_ratio=0.1,
        target_rir_strength=0.0,  # 중요: Clean Target
        input_rir_strength=1.0,
        is_validation=True
    )

    import platform
    num_workers = 0 if "COLAB" in os.environ or platform.system() == "Windows" else 2

    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=num_workers, pin_memory=torch.cuda.is_available()
    )
    val_loader = DataLoader(
        val_dataset, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=num_workers, pin_memory=torch.cuda.is_available()
    )

    # Model
    print("-> Pre-trained Model 로드 중...")
    model = BaseModel.from_pretrained("JorisCos/DCCRNet_Libri1Mix_enhsingle_16k").to(device)

    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)

    # Loss functions to device
    global mr_stft_loss_func, stoi_loss_func
    mr_stft_loss_func = mr_stft_loss_func.to(device)
    stoi_loss_func = stoi_loss_func.to(device)

    best_val_loss = float('inf')
    start_epoch = 1

    if LOAD_CHECKPOINT:
        checkpoint_files = [f for f in os.listdir(MODEL_SAVE_PATH) if f.startswith("checkpoint_epoch_")]
        if checkpoint_files:
            try:
                latest_epoch = max([int(f.split('_')[-1].split('.')[0]) for f in checkpoint_files])
                latest_checkpoint_path = os.path.join(MODEL_SAVE_PATH, f"checkpoint_epoch_{latest_epoch}.pth")

                checkpoint = torch.load(latest_checkpoint_path, map_location=device)
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
                best_val_loss = checkpoint['best_val_loss']
                start_epoch = checkpoint['epoch'] + 1

                print(f"체크포인트 로드: Epoch {checkpoint['epoch']} → {start_epoch}부터 재개")
            except Exception as e:
                print(f"체크포인트 로드 실패: {e}")

    # Training Loop
    for epoch in range(start_epoch, EPOCHS + 1):
        # 1. Training
        model.train()
        train_loss_acc = 0.0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch} [Train]")
        for mixture, target, target_rms in pbar:
            mixture = mixture.to(device)
            target = target.to(device)

            optimizer.zero_grad()
            model_in = mixture.unsqueeze(1) if mixture.dim() == 2 else mixture
            estimated = model(model_in)

            if isinstance(estimated, dict):
                estimated = estimated.get("waveform", list(estimated.values())[0])
            elif isinstance(estimated, (list, tuple)):
                estimated = estimated[0]
            if estimated.dim() == 3 and estimated.size(1) == 1:
                estimated = estimated.squeeze(1)

            # 변경된 손실함수 사용
            loss = combined_loss_stt_optimized(estimated, target)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

            train_loss_acc += loss.item()
            pbar.set_postfix(loss=f"{loss.item():.4f}")

        # 2. Validation
        model.eval()
        val_loss_acc = 0.0
        sisnr_sum, sisnri_sum, stoi_sum = 0.0, 0.0, 0.0
        speech_sample_count = 0
        val_batches = 0

        with torch.no_grad():
            for mixture, target, target_rms in tqdm(val_loader, desc=f"Epoch {epoch} [Val]"):
                mixture = mixture.to(device)
                target = target.to(device)
                target_rms = target_rms.to(device)

                model_in = mixture.unsqueeze(1) if mixture.dim() == 2 else mixture
                estimated = model(model_in)

                if isinstance(estimated, dict):
                    estimated = estimated.get("waveform", list(estimated.values())[0])
                elif isinstance(estimated, (list, tuple)):
                    estimated = estimated[0]
                if estimated.dim() == 3 and estimated.size(1) == 1:
                    estimated = estimated.squeeze(1)

                # Validation Loss
                loss = combined_loss_stt_optimized(estimated, target)
                val_loss_acc += loss.item()

                # Metrics (Speech only)
                is_speech_mask = (target_rms > 1e-6)
                if is_speech_mask.any():
                    speech_est = estimated[is_speech_mask]
                    speech_target = target[is_speech_mask]
                    speech_mix = mixture[is_speech_mask]

                    # SI-SNR
                    sisnr_est = si_snr(speech_est, speech_target)
                    sisnr_init = si_snr(speech_mix, speech_target)

                    sisnr_sum += float(torch.sum(sisnr_est).item())
                    sisnri_sum += float(torch.sum(sisnr_est - sisnr_init).item())

                    # STOI
                    try:
                        stoi_vals = -stoi_loss_func(speech_est, speech_target)
                        if isinstance(stoi_vals, torch.Tensor):
                            stoi_sum += float(torch.sum(stoi_vals).item())
                        else:
                            stoi_sum += float(stoi_vals)
                    except:
                        pass

                    speech_sample_count += int(speech_target.size(0))
                val_batches += 1

        # Average computations
        avg_loss = val_loss_acc / (val_batches if val_batches > 0 else 1)
        if speech_sample_count > 0:
            avg_sisnr = sisnr_sum / speech_sample_count
            avg_sisnri = sisnri_sum / speech_sample_count
            avg_stoi = stoi_sum / speech_sample_count
        else:
            avg_sisnr = avg_sisnri = avg_stoi = 0.0

        print(f"\n[Epoch {epoch}] Results:")
        print(f"    Total Loss: {avg_loss:.4f}")
        print(f"    SI-SNR: {avg_sisnr:.2f} dB")
        print(f"    SI-SNRi: {avg_sisnri:.2f} dB")
        print(f"    STOI: {avg_stoi:.4f}")

        scheduler.step(avg_loss)
        current_lr = optimizer.param_groups[0]['lr']
        print(f"    Current LR: {current_lr:.2e}")

        # Save Best Model (Loss 기준)
        if avg_loss < best_val_loss:
            best_val_loss = avg_loss
            torch.save(model.state_dict(), os.path.join(MODEL_SAVE_PATH, "best_model_stt_final.pth"))
            print(f"    --> ★ Best Model Saved! (Loss: {avg_loss:.4f})")

        # Save Checkpoint
        checkpoint_path = os.path.join(MODEL_SAVE_PATH, f"checkpoint_epoch_{epoch}.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_loss': best_val_loss,
        }, checkpoint_path)
        print(f"    Checkpoint saved: {checkpoint_path}")

if __name__ == "__main__":
    main_train()