In [1]:
import os
import glob
from typing import List, Optional, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import time
import random

from speechbrain.pretrained import EncoderClassifier
from torch.utils.data import Dataset, DataLoader
import torchaudio

device = "cuda" if torch.cuda.is_available() else "cpu"


  if ismodule(module) and hasattr(module, '__file__'):
  from speechbrain.pretrained import EncoderClassifier


In [2]:
#config
####################################
# 경로 및 공통 설정
####################################

# 실험 노트북(.ipynb)이 있는 위치를 기준으로 상대경로를 잡는다고 가정
PROJECT_ROOT = os.path.abspath(".")  # 필요하면 직접 바꿔도 됨

LIBRISPEECH_ROOT = os.path.join(PROJECT_ROOT, "data", "train", "dev-clean")
# 예: ./data/train/dev-clean/84/..., 174/..., SPEAKERS.TXT ...

print("LIBRISPEECH_ROOT =", LIBRISPEECH_ROOT)
assert os.path.isdir(LIBRISPEECH_ROOT), "dev-clean 경로를 확인하세요."

# 오디오 파라미터
SAMPLE_RATE = 16000  # LibriSpeech 기본 16kHz

# DataLoader 파라미터
BATCH_SIZE = 4
NUM_WORKERS = 0  # 윈도우면 0~2 정도로, 리눅스면 더 올려도 됨
PIN_MEMORY = True if torch.cuda.is_available() else False

# 특정 화자만 사용할 경우 지정 (None이면 전체 화자)
# 예: speaker_list = ["84", "174"]
speaker_list = None


LIBRISPEECH_ROOT = c:\project\vcshield\data\train\dev-clean


In [3]:
#dataset
class LibriSpeechSpeakerDataset(Dataset):
    """
    LibriSpeech dev-clean에서 화자 단위로 발화들을 로드한다.
    각 item은 dict:
        {
          "waveform": Tensor shape (1, T) float32 [-1,1],
          "speaker_id": str,
          "utt_path": str
        }
    """

    def __init__(
        self,
        root_dir: str,
        target_sr: int = 16000,
        speaker_list: Optional[List[str]] = None,
        min_duration_sec: float = 0.5,
    ):
        """
        root_dir: dev-clean 경로
        target_sr: 리샘플할 샘플레이트 (보통 16k)
        speaker_list: 사용할 화자 ID 리스트. None이면 root_dir 내 모든 화자 ID 사용.
        min_duration_sec: 너무 짧은 음성을 버리기 위한 최소 길이(초)
        """
        self.root_dir = root_dir
        self.target_sr = target_sr
        self.min_duration_sec = min_duration_sec

        # 1) 화자 폴더 수집 (폴더명이 전부 숫자인 것만 화자로 간주)
        all_speakers = [
            d for d in os.listdir(root_dir)
            if os.path.isdir(os.path.join(root_dir, d)) and d.isdigit()
        ]

        if speaker_list is None:
            self.speakers = sorted(all_speakers)
        else:
            # 교집합만 취함
            self.speakers = sorted([s for s in all_speakers if s in speaker_list])

        # 2) 각 화자에서 실제 오디오 파일(.flac/.wav) 경로 모으기
        #    LibriSpeech는 일반적으로 spk_id/chapter_id/*.flac 형태
        self.items = []
        for spk in self.speakers:
            spk_dir = os.path.join(root_dir, spk)
            # 챕터 디렉토리들
            for ch_name in os.listdir(spk_dir):
                ch_dir = os.path.join(spk_dir, ch_name)
                if not os.path.isdir(ch_dir):
                    continue

                # .wav / .flac 다 지원
                wav_paths = glob.glob(os.path.join(ch_dir, "*.wav"))
                flac_paths = glob.glob(os.path.join(ch_dir, "*.flac"))
                audio_paths = sorted(wav_paths + flac_paths)

                for ap in audio_paths:
                    # 길이 필터링을 위해 일단 메타만 등록하고,
                    # 실제 __getitem__에서 필요하면 필터하자.
                    self.items.append({
                        "speaker_id": spk,
                        "utt_path": ap
                    })

        print(f"[LibriSpeechSpeakerDataset] speakers: {len(self.speakers)}, utterances(raw): {len(self.items)}")

        # 사전 길이 필터를 적용해도 되지만, 여기서는 __getitem__에서 처리 실패 시 재시도 하기보단
        # 그냥 __len__/__getitem__이 일관되게 동작하도록 그대로 둔다.
        # (필요하면 나중에 pre-filter 로직 추가 가능)

    def __len__(self):
        return len(self.items)

    def _load_audio(self, path: str) -> torch.Tensor:
        """
        path에서 오디오 로딩하고 mono+resample까지 맞춰서 (1, T) float32 [-1,1] 반환
        """
        wav, sr = torchaudio.load(path)  # wav: (C, T), float32 -1~1 범위일 가능성 높음
        # 모노화
        if wav.shape[0] > 1:
            wav = torch.mean(wav, dim=0, keepdim=True)
        # 리샘플
        if sr != self.target_sr:
            wav = torchaudio.functional.resample(wav, sr, self.target_sr)
        return wav

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        meta = self.items[idx]
        spk = meta["speaker_id"]
        path = meta["utt_path"]

        wav = self._load_audio(path)  # (1, T)

        # 너무 짧으면(말 없는 부분 등) downstream에서 학습이 불안정할 수 있으므로 여기서 잘라낼 수 있다.
        # 여기서는 min_duration_sec 이상만 보장하도록 잘라주는 정도만 (필요하다면)
        min_len = int(self.min_duration_sec * self.target_sr)
        if wav.shape[1] < min_len:
            # 너무 짧다면 패딩 혹은 스킵 로직을 짤 수도 있다.
            # 간단하게는 zero-pad
            pad_len = min_len - wav.shape[1]
            wav = torch.cat([wav, torch.zeros((1, pad_len), dtype=wav.dtype)], dim=1)

        # MAX_SEC = 3.0
        # max_len = int(MAX_SEC * self.target_sr)
        # if wav.shape[1] > max_len:
        #     start = (wav.shape[1] - max_len) // 2
        #     wav = wav[:, start:start+max_len]

        return {
            "waveform": wav,       # (1, T_resampled)
            "speaker_id": spk,     # string
            "utt_path": path       # string
        }
    
dataset = LibriSpeechSpeakerDataset(
    root_dir=LIBRISPEECH_ROOT,
    target_sr=SAMPLE_RATE,
    speaker_list=speaker_list,   # None이면 전체 speaker
    min_duration_sec=0.5         # 너무 짧은 샘플은 최소 0.5초 길이로 패딩
)


[LibriSpeechSpeakerDataset] speakers: 40, utterances(raw): 2703


In [4]:
from speechbrain.pretrained import EncoderClassifier
from speechbrain.utils.fetching import LocalStrategy

class ECAPASpeakerEncoder(nn.Module):
    """
    ECAPA-TDNN speaker encoder wrapper.
    - freeze params (no finetune)
    - BUT allow gradients to flow w.r.t. the input waveform
      so L_resist can push G through vocoder.
    """
    def __init__(self, device="cuda"):
        super().__init__()

        self.classifier = EncoderClassifier.from_hparams(
            source="speechbrain/spkrec-ecapa-voxceleb",
            run_opts={"device": device},
            savedir="./pretrained_ecapa",
            local_strategy=LocalStrategy.COPY,  # <= Windows safe (copies instead of symlink) :contentReference[oaicite:1]{index=1}
        )

        # freeze ECAPA weights so we don't fine-tune it
        for p in self.classifier.parameters():
            p.requires_grad = False

        self.device = device

    def forward(self, wave_batch):
        """
        wave_batch: (B,1,T) float32 @16kHz
        returns: (B, emb_dim)
        NOTE: no torch.no_grad() here. We WANT grad to flow back
              from cosine loss into wave_batch (→ vocoder → G).
        """
        if wave_batch.dim() == 3 and wave_batch.shape[1] == 1:
            wave_in = wave_batch.squeeze(1)  # (B,T)
        else:
            wave_in = wave_batch            # assume already (B,T)

        # encode_batch returns (B,1,emb_dim)
        emb = self.classifier.encode_batch(wave_in)
        emb = emb.squeeze(1)               # (B, emb_dim)
        return emb


In [5]:
import torch.nn.functional as F

def match_spatial(src, ref):
    """
    src, ref: 4D 텐서 (B,C,H,W)
    리턴: src와 ref 중 작은 쪽 크기로 둘을 맞춰서 (src_adj, ref_adj)를 돌려준다.
    - 만약 H/W가 서로 다르면
      중앙을 기준으로 잘라서 동일한 (H',W')로 맞춤
    - src, ref 중 어느 쪽이 더 큰지/작은지는 각각 축마다 독립적으로 처리
    """
    B1,C1,Hs,Ws = src.shape
    B2,C2,Hr,Wr = ref.shape
    assert B1 == B2, "batch mismatch in match_spatial"

    # 최종 목표 크기
    Ht = min(Hs, Hr)
    Wt = min(Ws, Wr)

    def center_crop(t, Ht, Wt):
        _,_,H,W = t.shape
        start_h = (H - Ht)//2
        start_w = (W - Wt)//2
        return t[:,:,start_h:start_h+Ht, start_w:start_w+Wt]

    src_c = center_crop(src, Ht, Wt)
    ref_c = center_crop(ref, Ht, Wt)
    return src_c, ref_c

def collate_with_padding(batch_list):
    """
    batch_list는 __getitem__에서 나온 dict들의 리스트 (len = B)

    목표:
    - waveform들을 가장 긴 샘플 길이에 맞춰 zero-pad (right padding)
    - speaker_id / utt_path는 리스트 그대로 유지
    """
    wave_list = [item["waveform"] for item in batch_list]  # [(1, T_i), ...]
    spk_list = [item["speaker_id"] for item in batch_list]
    path_list = [item["utt_path"] for item in batch_list]

    # 가장 긴 길이 구하기
    max_len = max(w.shape[1] for w in wave_list)

    # pad 후 스택
    padded = []
    for w in wave_list:
        if w.shape[1] < max_len:
            pad_amount = max_len - w.shape[1]
            w = F.pad(w, (0, pad_amount), mode="constant", value=0.0)
        padded.append(w)  # (1, max_len)
    wave_tensor = torch.stack(padded, dim=0)  # (B, 1, max_len)

    return {
        "waveform": wave_tensor,
        "speaker_id": spk_list,
        "utt_path": path_list
    }

# collate_fn을 적용한 DataLoader
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    drop_last=False,
    collate_fn=collate_with_padding
)

batch = next(iter(loader))
print("=== Sanity Check w/ collate_fn ===")
print("waveform.shape:", batch["waveform"].shape)  # (B, 1, max_len_in_batch)
print("speaker_id example:", batch["speaker_id"])
print("utt_path[0]:", batch["utt_path"][0])


=== Sanity Check w/ collate_fn ===
waveform.shape: torch.Size([4, 1, 208480])
speaker_id example: ['2277', '5338', '251', '1673']
utt_path[0]: c:\project\vcshield\data\train\dev-clean\2277\149897\2277-149897-0004.flac


In [6]:
import torch
import torchaudio

def build_mel_extractor(
    sample_rate=16000,
    n_fft=1024,
    hop_length=256,
    win_length=1024,
    n_mels=80,
    f_min=0.0,
    f_max=8000.0,
    power=1.0,
    log_offset=1e-6
):
    """
    반환 함수 wav_to_mel_db:
      입력  : waveform (B,1,T) 또는 (1,T)
      출력  : mel_batch (B, n_mels, time)
    """

    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        f_min=f_min,
        f_max=f_max,
        n_mels=n_mels,
        power=power,
    )

    def wav_to_mel_db(waveform: torch.Tensor):
        """
        waveform: (B,1,T) or (1,T)
        returns: (B, n_mels, time)
        """
        single_input = False

        # case: (1,T) -> (1,1,T)
        if waveform.dim() == 2:
            waveform = waveform.unsqueeze(0)
            single_input = True

        # sanity check
        if waveform.dim() != 3 or waveform.shape[1] != 1:
            raise ValueError(f"Expected (B,1,T) or (1,T); got {tuple(waveform.shape)}")

        B = waveform.size(0)
        mel_list = []

        for b in range(B):
            # waveform[b]: (1, T)
            mel = mel_transform(waveform[b])  # (1, n_mels, time)
            # squeeze channel dim -> (n_mels, time)
            mel = mel.squeeze(0)

            mel_db = torch.log(mel + log_offset)  # (n_mels, time)
            mel_list.append(mel_db)

        # stack -> (B, n_mels, time)
        mel_batch = torch.stack(mel_list, dim=0)

        if single_input:
            mel_batch = mel_batch[0]  # (n_mels, time)

        return mel_batch

    return wav_to_mel_db


# 빌드
wav_to_mel_db = build_mel_extractor(
    sample_rate=SAMPLE_RATE,
    n_fft=1024,
    hop_length=256,
    win_length=1024,
    n_mels=80,
    f_min=0.0,
    f_max=8000.0,
    power=1.0,
    log_offset=1e-6
)

# 테스트
test_batch = next(iter(loader))
test_wave = test_batch["waveform"]      # (B,1,T)
test_mel  = wav_to_mel_db(test_wave)    # 기대: (B,80,time)

print("waveform:", test_wave.shape)
print("mel_db:", test_mel.shape)
print("speaker_id:", test_batch["speaker_id"])


waveform: torch.Size([4, 1, 113920])
mel_db: torch.Size([4, 80, 446])
speaker_id: ['6345', '422', '2277', '652']


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

class StochasticPurifier(nn.Module):
    def __init__(self,
                 time_blur_kernel_sizes=(3,5,7),
                 freq_blur_kernel_sizes=(3,5),
                 downsample_factors=(2,3)):
        super().__init__()
        self.time_blur_kernel_sizes = time_blur_kernel_sizes
        self.freq_blur_kernel_sizes = freq_blur_kernel_sizes
        self.downsample_factors = downsample_factors

    def _time_blur(self, mel):
        B, Fm, Tm = mel.shape
        k = random.choice(self.time_blur_kernel_sizes)
        pad = k // 2
        kernel = torch.ones(Fm, 1, k, device=mel.device, dtype=mel.dtype) / k
        mel_pad = F.pad(mel, (pad, pad), mode='reflect')
        out = F.conv1d(mel_pad, kernel, stride=1, padding=0, groups=Fm)
        return out

    def _time_down_up(self, mel):
        B, Fm, Tm = mel.shape
        factor = random.choice(self.downsample_factors)
        if Tm // factor < 2:
            return mel
        T_down = max(2, Tm // factor)
        mel_down = F.interpolate(
            mel.unsqueeze(1),  # (B,1,Fm,Tm)
            size=(Fm, T_down),
            mode='bilinear',
            align_corners=False
        ).squeeze(1)  # (B,Fm,T_down)
        mel_up = F.interpolate(
            mel_down.unsqueeze(1),
            size=(Fm, Tm),
            mode='bilinear',
            align_corners=False
        ).squeeze(1)  # (B,Fm,Tm)
        return mel_up

    def _freq_blur(self, mel):
        B, Fm, Tm = mel.shape
        k = random.choice(self.freq_blur_kernel_sizes)
        pad = k // 2

        mel_t = mel.transpose(1,2)      # (B,Tm,Fm)
        mel_bt = mel_t.reshape(B*Tm,1,Fm)
        kernel = torch.ones(1,1,k, device=mel.device, dtype=mel.dtype)/k
        mel_bt_pad = F.pad(mel_bt, (pad,pad), mode='reflect')
        mel_bt_blur = F.conv1d(mel_bt_pad, kernel, stride=1, padding=0)
        mel_blur_t = mel_bt_blur.reshape(B,Tm,Fm)
        mel_blur = mel_blur_t.transpose(1,2)  # (B,Fm,Tm)
        return mel_blur

    def forward(self, mel):
        ops = []
        if random.random() < 0.7:
            ops.append("time_blur")
        if random.random() < 0.5:
            ops.append("time_down_up")
        if random.random() < 0.5:
            ops.append("freq_blur")

        x = mel
        for o in ops:
            if o == "time_blur":
                x = self._time_blur(x)
            elif o == "time_down_up":
                x = self._time_down_up(x)
            elif o == "freq_blur":
                x = self._freq_blur(x)
        return x

In [8]:
# ===== 5. 손실 계산 시 mel 길이 mismatch 보정 유틸 =====
def match_mel_for_loss(mel_a, mel_b):
    """
    두 mel 텐서 (B,F,T)를 중앙 crop해서 같은 (F',T')로 만든다.
    content loss에서 mel_adv vs mel_clean 비교용.
    """
    A4 = mel_a.unsqueeze(1)  # (B,1,F,T)
    B4 = mel_b.unsqueeze(1)  # (B,1,F,T)
    A4c, B4c = match_spatial(A4, B4)
    A3c = A4c.squeeze(1)     # (B,F',T')
    B3c = B4c.squeeze(1)     # (B,F',T')
    return A3c, B3c

In [9]:
import sys, json, torch
import torch.nn as nn
import torchaudio

# HiFi-GAN import path
sys.path.append("./external/hifigan")
from models import Generator
from env import AttrDict

class HiFiGANVocoderWrapper(nn.Module):
    """
    HiFi-GAN vocoder wrapper (UNIVERSAL_V1).
    입력 : mel (B, 80, Tm)  [log-mel-ish]
    출력 : wave_16k (B,1,T16k)

    - gen 파라미터는 freeze (requires_grad=False)
    - BUT forward는 no_grad 안 씀 -> gradient는 mel까지 거슬러 올라간다.
    """
    def __init__(
        self,
        hifigan_config_path="./external/hifigan/UNIVERSAL_V1/config.json",
        hifigan_ckpt_path="./external/hifigan/UNIVERSAL_V1/g_02500000",
        sr_gen=22050,
        sr_target=16000,
        device="cuda"
    ):
        super().__init__()
        self.device = device
        self.sr_gen = sr_gen
        self.sr_target = sr_target

        # 1) load config
        with open(hifigan_config_path, "r") as f:
            h = AttrDict(json.load(f))

        # 2) init generator
        gen = Generator(h).to(device)
        gen.eval()

        # 3) load weights
        state = torch.load(hifigan_ckpt_path, map_location=device)
        if "generator" in state:
            gen.load_state_dict(state["generator"])
        else:
            gen.load_state_dict(state)

        # freeze params
        for p in gen.parameters():
            p.requires_grad = False

        self.gen = gen

        # dummy stats (placeholder – 나중에 HiFi-GAN 학습 시 mel normalization 맞추면 여기 조정)
        self.register_buffer("dummy_mean", torch.tensor(0.0))
        self.register_buffer("dummy_std",  torch.tensor(1.0))

    def preprocess_mel_for_hifigan(self, mel):
        # mel: (B,80,T)
        mel_norm = (mel - self.dummy_mean) / (self.dummy_std + 1e-8)
        return mel_norm

    def forward(self, mel_batch):
        """
        mel_batch: (B, 80, Tm)
        return   : (B, 1, T_16k)
        """
        mel_in = self.preprocess_mel_for_hifigan(mel_batch).to(self.device)
        # NO torch.no_grad(): keep graph
        wave_gen_22k = self.gen(mel_in)  # (B,1,T22k)

        # resample 22.05k ->16k
        if self.sr_gen != self.sr_target:
            wave_16k = torchaudio.functional.resample(
                wave_gen_22k.squeeze(1),  # (B,T22k)
                orig_freq=self.sr_gen,
                new_freq=self.sr_target
            ).unsqueeze(1)               # (B,1,T16k)
        else:
            wave_16k = wave_gen_22k

        wave_16k = torch.clamp(wave_16k, -1.0, 1.0)
        return wave_16k


In [10]:
class ConvBlock(nn.Module):
    """
    conv -> BN -> LeakyReLU
    """
    def __init__(self, in_ch, out_ch, k=3, s=1, p=1):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p)
        self.bn   = nn.BatchNorm2d(out_ch)
        self.act  = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


class GatedSkipFusion(nn.Module):
    """
    skip 특징맵(skip_feat)과 upsample된 특징(up_feat)을 결합할 때,
    skip_feat에 학습 가능한 게이트를 곱해 speaker-specific 정보를
    부분적으로 억제하거나 왜곡하도록 만든다.

    up_feat:   (B, C_up, H, W)
    skip_feat: (B, C_skip, H, W)  (match_spatial로 맞춘 뒤 들어온다고 가정)

    출력: concat([up_feat, gated_skip], dim=1)
    """
    def __init__(self, c_skip, c_up):
        super().__init__()
        # 게이트를 만들 1x1 conv -> sigmoid
        self.gate_gen = nn.Sequential(
            nn.Conv2d(c_skip + c_up, c_skip, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, up_feat, skip_feat):
        # 크기가 약간 다를 수 있으므로 맞춰준다.
        up_m, skip_m = match_spatial(up_feat, skip_feat)  # (B,C_up,H',W'), (B,C_skip,H',W')

        # 게이트 계산: concat해서 gate 만들고 skip에 곱한다.
        gate_in = torch.cat([up_m, skip_m], dim=1)        # (B, C_up+C_skip, H', W')
        gate    = self.gate_gen(gate_in)                  # (B, C_skip, H', W')
        skip_g  = skip_m * gate                           # (B, C_skip, H', W')

        fused   = torch.cat([up_m, skip_g], dim=1)        # (B, C_up+C_skip, H', W')
        return fused


#########################################
# 개선된 U-Net 생성기
#########################################

class UNetLikeGeneratorV2(nn.Module):
    """
    개선된 생성기:
    - Down/Up path를 3단으로 늘려 receptive field 확장
    - Gated skip connection으로 speaker identity 누수 줄임
    - 출력 perturbation에 ε * tanh(...)를 적용해 magnitude 제한
    - match_spatial을 사용해 off-by-one 처리

    입력:  mel_batch (B, n_mels, T)
    출력:  adv_mel   (B, n_mels, T')
    """
    def __init__(self, base_ch=64, n_mels=80, epsilon=0.1):
        super().__init__()
        self.n_mels  = n_mels
        self.epsilon = epsilon  # perturbation 최대 크기를 제어하는 하이퍼파라미터

        # ========== Encoder Blocks ==========
        # enc1/down1
        self.enc1 = nn.Sequential(
            ConvBlock(1, base_ch),
            ConvBlock(base_ch, base_ch)
        )
        self.down1 = nn.Conv2d(
            base_ch, base_ch*2,
            kernel_size=4, stride=2, padding=1
        )  # /2

        # enc2/down2
        self.enc2 = nn.Sequential(
            ConvBlock(base_ch*2, base_ch*2),
            ConvBlock(base_ch*2, base_ch*2)
        )
        self.down2 = nn.Conv2d(
            base_ch*2, base_ch*4,
            kernel_size=4, stride=2, padding=1
        )  # /4

        # enc3/down3
        self.enc3 = nn.Sequential(
            ConvBlock(base_ch*4, base_ch*4),
            ConvBlock(base_ch*4, base_ch*4)
        )
        self.down3 = nn.Conv2d(
            base_ch*4, base_ch*8,
            kernel_size=4, stride=2, padding=1
        )  # /8

        # bottleneck
        self.bottleneck = nn.Sequential(
            ConvBlock(base_ch*8, base_ch*8),
            ConvBlock(base_ch*8, base_ch*8)
        )

        # ========== Decoder Blocks ==========
        # up3: from bottleneck -> scale x2
        self.up3 = nn.ConvTranspose2d(
            base_ch*8, base_ch*4,
            kernel_size=4, stride=2, padding=1
        )  # *4 (상대적으로 bottleneck에서 /8 -> /4)
        self.skip_fuse3 = GatedSkipFusion(c_skip=base_ch*4, c_up=base_ch*4)
        # 결과 채널 수는 base_ch*8
        self.dec3 = nn.Sequential(
            ConvBlock(base_ch*8, base_ch*4),
            ConvBlock(base_ch*4, base_ch*4)
        )

        # up2: from dec3 -> scale x2
        self.up2 = nn.ConvTranspose2d(
            base_ch*4, base_ch*2,
            kernel_size=4, stride=2, padding=1
        )  # *8 -> /2 scale 대비
        self.skip_fuse2 = GatedSkipFusion(c_skip=base_ch*2, c_up=base_ch*2)
        self.dec2 = nn.Sequential(
            ConvBlock(base_ch*4, base_ch*2),
            ConvBlock(base_ch*2, base_ch*2)
        )

        # up1: from dec2 -> scale x2
        self.up1 = nn.ConvTranspose2d(
            base_ch*2, base_ch,
            kernel_size=4, stride=2, padding=1
        )  # back to ~original res
        self.skip_fuse1 = GatedSkipFusion(c_skip=base_ch, c_up=base_ch)
        self.dec1 = nn.Sequential(
            ConvBlock(base_ch*2, base_ch),
            ConvBlock(base_ch, base_ch)
        )

        # 최종 1x1 conv로 perturbation(delta) 생성
        self.out_conv = nn.Conv2d(base_ch, 1, kernel_size=1)

    def forward(self, mel_batch):
        """
        mel_batch: (B, n_mels, T)
        return    : adv_mel (B, n_mels, T_adj)
        """
        B, Freq, Time = mel_batch.shape  # Freq~80, Time~variable

        x = mel_batch.unsqueeze(1)  # (B,1,Freq,Time)

        # ----- Encoder -----
        e1 = self.enc1(x)            # (B,base_ch,       F,   T)
        d1 = self.down1(e1)          # (B,base_ch*2,     F/2, T/2)

        e2 = self.enc2(d1)           # (B,base_ch*2,     F/2, T/2)
        d2 = self.down2(e2)          # (B,base_ch*4,     F/4, T/4)

        e3 = self.enc3(d2)           # (B,base_ch*4,     F/4, T/4)
        d3 = self.down3(e3)          # (B,base_ch*8,     F/8, T/8)

        bott = self.bottleneck(d3)   # (B,base_ch*8,     F/8, T/8)

        # ----- Decoder -----
        u3 = self.up3(bott)          # (B,base_ch*4,     F/4?,T/4?)
        # gated skip with e3
        fuse3 = self.skip_fuse3(u3, e3)  # (B,base_ch*8, F/4?,T/4?)
        dec3  = self.dec3(fuse3)         # (B,base_ch*4, F/4?,T/4?)

        u2 = self.up2(dec3)          # (B,base_ch*2,     F/2?,T/2?)
        fuse2 = self.skip_fuse2(u2, e2)  # (B,base_ch*4, F/2?,T/2?)
        dec2  = self.dec2(fuse2)         # (B,base_ch*2, F/2?,T/2?)

        u1 = self.up1(dec2)          # (B,base_ch,       F?,  T?)
        fuse1 = self.skip_fuse1(u1, e1)  # (B,base_ch*2, F?, T?)
        dec1  = self.dec1(fuse1)         # (B,base_ch,   F?, T?)

        delta = self.out_conv(dec1)      # (B,1,F_out,T_out)

        # perturbation magnitude 제한: ε * tanh(delta)
        delta_limited = self.epsilon * torch.tanh(delta)

        # 원본 mel과 residual add
        mel_as_img = mel_batch.unsqueeze(1)  # (B,1,Freq,Time)

        delta_adj, mel_adj = match_spatial(delta_limited, mel_as_img)
        adv_mel = mel_adj + delta_adj        # (B,1,F_aligned,T_aligned)
        adv_mel = adv_mel.squeeze(1)         # (B,F_aligned,T_aligned)

        return adv_mel

In [11]:
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

########################################################
# 1. 유틸 함수들 (이미 위에서 정의한 걸 그대로 사용)
########################################################

def match_wave_for_ecapa(wa, wb):
    """
    wa, wb: (B,1,T)
    길이가 다르면 가운데 기준으로 crop해서 동일 길이로 맞춘다.
    """
    Ta = wa.shape[-1]
    Tb = wb.shape[-1]
    if Ta == Tb:
        return wa, wb
    if Ta > Tb:
        start = (Ta - Tb) // 2
        wa = wa[..., start:start+Tb]
    else:
        start = (Tb - Ta) // 2
        wb = wb[..., start:start+Ta]
    return wa, wb

def match_spatial(src, ref):
    """
    src, ref: (B,C,H,W)
    두 텐서를 가운데 중심으로 크롭해서 같은 (H',W')로 만든다.
    """
    B1,C1,Hs,Ws = src.shape
    B2,C2,Hr,Wr = ref.shape
    assert B1 == B2, "batch mismatch in match_spatial"

    Ht = min(Hs, Hr)
    Wt = min(Ws, Wr)

    def center_crop(t, Ht, Wt):
        _,_,H,W = t.shape
        start_h = (H - Ht)//2
        start_w = (W - Wt)//2
        return t[:,:,start_h:start_h+Ht, start_w:start_w+Wt]

    src_c = center_crop(src, Ht, Wt)
    ref_c = center_crop(ref, Ht, Wt)
    return src_c, ref_c

def match_mel_for_loss(mel_a, mel_b):
    """
    mel_a, mel_b: (B,F,T)
    mel_adv vs mel_clean 비교용으로 중앙 crop해서 동일 크기로 맞춘다.
    """
    A4 = mel_a.unsqueeze(1)  # (B,1,F,T)
    B4 = mel_b.unsqueeze(1)  # (B,1,F,T)
    A4c, B4c = match_spatial(A4, B4)
    A3c = A4c.squeeze(1)     # (B,F',T')
    B3c = B4c.squeeze(1)     # (B,F',T')
    return A3c, B3c

########################################################
# 2. 하이퍼파라미터 및 모델 초기화
########################################################

lambda_c = 5.0   # content 보존 가중치
lambda_r = 1.0   # speaker 유사도 붕괴 가중치
lr = 1e-4
num_epochs = 40   # 예시

device = "cuda" if torch.cuda.is_available() else "cpu"

G = UNetLikeGeneratorV2(
        base_ch=64,
        n_mels=80,
        epsilon=0.1
    ).to(device)

D_purify = StochasticPurifier().to(device)

V_vocoder = HiFiGANVocoderWrapper(
    hifigan_config_path = "./external/hifigan/UNIVERSAL_V1/config.json",
    hifigan_ckpt_path   = "./external/hifigan/UNIVERSAL_V1/g_02500000",
    sr_gen              = 22050,
    sr_target           = 16000,
    device              = device
).to(device)

ECAPA = ECAPASpeakerEncoder(device=device).to(device)

optimizer = torch.optim.Adam(G.parameters(), lr=lr)

writer = SummaryWriter(log_dir=f"./runs/exp_speedup")
global_step = 0


########################################################
# 2. 유틸: 중앙 crop으로 wave 길이 맞추기
########################################################
def match_wave_for_ecapa(wa, wb):
    """
    wa, wb: (B,1,T)
    길이 다르면 중앙 crop해서 동일하게
    """
    Ta = wa.shape[-1]
    Tb = wb.shape[-1]
    if Ta == Tb:
        return wa, wb
    if Ta > Tb:
        start = (Ta - Tb) // 2
        wa = wa[..., start:start+Tb]
    else:
        start = (Tb - Ta) // 2
        wb = wb[..., start:start+Ta]
    return wa, wb


########################################################
# 3. 학습 루프
########################################################
for epoch in range(num_epochs):

    pbar = tqdm(loader,
                desc=f"[SpeedUp] Epoch {epoch+1}/{num_epochs}",
                dynamic_ncols=True)

    for batch in pbar:
        # -------------------------------------------------
        # 1. clean waveform (B,1,T_16k)  (이미 Dataset에서 3초 crop했다고 가정)
        wave_clean = batch["waveform"].to(device)

        # 2. mel 변환
        #    torchaudio 기반 mel 추출기는 CPU에서 돌리고 다시 device로 올리는 구조
        mel_clean_cpu = wav_to_mel_db(wave_clean.cpu())   # (B,80,Tmel_clean)
        mel_clean = mel_clean_cpu.to(device)

        # 3. 생성기 G: 보호용 adversarial mel 생성
        mel_adv = G(mel_clean)                # (B,80,Tmel_adv)

        # 4. purification 방어 시뮬 (time-blur / down-up / freq-blur 등)
        mel_purified = D_purify(mel_adv)      # (B,80,Tmel_pur)

        # ----------------------------
        # 5. Resist loss는 배치 전체가 아니라 "무작위 1개"만 사용
        # ----------------------------
        B = mel_purified.shape[0]
        pick_idx = torch.randint(low=0, high=B, size=(1,)).item()

        mel_purified_sub = mel_purified[pick_idx:pick_idx+1, :, :]  # (1,80,T')
        wave_clean_sub   = wave_clean[  pick_idx:pick_idx+1, :, :]  # (1,1,Tclean)

        # vocoder로 mel -> wave_adv (여긴 grad 유지: 이 1개 샘플만 비싸게 계산)
        wave_adv_sub = V_vocoder(mel_purified_sub)  # (1,1,Tadv_16k)

        # 길이 맞추기
        wave_adv_m, wave_clean_m = match_wave_for_ecapa(wave_adv_sub, wave_clean_sub)

        # ECAPA 임베딩
        emb_orig_sub     = ECAPA(wave_clean_m)   # (1,emb_dim)
        emb_purified_sub = ECAPA(wave_adv_m)     # (1,emb_dim)

        # 화자 유사도 (cosine) -> 낮추고 싶다 => 음수로
        cos_sim_sub = torch.sum(
            F.normalize(emb_orig_sub, p=2, dim=1) *
            F.normalize(emb_purified_sub, p=2, dim=1),
            dim=1
        )  # (1,)
        L_resist = torch.mean(cos_sim_sub)  # 유사하면 클수록 큼

        # ----------------------------
        # 6. Content loss는 여전히 full batch로 계산
        # ----------------------------
        mel_adv_crop, mel_clean_crop = match_mel_for_loss(mel_adv, mel_clean)
        L_content = F.l1_loss(mel_adv_crop, mel_clean_crop)

        # ----------------------------
        # 7. Total loss
        # ----------------------------
        L_total = lambda_c * L_content + lambda_r * L_resist

        # ----------------------------
        # 8. backward & step
        # ----------------------------
        optimizer.zero_grad()
        L_total.backward()
        optimizer.step()

        # ----------------------------
        # 9. 로깅 & tqdm 표시
        # ----------------------------
        with torch.no_grad():
            # 참고용 metric: 우리가 실제로 cos_sim_sub (단일 샘플)만 썼음
            mean_cos = cos_sim_sub.mean().item()

            writer.add_scalar("loss/total",     L_total.item(),    global_step)
            writer.add_scalar("loss/content",   L_content.item(),  global_step)
            writer.add_scalar("loss/resist",    L_resist.item(),   global_step)

            writer.add_scalar("metric/cos_sim_sub", mean_cos,      global_step)
            writer.add_scalar("metric/used_idx",     float(pick_idx), global_step)

            # mel 길이 차이 로깅 (참고용)
            writer.add_scalar("metric/delta_Tmel",
                              float(mel_clean.shape[-1] - mel_adv.shape[-1]),
                              global_step)

            # waveform 길이 차이도 한 샘플 기준
            writer.add_scalar("metric/wave_len_diff_sub",
                              float(wave_clean_m.shape[-1] - wave_adv_m.shape[-1]),
                              global_step)

        pbar.set_postfix({
            "L_tot":   f"{L_total.item():.3f}",
            "L_cont":  f"{L_content.item():.3f}",
            "L_res":   f"{L_resist.item():.3f}",
            "cos_sub": f"{mean_cos:.3f}",
            "idx":     int(pick_idx),
        })

        global_step += 1

writer.close()
print("속도 개선 루프 종료 (partial-batch resist + trimmed audio)")

  WeightNorm.apply(module, name, dim)
  wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
[SpeedUp] Epoch 1/40:  22%|██▏       | 150/676 [00:43<02:34,  3.41it/s, L_tot=0.213, L_cont=0.001, L_res=0.207, cos_sub=0.207, idx=2] 


KeyboardInterrupt: 