In [None]:
# -----------------------------------------------------------------------------
# 1. 라이브러리 설치 및 임포트
# -----------------------------------------------------------------------------
print("--- 1. 라이브러리 설치 및 임포트 중... ---")
!pip install -q asteroid pyunpack patool torch-stoi

import os
import random
import numpy as np
import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
from torch.utils.data import Dataset, DataLoader, Subset
from pathlib import Path
from tqdm import tqdm
from pyunpack import Archive

# Google Drive
from google.colab import drive

# Asteroid & Custom Loss
from asteroid.models import ConvTasNet
from asteroid.losses import PITLossWrapper
from torch_stoi import NegSTOILoss

print("라이브러리 준비 완료")

# -----------------------------------------------------------------------------
# 2. 경로 및 하이퍼파라미터 설정 (전역 설정)
# -----------------------------------------------------------------------------
# Drive 마운트
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# --- 경로 설정 ---
# 본인 환경에 맞게 수정이 필요
DRIVE_PATH = "" #경로 지정 필요
MODEL_PATH = "" #경로 지정 필요
ZIP_PATH = os.path.join(MODEL_PATH, "zips")
LOCAL_DATA_PATH = "/content/"
TRAIN_DATA_PATH = os.path.join(LOCAL_DATA_PATH, "train/")
VAL_DATA_PATH = os.path.join(LOCAL_DATA_PATH, "val/")

# 모델 저장 파일명
CHECKPOINT_PATH = os.path.join(MODEL_PATH, "convtasnet_latest_checkpoint.pth")
BEST_MODEL_PATH = os.path.join(MODEL_PATH, "convtasnet_best_model.pth")

# --- 하이퍼파라미터 ---
NUM_EPOCHS = 50
BATCH_SIZE = 2
SAMPLE_RATE = 16000
SEGMENT_DURATION_SEC = 3
TRAIN_STEPS_PER_EPOCH = 1000   # 훈련 시 1 Epoch당 생성할 배치 수
MAX_VAL_SAMPLES = 200          # 검증 시 사용할 고정 샘플 수
EPSILON = 1e-8
LOSS_LAMBDA_V3 = 0.8           # SI-SNR 가중치

# GPU 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"사용할 장치: {device}")

# 로컬 폴더 생성
os.makedirs(TRAIN_DATA_PATH, exist_ok=True)
os.makedirs(VAL_DATA_PATH, exist_ok=True)
os.makedirs(MODEL_PATH, exist_ok=True)

# -----------------------------------------------------------------------------
# 3. 데이터 압축 해제 (ZIP 처리)
# -----------------------------------------------------------------------------
def extract_zip_safe_py(zip_file):
    """한글/긴 파일명 안전 압축 해제 및 자동 분류"""
    zip_name = os.path.basename(zip_file)

    if zip_name == "train_안내방송.zip":
        dst_dir = os.path.join(TRAIN_DATA_PATH, "안내방송")
    else:
        zip_name_lower = zip_name.lower()
        if "train" in zip_name_lower:
            dst_dir = TRAIN_DATA_PATH
        elif "val" in zip_name_lower:
            dst_dir = VAL_DATA_PATH
        else:
            dst_dir = LOCAL_DATA_PATH

    os.makedirs(dst_dir, exist_ok=True)
    try:
        Archive(zip_file).extractall(dst_dir)
    except Exception as e:
        print(f"[Warning] 압축 해제 실패: {zip_name} / 에러: {e}")

if os.path.exists(ZIP_PATH):
    zip_files = [f for f in os.listdir(ZIP_PATH) if f.endswith(".zip")]
    if not zip_files:
        print("[Warning] ZIP 파일을 찾을 수 없습니다.")
    else:
        print(f"총 {len(zip_files)}개의 ZIP 파일 압축 해제 시작...")
        for zip_name in tqdm(zip_files, desc="압축 해제 중"):
            src_zip = os.path.join(ZIP_PATH, zip_name)
            extract_zip_safe_py(src_zip)
        print("모든 데이터 압축 해제 완료")
else:
    print(f"[Warning] ZIP 경로 없음: {ZIP_PATH} (데이터가 이미 풀려있다면 무시하세요)")


# -----------------------------------------------------------------------------
# 4. 오디오 처리 헬퍼 함수 및 Dataset 정의
# -----------------------------------------------------------------------------
def load_audio(path, sample_rate=16000):
    try:
        waveform, sr = torchaudio.load(path)
    except Exception as e:
        return torch.zeros((1, sample_rate))

    if sr != sample_rate:
        resampler = T.Resample(orig_freq=sr, new_freq=sample_rate)
        waveform = resampler(waveform)

    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    if waveform.shape[1] == 0:
        return torch.zeros((1, sample_rate))
    return waveform

def pad_or_trim(waveform, target_len):
    current_len = waveform.shape[1]
    if current_len > target_len:
        start = random.randint(0, current_len - target_len)
        return waveform[:, start : start + target_len]
    elif current_len < target_len:
        return torch.nn.functional.pad(waveform, (0, target_len - current_len))
    else:
        return waveform

def mix_audio(signal, noise, snr_db):
    signal_power = signal.norm(p=2)
    noise_power = noise.norm(p=2)
    if signal_power == 0 or noise_power == 0:
        return signal + noise
    snr = 10 ** (snr_db / 10)
    scale = (signal_power / (noise_power * snr)) ** 0.5
    return signal + (noise * scale)

class SoribomDataset(Dataset):
    def __init__(self, data_root_path, sample_rate=16000, segment_duration_sec=10,
                 mode='train', steps_per_epoch=None, batch_size=1):
        self.data_root = Path(data_root_path)
        self.sample_rate = sample_rate
        self.segment_len = sample_rate * segment_duration_sec
        self.mode = mode
        self.steps_per_epoch = steps_per_epoch
        self.batch_size = batch_size

        print(f"[{mode.upper()}] 데이터셋 스캔 중: {self.data_root}")
        self.announcements = self._find_files(self.data_root / '안내방송')
        self.dialogues = self._find_files(self.data_root / '대화소음')
        self.environments = self._find_files(self.data_root / '환경소음')
        self.rirs = self._find_files(self.data_root / '공간음향')

        print(f"  - 안내방송: {len(self.announcements)}개")
        if not self.announcements: print("[Warning] 안내방송 파일이 없습니다.")

    def _find_files(self, path):
        return list(path.rglob('*.wav')) + list(path.rglob('*.mp3')) + list(path.rglob('*.flac'))

    def __len__(self):
        if self.mode == 'train' and self.steps_per_epoch:
            return self.steps_per_epoch * self.batch_size
        else:
            return len(self.announcements)

    def __getitem__(self, idx):
        # 1. 소스 오디오 로드
        if self.mode == 'train':
            target_speech = load_audio(random.choice(self.announcements), self.sample_rate)
            try: dialogue_noise = load_audio(random.choice(self.dialogues), self.sample_rate)
            except: dialogue_noise = torch.zeros((1, self.segment_len))
            try: env_noise = load_audio(random.choice(self.environments), self.sample_rate)
            except: env_noise = torch.zeros((1, self.segment_len))
            try: rir = load_audio(random.choice(self.rirs), self.sample_rate)
            except: rir = torch.tensor([[1.0]])
        else: # Validation/Eval
            target_speech = load_audio(self.announcements[idx], self.sample_rate)
            try: dialogue_noise = load_audio(self.dialogues[idx % len(self.dialogues)], self.sample_rate)
            except: dialogue_noise = torch.zeros((1, self.segment_len))
            try: env_noise = load_audio(self.environments[idx % len(self.environments)], self.sample_rate)
            except: env_noise = torch.zeros((1, self.segment_len))
            try: rir = load_audio(self.rirs[idx % len(self.rirs)], self.sample_rate)
            except: rir = torch.tensor([[1.0]])

        # 2. 전처리 (길이 맞추기, Reverb)
        target_speech = pad_or_trim(target_speech, self.segment_len)
        dialogue_noise = pad_or_trim(dialogue_noise, self.segment_len)
        env_noise = pad_or_trim(env_noise, self.segment_len)

        rir = rir / torch.norm(rir, p=2)
        if rir.shape[1] > self.sample_rate: rir = rir[:, :self.sample_rate]

        target_reverb = torchaudio.functional.fftconvolve(target_speech, rir, mode='same')
        dialogue_reverb = torchaudio.functional.fftconvolve(dialogue_noise, rir, mode='same')
        env_reverb = torchaudio.functional.fftconvolve(env_noise, rir, mode='same')

        # 3. 믹싱
        target_clean = target_reverb
        target_noise = mix_audio(env_reverb, dialogue_reverb, random.uniform(-3, 3))
        mixture = mix_audio(target_clean, target_noise, random.uniform(-5, 10))

        # 4. 정규화
        max_val = torch.max(torch.abs(mixture))
        if max_val > 0: mixture = mixture / max_val

        # 5. 반환 (Mixture, [Clean, Noise])
        targets = torch.stack([target_clean.squeeze(0), target_noise.squeeze(0)], dim=0)
        return mixture.squeeze(0), targets

# -----------------------------------------------------------------------------
# 5. 손실 함수 (V3 - STOI + SI-SNR)
# -----------------------------------------------------------------------------
stoi_loss_module = NegSTOILoss(sample_rate=SAMPLE_RATE).to(device)

def si_snr_score(estimate, target, epsilon=EPSILON):
    dot = torch.sum(estimate * target, dim=-1, keepdim=True)
    target_norm_sq = torch.sum(target**2, dim=-1, keepdim=True)
    target_scaled = (dot / (target_norm_sq + epsilon)) * target
    noise = estimate - target_scaled
    snr_sq = torch.sum(target_scaled**2, dim=-1) / (torch.sum(noise**2, dim=-1) + epsilon)
    return 10 * torch.log10(snr_sq + epsilon).squeeze(-1)

def pairwise_combined_loss_v3(estimates, targets, epsilon=EPSILON):
    B, C_est, T = estimates.shape
    _, C_tgt, _ = targets.shape

    estimates_exp = estimates.unsqueeze(2).expand(B, C_est, C_tgt, T)
    targets_exp = targets.unsqueeze(1).expand(B, C_est, C_tgt, T)

    # SI-SNR Loss
    loss_sisnr = -si_snr_score(estimates_exp, targets_exp, epsilon)

    # STOI Loss
    est_flat = estimates_exp.reshape(-1, T)
    tgt_flat = targets_exp.reshape(-1, T)
    loss_stoi = stoi_loss_module(est_flat, tgt_flat).reshape(B, C_est, C_tgt)

    # L1 Loss (무음 구간용)
    loss_l1 = torch.mean(torch.abs(estimates_exp - targets_exp), dim=-1)

    # Masking
    target_energy = torch.sum(targets**2, dim=-1)
    is_silent = (target_energy < epsilon).unsqueeze(1).expand(B, C_est, C_tgt)

    loss_speech = (LOSS_LAMBDA_V3 * loss_sisnr) + ((1 - LOSS_LAMBDA_V3) * loss_stoi)

    return torch.where(is_silent, loss_l1, loss_speech)

# -----------------------------------------------------------------------------
# 6. 모델 및 훈련 함수
# -----------------------------------------------------------------------------
def build_model_and_optimizer():
    print("Pre-trained ConvTasNet 로드 중...")
    model = ConvTasNet.from_pretrained("JorisCos/ConvTasNet_Libri2Mix_sepnoisy_16k")
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    loss_func = PITLossWrapper(pairwise_combined_loss_v3, pit_from='pw_mtx')
    return model, optimizer, loss_func

def train_epoch(model, loader, optimizer, loss_func):
    model.train()
    total_loss = 0.0
    for mixture, targets in tqdm(loader, desc="[훈련]"):
        mixture, targets = mixture.to(device), targets.to(device)
        optimizer.zero_grad()
        loss = loss_func(model(mixture), targets)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def validate_epoch(model, loader, loss_func):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for mixture, targets in tqdm(loader, desc="[검증]"):
            mixture, targets = mixture.to(device), targets.to(device)
            loss = loss_func(model(mixture), targets)
            total_loss += loss.item()
    return total_loss / len(loader)

# -----------------------------------------------------------------------------
# 7. 메인 실행 루프 (데이터 로더 생성 -> 훈련 -> 저장)
# -----------------------------------------------------------------------------
print("\n=== 7. 메인 프로세스 시작 ===")

# A. 데이터셋 및 로더 준비
train_dataset = SoribomDataset(TRAIN_DATA_PATH, sample_rate=SAMPLE_RATE,
                               segment_duration_sec=SEGMENT_DURATION_SEC,
                               mode='train', steps_per_epoch=TRAIN_STEPS_PER_EPOCH, batch_size=BATCH_SIZE)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

val_dataset_full = SoribomDataset(VAL_DATA_PATH, sample_rate=SAMPLE_RATE,
                                  segment_duration_sec=SEGMENT_DURATION_SEC,
                                  mode='val', batch_size=BATCH_SIZE)

# 검증 샘플 고정
if MAX_VAL_SAMPLES and len(val_dataset_full) > MAX_VAL_SAMPLES:
    print(f"검증 샘플을 {MAX_VAL_SAMPLES}개로 제한합니다.")
    random.seed(42)
    val_indices = random.sample(range(len(val_dataset_full)), MAX_VAL_SAMPLES)
    val_dataset = Subset(val_dataset_full, val_indices)
else:
    val_dataset = val_dataset_full

val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

# B. 모델 준비
model, optimizer, loss_func = build_model_and_optimizer()

# C. 체크포인트 로드
start_epoch = 1
best_val_loss = float('inf')

if os.path.exists(CHECKPOINT_PATH):
    print(f"체크포인트 로드: {CHECKPOINT_PATH}")
    ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    optimizer.load_state_dict(ckpt['optimizer_state_dict'])
    start_epoch = ckpt['epoch'] + 1
    best_val_loss = ckpt.get('best_val_loss', float('inf'))
    print(f"-> Epoch {start_epoch}부터 시작 (Best Loss: {best_val_loss:.4f})")
else:
    print("새로운 훈련 시작")

# D. 훈련 루프
for epoch in range(start_epoch, NUM_EPOCHS + 1):
    print(f"\n--- Epoch {epoch} / {NUM_EPOCHS} ---")
    train_loss = train_epoch(model, train_loader, optimizer, loss_func)
    val_loss = validate_epoch(model, val_loader, loss_func)

    print(f"Epoch {epoch} 결과: Train={train_loss:.4f}, Val={val_loss:.4f}")

    # Best Model 저장
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        print(f"*** 최고 성능 갱신! 모델 저장: {BEST_MODEL_PATH}")

    # 체크포인트 저장
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_loss': best_val_loss,
    }, CHECKPOINT_PATH)
    print(f"체크포인트 저장됨.")

print("\n모든 훈련이 완료되었습니다.")

# -----------------------------------------------------------------------------
# 8. 최종 성능 평가 (SI-SNR, SI-SNRi, STOI)
# -----------------------------------------------------------------------------
print("\n=== 8. 최종 성능 평가 시작 (Best Model 사용) ===")

# 최고 성능 모델 로드
if os.path.exists(BEST_MODEL_PATH):
    model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=device))
    print(f"평가를 위해 최고 성능 모델 로드 완료: {BEST_MODEL_PATH}")
else:
    print("[Warning] 최고 모델 파일이 없습니다. 현재 모델 상태로 평가를 진행합니다.")

model.eval()
all_sisnr, all_sisnri, all_stoi = [], [], []

with torch.no_grad():
    for mixture, targets in tqdm(val_loader, desc="[최종 평가]"):
        mixture, targets = mixture.to(device), targets.to(device)
        target_clean = targets[:, 0, :]

        estimated_sources = model(mixture)
        est_1, est_2 = estimated_sources[:, 0, :], estimated_sources[:, 1, :]

        # Permutation Invariant 해결 (더 높은 SI-SNR을 가진 소스를 clean으로 간주)
        sisnr_1 = si_snr_score(est_1, target_clean)
        sisnr_2 = si_snr_score(est_2, target_clean)

        mask = (sisnr_1 > sisnr_2).unsqueeze(1)
        best_estimate = torch.where(mask.expand_as(est_1), est_1, est_2)

        # Metrics
        output_sisnr = torch.max(sisnr_1, sisnr_2)
        initial_sisnr = si_snr_score(mixture, target_clean)
        sisnri = output_sisnr - initial_sisnr
        stoi_val = -stoi_loss_module(best_estimate, target_clean)

        all_sisnr.extend(output_sisnr.cpu().numpy())
        all_sisnri.extend(sisnri.cpu().numpy())
        all_stoi.extend(stoi_val.cpu().numpy())

print("\n================ 최종 성능 평가 결과 ================")
print(f" 1. SI-SNR (최종 음질)  : {np.mean(all_sisnr):.4f} (dB)")
print(f" 2. SI-SNRi (음질 향상도): {np.mean(all_sisnri):.4f} (dB)")
print(f" 3. STOI (명료도)       : {np.mean(all_stoi):.4f}")
print("===================================================")