# ASVspoof 2019 LA – Multi-branch Attention Model

Notebook này tái cấu trúc toàn bộ project huấn luyện mô hình đa nhánh phát hiện giả mạo giọng nói sang định dạng `.ipynb` để có thể chạy trực tiếp trên Kaggle hoặc các môi trường notebook khác. Mọi phần mã nguồn trong thư mục `src/` và script `train.py` đều đã được tổ chức lại thành các cell có chú thích rõ ràng.

## 1. Chuẩn bị môi trường
Chạy cell bên dưới để cài đặt các phụ thuộc cần thiết. Bạn có thể tuỳ chỉnh danh sách nếu môi trường đã có sẵn một số thư viện.

In [None]:
# Nếu chạy trên Kaggle hoặc môi trường mới, bỏ comment dòng dưới để cài đặt.
# !pip install torch torchaudio numpy scipy pandas PyYAML librosa soundfile tqdm matplotlib tensorboard

## 2. Thư viện và cấu hình toàn cục

In [None]:
from __future__ import annotations

import csv
import math
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, TypedDict

import librosa
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchaudio
import yaml
from torch import Tensor
from torch.utils.data import DataLoader, Dataset

# Đảm bảo reproducibility (có thể điều chỉnh tuỳ nhu cầu)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

## 3. Cấu hình đặc trưng và xử lý âm thanh
Các cell dưới đây tương ứng với `src/data/features.py` và `src/utils/audio.py`.

In [None]:
class SpectralConfig(TypedDict, total=False):
    n_fft: int
    hop_length: int
    win_length: int
    n_mels: int
    f_min: float
    f_max: Optional[float]
    power: float


class TemporalConfig(TypedDict, total=False):
    emphasis: bool
    highpass_cutoff: float


class CepstralConfig(TypedDict, total=False):
    hop_length: int
    n_bins: int
    bins_per_octave: int
    f_min: float


@dataclass
class FeatureConfig:
    """Tập hợp các tham số dựng đặc trưng cho từng nhánh."""

    sample_rate: int = 16000
    spectral: SpectralConfig = field(
        default_factory=lambda: SpectralConfig(
            n_fft=1024,
            hop_length=256,
            win_length=1024,
            n_mels=128,
            f_min=20.0,
            f_max=None,
            power=2.0,
        )
    )
    temporal: TemporalConfig = field(
        default_factory=lambda: TemporalConfig(
            emphasis=True,
            highpass_cutoff=30.0,
        )
    )
    cepstral: CepstralConfig = field(
        default_factory=lambda: CepstralConfig(
            hop_length=256,
            n_bins=84,
            bins_per_octave=12,
            f_min=32.7,
        )
    )

In [None]:
def ensure_mono(waveform: Tensor) -> Tensor:
    """Chuyển waveform nhiều kênh về mono bằng cách trung bình theo trục kênh."""
    if waveform.size(0) == 1:
        return waveform
    return waveform.mean(dim=0, keepdim=True)


def load_audio(path: str, target_sample_rate: int, normalize: bool = True) -> Tuple[Tensor, int]:
    """Đọc audio, resample nếu cần và chuẩn hoá biên độ về [-1, 1]."""
    waveform, sample_rate = torchaudio.load(path)
    if sample_rate != target_sample_rate:
        resampler = torchaudio.transforms.Resample(sample_rate, target_sample_rate)
        waveform = resampler(waveform)
        sample_rate = target_sample_rate

    waveform = ensure_mono(waveform)
    if normalize:
        peak = waveform.abs().max()
        if peak > 0:
            waveform = waveform / peak
    return waveform, sample_rate


def pad_or_trim(waveform: Tensor, target_num_samples: int, mode: str = "repeat") -> Tensor:
    """Đưa waveform về độ dài cố định bằng cách pad hoặc cắt."""
    current = waveform.size(-1)
    if current == target_num_samples:
        return waveform
    if current > target_num_samples:
        return waveform[..., :target_num_samples]

    diff = target_num_samples - current
    if mode == "zeros":
        padded = torch.nn.functional.pad(waveform, (0, diff))
    elif mode == "reflect":
        padded = torch.nn.functional.pad(waveform, (0, diff), mode="reflect")
    elif mode == "repeat":
        repeats = math.ceil(target_num_samples / current)
        padded = waveform.repeat(1, repeats)[..., :target_num_samples]
    else:
        raise ValueError(f"pad_mode không được hỗ trợ: {mode}")
    return padded

In [None]:
MultiBranchFeatures = Dict[str, Tensor]


class MultiBranchFeatureExtractor:
    """Sinh đặc trưng cho ba nhánh: spectral (Mel), temporal (sóng), cepstral (CQT)."""

    def __init__(self, config: FeatureConfig) -> None:
        self.config = config
        spec_cfg = config.spectral
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=config.sample_rate,
            n_fft=spec_cfg.get("n_fft", 1024),
            hop_length=spec_cfg.get("hop_length", 256),
            win_length=spec_cfg.get("win_length", spec_cfg.get("n_fft", 1024)),
            f_min=spec_cfg.get("f_min", 20.0),
            f_max=spec_cfg.get("f_max"),
            n_mels=spec_cfg.get("n_mels", 128),
            power=spec_cfg.get("power", 2.0),
            normalized=False,
        )

        temp_cfg = config.temporal
        self.apply_pre_emphasis = temp_cfg.get("emphasis", True)
        self.highpass_cutoff = temp_cfg.get("highpass_cutoff", 30.0)
        self.cqt_cfg = config.cepstral

    def __call__(self, waveform: Tensor) -> MultiBranchFeatures:
        waveform = ensure_mono(waveform)
        mel = self._compute_mel_spectrogram(waveform)
        temporal = self._prepare_temporal_branch(waveform)
        cqt = self._compute_cqt(waveform)
        return {"spectral": mel, "temporal": temporal, "cepstral": cqt}

    def _compute_mel_spectrogram(self, waveform: Tensor) -> Tensor:
        mel = self.mel_transform(waveform)
        return torch.log1p(mel)

    def _prepare_temporal_branch(self, waveform: Tensor) -> Tensor:
        output = waveform
        if self.apply_pre_emphasis:
            output = torchaudio.functional.preemphasis(output, 0.97)
        if self.highpass_cutoff is not None and self.highpass_cutoff > 0:
            output = torchaudio.functional.highpass_biquad(
                output,
                sample_rate=self.config.sample_rate,
                cutoff_freq=self.highpass_cutoff,
            )
        return output

    def _compute_cqt(self, waveform: Tensor) -> Tensor:
        y = waveform.squeeze(0).cpu().numpy()
        cqt = librosa.cqt(
            y,
            sr=self.config.sample_rate,
            hop_length=self.cqt_cfg.get("hop_length", 256),
            n_bins=self.cqt_cfg.get("n_bins", 84),
            bins_per_octave=self.cqt_cfg.get("bins_per_octave", 12),
            fmin=self.cqt_cfg.get("f_min", 32.7),
        )
        magnitude = torch.from_numpy((np.abs(cqt) ** 2).astype("float32"))
        magnitude = torch.log1p(magnitude)
        return magnitude.unsqueeze(0)

## 4. Dataset và DataModule
Các cell tương ứng với `src/data/asvspoof_dataset.py` và `src/data/datamodule.py`.

In [None]:
LA_LABELS = {"bonafide": 0, "spoof": 1}


@dataclass(frozen=True)
class DatasetSpec:
    partition_dir_template: str
    protocol_patterns: Tuple[str, ...]
    audio_subdirs: Tuple[str, ...] = ("", "flac", "wav")
    audio_extensions: Tuple[str, ...] = (".flac", ".wav")
    protocol_dir_templates: Tuple[str, ...] = (
        "{partition_dir}/protocol",
        "{partition_dir}",
    )


DEFAULT_DATASET_VARIANT = "ASVspoof2019_LA"


DATASET_SPECS: Dict[str, DatasetSpec] = {
    "ASVspoof2019_LA": DatasetSpec(
        partition_dir_template="ASVspoof2019_LA_{partition}",
        protocol_patterns=(
            "ASVspoof2019.LA.cm.{partition}.trn.txt",
            "ASVspoof2019.LA.cm.{partition}.trl.txt",
            "ASVspoof2019.LA.cm.{partition}.txt",
        ),
        audio_subdirs=("", "flac", "wav"),
        audio_extensions=(".flac", ".wav"),
        protocol_dir_templates=(
            "ASVspoof2019_LA_cm_protocols",
            "{partition_dir}/protocol",
            "{partition_dir}",
        ),
    ),
    "ASVspoof2019_PA": DatasetSpec(
        partition_dir_template="ASVspoof2019_PA_{partition}",
        protocol_patterns=(
            "ASVspoof2019.PA.cm.{partition}.trn.txt",
            "ASVspoof2019.PA.cm.{partition}.trl.txt",
            "ASVspoof2019.PA.cm.{partition}.txt",
        ),
        audio_subdirs=("", "wav", "flac"),
        audio_extensions=(".wav", ".flac"),
        protocol_dir_templates=(
            "ASVspoof2019_PA_cm_protocols",
            "{partition_dir}/protocol",
            "{partition_dir}",
        ),
    ),
    "ASVspoof5": DatasetSpec(
        partition_dir_template="ASVspoof5_{partition}",
        protocol_patterns=(
            "ASVspoof5.cm.{partition}.txt",
            "ASVspoof5.{partition}.cm.txt",
            "ASVspoof5.{partition}.txt",
        ),
        audio_subdirs=("", "wav", "flac"),
        audio_extensions=(".wav", ".flac"),
    ),
}


@dataclass
class ASVExample:
    utt_id: str
    speaker_id: str
    path: str
    label: int
    system_id: Optional[str] = None
    attack_type: Optional[str] = None


class ASVspoofLADataset(Dataset):
    """Dataset PyTorch cho các phiên bản ASVspoof."""

    def __init__(
        self,
        data_root: str,
        partition: str,
        feature_extractor: MultiBranchFeatureExtractor,
        protocol_file: Optional[str] = None,
        sample_rate: int = 16000,
        max_duration: float = 6.0,
        pad_mode: str = "repeat",
        preload_waveforms: bool = False,
        dataset_variant: str = DEFAULT_DATASET_VARIANT,
    ) -> None:
        super().__init__()
        if dataset_variant not in DATASET_SPECS:
            raise ValueError(
                f"Unsupported dataset_variant {dataset_variant!r}. Available: {list(DATASET_SPECS.keys())}"
            )
        self.dataset_variant = dataset_variant
        self.dataset_spec = DATASET_SPECS[dataset_variant]

        self.data_root = data_root
        self.partition = partition
        self.sample_rate = sample_rate
        self.max_num_samples = int(sample_rate * max_duration)
        self.pad_mode = pad_mode
        self.feature_extractor = feature_extractor
        self.preload_waveforms = preload_waveforms

        if protocol_file is None:
            protocol_file = self._infer_protocol_path()

        self.protocol_file = protocol_file
        self.examples = self._load_metadata()

        if self.preload_waveforms:
            self._waveform_cache: Dict[str, Tensor] = {}
            for example in self.examples:
                waveform, _ = load_audio(example.path, self.sample_rate, normalize=True)
                waveform = pad_or_trim(waveform, self.max_num_samples, mode=self.pad_mode)
                self._waveform_cache[example.utt_id] = waveform
        else:
            self._waveform_cache = {}

    def _infer_protocol_path(self) -> str:
        partition_dir = self.dataset_spec.partition_dir_template.format(partition=self.partition)

        protocol_dirs: List[str] = []
        for template in self.dataset_spec.protocol_dir_templates:
            protocol_dirs.append(
                os.path.join(
                    self.data_root,
                    template.format(partition=self.partition, partition_dir=partition_dir),
                )
            )

        protocol_dirs.extend(
            [
                os.path.join(self.data_root, partition_dir, "protocol"),
                os.path.join(self.data_root, partition_dir),
            ]
        )

        seen = set()
        protocol_dirs = [path for path in protocol_dirs if not (path in seen or seen.add(path))]

        candidates: List[str] = []
        for proto_dir in protocol_dirs:
            for pattern in self.dataset_spec.protocol_patterns:
                pattern_path = pattern.format(partition=self.partition)
                candidates.append(os.path.join(proto_dir, pattern_path))
                if pattern_path.endswith(".trn.txt"):
                    candidates.append(os.path.join(proto_dir, pattern_path.replace(".trn", "")))
                if pattern_path.endswith(".trl.txt"):
                    candidates.append(os.path.join(proto_dir, pattern_path.replace(".trl", "")))

        existing = [path for path in candidates if os.path.exists(path)]
        if not existing:
            raise FileNotFoundError(
                f"Không tìm thấy protocol cho partition={self.partition} trong {', '.join(protocol_dirs)}."
            )
        return existing[0]

    def _load_metadata(self) -> List[ASVExample]:
        examples: List[ASVExample] = []
        with open(self.protocol_file, "r", encoding="utf-8") as handle:
            reader = csv.reader(handle, delimiter=" ")
            for row in reader:
                tokens = [tok for tok in row if tok]
                if not tokens:
                    continue
                if len(tokens) < 3:
                    raise ValueError(f"Không thể parse dòng protocol: {tokens}")

                speaker_id, utt_id = tokens[:2]
                label_token = tokens[-1].lower()

                middle_tokens = tokens[2:-1]
                system_id: Optional[str] = None
                attack_type: Optional[str] = None
                if middle_tokens:
                    system_candidate = middle_tokens[-1]
                    if system_candidate != "-":
                        system_id = system_candidate
                    if len(middle_tokens) >= 2:
                        attack_candidate = middle_tokens[0]
                        if attack_candidate != "-":
                            attack_type = attack_candidate

                if label_token not in LA_LABELS:
                    raise ValueError(f"Nhãn không hợp lệ: {label_token}")

                partition_dir = self.dataset_spec.partition_dir_template.format(partition=self.partition)
                audio_path: Optional[str] = None
                for subdir in self.dataset_spec.audio_subdirs:
                    audio_dir = os.path.join(self.data_root, partition_dir, subdir)
                    for ext in self.dataset_spec.audio_extensions:
                        candidate = os.path.join(audio_dir, f"{utt_id}{ext}")
                        if os.path.exists(candidate):
                            audio_path = candidate
                            break
                    if audio_path is not None:
                        break
                if audio_path is None:
                    raise FileNotFoundError(
                        f"Không tìm thấy file audio cho {utt_id} trong {partition_dir}"
                    )

                examples.append(
                    ASVExample(
                        utt_id=utt_id,
                        speaker_id=speaker_id,
                        path=audio_path,
                        label=LA_LABELS[label_token],
                        system_id=system_id,
                        attack_type=attack_type,
                    )
                )
        return examples

    def __len__(self) -> int:
        return len(self.examples)

    def __getitem__(self, index: int) -> Dict[str, Any]:
        example = self.examples[index]
        if example.utt_id in self._waveform_cache:
            waveform = self._waveform_cache[example.utt_id]
        else:
            waveform, _ = load_audio(example.path, self.sample_rate, normalize=True)
            waveform = pad_or_trim(waveform, self.max_num_samples, mode=self.pad_mode)

        features = self.feature_extractor(waveform)
        sample: Dict[str, Any] = {
            "utt_id": example.utt_id,
            "speaker_id": example.speaker_id,
            "label": torch.tensor(example.label, dtype=torch.long),
            "features": features,
        }

        metadata = {}
        if example.system_id is not None:
            metadata["system_id"] = example.system_id
        if example.attack_type is not None:
            metadata["attack_type"] = example.attack_type
        if metadata:
            sample["meta"] = metadata
        return sample


def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Ghép batch và giữ các nhánh đặc trưng ở dạng riêng biệt."""
    labels = torch.stack([item["label"] for item in batch], dim=0)

    branch_tensors: Dict[str, List[Tensor]] = {}
    for item in batch:
        features: MultiBranchFeatures = item["features"]
        for branch_name, tensor in features.items():
            branch_tensors.setdefault(branch_name, []).append(tensor)

    stacked_features = {
        branch_name: torch.stack(tensors, dim=0)
        for branch_name, tensors in branch_tensors.items()
    }

    output: Dict[str, Any] = {
        "features": stacked_features,
        "labels": labels,
        "utt_ids": [item["utt_id"] for item in batch],
        "speaker_ids": [item["speaker_id"] for item in batch],
    }

    metas = [item.get("meta") for item in batch]
    if any(meta is not None for meta in metas):
        output["meta"] = metas
    return output


In [None]:
@dataclass
class PartitionConfig:
    partition: str
    protocol_file: Optional[str] = None
    batch_size: int = 32
    shuffle: bool = True
    drop_last: bool = False


@dataclass
class DataModuleConfig:
    data_root: str
    sample_rate: int = 16000
    max_duration: float = 6.0
    pad_mode: str = "repeat"
    num_workers: int = 4
    pin_memory: bool = True
    prefetch_factor: int = 2
    feature: FeatureConfig = field(default_factory=FeatureConfig)
    train: Optional[PartitionConfig] = None
    valid: Optional[PartitionConfig] = None
    test: Optional[PartitionConfig] = None
    preload_waveforms: bool = False
    dataset_variant: str = DEFAULT_DATASET_VARIANT


class ASVspoofDataModule:
    """Đóng gói logic tạo DataLoader cho train / validation / test."""

    def __init__(self, config: DataModuleConfig) -> None:
        if config.train is None or config.valid is None:
            raise ValueError("Cần cấu hình train và valid partitions.")
        self.config = config
        self.feature_extractor = MultiBranchFeatureExtractor(config.feature)
        self._datasets: Dict[str, ASVspoofLADataset] = {}
        self._rng = self._build_generator()

    def setup(self, stage: Optional[str] = None) -> None:
        if stage in (None, "fit"):
            self._datasets["train"] = self._build_dataset(self.config.train)
            self._datasets["valid"] = self._build_dataset(self.config.valid)
        if stage in (None, "test") and self.config.test is not None:
            self._datasets["test"] = self._build_dataset(self.config.test)

    def _build_dataset(self, part_cfg: PartitionConfig) -> ASVspoofLADataset:
        return ASVspoofLADataset(
            data_root=self.config.data_root,
            partition=part_cfg.partition,
            protocol_file=part_cfg.protocol_file,
            feature_extractor=self.feature_extractor,
            sample_rate=self.config.sample_rate,
            max_duration=self.config.max_duration,
            pad_mode=self.config.pad_mode,
            preload_waveforms=self.config.preload_waveforms,
            dataset_variant=self.config.dataset_variant,
        )

    def _build_loader(self, dataset: ASVspoofLADataset, part_cfg: PartitionConfig) -> DataLoader:
        loader_kwargs = dict(
            dataset=dataset,
            batch_size=part_cfg.batch_size,
            shuffle=part_cfg.shuffle,
            drop_last=part_cfg.drop_last,
            num_workers=self.config.num_workers,
            pin_memory=self.config.pin_memory,
            collate_fn=collate_fn,
            generator=self._rng,
        )
        if self.config.num_workers > 0:
            loader_kwargs["prefetch_factor"] = self.config.prefetch_factor
        return DataLoader(**loader_kwargs)

    def train_dataloader(self) -> DataLoader:
        return self._build_loader(self._datasets["train"], self.config.train)

    def val_dataloader(self) -> DataLoader:
        return self._build_loader(self._datasets["valid"], self.config.valid)

    def test_dataloader(self) -> DataLoader:
        if "test" not in self._datasets:
            raise RuntimeError("Chưa cấu hình test dataset.")
        assert self.config.test is not None
        return self._build_loader(self._datasets["test"], self.config.test)

    def _build_generator(self) -> torch.Generator:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        generator = torch.Generator(device=device)
        generator.manual_seed(torch.initial_seed())
        return generator


## 5. Kiến trúc mô hình đa nhánh
Các cell này tương ứng với `src/models/multi_branch_model.py`.

In [None]:
def conv2d_block(
    in_channels: int,
    out_channels: int,
    kernel_size: Tuple[int, int] = (3, 3),
    stride: Tuple[int, int] = (1, 1),
    padding: Tuple[int, int] = (1, 1),
    dropout: float = 0.0,
) -> nn.Sequential:
    layers = [
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    ]
    if dropout > 0:
        layers.append(nn.Dropout2d(dropout))
    return nn.Sequential(*layers)


def conv1d_block(
    in_channels: int,
    out_channels: int,
    kernel_size: int = 3,
    stride: int = 1,
    padding: int = 1,
    dropout: float = 0.0,
) -> nn.Sequential:
    layers = [
        nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding),
        nn.BatchNorm1d(out_channels),
        nn.ReLU(inplace=True),
    ]
    if dropout > 0:
        layers.append(nn.Dropout(dropout))
    return nn.Sequential(*layers)

In [None]:
class SpectralBranch(nn.Module):
    def __init__(self, in_channels: int = 1, hidden_dim: int = 256) -> None:
        super().__init__()
        self.features = nn.Sequential(
            conv2d_block(in_channels, 32, dropout=0.1),
            nn.MaxPool2d((2, 2)),
            conv2d_block(32, 64, dropout=0.15),
            nn.MaxPool2d((2, 2)),
            conv2d_block(64, 128, dropout=0.2),
            conv2d_block(128, 128, dropout=0.2),
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.proj = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.features(x)
        x = self.pool(x)
        x = self.proj(x)
        return x


class TemporalBranch(nn.Module):
    def __init__(self, in_channels: int = 1, hidden_dim: int = 256) -> None:
        super().__init__()
        self.conv_stack = nn.Sequential(
            conv1d_block(in_channels, 32, kernel_size=11, stride=2, padding=5, dropout=0.1),
            conv1d_block(32, 64, kernel_size=9, stride=2, padding=4, dropout=0.1),
            conv1d_block(64, 128, kernel_size=7, stride=2, padding=3, dropout=0.15),
            conv1d_block(128, 128, kernel_size=5, stride=1, padding=2, dropout=0.15),
        )
        self.temporal_pool = nn.AdaptiveAvgPool1d(1)
        self.proj = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv_stack(x)
        x = self.temporal_pool(x)
        x = self.proj(x)
        return x


class CepstralBranch(nn.Module):
    def __init__(self, in_channels: int = 1, hidden_dim: int = 256) -> None:
        super().__init__()
        self.features = nn.Sequential(
            conv2d_block(in_channels, 32, kernel_size=(3, 5), padding=(1, 2), dropout=0.1),
            nn.MaxPool2d((2, 2)),
            conv2d_block(32, 64, kernel_size=(3, 5), padding=(1, 2), dropout=0.15),
            nn.MaxPool2d((2, 2)),
            conv2d_block(64, 128, kernel_size=(3, 3), padding=(1, 1), dropout=0.2),
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.proj = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.features(x)
        x = self.pool(x)
        x = self.proj(x)
        return x

In [None]:
class AttentionFusion(nn.Module):
    """Tầng self-attention đơn giản trên embedding của các nhánh."""

    def __init__(self, embed_dim: int, attn_dim: int = 128, dropout: float = 0.1) -> None:
        super().__init__()
        self.proj = nn.Linear(embed_dim, attn_dim)
        self.score = nn.Linear(attn_dim, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, branch_embeddings: Tensor) -> Tuple[Tensor, Tensor]:
        attn_hidden = torch.tanh(self.proj(branch_embeddings))
        scores = self.score(attn_hidden).squeeze(-1)
        weights = torch.softmax(scores, dim=-1)
        branch_embeddings = self.dropout(branch_embeddings)
        fused = torch.sum(branch_embeddings * weights.unsqueeze(-1), dim=1)
        return fused, weights


@dataclass
class MultiBranchModelConfig:
    embed_dim: int = 256
    attn_dim: int = 128
    num_classes: int = 2
    classifier_hidden: int = 128
    dropout: float = 0.3


class MultiBranchAttentionModel(nn.Module):
    """Kiến trúc đa nhánh với attention fusion."""

    def __init__(self, config: MultiBranchModelConfig) -> None:
        super().__init__()
        self.config = config
        self.branches = nn.ModuleDict(
            {
                "spectral": SpectralBranch(in_channels=1, hidden_dim=config.embed_dim),
                "temporal": TemporalBranch(in_channels=1, hidden_dim=config.embed_dim),
                "cepstral": CepstralBranch(in_channels=1, hidden_dim=config.embed_dim),
            }
        )
        self.fusion = AttentionFusion(config.embed_dim, config.attn_dim, dropout=config.dropout)
        self.classifier = nn.Sequential(
            nn.Linear(config.embed_dim, config.classifier_hidden),
            nn.ReLU(inplace=True),
            nn.Dropout(config.dropout),
            nn.Linear(config.classifier_hidden, config.num_classes),
        )

    def forward(self, features: Dict[str, Tensor]) -> Dict[str, Tensor]:
        branch_outputs = []
        attn_order = []
        for branch_name, module in self.branches.items():
            if branch_name not in features:
                raise KeyError(f"Thiếu nhánh {branch_name} trong input features.")
            branch_out = module(features[branch_name])
            branch_outputs.append(branch_out.unsqueeze(1))
            attn_order.append(branch_name)

        branch_stack = torch.cat(branch_outputs, dim=1)
        fused, weights = self.fusion(branch_stack)
        logits = self.classifier(fused)
        return {
            "logits": logits,
            "fused": fused,
            "attention_weights": weights,
            "branch_embeddings": branch_stack,
            "branch_order": attn_order,
        }

## 6. Hàm đánh giá và metric
Tương ứng với `src/utils/metrics.py`.

In [None]:
DEFAULT_TDCF_PARAMS = {
    "P_tar": 0.9802,
    "P_non": 0.0091,
    "P_spoof": 0.0107,
    "C_miss_asv": 1.0,
    "C_fa_asv": 10.0,
    "C_miss_cm": 1.0,
    "C_fa_cm": 10.0,
    "P_miss_asv": 0.05,
    "P_fa_asv": 0.01,
    "P_fa_asv_spoof": 0.30,
}


def compute_accuracy(logits: Tensor, labels: Tensor) -> float:
    preds = torch.argmax(logits, dim=1)
    correct = (preds == labels).sum().item()
    return correct / labels.numel()


def compute_eer(scores: Tensor, labels: Tensor) -> float:
    """Tính Equal Error Rate (EER)."""
    labels_np = labels.detach().cpu().numpy().astype(np.int32)
    scores_np = scores.detach().cpu().numpy()

    order = np.argsort(scores_np)[::-1]
    sorted_labels = labels_np[order]

    positives = sorted_labels.sum()
    negatives = len(sorted_labels) - positives
    if positives == 0 or negatives == 0:
        return 0.0

    false_accepts = 0
    false_rejects = positives
    min_gap = 1.0
    eer = 1.0

    for label in sorted_labels:
        if label == 1:
            false_rejects -= 1
        else:
            false_accepts += 1

        far = false_accepts / negatives
        frr = false_rejects / positives
        gap = abs(far - frr)
        if gap < min_gap:
            min_gap = gap
            eer = (far + frr) / 2.0
    return float(eer)


def compute_tdcf(scores: Tensor, labels: Tensor, params: Optional[Dict[str, float]] = None) -> float:
    labels_np = labels.detach().cpu().numpy().astype(np.int32)
    scores_np = scores.detach().cpu().numpy()

    bona_scores = scores_np[labels_np == 0]
    spoof_scores = scores_np[labels_np == 1]
    if bona_scores.size == 0 or spoof_scores.size == 0:
        return 0.0

    params = {**DEFAULT_TDCF_PARAMS, **(params or {})}

    c_miss_asv = params["C_miss_asv"]
    c_fa_asv = params["C_fa_asv"]
    c_miss_cm = params["C_miss_cm"]
    c_fa_cm = params["C_fa_cm"]
    p_tar = params["P_tar"]
    p_non = params["P_non"]
    p_spoof = params["P_spoof"]
    p_miss_asv = params["P_miss_asv"]
    p_fa_asv = params["P_fa_asv"]
    p_fa_asv_spoof = params["P_fa_asv_spoof"]

    thresholds = np.concatenate(([-np.inf], np.sort(scores_np), [np.inf]))
    c_default = min(c_miss_asv * p_tar, c_fa_asv * p_non)
    if c_default <= 0:
        c_default = 1.0

    asv_term = c_miss_asv * p_tar * p_miss_asv + c_fa_asv * p_non * p_fa_asv

    tdcf_values = []
    for tau in thresholds:
        p_miss_cm = float(np.mean(bona_scores >= tau))
        p_fa_cm = float(np.mean(spoof_scores < tau))
        cm_term = (
            c_miss_cm * p_tar * (1.0 - p_miss_asv) * p_miss_cm
            + c_fa_cm * p_spoof * (1.0 - p_fa_asv_spoof) * p_fa_cm
        )
        tdcf_values.append((asv_term + cm_term) / c_default)

    return float(np.min(tdcf_values))


def aggregate_metrics(logits: Tensor, labels: Tensor) -> Dict[str, float]:
    probs = torch.softmax(logits, dim=-1)
    spoof_scores = probs[:, 1]
    accuracy = compute_accuracy(logits, labels)
    eer = compute_eer(spoof_scores, labels)
    tdcf = compute_tdcf(spoof_scores, labels)
    return {"accuracy": accuracy, "eer": eer, "t_dcf": tdcf}

## 7. Vòng lặp huấn luyện
Các cell dưới đây tương ứng với `src/training/engine.py`.

In [None]:
@dataclass
class OptimizerConfig:
    lr: float = 1e-4
    weight_decay: float = 1e-5
    betas: tuple = (0.9, 0.98)
    eps: float = 1e-8


@dataclass
class SchedulerConfig:
    use_cosine: bool = True
    min_lr: float = 1e-6
    t_max: Optional[int] = None


@dataclass
class TrainingConfig:
    epochs: int = 50
    device: Optional[str] = None
    log_interval: int = 20
    grad_clip: float = 5.0
    mixed_precision: bool = True
    checkpoint_dir: str = "checkpoints"
    best_metric: str = "eer"
    patience: int = 10
    resume_from: Optional[str] = None
    save_every: int = 0
    history: List[Dict[str, float]] = field(default_factory=list)
    evaluate_on_test: bool = False
    model_output_path: Optional[str] = None

In [None]:
class Trainer:
    def __init__(
        self,
        model: nn.Module,
        train_config: TrainingConfig,
        optim_config: OptimizerConfig,
        scheduler_config: SchedulerConfig,
    ) -> None:
        self.model = model
        self.train_config = train_config
        self.optim_config = optim_config
        self.scheduler_config = scheduler_config

        device_str = (
            train_config.device
            if train_config.device is not None
            else ("cuda" if torch.cuda.is_available() else "cpu")
        )
        self.device = torch.device(device_str)
        self.model.to(self.device)

        self.criterion = nn.CrossEntropyLoss()
        self.scaler = torch.cuda.amp.GradScaler(
            enabled=(self.device.type == "cuda" and train_config.mixed_precision)
        )
        self.best_metric_value: Optional[float] = None
        self.best_epoch: Optional[int] = None

    def fit(self, datamodule: ASVspoofDataModule) -> Dict[str, List[Dict[str, float]]]:
        datamodule.setup(stage="fit")
        train_loader = datamodule.train_dataloader()
        valid_loader = datamodule.val_dataloader()

        optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=self.optim_config.lr,
            weight_decay=self.optim_config.weight_decay,
            betas=self.optim_config.betas,
            eps=self.optim_config.eps,
        )
        scheduler = self._build_scheduler(optimizer)
        os.makedirs(self.train_config.checkpoint_dir, exist_ok=True)

        history = {"train": [], "valid": []}
        patience_counter = 0

        for epoch in range(1, self.train_config.epochs + 1):
            train_metrics = self._run_epoch(
                loader=train_loader,
                optimizer=optimizer,
                scheduler=scheduler,
                epoch=epoch,
                train=True,
            )
            valid_metrics = self._run_epoch(
                loader=valid_loader,
                optimizer=None,
                scheduler=None,
                epoch=epoch,
                train=False,
            )

            history["train"].append(train_metrics)
            history["valid"].append(valid_metrics)
            self.train_config.history.append(
                {"epoch": epoch, **train_metrics, **{f"val_{k}": v for k, v in valid_metrics.items()}}
            )

            self._log_epoch_metrics(epoch, train_metrics, valid_metrics)

            current_metric = valid_metrics.get(self.train_config.best_metric)
            if current_metric is None:
                raise KeyError(
                    f"Không tìm thấy metric {self.train_config.best_metric} trong valid metrics: {valid_metrics}"
                )

            if self._is_better(current_metric):
                self.best_metric_value = current_metric
                self.best_epoch = epoch
                patience_counter = 0
                self._save_checkpoint(optimizer, epoch, best=True)
            else:
                patience_counter += 1

            if self.train_config.save_every > 0 and epoch % self.train_config.save_every == 0:
                self._save_checkpoint(optimizer, epoch, best=False)

            if scheduler is not None and getattr(scheduler, "step", None) is not None:
                scheduler.step()

            if patience_counter >= self.train_config.patience:
                print(f"[Trainer] Early stopping ở epoch {epoch}.")
                break

        return history

    def evaluate(self, datamodule: ASVspoofDataModule) -> Dict[str, float]:
        datamodule.setup(stage="test")
        test_loader = datamodule.test_dataloader()
        metrics = self._run_epoch(loader=test_loader, optimizer=None, scheduler=None, epoch=0, train=False)
        return metrics

    def _run_epoch(
        self,
        loader: DataLoader,
        optimizer: Optional[torch.optim.Optimizer],
        scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
        epoch: int,
        train: bool,
    ) -> Dict[str, float]:
        if train:
            self.model.train()
        else:
            self.model.eval()

        total_loss = 0.0
        total_samples = 0
        all_logits: List[Tensor] = []
        all_labels: List[Tensor] = []

        for step, batch in enumerate(loader, start=1):
            features = {
                name: tensor.to(self.device, non_blocking=True)
                for name, tensor in batch["features"].items()
            }
            labels = batch["labels"].to(self.device, non_blocking=True)
            batch_size = labels.size(0)

            with torch.set_grad_enabled(train):
                with torch.cuda.amp.autocast(enabled=self.scaler.is_enabled()):
                    outputs = self.model(features)
                    logits = outputs["logits"]
                    loss = self.criterion(logits, labels)

                if train:
                    assert optimizer is not None
                    self.scaler.scale(loss).backward()
                    if self.train_config.grad_clip > 0:
                        self.scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.train_config.grad_clip
                        )
                    self.scaler.step(optimizer)
                    self.scaler.update()
                    optimizer.zero_grad(set_to_none=True)

            total_loss += loss.item() * batch_size
            total_samples += batch_size
            all_logits.append(logits.detach().cpu())
            all_labels.append(labels.detach().cpu())

            if train and self.train_config.log_interval and step % self.train_config.log_interval == 0:
                current_loss = total_loss / total_samples
                lr = optimizer.param_groups[0]['lr'] if optimizer is not None else 0.0
                print(
                    f"[Epoch {epoch}] Step {step}/{len(loader)} "
                    f"Loss: {current_loss:.4f} LR: {lr:.2e}"
                )

        avg_loss = total_loss / max(total_samples, 1)
        logits_tensor = torch.cat(all_logits, dim=0)
        labels_tensor = torch.cat(all_labels, dim=0)
        metrics = aggregate_metrics(logits_tensor, labels_tensor)
        metrics["loss"] = avg_loss
        return metrics

    def _is_better(self, value: float) -> bool:
        if self.best_metric_value is None:
            return True
        if self.train_config.best_metric in {"loss", "eer"}:
            return value < self.best_metric_value
        return value > self.best_metric_value

    def _save_checkpoint(self, optimizer: Optional[torch.optim.Optimizer], epoch: int, best: bool) -> None:
        state = {
            "epoch": epoch,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict() if optimizer is not None else None,
            "best_metric": self.best_metric_value,
            "best_epoch": self.best_epoch,
        }
        suffix = "best" if best else f"epoch_{epoch:03d}"
        path = os.path.join(self.train_config.checkpoint_dir, f"checkpoint_{suffix}.pt")
        torch.save(state, path)
        tag = "BEST" if best else "SNAPSHOT"
        print(f"[Trainer] Đã lưu checkpoint ({tag}) tại {path}")

    def _build_scheduler(self, optimizer: torch.optim.Optimizer):
        if not self.scheduler_config.use_cosine:
            return None
        t_max = (
            self.scheduler_config.t_max
            if self.scheduler_config.t_max is not None
            else self.train_config.epochs
        )
        return torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=t_max,
            eta_min=self.scheduler_config.min_lr,
        )

    def load_checkpoint(self, path: str) -> None:
        if not os.path.exists(path):
            raise FileNotFoundError(f"Không tìm thấy checkpoint: {path}")
        state = torch.load(path, map_location=self.device)
        model_state = state.get("model_state_dict", state)
        self.model.load_state_dict(model_state)
        self.model.to(self.device)
        self.best_metric_value = state.get("best_metric", self.best_metric_value)
        self.best_epoch = state.get("best_epoch", self.best_epoch)
        print(f"[Trainer] Đã tải checkpoint từ {path}")

    def _log_epoch_metrics(
        self,
        epoch: int,
        train_metrics: Dict[str, float],
        valid_metrics: Dict[str, float],
    ) -> None:
        def _format(metrics: Dict[str, float]) -> str:
            ordered = sorted(metrics.items())
            return " ".join(f"{name}: {value:.4f}" for name, value in ordered)

        print(f"[Epoch {epoch}] Train metrics -> {_format(train_metrics)}")
        print(f"[Epoch {epoch}] Valid metrics -> {_format(valid_metrics)}")


## 8. Hàm tiện ích đọc YAML và xây cấu hình
Phần này tương đương `train.py` trong dự án gốc.

In [None]:
def load_yaml_config(path: str) -> Dict[str, Any]:
    with open(path, "r", encoding="utf-8") as handle:
        return yaml.safe_load(handle)


def build_feature_config(cfg: Dict[str, Any]) -> FeatureConfig:
    base = FeatureConfig()
    spectral = {**base.spectral, **cfg.get("spectral", {})}
    temporal = {**base.temporal, **cfg.get("temporal", {})}
    cepstral = {**base.cepstral, **cfg.get("cepstral", {})}
    return FeatureConfig(
        sample_rate=cfg.get("sample_rate", base.sample_rate),
        spectral=spectral,
        temporal=temporal,
        cepstral=cepstral,
    )


def build_partition_config(cfg: Dict[str, Any]) -> PartitionConfig:
    return PartitionConfig(
        partition=cfg["partition"],
        protocol_file=cfg.get("protocol_file"),
        batch_size=cfg.get("batch_size", 32),
        shuffle=cfg.get("shuffle", True),
        drop_last=cfg.get("drop_last", False),
    )


def build_data_module_config(cfg: Dict[str, Any]) -> DataModuleConfig:
    feature_cfg = build_feature_config(cfg.get("feature", {}))
    train_cfg = build_partition_config(cfg["train"])
    valid_cfg = build_partition_config(cfg["valid"])
    test_cfg = (
        build_partition_config(cfg["test"]) if cfg.get("test") is not None else None
    )
    return DataModuleConfig(
        data_root=cfg["data_root"],
        sample_rate=cfg.get("sample_rate", feature_cfg.sample_rate),
        max_duration=cfg.get("max_duration", 6.0),
        pad_mode=cfg.get("pad_mode", "repeat"),
        num_workers=cfg.get("num_workers", 4),
        pin_memory=cfg.get("pin_memory", True),
        prefetch_factor=cfg.get("prefetch_factor", 2),
        feature=feature_cfg,
        train=train_cfg,
        valid=valid_cfg,
        test=test_cfg,
        preload_waveforms=cfg.get("preload_waveforms", False),
        dataset_variant=cfg.get("dataset_variant", DEFAULT_DATASET_VARIANT),
    )


def build_model_config(cfg: Dict[str, Any]) -> MultiBranchModelConfig:
    base = MultiBranchModelConfig()
    return MultiBranchModelConfig(
        embed_dim=cfg.get("embed_dim", base.embed_dim),
        attn_dim=cfg.get("attn_dim", base.attn_dim),
        num_classes=cfg.get("num_classes", base.num_classes),
        classifier_hidden=cfg.get("classifier_hidden", base.classifier_hidden),
        dropout=cfg.get("dropout", base.dropout),
    )


def build_training_config(cfg: Dict[str, Any]) -> TrainingConfig:
    base = TrainingConfig()
    return TrainingConfig(
        epochs=cfg.get("epochs", base.epochs),
        device=cfg.get("device", base.device),
        log_interval=cfg.get("log_interval", base.log_interval),
        grad_clip=cfg.get("grad_clip", base.grad_clip),
        mixed_precision=cfg.get("mixed_precision", base.mixed_precision),
        checkpoint_dir=cfg.get("checkpoint_dir", base.checkpoint_dir),
        best_metric=cfg.get("best_metric", base.best_metric),
        patience=cfg.get("patience", base.patience),
        resume_from=cfg.get("resume_from", base.resume_from),
        save_every=cfg.get("save_every", base.save_every),
        evaluate_on_test=cfg.get("evaluate_on_test", base.evaluate_on_test),
        model_output_path=cfg.get("model_output_path", base.model_output_path),
    )


def build_optimizer_config(cfg: Dict[str, Any]) -> OptimizerConfig:
    base = OptimizerConfig()
    return OptimizerConfig(
        lr=cfg.get("lr", base.lr),
        weight_decay=cfg.get("weight_decay", base.weight_decay),
        betas=tuple(cfg.get("betas", base.betas)),
        eps=cfg.get("eps", base.eps),
    )


def build_scheduler_config(cfg: Dict[str, Any], total_epochs: int) -> SchedulerConfig:
    base = SchedulerConfig()
    use_cosine = cfg.get("use_cosine", base.use_cosine)
    min_lr = cfg.get("min_lr", base.min_lr)
    t_max = cfg.get("t_max", total_epochs if base.t_max is None else base.t_max)
    return SchedulerConfig(use_cosine=use_cosine, min_lr=min_lr, t_max=t_max)

## 9. Hàm chạy huấn luyện chính
Cell này gom tất cả lại tương đương hàm `main()` trong `train.py`.

In [None]:
def plot_training_curves(history: Dict[str, List[Dict[str, float]]], output_path: str) -> None:
    if not history or not history.get("train"):
        return

    output_dir = os.path.dirname(output_path) or "."
    os.makedirs(output_dir, exist_ok=True)

    train_history = history.get("train", [])
    valid_history = history.get("valid", [])
    if not train_history:
        return

    metrics = list(train_history[0].keys())
    epochs = list(range(1, len(train_history) + 1))
    num_metrics = len(metrics)
    fig, axes = plt.subplots(num_metrics, 1, figsize=(8, 4 * num_metrics), sharex=True)
    if num_metrics == 1:
        axes = [axes]

    for ax, metric in zip(axes, metrics):
        train_values = [record.get(metric) for record in train_history]
        val_values = [record.get(metric) for record in valid_history]
        ax.plot(epochs, train_values, label="Train", marker="o")
        if valid_history:
            ax.plot(epochs, val_values, label="Validation", marker="s")
        ax.set_ylabel(metric.replace("_", " ").title())
        ax.grid(True, linestyle="--", alpha=0.3)
        ax.legend()

    axes[-1].set_xlabel("Epoch")
    fig.suptitle("Training Curves", fontsize=14)
    fig.tight_layout(rect=[0, 0, 1, 0.97])
    fig.savefig(output_path, dpi=150)
    plt.close(fig)


def export_trained_model(trainer: Trainer, output_path: str) -> None:
    output_dir = os.path.dirname(output_path) or "."
    os.makedirs(output_dir, exist_ok=True)
    torch.save(trainer.model.state_dict(), output_path)
    print(f"[Trainer] Đã lưu trọng số mô hình tại {output_path}")


def prepare_inference_from_config(
    config: Dict[str, Any],
    checkpoint_path: str,
    device: Optional[str] = None,
):
    data_cfg = build_data_module_config(config["data"])
    model_cfg = build_model_config(config.get("model", {}))

    model = MultiBranchAttentionModel(model_cfg)
    device_str = device or ("cuda" if torch.cuda.is_available() else "cpu")
    device_obj = torch.device(device_str)
    model.to(device_obj)

    state = torch.load(checkpoint_path, map_location=device_obj)
    model_state = state.get("model_state_dict", state)
    model.load_state_dict(model_state)
    model.eval()

    feature_extractor = MultiBranchFeatureExtractor(data_cfg.feature)
    return model, feature_extractor, data_cfg, device_obj


def infer_audio_from_checkpoint(
    audio_path: str,
    config: Dict[str, Any],
    checkpoint_path: str,
    device: Optional[str] = None,
) -> Dict[str, Any]:
    model, feature_extractor, data_cfg, device_obj = prepare_inference_from_config(
        config, checkpoint_path, device=device
    )
    waveform, _ = load_audio(audio_path, data_cfg.sample_rate, normalize=True)
    max_samples = int(data_cfg.sample_rate * data_cfg.max_duration)
    waveform = pad_or_trim(waveform, max_samples, mode=data_cfg.pad_mode)

    features = feature_extractor(waveform)
    batched_features = {
        name: tensor.unsqueeze(0).to(device_obj, non_blocking=True)
        for name, tensor in features.items()
    }

    with torch.no_grad():
        outputs = model(batched_features)
        logits = outputs["logits"]
        probs = torch.softmax(logits, dim=-1)

    spoof_score = float(probs[0, 1].item())
    prediction_idx = int(torch.argmax(probs, dim=-1).item())
    prediction_label = "spoof" if prediction_idx == 1 else "bonafide"
    return {
        "probabilities": probs.squeeze(0).cpu().tolist(),
        "spoof_score": spoof_score,
        "prediction_index": prediction_idx,
        "prediction_label": prediction_label,
    }


def run_training_from_config(config: Dict[str, Any]) -> Dict[str, Any]:
    data_cfg = build_data_module_config(config["data"])
    model_cfg = build_model_config(config.get("model", {}))
    training_cfg = build_training_config(config.get("training", {}))
    optimizer_cfg = build_optimizer_config(config.get("optimizer", {}))
    scheduler_cfg = build_scheduler_config(
        config.get("scheduler", {}),
        total_epochs=training_cfg.epochs,
    )

    datamodule = ASVspoofDataModule(data_cfg)
    model = MultiBranchAttentionModel(model_cfg)
    trainer = Trainer(model, training_cfg, optimizer_cfg, scheduler_cfg)

    history = trainer.fit(datamodule)

    plot_path = os.path.join(training_cfg.checkpoint_dir, "training_curves.png")
    plot_training_curves(history, plot_path)

    best_checkpoint = os.path.join(training_cfg.checkpoint_dir, "checkpoint_best.pt")
    if os.path.exists(best_checkpoint):
        trainer.load_checkpoint(best_checkpoint)
    else:
        print(
            f"[Trainer] Không tìm thấy checkpoint tốt nhất tại {best_checkpoint}. "
            "Giữ nguyên trọng số cuối cùng."
        )

    model_path = training_cfg.model_output_path or os.path.join(
        training_cfg.checkpoint_dir, "multibranch_model.pt"
    )
    export_trained_model(trainer, model_path)

    results = {
        "history": history,
        "best_metric": trainer.best_metric_value,
        "best_epoch": trainer.best_epoch,
        "training_plot": plot_path,
        "model_path": model_path,
        "best_checkpoint": best_checkpoint if os.path.exists(best_checkpoint) else None,
    }

    if training_cfg.evaluate_on_test and data_cfg.test is not None:
        test_metrics = trainer.evaluate(datamodule)
        results["test_metrics"] = test_metrics

    return results

## 10. Ví dụ cấu hình
Bạn có thể chỉnh sửa trực tiếp dictionary bên dưới hoặc đọc YAML bằng `load_yaml_config`.

In [None]:
example_config = yaml.safe_load("""data:
  data_root: "/kaggle/input/asvpoof-2019-dataset/LA/LA"
  sample_rate: 16000
  max_duration: 6.0
  pad_mode: "repeat"
  num_workers: 4
  pin_memory: true
  prefetch_factor: 2
  preload_waveforms: false
  dataset_variant: "ASVspoof2019_LA"
  feature:
    spectral:
      n_mels: 128
      n_fft: 1024
      hop_length: 256
    temporal:
      emphasis: true
      highpass_cutoff: 20.0
    cepstral:
      n_bins: 96
      bins_per_octave: 12
  train:
    partition: "train"
    protocol_file: "/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt"
    batch_size: 16
    shuffle: true
    drop_last: true
  valid:
    partition: "dev"
    protocol_file: "/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt"
    batch_size: 16
    shuffle: false
    drop_last: false
  test: null
model:
  embed_dim: 256
  attn_dim: 128
  classifier_hidden: 128
  dropout: 0.3
training:
  epochs: 50
  device: "cuda"
  log_interval: 20
  grad_clip: 5.0
  mixed_precision: true
  checkpoint_dir: "checkpoints"
  best_metric: "eer"
  patience: 8
  save_every: 0
  evaluate_on_test: false
  model_output_path: null
optimizer:
  lr: 0.0002
  weight_decay: 0.00001
  betas: [0.9, 0.98]
scheduler:
  use_cosine: true
  min_lr: 0.000001
""")
example_config

## 11. Khám phá dữ liệu với cấu hình hiện tại
Chạy các cell sau sau khi bạn đã cập nhật `example_config` (hoặc cấu hình của riêng bạn) để rà soát nhanh cấu trúc thư mục, file protocol và nội dung audio.


In [None]:
from pathlib import Path
from itertools import islice

if "example_config" not in globals():
    raise RuntimeError("Vui lòng chạy cell cấu hình ví dụ (mục 10) trước khi khám phá dữ liệu.")

data_cfg = example_config["data"]
data_root = Path(data_cfg["data_root"]).expanduser()
if not data_root.exists():
    raise FileNotFoundError(f"Không tìm thấy thư mục dữ liệu: {data_root}")

variant_prefix = data_cfg.get("dataset_variant", "ASVspoof2019_LA")

def resolve_partition_dir(split_key: str) -> Path:
    partition_cfg = data_cfg[split_key]
    partition_name = partition_cfg.get("partition", split_key)
    custom_dir = partition_cfg.get("custom_dir")
    if custom_dir is not None:
        return data_root / custom_dir
    if partition_name.startswith(variant_prefix):
        return data_root / partition_name
    return data_root / f"{variant_prefix}_{partition_name}"

for split_key in ("train", "valid", "test"):
    if split_key not in data_cfg or data_cfg[split_key] is None:
        continue
    partition_dir = resolve_partition_dir(split_key)
    print(f"[{split_key.upper()}] {partition_dir}")
    if not partition_dir.exists():
        print("  -> Không tìm thấy thư mục này. Kiểm tra lại đường dẫn hoặc giải nén dữ liệu.")
        continue
    flac_dir = partition_dir / "flac"
    protocol_dir = partition_dir / "protocol"
    num_audio = sum(1 for _ in flac_dir.glob("*.flac")) if flac_dir.exists() else 0
    sample_files = list(islice(flac_dir.glob("*.flac"), 5)) if flac_dir.exists() else []
    print(f"  Số file âm thanh: {num_audio:,}")
    if sample_files:
        examples = ', '.join(path.name for path in sample_files)
        print(f"  Ví dụ file: {examples}")
    protocol_files = sorted(protocol_dir.glob("*.txt")) if protocol_dir.exists() else []
    if protocol_files:
        print("  Protocol:")
        for proto in protocol_files[:2]:
            print(f"    - {proto.name}")
    print()


In [None]:
from IPython.display import display
import pandas as pd

train_dir = resolve_partition_dir("train")
protocol_cfg = data_cfg["train"].get("protocol_file")
if protocol_cfg:
    protocol_path = Path(protocol_cfg)
    if not protocol_path.is_absolute():
        protocol_path = train_dir / "protocol" / protocol_cfg
else:
    protocol_candidates = sorted((train_dir / "protocol").glob("*.txt"))
    protocol_path = protocol_candidates[0] if protocol_candidates else None

if not protocol_path or not protocol_path.exists():
    raise FileNotFoundError("Không tìm thấy file protocol cho tập train.")

protocol_cols = ["speaker_id", "utt_id", "source", "attack_id", "label"]
protocol_df = pd.read_csv(
    protocol_path,
    sep=" ",
    header=None,
    names=protocol_cols,
)

display(protocol_df.head())
print("\nSố lượng mẫu theo nhãn:")
display(protocol_df["label"].value_counts())


In [None]:
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
from IPython.display import Audio, display

sample_row = protocol_df.sample(1, random_state=0) if len(protocol_df) > 1 else protocol_df
sample_row = sample_row.iloc[0]
audio_path = train_dir / "flac" / f"{sample_row['utt_id']}.flac"
if not audio_path.exists():
    raise FileNotFoundError(f"Không tìm thấy file audio: {audio_path}")

target_sr = data_cfg.get("sample_rate", 16000)
waveform, sr = librosa.load(audio_path, sr=target_sr)
duration = waveform.shape[0] / sr
print(f"Đang xem: {sample_row['utt_id']} ({sample_row['label']}) - {duration:.2f}s @ {sr} Hz")

display(Audio(waveform, rate=sr))

spectral_cfg = data_cfg.get("feature", {}).get("spectral", {})
n_fft = spectral_cfg.get("n_fft", 1024)
hop_length = spectral_cfg.get("hop_length", 256)
n_mels = spectral_cfg.get("n_mels", 128)
mel_spec = librosa.feature.melspectrogram(
    y=waveform, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
)
mel_db = librosa.power_to_db(mel_spec, ref=np.max)

fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=False)
time_axis = np.linspace(0, duration, num=waveform.shape[0])
axes[0].plot(time_axis, waveform)
axes[0].set_title("Waveform")
axes[0].set_xlabel("Thời gian (s)")
axes[0].set_ylabel("Biên độ")
axes[0].grid(True, linestyle="--", alpha=0.3)

img = librosa.display.specshow(
    mel_db,
    x_axis="time",
    y_axis="mel",
    sr=sr,
    hop_length=hop_length,
    cmap="magma",
    ax=axes[1],
)
axes[1].set_title("Mel-spectrogram (dB)")
fig.colorbar(img, ax=axes[1], format="%.0f dB")
fig.tight_layout()
plt.show()


## 12. Thực thi huấn luyện (tùy chọn)
Chỉ chạy cell này khi bạn đã mount đúng dữ liệu. Nếu muốn đọc từ file YAML, dùng `config = load_yaml_config(path)`.

In [None]:
# Ví dụ:
# config = load_yaml_config("/kaggle/input/asvspoof-configs/asvspoof_multibranch.yaml")
# results = run_training_from_config(config)
# results

### 12.1 Theo dõi lịch sử huấn luyện
Sau khi chạy cell huấn luyện ở trên, bạn có thể sử dụng đoạn mã dưới đây để xem bảng metric theo từng epoch và hiển thị biểu đồ đường cong huấn luyện đã lưu.


In [None]:
import pandas as pd
from pathlib import Path
from IPython.display import Image, display

if "results" not in globals():
    raise RuntimeError("Chưa tìm thấy biến `results`. Hãy chạy hàm huấn luyện trước khi tổng hợp lịch sử.")

history = results.get("history", {})
train_history = history.get("train", [])
valid_history = history.get("valid", [])

if not train_history:
    print("Lịch sử huấn luyện trống.")
else:
    rows = []
    for epoch_idx, (train_metrics, valid_metrics) in enumerate(zip(train_history, valid_history), start=1):
        row = {"epoch": epoch_idx, **train_metrics}
        row.update({f"val_{name}": value for name, value in valid_metrics.items()})
        rows.append(row)
    history_df = pd.DataFrame(rows)
    display(history_df)

best_metric = results.get("best_metric")
best_epoch = results.get("best_epoch")
if best_metric is not None and best_epoch is not None:
    print(f"Best validation metric đạt {best_metric:.4f} tại epoch {best_epoch}")

plot_path = Path(results.get("training_plot", ""))
if plot_path.exists():
    display(Image(filename=str(plot_path)))
elif train_history:
    inline_plot = Path("training_curves_inline.png")
    plot_training_curves(history, str(inline_plot))
    if inline_plot.exists():
        display(Image(filename=str(inline_plot)))
