# CRNN-based Voice Activity Detection (VAD) — Training Notebook

**EN**: This notebook mirrors `vad_train_crnn.py` with clear, sectioned explanations and bilingual comments.  
**KO**: 이 노트북은 `vad_train_crnn.py`의 내용을 섹션별 설명과 한/영 주석으로 정리한 버전입니다.

## 0. Environment / Imports

**EN**: Core libraries for audio I/O, feature extraction, PyTorch model training, and utilities.  
**KO**: 오디오 입출력, 특징 추출, PyTorch 학습 및 각종 유틸리티 임포트.

In [None]:
import os, math, time, json, random
from pathlib import Path
from typing import List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import soundfile as sf
import librosa

print(f'[info] torch={torch.__version__} cuda={torch.cuda.is_available()}')

## 1. Reproducibility Utility

**EN**: Sets seeds and toggles CuDNN flags for stable, fast training.  
**KO**: 시드 고정 및 CuDNN 설정으로 재현성과 속도를 확보합니다.

In [None]:
def set_seed(seed: int = 1337):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False  # allow TF32
    torch.backends.cudnn.benchmark = True

## 2. Audio Utilities

**EN**: Load mono audio and compute log-mel spectrograms.  
**KO**: 모노 오디오 로드 및 로그-멜 스펙트로그램 계산 함수입니다.

In [None]:
def load_audio(path: Path, target_sr: int) -> np.ndarray:
    wav, sr = sf.read(str(path), dtype="float32", always_2d=False)
    if wav.ndim > 1:
        wav = np.mean(wav, axis=1)
    if sr != target_sr:
        wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
    return wav

def to_logmel(
    wav: np.ndarray,
    sr: int,
    n_fft: int = 1024,
    hop_length: int = 160,
    win_length: int = 400,
    n_mels: int = 64,
    fmin: int = 50,
    fmax: Optional[int] = None,
    eps: float = 1e-10,
) -> np.ndarray:
    spec = np.abs(librosa.stft(wav, n_fft=n_fft, hop_length=hop_length, win_length=win_length)) ** 2
    mel = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax or sr//2)
    mel_spec = np.dot(mel, spec)
    logmel = np.log(mel_spec + eps)
    return logmel.T  # (time, n_mels)

## 3. Dataset: Frame-level Labels

**EN**: Reads pairs of `(wav_path, label_npy_path)` and returns `(T,F)` log-mel features with frame-level labels `(T,)`.  
**KO**: `(오디오, 라벨)` 쌍을 읽고 `(T,F)` 로그멜과 `(T,)` 프레임 단위 라벨을 반환합니다.

> Augmentation: Optional noise mixing using files from `noise_dir` during training.

In [None]:
class FrameLabelDataset(Dataset):
    def __init__(
        self,
        items: List[Tuple[Path, Path]],
        sr: int = 16000,
        n_fft: int = 1024,
        hop_length: int = 160,
        win_length: int = 400,
        n_mels: int = 64,
        augment_noise_paths: Optional[List[Path]] = None,
        snr_db_range: Tuple[float, float] = (5.0, 20.0),
    ):
        self.items = items
        self.sr = sr
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.n_mels = n_mels
        self.augment_noise_paths = augment_noise_paths or []
        self.snr_db_range = snr_db_range

    def __len__(self):
        return len(self.items)

    def _mix_noise(self, clean: np.ndarray) -> np.ndarray:
        if not self.augment_noise_paths:
            return clean
        noise_path = random.choice(self.augment_noise_paths)
        noise = load_audio(noise_path, self.sr)
        if len(noise) < len(clean):
            reps = math.ceil(len(clean) / len(noise))
            noise = np.tile(noise, reps)[: len(clean)]
        else:
            start = random.randint(0, len(noise) - len(clean))
            noise = noise[start : start + len(clean)]
        snr_db = random.uniform(*self.snr_db_range)
        sig_pwr = np.mean(clean**2) + 1e-12
        noise_pwr = np.mean(noise**2) + 1e-12
        scale = np.sqrt(sig_pwr / (10 ** (snr_db / 10) * noise_pwr))
        return clean + scale * noise

    def __getitem__(self, idx: int):
        wav_path, label_path = self.items[idx]
        wav = load_audio(wav_path, self.sr)
        if random.random() < 0.7:
            wav = self._mix_noise(wav)

        logmel = to_logmel(
            wav, sr=self.sr,
            n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length,
            n_mels=self.n_mels
        )  # (T, F)

        labels = np.load(label_path)  # shape (T,) with {0,1}
        T = min(len(labels), logmel.shape[0])
        logmel = logmel[:T]
        labels = labels[:T]

        x = torch.tensor(logmel, dtype=torch.float32)  # (T, F)
        y = torch.tensor(labels, dtype=torch.float32)  # (T,)
        return x, y

def collate_pad(batch):
    xs, ys = zip(*batch)
    lengths = [x.shape[0] for x in xs]
    max_len = max(lengths)
    Fdim = xs[0].shape[1]
    x_pad = torch.zeros(len(xs), max_len, Fdim, dtype=torch.float32)
    y_pad = torch.zeros(len(xs), max_len, dtype=torch.float32)
    for i, (x, y) in enumerate(zip(xs, ys)):
        T = x.shape[0]
        x_pad[i, :T] = x
        y_pad[i, :T] = y
    return x_pad, y_pad, torch.tensor(lengths, dtype=torch.int32)

## 4. Model: CRNN for VAD

**EN**: Lightweight CNN over frequency + GRU over time → frame-level logits.  
**KO**: 주파수 축 CNN, 시간 축 GRU로 구성된 경량 CRNN → 프레임 로짓 출력.

In [None]:
class CRNNVAD(nn.Module):
    def __init__(self, n_mels: int = 64, cnn_channels: int = 64, rnn_hidden: int = 128):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, cnn_channels, kernel_size=(3,3), padding=1),
            nn.BatchNorm2d(cnn_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(cnn_channels, cnn_channels, kernel_size=(3,3), padding=1),
            nn.BatchNorm2d(cnn_channels),
            nn.ReLU(inplace=True),
        )
        self.rnn = nn.GRU(input_size=cnn_channels * n_mels, hidden_size=rnn_hidden,
                          num_layers=1, batch_first=True, bidirectional=True)
        self.out = nn.Linear(2 * rnn_hidden, 1)

    def forward(self, x: torch.Tensor, lengths: torch.Tensor):
        # x: (B, T, F)
        B, T, Fdim = x.shape
        x = x.unsqueeze(1)                 # (B, 1, T, F)
        x = self.conv(x)                   # (B, C, T, F)
        B, C, T, Fdim = x.shape
        x = x.permute(0, 2, 1, 3).contiguous().view(B, T, C * Fdim)  # (B, T, C*F)

        packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, _ = self.rnn(packed)
        out, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)  # (B, T, 2H)
        logits = self.out(out).squeeze(-1)  # (B, T)
        return logits

## 5. Loss & Metrics

**EN**: BCE loss with padding mask; frame-level F1 metric for validation.  
**KO**: 패딩을 무시하는 BCE 손실과 프레임 F1 평가지표.

In [None]:
def bce_with_mask(logits: torch.Tensor, targets: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
    B, T = logits.shape
    mask = torch.arange(T, device=logits.device).unsqueeze(0) < lengths.unsqueeze(1)
    loss = F.binary_cross_entropy_with_logits(logits[mask], targets[mask])
    return loss

@torch.no_grad()
def frame_f1(logits: torch.Tensor, targets: torch.Tensor, lengths: torch.Tensor, thresh: float = 0.5) -> float:
    B, T = logits.shape
    mask = torch.arange(T, device=logits.device).unsqueeze(0) < lengths.unsqueeze(1)
    preds = (torch.sigmoid(logits) > thresh) & mask
    targs = (targets > 0.5) & mask
    tp = (preds & targs).sum().item()
    fp = (preds & ~targs).sum().item()
    fn = (~preds & targs).sum().item()
    precision = tp / (tp + fp + 1e-9)
    recall = tp / (tp + fn + 1e-9)
    f1 = 2 * precision * recall / (precision + recall + 1e-9)
    return float(f1)

## 6. Training / Evaluation Loops

**EN**: AMP-enabled training and evaluation with average F1.  
**KO**: AMP 지원 학습 루프와 평균 F1 평가 루프입니다.

In [None]:
def train_one_epoch(model, loader, optimizer, scaler, device):
    model.train()
    total_loss = 0.0
    total_batches = 0
    for x, y, lengths in loader:
        x, y, lengths = x.to(device), y.to(device), lengths.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=device.type == "cuda"):
            logits = model(x, lengths)
            loss = bce_with_mask(logits, y, lengths)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
        total_batches += 1
    return total_loss / max(total_batches, 1)

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total_f1 = 0.0
    total_batches = 0
    for x, y, lengths in loader:
        x, y, lengths = x.to(device), y.to(device), lengths.to(device)
        logits = model(x, lengths)
        total_f1 += frame_f1(logits, y, lengths)
        total_batches += 1
    return total_f1 / max(total_batches, 1)

## 7. Checkpoint & Export

**EN**: Save checkpoints; export TorchScript (`.jit`) and ONNX (`.onnx`).  
**KO**: 체크포인트 저장 및 TorchScript/ONNX 내보내기.

In [None]:
def save_checkpoint(model, path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(model.state_dict(), str(path))

def export_torchscript(model, n_mels: int, path: Path, device: torch.device):
    path.parent.mkdir(parents=True, exist_ok=True)
    model.eval()
    x = torch.randn(1, 100, n_mels, device=device)
    lengths = torch.tensor([100], device=device)
    scripted = torch.jit.trace(model, (x, lengths))
    scripted.save(str(path))

def export_onnx(model, n_mels: int, path: Path, device: torch.device):
    path.parent.mkdir(parents=True, exist_ok=True)
    model.eval()
    x = torch.randn(1, 100, n_mels, device=device)
    lengths = torch.tensor([100], device=device)
    torch.onnx.export(
        model, (x, lengths), str(path),
        input_names=["x", "lengths"],
        output_names=["logits"],
        opset_version=17,
        dynamic_axes={"x": {1: "time"}, "logits": {1: "time"}}
    )

## 8. Data Loading Helpers

**EN**: Read lists of `(wav_path, label_npy_path)` from JSON; scan noise directory; make loaders.  
**KO**: JSON에서 `(오디오, 라벨)` 목록을 읽고, 노이즈 디렉터리 스캔 및 DataLoader 생성.

In [None]:
def load_item_list(json_path: str) -> List[Tuple[Path, Path]]:
    if not json_path:
        return []
    with open(json_path, "r", encoding="utf-8") as f:
        pairs = json.load(f)
    return [(Path(a), Path(b)) for a, b in pairs]

def scan_noise(noise_dir: str) -> List[Path]:
    if not noise_dir:
        return []
    p = Path(noise_dir)
    wavs = sorted(list(p.rglob("*.wav")))
    return wavs

def make_loader(items: List[Tuple[Path, Path]], args, shuffle: bool) -> DataLoader:
    ds = FrameLabelDataset(
        items,
        sr=args.sr,
        n_fft=args.n_fft,
        hop_length=args.hop_length,
        win_length=args.win_length,
        n_mels=args.n_mels,
        augment_noise_paths=scan_noise(args.noise_dir) if shuffle else [],
    )
    return DataLoader(ds, batch_size=args.batch_size, shuffle=shuffle, num_workers=2, collate_fn=collate_pad, pin_memory=True)

## 9. Configuration & Main Entrypoint

**EN**: Hyperparameters and I/O settings. Set `export=True` to produce `.jit` and `.onnx`.  
**KO**: 하이퍼파라미터/입출력 설정. `export=True`로 내보내기(.jit/.onnx) 활성화.

In [None]:
import argparse

def get_default_args():
    p = argparse.ArgumentParser(description="Train CRNN VAD")
    # Data
    p.add_argument("--train_list", type=str, required=False, default="", help="Path to JSON list of (wav,label) for train")
    p.add_argument("--valid_list", type=str, required=False, default="", help="Path to JSON list of (wav,label) for valid")
    p.add_argument("--noise_dir", type=str, default="", help="Optional dir of noise wavs for augmentation")
    # Audio / Features
    p.add_argument("--sr", type=int, default=16000)
    p.add_argument("--n_mels", type=int, default=64)
    p.add_argument("--n_fft", type=int, default=1024)
    p.add_argument("--hop_length", type=int, default=160)
    p.add_argument("--win_length", type=int, default=400)
    # Model / Train
    p.add_argument("--cnn_channels", type=int, default=64)
    p.add_argument("--rnn_hidden", type=int, default=128)
    p.add_argument("--batch_size", type=int, default=16)
    p.add_argument("--epochs", type=int, default=10)
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--seed", type=int, default=1337)
    # I/O
    p.add_argument("--out_dir", type=str, default="vad_out")
    p.add_argument("--save_every", type=int, default=1)
    p.add_argument("--export", action="store_true", help="Export TorchScript and ONNX at the end")
    args, _ = p.parse_known_args([])  # in-notebook safe
    return args

def run_training(args=None):
    if args is None:
        args = get_default_args()
    set_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[info] device={device}")

    train_items = load_item_list(args.train_list)
    valid_items = load_item_list(args.valid_list)

    if not train_items:
        print("[warn] train_list is empty. Provide a JSON list of [(wav,label), ...]. See 'Data Format' section below.")
    if not valid_items:
        print("[warn] valid_list is empty. Using train split for eval.")
        valid_items = train_items

    model = CRNNVAD(n_mels=args.n_mels, cnn_channels=args.cnn_channels, rnn_hidden=args.rnn_hidden).to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=args.lr)
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))

    train_loader = make_loader(train_items, args, shuffle=True) if train_items else None
    valid_loader = make_loader(valid_items, args, shuffle=False) if valid_items else None

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    best_f1 = 0.0
    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        if train_loader is not None:
            train_loss = train_one_epoch(model, train_loader, optim, scaler, device)
        else:
            train_loss = float("nan")
        val_f1 = evaluate(model, valid_loader, device) if valid_loader is not None else float("nan")
        dt = time.time() - t0
        print(f"[epoch {epoch:03d}] loss={train_loss:.4f} val_f1={val_f1:.4f} time={dt:.1f}s")

        if (epoch % args.save_every) == 0:
            save_checkpoint(model, out_dir / f"crnn_vad_epoch{epoch:03d}.pth")

        if not math.isnan(val_f1) and val_f1 > best_f1:
            best_f1 = val_f1
            save_checkpoint(model, out_dir / "best.pth")

    if args.export:
        export_torchscript(model, args.n_mels, out_dir / "vad_crnn.jit", device)
        try:
            export_onnx(model, args.n_mels, out_dir / "vad_crnn.onnx", device)
        except Exception as e:
            print(f"[warn] ONNX export failed: {e}")

    cfg = vars(args).copy()
    with open(out_dir / "config.json", "w", encoding="utf-8") as f:
        json.dump(cfg, f, ensure_ascii=False, indent=2)

    print(f"[done] outputs saved to: {out_dir}")

## 10. Data Format (JSON)

**EN**: `--train_list` and `--valid_list` expect a JSON array of pairs:  
```json
[
  ["/path/to/audio_000.wav", "/path/to/audio_000_labels.npy"],
  ["/path/to/audio_001.wav", "/path/to/audio_001_labels.npy"]
]
```
The label `.npy` should be a 1D array of shape `(T,)` with values `{0,1}` aligned to log-mel frames.

**KO**: `--train_list`, `--valid_list`는 다음 형식의 JSON 배열을 기대합니다.  
라벨 `.npy`는 `(T,)` 1차원 배열이며 각 프레임의 발화 여부 `{0,1}`를 가집니다.

In [None]:
# (Optional) Quick schema check helper
def preview_list(json_path: str, max_rows: int = 3):
    pairs = load_item_list(json_path)
    print(f'#items={len(pairs)}')
    for i, (a,b) in enumerate(pairs[:max_rows]):
        print(i, a, b)

## 11. Usage (In-Notebook)

**EN**: Set `args.train_list` / `args.valid_list` to your JSON files and run `run_training(args)`.  
**KO**: `args.train_list` / `args.valid_list`에 JSON 경로 설정 후 `run_training(args)` 실행.

In [None]:
# Example:
# args = get_default_args()
# args.train_list = "/content/train_list.json"
# args.valid_list = "/content/valid_list.json"
# args.noise_dir  = "/content/noise_wavs"
# args.epochs = 5
# args.export = True
# run_training(args)