In [1]:
import os
import glob
from typing import List, Optional, Dict

import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio

In [2]:
#config
####################################
# 경로 및 공통 설정
####################################

# 실험 노트북(.ipynb)이 있는 위치를 기준으로 상대경로를 잡는다고 가정
PROJECT_ROOT = os.path.abspath(".")  # 필요하면 직접 바꿔도 됨

LIBRISPEECH_ROOT = os.path.join(PROJECT_ROOT, "data", "train", "dev-clean")
# 예: ./data/train/dev-clean/84/..., 174/..., SPEAKERS.TXT ...

print("LIBRISPEECH_ROOT =", LIBRISPEECH_ROOT)
assert os.path.isdir(LIBRISPEECH_ROOT), "dev-clean 경로를 확인하세요."

# 오디오 파라미터
SAMPLE_RATE = 16000  # LibriSpeech 기본 16kHz

# DataLoader 파라미터
BATCH_SIZE = 4
NUM_WORKERS = 0  # 윈도우면 0~2 정도로, 리눅스면 더 올려도 됨
PIN_MEMORY = True if torch.cuda.is_available() else False

# 특정 화자만 사용할 경우 지정 (None이면 전체 화자)
# 예: speaker_list = ["84", "174"]
speaker_list = None


LIBRISPEECH_ROOT = c:\project\vcshield\data\train\dev-clean


In [3]:
#dataset
class LibriSpeechSpeakerDataset(Dataset):
    """
    LibriSpeech dev-clean에서 화자 단위로 발화들을 로드한다.
    각 item은 dict:
        {
          "waveform": Tensor shape (1, T) float32 [-1,1],
          "speaker_id": str,
          "utt_path": str
        }
    """

    def __init__(
        self,
        root_dir: str,
        target_sr: int = 16000,
        speaker_list: Optional[List[str]] = None,
        min_duration_sec: float = 0.5,
    ):
        """
        root_dir: dev-clean 경로
        target_sr: 리샘플할 샘플레이트 (보통 16k)
        speaker_list: 사용할 화자 ID 리스트. None이면 root_dir 내 모든 화자 ID 사용.
        min_duration_sec: 너무 짧은 음성을 버리기 위한 최소 길이(초)
        """
        self.root_dir = root_dir
        self.target_sr = target_sr
        self.min_duration_sec = min_duration_sec

        # 1) 화자 폴더 수집 (폴더명이 전부 숫자인 것만 화자로 간주)
        all_speakers = [
            d for d in os.listdir(root_dir)
            if os.path.isdir(os.path.join(root_dir, d)) and d.isdigit()
        ]

        if speaker_list is None:
            self.speakers = sorted(all_speakers)
        else:
            # 교집합만 취함
            self.speakers = sorted([s for s in all_speakers if s in speaker_list])

        # 2) 각 화자에서 실제 오디오 파일(.flac/.wav) 경로 모으기
        #    LibriSpeech는 일반적으로 spk_id/chapter_id/*.flac 형태
        self.items = []
        for spk in self.speakers:
            spk_dir = os.path.join(root_dir, spk)
            # 챕터 디렉토리들
            for ch_name in os.listdir(spk_dir):
                ch_dir = os.path.join(spk_dir, ch_name)
                if not os.path.isdir(ch_dir):
                    continue

                # .wav / .flac 다 지원
                wav_paths = glob.glob(os.path.join(ch_dir, "*.wav"))
                flac_paths = glob.glob(os.path.join(ch_dir, "*.flac"))
                audio_paths = sorted(wav_paths + flac_paths)

                for ap in audio_paths:
                    # 길이 필터링을 위해 일단 메타만 등록하고,
                    # 실제 __getitem__에서 필요하면 필터하자.
                    self.items.append({
                        "speaker_id": spk,
                        "utt_path": ap
                    })

        print(f"[LibriSpeechSpeakerDataset] speakers: {len(self.speakers)}, utterances(raw): {len(self.items)}")

        # 사전 길이 필터를 적용해도 되지만, 여기서는 __getitem__에서 처리 실패 시 재시도 하기보단
        # 그냥 __len__/__getitem__이 일관되게 동작하도록 그대로 둔다.
        # (필요하면 나중에 pre-filter 로직 추가 가능)

    def __len__(self):
        return len(self.items)

    def _load_audio(self, path: str) -> torch.Tensor:
        """
        path에서 오디오 로딩하고 mono+resample까지 맞춰서 (1, T) float32 [-1,1] 반환
        """
        wav, sr = torchaudio.load(path)  # wav: (C, T), float32 -1~1 범위일 가능성 높음
        # 모노화
        if wav.shape[0] > 1:
            wav = torch.mean(wav, dim=0, keepdim=True)
        # 리샘플
        if sr != self.target_sr:
            wav = torchaudio.functional.resample(wav, sr, self.target_sr)
        return wav

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        meta = self.items[idx]
        spk = meta["speaker_id"]
        path = meta["utt_path"]

        wav = self._load_audio(path)  # (1, T)

        # 너무 짧으면(말 없는 부분 등) downstream에서 학습이 불안정할 수 있으므로 여기서 잘라낼 수 있다.
        # 여기서는 min_duration_sec 이상만 보장하도록 잘라주는 정도만 (필요하다면)
        min_len = int(self.min_duration_sec * self.target_sr)
        if wav.shape[1] < min_len:
            # 너무 짧다면 패딩 혹은 스킵 로직을 짤 수도 있다.
            # 간단하게는 zero-pad
            pad_len = min_len - wav.shape[1]
            wav = torch.cat([wav, torch.zeros((1, pad_len), dtype=wav.dtype)], dim=1)

        return {
            "waveform": wav,       # (1, T_resampled)
            "speaker_id": spk,     # string
            "utt_path": path       # string
        }


In [4]:
#데이터 체크
# Dataset 생성
dataset = LibriSpeechSpeakerDataset(
    root_dir=LIBRISPEECH_ROOT,
    target_sr=SAMPLE_RATE,
    speaker_list=speaker_list,   # None이면 전체 speaker
    min_duration_sec=0.5         # 너무 짧은 샘플은 최소 0.5초 길이로 패딩
)

# # DataLoader 생성
# loader = DataLoader(
#     dataset,
#     batch_size=BATCH_SIZE,
#     shuffle=True,
#     num_workers=NUM_WORKERS,
#     pin_memory=PIN_MEMORY,
#     drop_last=False
# )

# # Sanity check: 한 배치만 뽑아서 정보 출력
# batch = next(iter(loader))

# print("=== Sanity Check ===")
# print("waveform.shape:", batch["waveform"].shape)  # (B, 1, T)
# print("speaker_id:", batch["speaker_id"])         # list[str] 길이 B
# print("utt_path[0]:", batch["utt_path"][0])

# # wave 시각화 / 플레이 등은 나중에 노트북에서 직접 할 수 있음
# # 예: torchaudio.display.waveplot 등


[LibriSpeechSpeakerDataset] speakers: 40, utterances(raw): 2703


In [5]:
import torch.nn.functional as F

def collate_with_padding(batch_list):
    """
    batch_list는 __getitem__에서 나온 dict들의 리스트 (len = B)

    목표:
    - waveform들을 가장 긴 샘플 길이에 맞춰 zero-pad (right padding)
    - speaker_id / utt_path는 리스트 그대로 유지
    """
    wave_list = [item["waveform"] for item in batch_list]  # [(1, T_i), ...]
    spk_list = [item["speaker_id"] for item in batch_list]
    path_list = [item["utt_path"] for item in batch_list]

    # 가장 긴 길이 구하기
    max_len = max(w.shape[1] for w in wave_list)

    # pad 후 스택
    padded = []
    for w in wave_list:
        if w.shape[1] < max_len:
            pad_amount = max_len - w.shape[1]
            w = F.pad(w, (0, pad_amount), mode="constant", value=0.0)
        padded.append(w)  # (1, max_len)
    wave_tensor = torch.stack(padded, dim=0)  # (B, 1, max_len)

    return {
        "waveform": wave_tensor,
        "speaker_id": spk_list,
        "utt_path": path_list
    }

# collate_fn을 적용한 DataLoader
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    drop_last=False,
    collate_fn=collate_with_padding
)

batch = next(iter(loader))
print("=== Sanity Check w/ collate_fn ===")
print("waveform.shape:", batch["waveform"].shape)  # (B, 1, max_len_in_batch)
print("speaker_id example:", batch["speaker_id"])
print("utt_path[0]:", batch["utt_path"][0])


=== Sanity Check w/ collate_fn ===
waveform.shape: torch.Size([4, 1, 514320])
speaker_id example: ['2078', '2428', '5338', '84']
utt_path[0]: c:\project\vcshield\data\train\dev-clean\2078\142845\2078-142845-0004.flac


In [6]:
import torch
import torchaudio

def build_mel_extractor(
    sample_rate=16000,
    n_fft=1024,
    hop_length=256,
    win_length=1024,
    n_mels=80,
    f_min=0.0,
    f_max=8000.0,
    power=1.0,
    log_offset=1e-6
):
    """
    반환 함수 wav_to_mel_db:
      입력  : waveform (B,1,T) 또는 (1,T)
      출력  : mel_batch (B, n_mels, time)
    """

    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        f_min=f_min,
        f_max=f_max,
        n_mels=n_mels,
        power=power,
    )

    def wav_to_mel_db(waveform: torch.Tensor):
        """
        waveform: (B,1,T) or (1,T)
        returns: (B, n_mels, time)
        """
        single_input = False

        # case: (1,T) -> (1,1,T)
        if waveform.dim() == 2:
            waveform = waveform.unsqueeze(0)
            single_input = True

        # sanity check
        if waveform.dim() != 3 or waveform.shape[1] != 1:
            raise ValueError(f"Expected (B,1,T) or (1,T); got {tuple(waveform.shape)}")

        B = waveform.size(0)
        mel_list = []

        for b in range(B):
            # waveform[b]: (1, T)
            mel = mel_transform(waveform[b])  # (1, n_mels, time)
            # squeeze channel dim -> (n_mels, time)
            mel = mel.squeeze(0)

            mel_db = torch.log(mel + log_offset)  # (n_mels, time)
            mel_list.append(mel_db)

        # stack -> (B, n_mels, time)
        mel_batch = torch.stack(mel_list, dim=0)

        if single_input:
            mel_batch = mel_batch[0]  # (n_mels, time)

        return mel_batch

    return wav_to_mel_db


# 빌드
wav_to_mel_db = build_mel_extractor(
    sample_rate=SAMPLE_RATE,
    n_fft=1024,
    hop_length=256,
    win_length=1024,
    n_mels=80,
    f_min=0.0,
    f_max=8000.0,
    power=1.0,
    log_offset=1e-6
)

# 테스트
test_batch = next(iter(loader))
test_wave = test_batch["waveform"]      # (B,1,T)
test_mel  = wav_to_mel_db(test_wave)    # 기대: (B,80,time)

print("waveform:", test_wave.shape)
print("mel_db:", test_mel.shape)
print("speaker_id:", test_batch["speaker_id"])


waveform: torch.Size([4, 1, 193520])
mel_db: torch.Size([4, 80, 756])
speaker_id: ['174', '2428', '2086', '8842']


In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, padding=1):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding)
        self.norm = nn.BatchNorm2d(out_ch)
        self.act  = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        return x

def match_spatial(src, ref):
    """
    src, ref: 4D 텐서 (B,C,H,W)
    리턴: src와 ref 중 작은 쪽 크기로 둘을 맞춰서 (src_adj, ref_adj)를 돌려준다.
    - 만약 H/W가 서로 다르면
      중앙을 기준으로 잘라서 동일한 (H',W')로 맞춤
    - src, ref 중 어느 쪽이 더 큰지/작은지는 각각 축마다 독립적으로 처리
    """
    B1,C1,Hs,Ws = src.shape
    B2,C2,Hr,Wr = ref.shape
    assert B1 == B2, "batch mismatch in match_spatial"

    # 최종 목표 크기
    Ht = min(Hs, Hr)
    Wt = min(Ws, Wr)

    def center_crop(t, Ht, Wt):
        _,_,H,W = t.shape
        start_h = (H - Ht)//2
        start_w = (W - Wt)//2
        return t[:,:,start_h:start_h+Ht, start_w:start_w+Wt]

    src_c = center_crop(src, Ht, Wt)
    ref_c = center_crop(ref, Ht, Wt)
    return src_c, ref_c

class UNetLikeGenerator(nn.Module):
    """
    입력:  mel_batch (B, n_mels, T)
    출력:  adv_mel   (B, n_mels, T)
    구조:  2D U-Net 스타일 (Freq x Time)
    """
    def __init__(self, base_ch=64, n_mels=80):
        super().__init__()
        self.n_mels = n_mels

        # Encoder -----------------
        self.enc1 = nn.Sequential(
            ConvBlock(1, base_ch),
            ConvBlock(base_ch, base_ch)
        )
        self.down1 = nn.Conv2d(base_ch, base_ch*2,
                               kernel_size=4, stride=2, padding=1)  # /2

        self.enc2 = nn.Sequential(
            ConvBlock(base_ch*2, base_ch*2),
            ConvBlock(base_ch*2, base_ch*2)
        )
        self.down2 = nn.Conv2d(base_ch*2, base_ch*4,
                               kernel_size=4, stride=2, padding=1)  # /4

        self.enc3 = nn.Sequential(
            ConvBlock(base_ch*4, base_ch*4),
            ConvBlock(base_ch*4, base_ch*4)
        )

        # Decoder -----------------
        self.up2 = nn.ConvTranspose2d(base_ch*4, base_ch*2,
                                      kernel_size=4, stride=2, padding=1) # *2
        self.dec2 = nn.Sequential(
            ConvBlock(base_ch*4, base_ch*2),  # concat(e2)
            ConvBlock(base_ch*2, base_ch*2)
        )

        self.up1 = nn.ConvTranspose2d(base_ch*2, base_ch,
                                      kernel_size=4, stride=2, padding=1) # *4
        self.dec1 = nn.Sequential(
            ConvBlock(base_ch*2, base_ch),    # concat(e1)
            ConvBlock(base_ch, base_ch)
        )

        # Output ------------------
        self.out_conv = nn.Conv2d(base_ch, 1, kernel_size=1)

    def forward(self, mel_batch):
        """
        mel_batch: (B, n_mels, T)
        return:    adv_mel  : (B, n_mels, T)
        """
        B, Freq, Time = mel_batch.shape

        x = mel_batch.unsqueeze(1)  # (B,1,Freq,Time)

        # ----- Encoder -----
        e1 = self.enc1(x)           # (B,base_ch,Freq,Time)
        d1 = self.down1(e1)         # (B,base_ch*2,~,~)

        e2 = self.enc2(d1)          # (B,base_ch*2,~,~)
        d2 = self.down2(e2)         # (B,base_ch*4,~,~)

        bottleneck = self.enc3(d2)  # (B,base_ch*4,~,~)

        # ----- Decoder -----
        u2 = self.up2(bottleneck)   # (B,base_ch*2,~,~)
        # match spatial size between u2 and e2
        u2m, e2m = match_spatial(u2, e2)
        cat2 = torch.cat([u2m, e2m], dim=1)  # (B,base_ch*4,~,~)
        dec2 = self.dec2(cat2)               # (B,base_ch*2,~,~)

        u1 = self.up1(dec2)                  # (B,base_ch,~,~)
        # match spatial size between u1 and e1
        u1m, e1m = match_spatial(u1, e1)
        cat1 = torch.cat([u1m, e1m], dim=1)  # (B,base_ch*2,Freq,Time) ideally
        dec1 = self.dec1(cat1)               # (B,base_ch,Freq,Time)

        out_delta = self.out_conv(dec1)      # (B,1,Freq_dec,Time_dec)

        # mel_batch.unsqueeze(1): (B,1,Freq_orig,Time_orig)
        mel_as_img = mel_batch.unsqueeze(1)

        # 최종 residual add 전에 공간 크기 맞추기
        out_delta_adj, mel_as_img_adj = match_spatial(out_delta, mel_as_img)

        adv_mel = mel_as_img_adj + out_delta_adj   # (B,1,Freq_final,Time_final)
        adv_mel = adv_mel.squeeze(1)               # (B,Freq_final,Time_final)

        import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, padding=1):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding)
        self.norm = nn.BatchNorm2d(out_ch)
        self.act  = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        return x

def match_spatial(src, ref):
    """
    src, ref: (B,C,H,W)
    반환: (src_cropped, ref_cropped)
    - 두 텐서를 동일한 (H',W')로 중앙 크롭해서 맞춘다.
    - H',W'는 각 축별로 min(Hsrc,Href), min(Wsrc,Wref)
    """
    B1,C1,Hs,Ws = src.shape
    B2,C2,Hr,Wr = ref.shape
    assert B1 == B2, "batch mismatch in match_spatial"

    Ht = min(Hs, Hr)
    Wt = min(Ws, Wr)

    def center_crop(t, Ht, Wt):
        _,_,H,W = t.shape
        start_h = (H - Ht)//2
        start_w = (W - Wt)//2
        return t[:, :, start_h:start_h+Ht, start_w:start_w+Wt]

    src_c = center_crop(src, Ht, Wt)
    ref_c = center_crop(ref, Ht, Wt)
    return src_c, ref_c


class UNetLikeGenerator(nn.Module):
    """
    입력:  mel_batch (B, n_mels, T)
    출력:  adv_mel   (B, n_mels, T')  <- off-by-one 있을 수 있으므로 최종 크롭으로 원본과 최대한 맞춤
    구조:  2D U-Net 스타일 (Freq x Time)
    """
    def __init__(self, base_ch=64, n_mels=80):
        super().__init__()
        self.n_mels = n_mels

        # ----- Encoder -----
        self.enc1 = nn.Sequential(
            ConvBlock(1, base_ch),
            ConvBlock(base_ch, base_ch)
        )
        self.down1 = nn.Conv2d(
            base_ch, base_ch*2,
            kernel_size=4, stride=2, padding=1
        )  # /2

        self.enc2 = nn.Sequential(
            ConvBlock(base_ch*2, base_ch*2),
            ConvBlock(base_ch*2, base_ch*2)
        )
        self.down2 = nn.Conv2d(
            base_ch*2, base_ch*4,
            kernel_size=4, stride=2, padding=1
        )  # /4

        self.enc3 = nn.Sequential(
            ConvBlock(base_ch*4, base_ch*4),
            ConvBlock(base_ch*4, base_ch*4)
        )

        # ----- Decoder -----
        self.up2 = nn.ConvTranspose2d(
            base_ch*4, base_ch*2,
            kernel_size=4, stride=2, padding=1
        )  # *2
        self.dec2 = nn.Sequential(
            ConvBlock(base_ch*4, base_ch*2),  # concat(e2)
            ConvBlock(base_ch*2, base_ch*2)
        )

        self.up1 = nn.ConvTranspose2d(
            base_ch*2, base_ch,
            kernel_size=4, stride=2, padding=1
        )  # *4
        self.dec1 = nn.Sequential(
            ConvBlock(base_ch*2, base_ch),    # concat(e1)
            ConvBlock(base_ch, base_ch)
        )

        # ----- Output -----
        self.out_conv = nn.Conv2d(base_ch, 1, kernel_size=1)

    def forward(self, mel_batch):
        """
        mel_batch: (B, n_mels, T)
        return:    adv_mel: (B, n_mels, T_adj)
        """
        B, Freq, Time = mel_batch.shape  # Freq ~80, Time ~900~1000 가변

        x = mel_batch.unsqueeze(1)  # (B,1,Freq,Time)

        # ===== Encoder =====
        e1 = self.enc1(x)           # (B,base_ch,Freq,Time)
        d1 = self.down1(e1)         # (B,base_ch*2,~,~)

        e2 = self.enc2(d1)          # (B,base_ch*2,~,~)
        d2 = self.down2(e2)         # (B,base_ch*4,~,~)

        bottleneck = self.enc3(d2)  # (B,base_ch*4,~,~)

        # ===== Decoder =====
        u2 = self.up2(bottleneck)   # (B,base_ch*2,~,~)
        u2m, e2m = match_spatial(u2, e2)
        cat2 = torch.cat([u2m, e2m], dim=1)   # (B,base_ch*4,~,~)
        dec2 = self.dec2(cat2)                # (B,base_ch*2,~,~)

        u1 = self.up1(dec2)                   # (B,base_ch,~,~)
        u1m, e1m = match_spatial(u1, e1)
        cat1 = torch.cat([u1m, e1m], dim=1)   # (B,base_ch*2,~,~)
        dec1 = self.dec1(cat1)                # (B,base_ch,~,~)

        out_delta = self.out_conv(dec1)       # (B,1,H_out,W_out)

        # ===== Residual add with original mel =====
        mel_as_img = mel_batch.unsqueeze(1)   # (B,1,Freq,Time)

        out_delta_adj, mel_as_img_adj = match_spatial(out_delta, mel_as_img)
        # 이제 두 텐서의 (H,W)이 동일해졌음

        adv_mel = mel_as_img_adj + out_delta_adj  # (B,1,Ht,Wt)
        adv_mel = adv_mel.squeeze(1)              # (B,Ht,Wt)

        return adv_mel



# 간단하게 모델 만들어보고 shape 통과 확인
device = "cuda" if torch.cuda.is_available() else "cpu"
G = UNetLikeGenerator(base_ch=64, n_mels=80).to(device)

with torch.no_grad():
    dummy_in = torch.randn(2, 80, 200).to(device)  # (B, n_mels, T_mel)
    dummy_out = G(dummy_in)
print("G input shape:", dummy_in.shape)
print("G output shape:", dummy_out.shape)


G input shape: torch.Size([2, 80, 200])
G output shape: torch.Size([2, 80, 200])


In [13]:
# 한 스텝 예시 (forward 경로 점검)

batch = next(iter(loader))

wave = batch["waveform"].to(device)  # (B,1,T)
spk_ids = batch["speaker_id"]
paths = batch["utt_path"]

# 1. waveform -> mel (CPU에서 mel 추출한 뒤 device로 올리도록 했었는데
#    torchaudio는 GPU안써도 괜찮으니 그냥 CPU에서 돌리고 나중에 to(device))
mel_clean_cpu = wav_to_mel_db(wave.cpu())   # (B,80,Tmel) on CPU
mel_clean = mel_clean_cpu.to(device)

print("mel_clean:", mel_clean.shape)  # 기대: (B,80,Tmel)

# 2. pass through G
mel_adv = G(mel_clean)  # (B,80,Tmel)

print("mel_adv:", mel_adv.shape)

# 3. (미래 작업 TODO)
#   - mel_adv를 디노이저 D에 통과 -> mel_denoised
#   - vocoder V(mel_denoised) -> wav_adv
#   - speaker encoder E(wav_adv)와 E(wave_clean) 비교해서 loss
#   - L_content, L_resist 계산
#   - backward/optimizer.step()

# 지금 단계의 목표:
#   데이터로더 -> mel 변환 -> G forward 가 오류 없이 shape 일관성 있게 흐르는지 확인


mel_clean: torch.Size([4, 80, 457])
mel_adv: torch.Size([4, 80, 456])


In [14]:
#디노이저
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleDenoiser(nn.Module):
    """
    멜 스펙 (B, n_mels, T) -> (B, n_mels, T)
    시간 방향으로만 살짝 smoothing을 넣어서
    '노이즈 정화 시도'를 근사한다.
    """
    def __init__(self, kernel_size=5):
        super().__init__()
        self.kernel_size = kernel_size
        # depthwise conv1d 흉내: 각 mel 채널별로 동일 커널 적용
        # weight shape: (C_out, C_in/groups, K)
        # groups = n_mels -> mel 채널별 독립 필터
        # 우리는 runtime에 weight를 생성해서 F.conv1d로 처리할 예정 (학습 X, 고정)
        # 이유: 간단히 deterministic blur만 하고 싶기 때문
        # 즉, forward 안에서 커널을 만들어 쓴다.

    def forward(self, mel):
        """
        mel: (B, n_mels, T)
        return: (B, n_mels, T)
        """
        B, C, T = mel.shape
        k = self.kernel_size

        # 평균커널: (C,1,k)
        kernel = torch.ones(C, 1, k, device=mel.device, dtype=mel.dtype) / k

        # padding 'reflect' or 'replicate' 등으로 시간 길이 유지
        pad = (k // 2, k // 2)
        mel_padded = F.pad(mel, pad=pad, mode='reflect')  # (B,C,T+padL+padR)

        # depthwise conv1d
        mel_blur = F.conv1d(
            mel_padded,          # (B,C,T+pad)
            kernel,              # (C,1,k)
            bias=None,
            stride=1,
            padding=0,
            groups=C
        )  # (B,C,T)

        return mel_blur

D = SimpleDenoiser(kernel_size=5).to(device)

# quick shape test
dummy = torch.randn(4, 80, 992).to(device)
out_d = D(dummy)
print("D input:", dummy.shape, "D output:", out_d.shape)


D input: torch.Size([4, 80, 992]) D output: torch.Size([4, 80, 992])


In [15]:
#ECAPA(임베딩을 뽑는 네트워크인데 임시로 흉내냄)
class DummySpeakerEncoder(nn.Module):
    """
    mel (B, 80, T) -> speaker embedding (B, emb_dim)
    아주 단순한 conv + global pooling + linear.
    실제로는 ECAPA 등으로 대체할 예정.
    """
    def __init__(self, in_ch=80, emb_dim=192):
        super().__init__()
        # conv over time
        self.conv1 = nn.Conv1d(in_ch, 128, kernel_size=5, padding=2)
        self.conv2 = nn.Conv1d(128, 128, kernel_size=5, padding=2)
        self.conv3 = nn.Conv1d(128, 256, kernel_size=5, padding=2)
        self.norm  = nn.BatchNorm1d(256)
        self.act   = nn.ReLU(inplace=True)
        self.proj  = nn.Linear(256, emb_dim)

    def forward(self, mel):
        """
        mel: (B, 80, T)
        return: (B, emb_dim)
        """
        x = self.conv1(mel)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)          # (B,256,T)
        x = self.norm(x)
        x = self.act(x)

        # global average pool over time
        x = torch.mean(x, dim=2)   # (B,256)

        x = self.proj(x)           # (B,emb_dim)
        # normalize for cosine stability
        x = F.normalize(x, p=2, dim=1)
        return x

E = DummySpeakerEncoder(in_ch=80, emb_dim=192).to(device)

# quick shape test
with torch.no_grad():
    emb_test = E(dummy)  # dummy: (4,80,992)
print("E input:", dummy.shape, "E output(emb):", emb_test.shape)



E input: torch.Size([4, 80, 992]) E output(emb): torch.Size([4, 192])


In [17]:
#손실 계산

def match_mel_for_loss(mel_a, mel_b):
    """
    mel_a: (B, F, T)
    mel_b: (B, F, T)
    return: (a_crop, b_crop) with identical (F', T')
    둘 중 더 작은 쪽에 중앙 크롭 맞춘다.
    """
    A4 = mel_a.unsqueeze(1)  # (B,1,F,T)
    B4 = mel_b.unsqueeze(1)  # (B,1,F,T)

    A4c, B4c = match_spatial(A4, B4)  # (B,1,F',T'), (B,1,F',T')

    A3c = A4c.squeeze(1)  # (B,F',T')
    B3c = B4c.squeeze(1)  # (B,F',T')

    return A3c, B3c


# 하이퍼파라미터
lambda_c = 1.0   # content 보존 가중치
lambda_r = 1.0   # resist(식별 방해) 가중치
lr = 1e-4

optimizer = torch.optim.Adam(G.parameters(), lr=lr)

# === 1 step example ===
batch = next(iter(loader))
wave = batch["waveform"].to(device)        # (B,1,T)

# 1. waveform -> mel_clean (CPU에서 mel 추출 후 device로 옮김)
mel_clean_cpu = wav_to_mel_db(wave.cpu())  # (B,80,Tmel) on CPU
mel_clean = mel_clean_cpu.to(device)       # (B,80,Tmel)

# 2. 생성기 통과
mel_adv = G(mel_clean)                     # (B,80,Tmel)

# 3. 디노이저(=공격자 정화 시뮬레이션)
mel_denoised = D(mel_adv)                  # (B,80,Tmel)

# 4. 화자 임베딩 (placeholder)
emb_orig = E(mel_clean)                    # (B,emb_dim)
emb_denoised = E(mel_denoised)             # (B,emb_dim)

# 5. 손실 계산
# (a) content loss: adv mel이 clean mel과 너무 달라지지 않도록
mel_adv_crop, mel_clean_crop = match_mel_for_loss(mel_adv, mel_clean)
L_content = F.l1_loss(mel_adv_crop, mel_clean_crop)

# (b) resist loss: 디노이즈 후 임베딩이 원본 임베딩과 닮지 않도록
# cosine similarity를 계산하고, 그걸 낮추고 싶다 → 음수로 곱해주면 된다.
cos_sim = torch.sum(emb_orig * emb_denoised, dim=1)  # (B,)
L_resist = - torch.mean(cos_sim)  # 유사도가 낮아질수록(멀어질수록) 낮은 loss가 되게 부호 설정

L_total = lambda_c * L_content + lambda_r * L_resist

print("L_content:", float(L_content.item()))
print("L_resist :", float(L_resist.item()))
print("L_total  :", float(L_total.item()))

# 6. backward & optimizer step
optimizer.zero_grad()
L_total.backward()
optimizer.step()

print("✅ backward/step까지 완료")


L_content: 0.3105723261833191
L_resist : -0.9993799328804016
L_total  : -0.6888076066970825
✅ backward/step까지 완료


In [18]:
import torch
import torch.nn.functional as F

# 1) mel 크롭 유틸
def match_mel_for_loss(mel_a, mel_b):
    """
    mel_a: (B,F,T)
    mel_b: (B,F,T)
    return: (a_crop, b_crop) with same (F',T')
    """
    A4 = mel_a.unsqueeze(1)  # (B,1,F,T)
    B4 = mel_b.unsqueeze(1)  # (B,1,F,T)
    A4c, B4c = match_spatial(A4, B4)
    A3c = A4c.squeeze(1)     # (B,F',T')
    B3c = B4c.squeeze(1)     # (B,F',T')
    return A3c, B3c

# ---------------- Hyperparams / Optimizer ----------------
lambda_c = 1.0   # content weight
lambda_r = 1.0   # resist weight
lr = 1e-4

optimizer = torch.optim.Adam(G.parameters(), lr=lr)

# ---------------- One training step prototype ----------------
batch = next(iter(loader))
wave = batch["waveform"].to(device)            # (B,1,T)

# waveform -> mel_clean (still on CPU for torchaudio, then to device)
mel_clean_cpu = wav_to_mel_db(wave.cpu())      # (B,80,Tmel_clean)
mel_clean = mel_clean_cpu.to(device)           # (B,80,Tmel_clean)

# G forward
mel_adv = G(mel_clean)                         # (B,80,Tmel_adv)

# D (denoiser / purification simulation)
mel_denoised = D(mel_adv)                      # (B,80,Tmel_adv)  (same T as mel_adv by design)

# Speaker embeddings (placeholder E on mel)
emb_orig      = E(mel_clean)                   # (B,emb_dim)
emb_denoised  = E(mel_denoised)                # (B,emb_dim)

# Content loss: crop mel_adv and mel_clean to same spatial size
mel_adv_crop, mel_clean_crop = match_mel_for_loss(mel_adv, mel_clean)
L_content = F.l1_loss(mel_adv_crop, mel_clean_crop)

# Resist loss: we want cosine similarity to go DOWN
cos_sim = torch.sum(emb_orig * emb_denoised, dim=1)  # (B,)
L_resist = - torch.mean(cos_sim)

L_total = lambda_c * L_content + lambda_r * L_resist

print("mel_clean.shape      :", mel_clean.shape)
print("mel_adv.shape        :", mel_adv.shape)
print("mel_adv_crop.shape   :", mel_adv_crop.shape)
print("mel_clean_crop.shape :", mel_clean_crop.shape)
print("emb_orig.shape       :", emb_orig.shape)
print("emb_denoised.shape   :", emb_denoised.shape)
print("L_content:", float(L_content.item()))
print("L_resist :", float(L_resist.item()))
print("L_total  :", float(L_total.item()))

optimizer.zero_grad()
L_total.backward()
optimizer.step()

print("✅ backward/step까지 완료")


mel_clean.shape      : torch.Size([4, 80, 761])
mel_adv.shape        : torch.Size([4, 80, 760])
mel_adv_crop.shape   : torch.Size([4, 80, 760])
mel_clean_crop.shape : torch.Size([4, 80, 760])
emb_orig.shape       : torch.Size([4, 192])
emb_denoised.shape   : torch.Size([4, 192])
L_content: 0.4890805780887604
L_resist : -0.9991296529769897
L_total  : -0.5100491046905518
✅ backward/step까지 완료
