# ASVspoof2019 LA Multi-Branch Training Notebook

Notebook này chuyển toàn bộ pipeline huấn luyện mô hình đa nhánh sang định dạng dễ chạy trên Kaggle. Các cell dưới đây gom lại toàn bộ logic từ project gốc với các chú thích chi tiết để bạn dễ tuỳ biến.

## 1. Thiết lập môi trường
Chạy cell dưới đây nếu bạn đang ở môi trường Kaggle và cần cài đặt phụ thuộc. Các gói trùng với bản `requirements.txt` gốc.

In [None]:
# Nếu chạy trên Kaggle Notebook bạn có thể bỏ comment dòng dưới để cài đặt phụ thuộc
# !pip install -q -r /kaggle/input/your-requirements/requirements.txt
# Hoặc khai báo thủ công:
# !pip install -q torch torchaudio librosa pyyaml

## 2. Import thư viện
Tập hợp toàn bộ import cần thiết cho pipeline.

In [None]:
from __future__ import annotations

import csv
import math
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

import librosa
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
import torchaudio
import yaml

## 3. Cấu hình đặc trưng và helper xử lý audio
Các hàm tiện ích gộp từ `src/utils/audio.py` và cấu hình đặc trưng ở `src/data/features.py`.

In [None]:
def load_audio(path: str, target_sample_rate: int, normalize: bool = True) -> Tuple[Tensor, int]:
    # Load audio, resample nếu cần và chuẩn hoá biên độ để phù hợp pipeline
    waveform, sample_rate = torchaudio.load(path)
    if sample_rate != target_sample_rate:
        resampler = torchaudio.transforms.Resample(sample_rate, target_sample_rate)
        waveform = resampler(waveform)
        sample_rate = target_sample_rate
    waveform = ensure_mono(waveform)
    if normalize:
        peak = waveform.abs().max()
        if peak > 0:
            waveform = waveform / peak
    return waveform, sample_rate


def ensure_mono(waveform: Tensor) -> Tensor:
    # Đảm bảo waveform là mono (1 kênh) bằng cách trung bình các kênh
    if waveform.size(0) == 1:
        return waveform
    return waveform.mean(dim=0, keepdim=True)


def pad_or_trim(waveform: Tensor, target_num_samples: int, mode: str = "repeat") -> Tensor:
    # Đưa waveform về độ dài cố định bằng cách cắt hoặc padding
    current = waveform.size(-1)
    if current == target_num_samples:
        return waveform
    if current > target_num_samples:
        return waveform[..., :target_num_samples]

    diff = target_num_samples - current
    if mode == "zeros":
        padded = torch.nn.functional.pad(waveform, (0, diff))
    elif mode == "reflect":
        padded = torch.nn.functional.pad(waveform, (0, diff), mode="reflect")
    elif mode == "repeat":
        repeats = math.ceil(target_num_samples / current)
        padded = waveform.repeat(1, repeats)[..., :target_num_samples]
    else:
        raise ValueError(f"pad_mode không được hỗ trợ: {mode}")

    return padded


class SpectralConfig(dict):
    # Lưu thông số mel-spectrogram
    pass


class TemporalConfig(dict):
    # Lưu tuỳ chọn xử lý waveform
    pass


class CepstralConfig(dict):
    # Lưu tham số CQT
    pass


@dataclass
class FeatureConfig:
    sample_rate: int = 16000
    spectral: SpectralConfig = field(
        default_factory=lambda: SpectralConfig(
            n_fft=1024,
            hop_length=256,
            win_length=1024,
            n_mels=128,
            f_min=20.0,
            f_max=None,
            power=2.0,
        )
    )
    temporal: TemporalConfig = field(
        default_factory=lambda: TemporalConfig(emphasis=True, highpass_cutoff=30.0)
    )
    cepstral: CepstralConfig = field(
        default_factory=lambda: CepstralConfig(
            hop_length=256,
            n_bins=84,
            bins_per_octave=12,
            f_min=32.7,
        )
    )


class MultiBranchFeatureExtractor:
    # Sinh đặc trưng mel, waveform và CQT cho ba nhánh mô hình
    def __init__(self, config: FeatureConfig) -> None:
        self.config = config
        spec_cfg = config.spectral
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=config.sample_rate,
            n_fft=spec_cfg.get("n_fft", 1024),
            hop_length=spec_cfg.get("hop_length", 256),
            win_length=spec_cfg.get("win_length", spec_cfg.get("n_fft", 1024)),
            f_min=spec_cfg.get("f_min", 20.0),
            f_max=spec_cfg.get("f_max"),
            n_mels=spec_cfg.get("n_mels", 128),
            power=spec_cfg.get("power", 2.0),
            normalized=False,
        )
        temp_cfg = config.temporal
        self.apply_pre_emphasis = temp_cfg.get("emphasis", True)
        self.highpass_cutoff = temp_cfg.get("highpass_cutoff", 30.0)
        self.cqt_cfg = config.cepstral

    def __call__(self, waveform: Tensor) -> Dict[str, Tensor]:
        waveform = ensure_mono(waveform)
        mel = self._compute_mel_spectrogram(waveform)
        temporal = self._prepare_temporal_branch(waveform)
        cqt = self._compute_cqt(waveform)
        return {"spectral": mel, "temporal": temporal, "cepstral": cqt}

    def _compute_mel_spectrogram(self, waveform: Tensor) -> Tensor:
        mel = self.mel_transform(waveform)
        return torch.log1p(mel)

    def _prepare_temporal_branch(self, waveform: Tensor) -> Tensor:
        if self.apply_pre_emphasis:
            waveform = torchaudio.functional.preemphasis(waveform, 0.97)
        if self.highpass_cutoff is not None and self.highpass_cutoff > 0:
            waveform = torchaudio.functional.highpass_biquad(
                waveform,
                sample_rate=self.config.sample_rate,
                cutoff_freq=self.highpass_cutoff,
            )
        return waveform

    def _compute_cqt(self, waveform: Tensor) -> Tensor:
        y = waveform.squeeze(0).cpu().numpy()
        cqt = librosa.cqt(
            y,
            sr=self.config.sample_rate,
            hop_length=self.cqt_cfg.get("hop_length", 256),
            n_bins=self.cqt_cfg.get("n_bins", 84),
            bins_per_octave=self.cqt_cfg.get("bins_per_octave", 12),
            fmin=self.cqt_cfg.get("f_min", 32.7),
        )
        magnitude = torch.from_numpy((abs(cqt) ** 2).astype("float32"))
        magnitude = torch.log1p(magnitude)
        return magnitude.unsqueeze(0)

## 4. Dataset ASVspoof2019 LA
Hợp nhất logic từ `src/data/asvspoof_dataset.py` với các chú thích chi tiết.

In [None]:
LA_LABELS = {"bonafide": 0, "spoof": 1}


@dataclass
class ASVExample:
    utt_id: str
    speaker_id: str
    path: str
    label: int
    system_id: Optional[str] = None
    attack_type: Optional[str] = None


class ASVspoofLADataset(Dataset):
    # Dataset PyTorch cho từng partition của ASVspoof2019 LA
    def __init__(
        self,
        data_root: str,
        partition: str,
        feature_extractor: MultiBranchFeatureExtractor,
        protocol_file: Optional[str] = None,
        sample_rate: int = 16000,
        max_duration: float = 6.0,
        pad_mode: str = "repeat",
        preload_waveforms: bool = False,
    ) -> None:
        super().__init__()
        self.data_root = data_root
        self.partition = partition
        self.sample_rate = sample_rate
        self.max_num_samples = int(sample_rate * max_duration)
        self.pad_mode = pad_mode
        self.feature_extractor = feature_extractor
        self.preload_waveforms = preload_waveforms

        if protocol_file is None:
            proto_dir = os.path.join(
                data_root,
                f"ASVspoof2019_LA_{partition}",
                "protocol",
            )
            pattern = f"ASVspoof2019.LA.cm.{partition}.trn.txt"
            candidates = [
                os.path.join(proto_dir, pattern),
                os.path.join(proto_dir, pattern.replace(".trn", "")),
            ]
            exists = [path for path in candidates if os.path.exists(path)]
            if not exists:
                raise FileNotFoundError(
                    f"Không tìm thấy protocol cho partition={partition}. Cần cung cấp protocol_file. Checked: {candidates}"
                )
            protocol_file = exists[0]

        self.protocol_file = protocol_file
        self.examples = self._load_metadata()

        if self.preload_waveforms:
            self._waveform_cache: Dict[str, Tensor] = {}
            for example in self.examples:
                waveform, _ = load_audio(example.path, self.sample_rate, normalize=True)
                waveform = pad_or_trim(waveform, self.max_num_samples, mode=self.pad_mode)
                self._waveform_cache[example.utt_id] = waveform
        else:
            self._waveform_cache = {}

    def _load_metadata(self) -> List[ASVExample]:
        examples: List[ASVExample] = []
        with open(self.protocol_file, "r", encoding="utf-8") as handle:
            reader = csv.reader(handle, delimiter=" ")
            for row in reader:
                tokens = [tok for tok in row if tok]
                if not tokens:
                    continue
                if len(tokens) == 4:
                    speaker_id, utt_id, system_id, label_token = tokens
                    attack_type = None
                elif len(tokens) >= 5:
                    speaker_id, utt_id, system_id, attack_type, label_token = tokens[:5]
                else:
                    raise ValueError(f"Không thể parse dòng protocol: {tokens}")

                label_token = label_token.lower()
                if label_token not in LA_LABELS:
                    raise ValueError(f"Nhãn không hợp lệ: {label_token}")

                partition_dir = f"ASVspoof2019_LA_{self.partition}"
                audio_dir = os.path.join(self.data_root, partition_dir, "flac")
                audio_path = os.path.join(audio_dir, f"{utt_id}.flac")
                if not os.path.exists(audio_path):
                    wav_path = os.path.join(audio_dir, f"{utt_id}.wav")
                    if os.path.exists(wav_path):
                        audio_path = wav_path
                    else:
                        raise FileNotFoundError(f"Không tìm thấy file audio cho {utt_id}")

                examples.append(
                    ASVExample(
                        utt_id=utt_id,
                        speaker_id=speaker_id,
                        path=audio_path,
                        label=LA_LABELS[label_token],
                        system_id=system_id,
                        attack_type=attack_type,
                    )
                )
        return examples

    def __len__(self) -> int:
        return len(self.examples)

    def __getitem__(self, index: int) -> Dict[str, Any]:
        example = self.examples[index]
        if example.utt_id in self._waveform_cache:
            waveform = self._waveform_cache[example.utt_id]
        else:
            waveform, _ = load_audio(example.path, self.sample_rate, normalize=True)
            waveform = pad_or_trim(waveform, self.max_num_samples, mode=self.pad_mode)

        features = self.feature_extractor(waveform)
        sample: Dict[str, Any] = {
            "utt_id": example.utt_id,
            "speaker_id": example.speaker_id,
            "label": torch.tensor(example.label, dtype=torch.long),
            "features": features,
        }

        metadata = {}
        if example.system_id is not None:
            metadata["system_id"] = example.system_id
        if example.attack_type is not None:
            metadata["attack_type"] = example.attack_type
        if metadata:
            sample["meta"] = metadata
        return sample


def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    labels = torch.stack([item["label"] for item in batch], dim=0)
    branch_tensors: Dict[str, List[Tensor]] = {}
    for item in batch:
        for branch_name, tensor in item["features"].items():
            branch_tensors.setdefault(branch_name, []).append(tensor)
    stacked_features = {
        branch_name: torch.stack(tensors, dim=0)
        for branch_name, tensors in branch_tensors.items()
    }
    output = {
        "features": stacked_features,
        "labels": labels,
        "utt_ids": [item["utt_id"] for item in batch],
        "speaker_ids": [item["speaker_id"] for item in batch],
    }
    metas = [item.get("meta") for item in batch]
    if any(meta is not None for meta in metas):
        output["meta"] = metas
    return output

## 5. DataModule tiện dụng
Giữ nguyên cấu hình từ `src/data/datamodule.py` để tạo DataLoader.

In [None]:
@dataclass
class PartitionConfig:
    partition: str
    protocol_file: Optional[str] = None
    batch_size: int = 32
    shuffle: bool = True
    drop_last: bool = False


@dataclass
class DataModuleConfig:
    data_root: str
    sample_rate: int = 16000
    max_duration: float = 6.0
    pad_mode: str = "repeat"
    num_workers: int = 4
    pin_memory: bool = True
    prefetch_factor: int = 2
    feature: FeatureConfig = field(default_factory=FeatureConfig)
    train: Optional[PartitionConfig] = None
    valid: Optional[PartitionConfig] = None
    test: Optional[PartitionConfig] = None
    preload_waveforms: bool = False


class ASVspoofDataModule:
    # Chuẩn bị DataLoader cho train/dev/test với cùng cấu hình
    def __init__(self, config: DataModuleConfig) -> None:
        if config.train is None or config.valid is None:
            raise ValueError("Cần cấu hình partition train và valid")
        self.config = config
        self.feature_extractor = MultiBranchFeatureExtractor(config.feature)
        self._datasets: Dict[str, ASVspoofLADataset] = {}

    def setup(self, stage: Optional[str] = None) -> None:
        if stage in (None, "fit"):
            self._datasets["train"] = self._build_dataset(self.config.train)
            self._datasets["valid"] = self._build_dataset(self.config.valid)
        if stage in (None, "test") and self.config.test is not None:
            self._datasets["test"] = self._build_dataset(self.config.test)

    def _build_dataset(self, part_cfg: PartitionConfig) -> ASVspoofLADataset:
        return ASVspoofLADataset(
            data_root=self.config.data_root,
            partition=part_cfg.partition,
            protocol_file=part_cfg.protocol_file,
            feature_extractor=self.feature_extractor,
            sample_rate=self.config.sample_rate,
            max_duration=self.config.max_duration,
            pad_mode=self.config.pad_mode,
            preload_waveforms=self.config.preload_waveforms,
        )

    def train_dataloader(self) -> DataLoader:
        return self._build_loader(self._datasets["train"], self.config.train)

    def val_dataloader(self) -> DataLoader:
        return self._build_loader(self._datasets["valid"], self.config.valid)

    def test_dataloader(self) -> DataLoader:
        if "test" not in self._datasets:
            raise RuntimeError("Chưa setup test dataset")
        return self._build_loader(self._datasets["test"], self.config.test)

    def _build_loader(self, dataset: ASVspoofLADataset, part_cfg: PartitionConfig) -> DataLoader:
        loader_kwargs = {
            "dataset": dataset,
            "batch_size": part_cfg.batch_size,
            "shuffle": part_cfg.shuffle,
            "drop_last": part_cfg.drop_last,
            "num_workers": self.config.num_workers,
            "pin_memory": self.config.pin_memory,
            "collate_fn": collate_fn,
        }
        if self.config.num_workers > 0:
            loader_kwargs["prefetch_factor"] = self.config.prefetch_factor
        return DataLoader(**loader_kwargs)

## 6. Kiến trúc mô hình đa nhánh
Sao chép từ `src/models/multi_branch_model.py` với chú thích cho từng thành phần.

In [None]:
def conv2d_block(
    in_channels: int,
    out_channels: int,
    kernel_size: Tuple[int, int] = (3, 3),
    stride: Tuple[int, int] = (1, 1),
    padding: Tuple[int, int] = (1, 1),
    dropout: float = 0.0,
) -> nn.Sequential:
    layers = [
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    ]
    if dropout > 0:
        layers.append(nn.Dropout2d(dropout))
    return nn.Sequential(*layers)


def conv1d_block(
    in_channels: int,
    out_channels: int,
    kernel_size: int = 3,
    stride: int = 1,
    padding: int = 1,
    dropout: float = 0.0,
) -> nn.Sequential:
    layers = [
        nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding),
        nn.BatchNorm1d(out_channels),
        nn.ReLU(inplace=True),
    ]
    if dropout > 0:
        layers.append(nn.Dropout(dropout))
    return nn.Sequential(*layers)


class SpectralBranch(nn.Module):
    def __init__(self, in_channels: int = 1, hidden_dim: int = 256) -> None:
        super().__init__()
        self.features = nn.Sequential(
            conv2d_block(in_channels, 32, dropout=0.1),
            nn.MaxPool2d((2, 2)),
            conv2d_block(32, 64, dropout=0.15),
            nn.MaxPool2d((2, 2)),
            conv2d_block(64, 128, dropout=0.2),
            conv2d_block(128, 128, dropout=0.2),
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.proj = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.proj(self.pool(self.features(x)))


class TemporalBranch(nn.Module):
    def __init__(self, in_channels: int = 1, hidden_dim: int = 256) -> None:
        super().__init__()
        self.conv_stack = nn.Sequential(
            conv1d_block(in_channels, 32, kernel_size=11, stride=2, padding=5, dropout=0.1),
            conv1d_block(32, 64, kernel_size=9, stride=2, padding=4, dropout=0.1),
            conv1d_block(64, 128, kernel_size=7, stride=2, padding=3, dropout=0.15),
            conv1d_block(128, 128, kernel_size=5, stride=1, padding=2, dropout=0.15),
        )
        self.temporal_pool = nn.AdaptiveAvgPool1d(1)
        self.proj = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.proj(self.temporal_pool(self.conv_stack(x)))


class CepstralBranch(nn.Module):
    def __init__(self, in_channels: int = 1, hidden_dim: int = 256) -> None:
        super().__init__()
        self.features = nn.Sequential(
            conv2d_block(in_channels, 32, kernel_size=(3, 5), padding=(1, 2), dropout=0.1),
            nn.MaxPool2d((2, 2)),
            conv2d_block(32, 64, kernel_size=(3, 5), padding=(1, 2), dropout=0.15),
            nn.MaxPool2d((2, 2)),
            conv2d_block(64, 128, kernel_size=(3, 3), padding=(1, 1), dropout=0.2),
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.proj = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.proj(self.pool(self.features(x)))


class AttentionFusion(nn.Module):
    def __init__(self, embed_dim: int, attn_dim: int = 128, dropout: float = 0.1) -> None:
        super().__init__()
        self.proj = nn.Linear(embed_dim, attn_dim)
        self.score = nn.Linear(attn_dim, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, branch_embeddings: Tensor) -> Tuple[Tensor, Tensor]:
        attn_hidden = torch.tanh(self.proj(branch_embeddings))
        scores = self.score(attn_hidden).squeeze(-1)
        weights = torch.softmax(scores, dim=-1)
        branch_embeddings = self.dropout(branch_embeddings)
        fused = torch.sum(branch_embeddings * weights.unsqueeze(-1), dim=1)
        return fused, weights


@dataclass
class MultiBranchModelConfig:
    embed_dim: int = 256
    attn_dim: int = 128
    num_classes: int = 2
    classifier_hidden: int = 128
    dropout: float = 0.3


class MultiBranchAttentionModel(nn.Module):
    # Kiến trúc đa nhánh với attention fusion
    def __init__(self, config: MultiBranchModelConfig) -> None:
        super().__init__()
        self.config = config
        self.branches = nn.ModuleDict(
            {
                "spectral": SpectralBranch(hidden_dim=config.embed_dim),
                "temporal": TemporalBranch(hidden_dim=config.embed_dim),
                "cepstral": CepstralBranch(hidden_dim=config.embed_dim),
            }
        )
        self.fusion = AttentionFusion(config.embed_dim, config.attn_dim, dropout=config.dropout)
        self.classifier = nn.Sequential(
            nn.Linear(config.embed_dim, config.classifier_hidden),
            nn.ReLU(inplace=True),
            nn.Dropout(config.dropout),
            nn.Linear(config.classifier_hidden, config.num_classes),
        )

    def forward(self, features: Dict[str, Tensor]) -> Dict[str, Tensor]:
        branch_outputs = []
        attn_order = []
        for branch_name, module in self.branches.items():
            if branch_name not in features:
                raise KeyError(f"Thiếu nhánh {branch_name} trong input")
            branch_out = module(features[branch_name])
            branch_outputs.append(branch_out.unsqueeze(1))
            attn_order.append(branch_name)
        branch_stack = torch.cat(branch_outputs, dim=1)
        fused, weights = self.fusion(branch_stack)
        logits = self.classifier(fused)
        return {
            "logits": logits,
            "fused": fused,
            "attention_weights": weights,
            "branch_embeddings": branch_stack,
            "branch_order": attn_order,
        }

## 7. Các hàm metric
Bao gồm accuracy và EER như trong `src/utils/metrics.py`.

In [None]:
def compute_accuracy(logits: Tensor, labels: Tensor) -> float:
    preds = torch.argmax(logits, dim=1)
    correct = (preds == labels).sum().item()
    return correct / labels.numel()


def compute_eer(scores: Tensor, labels: Tensor) -> float:
    labels_np = labels.detach().cpu().numpy().astype(np.int32)
    scores_np = scores.detach().cpu().numpy()
    order = np.argsort(scores_np)[::-1]
    sorted_labels = labels_np[order]
    positives = sorted_labels.sum()
    negatives = len(sorted_labels) - positives
    if positives == 0 or negatives == 0:
        return 0.0
    false_accepts = 0
    false_rejects = positives
    min_gap = 1.0
    eer = 1.0
    for label in sorted_labels:
        if label == 1:
            false_rejects -= 1
        else:
            false_accepts += 1
        far = false_accepts / negatives
        frr = false_rejects / positives
        gap = abs(far - frr)
        if gap < min_gap:
            min_gap = gap
            eer = (far + frr) / 2.0
    return float(eer)


def aggregate_metrics(logits: Tensor, labels: Tensor) -> Dict[str, float]:
    probs = torch.softmax(logits, dim=-1)
    spoof_scores = probs[:, 1]
    accuracy = compute_accuracy(logits, labels)
    eer = compute_eer(spoof_scores, labels)
    return {"accuracy": accuracy, "eer": eer}

## 8. Trainer
Định nghĩa lớp Trainer gom từ `src/training/engine.py`.

In [None]:
@dataclass
class OptimizerConfig:
    lr: float = 1e-4
    weight_decay: float = 1e-5
    betas: Tuple[float, float] = (0.9, 0.98)
    eps: float = 1e-8


@dataclass
class SchedulerConfig:
    use_cosine: bool = True
    min_lr: float = 1e-6
    t_max: Optional[int] = None


@dataclass
class TrainingConfig:
    epochs: int = 50
    device: Optional[str] = None
    log_interval: int = 20
    grad_clip: float = 5.0
    mixed_precision: bool = True
    checkpoint_dir: str = "checkpoints"
    best_metric: str = "eer"
    patience: int = 10
    resume_from: Optional[str] = None
    save_every: int = 0
    history: List[Dict[str, float]] = field(default_factory=list)
    evaluate_on_test: bool = False


class Trainer:
    def __init__(
        self,
        model: nn.Module,
        train_config: TrainingConfig,
        optim_config: OptimizerConfig,
        scheduler_config: SchedulerConfig,
    ) -> None:
        self.model = model
        self.train_config = train_config
        self.optim_config = optim_config
        self.scheduler_config = scheduler_config
        device_str = train_config.device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.device = torch.device(device_str)
        self.model.to(self.device)
        self.criterion = nn.CrossEntropyLoss()
        self.scaler = torch.cuda.amp.GradScaler(
            enabled=(self.device.type == "cuda" and train_config.mixed_precision)
        )
        self.best_metric_value: Optional[float] = None
        self.best_epoch: Optional[int] = None

    def fit(self, datamodule: ASVspoofDataModule) -> Dict[str, List[Dict[str, float]]]:
        datamodule.setup(stage="fit")
        train_loader = datamodule.train_dataloader()
        valid_loader = datamodule.val_dataloader()
        optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=self.optim_config.lr,
            weight_decay=self.optim_config.weight_decay,
            betas=self.optim_config.betas,
            eps=self.optim_config.eps,
        )
        scheduler = self._build_scheduler(optimizer)
        os.makedirs(self.train_config.checkpoint_dir, exist_ok=True)
        history = {"train": [], "valid": []}
        patience_counter = 0
        for epoch in range(1, self.train_config.epochs + 1):
            train_metrics = self._run_epoch(
                loader=train_loader,
                optimizer=optimizer,
                scheduler=scheduler,
                epoch=epoch,
                train=True,
            )
            valid_metrics = self._run_epoch(
                loader=valid_loader,
                optimizer=None,
                scheduler=None,
                epoch=epoch,
                train=False,
            )
            history["train"].append(train_metrics)
            history["valid"].append(valid_metrics)
            self.train_config.history.append(
                {"epoch": epoch, **train_metrics, **{f"val_{k}": v for k, v in valid_metrics.items()}}
            )
            current_metric = valid_metrics.get(self.train_config.best_metric)
            if current_metric is None:
                raise KeyError(
                    f"Không tìm thấy metric {self.train_config.best_metric} trong valid metrics"
                )
            if self._is_better(current_metric):
                self.best_metric_value = current_metric
                self.best_epoch = epoch
                patience_counter = 0
                self._save_checkpoint(optimizer, epoch, best=True)
            else:
                patience_counter += 1
            if self.train_config.save_every > 0 and epoch % self.train_config.save_every == 0:
                self._save_checkpoint(optimizer, epoch, best=False)
            if scheduler is not None and getattr(scheduler, "step", None) is not None:
                scheduler.step()
            if patience_counter >= self.train_config.patience:
                print(f"[Trainer] Early stopping ở epoch {epoch}.")
                break
        return history

    def evaluate(self, datamodule: ASVspoofDataModule) -> Dict[str, float]:
        datamodule.setup(stage="test")
        test_loader = datamodule.test_dataloader()
        return self._run_epoch(loader=test_loader, optimizer=None, scheduler=None, epoch=0, train=False)

    def _run_epoch(
        self,
        loader,
        optimizer,
        scheduler,
        epoch: int,
        train: bool,
    ) -> Dict[str, float]:
        self.model.train(mode=train)
        total_loss = 0.0
        total_samples = 0
        all_logits: List[Tensor] = []
        all_labels: List[Tensor] = []
        for step, batch in enumerate(loader, start=1):
            features = {
                name: tensor.to(self.device, non_blocking=True)
                for name, tensor in batch["features"].items()
            }
            labels = batch["labels"].to(self.device, non_blocking=True)
            batch_size = labels.size(0)
            with torch.set_grad_enabled(train):
                with torch.cuda.amp.autocast(enabled=self.scaler.is_enabled()):
                    outputs = self.model(features)
                    logits = outputs["logits"]
                    loss = self.criterion(logits, labels)
                if train:
                    self.scaler.scale(loss).backward()
                    if self.train_config.grad_clip > 0:
                        self.scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.train_config.grad_clip
                        )
                    self.scaler.step(optimizer)
                    self.scaler.update()
                    optimizer.zero_grad(set_to_none=True)
            total_loss += loss.item() * batch_size
            total_samples += batch_size
            all_logits.append(logits.detach().cpu())
            all_labels.append(labels.detach().cpu())
            if train and self.train_config.log_interval and step % self.train_config.log_interval == 0:
                current_loss = total_loss / total_samples
                print(
                    f"[Epoch {epoch}] Step {step}/{len(loader)} Loss: {current_loss:.4f} "
                    f"LR: {optimizer.param_groups[0]['lr']:.2e}"
                )
        avg_loss = total_loss / max(total_samples, 1)
        logits_tensor = torch.cat(all_logits, dim=0) if all_logits else torch.empty((0, 2))
        labels_tensor = torch.cat(all_labels, dim=0) if all_labels else torch.empty((0,), dtype=torch.long)
        metrics = (
            aggregate_metrics(logits_tensor, labels_tensor)
            if total_samples > 0
            else {"accuracy": 0.0, "eer": 0.0}
        )
        metrics["loss"] = avg_loss
        return metrics

    def _is_better(self, value: float) -> bool:
        if self.best_metric_value is None:
            return True
        if self.train_config.best_metric in {"loss", "eer"}:
            return value < self.best_metric_value
        return value > self.best_metric_value

    def _save_checkpoint(self, optimizer, epoch: int, best: bool) -> None:
        state = {
            "epoch": epoch,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict() if optimizer is not None else None,
            "best_metric": self.best_metric_value,
            "best_epoch": self.best_epoch,
        }
        suffix = "best" if best else f"epoch_{epoch:03d}"
        path = os.path.join(self.train_config.checkpoint_dir, f"checkpoint_{suffix}.pt")
        torch.save(state, path)
        tag = "BEST" if best else "SNAPSHOT"
        print(f"[Trainer] Đã lưu checkpoint ({tag}) tại {path}")

    def _build_scheduler(self, optimizer):
        if not self.scheduler_config.use_cosine:
            return None
        t_max = self.scheduler_config.t_max or self.train_config.epochs
        return torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=t_max,
            eta_min=self.scheduler_config.min_lr,
        )

## 9. Hàm hỗ trợ đọc YAML cấu hình (tuỳ chọn)
Giống `train.py` để dễ dàng load config từ file khi chạy trên Kaggle.

In [None]:
def load_yaml_config(path: str) -> Dict[str, Any]:
    with open(path, "r", encoding="utf-8") as handle:
        return yaml.safe_load(handle)


def build_feature_config(cfg: Dict[str, Any]) -> FeatureConfig:
    base = FeatureConfig()
    spectral = {**base.spectral, **cfg.get("spectral", {})}
    temporal = {**base.temporal, **cfg.get("temporal", {})}
    cepstral = {**base.cepstral, **cfg.get("cepstral", {})}
    return FeatureConfig(
        sample_rate=cfg.get("sample_rate", base.sample_rate),
        spectral=SpectralConfig(**spectral),
        temporal=TemporalConfig(**temporal),
        cepstral=CepstralConfig(**cepstral),
    )


def build_partition_config(cfg: Dict[str, Any]) -> PartitionConfig:
    return PartitionConfig(
        partition=cfg["partition"],
        protocol_file=cfg.get("protocol_file"),
        batch_size=cfg.get("batch_size", 32),
        shuffle=cfg.get("shuffle", True),
        drop_last=cfg.get("drop_last", False),
    )


def build_data_module_config(cfg: Dict[str, Any]) -> DataModuleConfig:
    feature_cfg = build_feature_config(cfg.get("feature", {}))
    train_cfg = build_partition_config(cfg["train"])
    valid_cfg = build_partition_config(cfg["valid"])
    test_cfg = build_partition_config(cfg["test"]) if cfg.get("test") else None
    return DataModuleConfig(
        data_root=cfg["data_root"],
        sample_rate=cfg.get("sample_rate", feature_cfg.sample_rate),
        max_duration=cfg.get("max_duration", 6.0),
        pad_mode=cfg.get("pad_mode", "repeat"),
        num_workers=cfg.get("num_workers", 4),
        pin_memory=cfg.get("pin_memory", True),
        prefetch_factor=cfg.get("prefetch_factor", 2),
        feature=feature_cfg,
        train=train_cfg,
        valid=valid_cfg,
        test=test_cfg,
        preload_waveforms=cfg.get("preload_waveforms", False),
    )


def build_model_config(cfg: Dict[str, Any]) -> MultiBranchModelConfig:
    base = MultiBranchModelConfig()
    return MultiBranchModelConfig(
        embed_dim=cfg.get("embed_dim", base.embed_dim),
        attn_dim=cfg.get("attn_dim", base.attn_dim),
        num_classes=cfg.get("num_classes", base.num_classes),
        classifier_hidden=cfg.get("classifier_hidden", base.classifier_hidden),
        dropout=cfg.get("dropout", base.dropout),
    )


def build_training_config(cfg: Dict[str, Any]) -> TrainingConfig:
    base = TrainingConfig()
    return TrainingConfig(
        epochs=cfg.get("epochs", base.epochs),
        device=cfg.get("device", base.device),
        log_interval=cfg.get("log_interval", base.log_interval),
        grad_clip=cfg.get("grad_clip", base.grad_clip),
        mixed_precision=cfg.get("mixed_precision", base.mixed_precision),
        checkpoint_dir=cfg.get("checkpoint_dir", base.checkpoint_dir),
        best_metric=cfg.get("best_metric", base.best_metric),
        patience=cfg.get("patience", base.patience),
        resume_from=cfg.get("resume_from", base.resume_from),
        save_every=cfg.get("save_every", base.save_every),
        evaluate_on_test=cfg.get("evaluate_on_test", base.evaluate_on_test),
    )


def build_optimizer_config(cfg: Dict[str, Any]) -> OptimizerConfig:
    base = OptimizerConfig()
    return OptimizerConfig(
        lr=cfg.get("lr", base.lr),
        weight_decay=cfg.get("weight_decay", base.weight_decay),
        betas=tuple(cfg.get("betas", base.betas)),
        eps=cfg.get("eps", base.eps),
    )


def build_scheduler_config(cfg: Dict[str, Any], total_epochs: int) -> SchedulerConfig:
    base = SchedulerConfig()
    use_cosine = cfg.get("use_cosine", base.use_cosine)
    min_lr = cfg.get("min_lr", base.min_lr)
    t_max = cfg.get("t_max", total_epochs if base.t_max is None else base.t_max)
    return SchedulerConfig(use_cosine=use_cosine, min_lr=min_lr, t_max=t_max)

## 10. Ví dụ chạy nhanh
Cell dưới đây minh hoạ cách kết nối tất cả thành phần. Bạn cần cập nhật `DATA_ROOT` trỏ tới thư mục chứa dataset trên Kaggle (ví dụ `/kaggle/input/asvspoof2019`).

In [None]:
DATA_ROOT = "/kaggle/input/asvspoof2019"  # cập nhật đường dẫn thực tế
USE_CONFIG = False  # Đổi sang True nếu bạn muốn load YAML

if USE_CONFIG:
    CONFIG_PATH = "/kaggle/input/asvspoof-config/config.yaml"
    config_dict = load_yaml_config(CONFIG_PATH)
    data_cfg = build_data_module_config(config_dict["data"])
    model_cfg = build_model_config(config_dict.get("model", {}))
    training_cfg = build_training_config(config_dict.get("training", {}))
    optimizer_cfg = build_optimizer_config(config_dict.get("optimizer", {}))
    scheduler_cfg = build_scheduler_config(
        config_dict.get("scheduler", {}), total_epochs=training_cfg.epochs
    )
else:
    data_cfg = DataModuleConfig(
        data_root=DATA_ROOT,
        train=PartitionConfig(partition="train"),
        valid=PartitionConfig(partition="dev"),
        test=PartitionConfig(partition="eval"),
        num_workers=2,
        preload_waveforms=False,
    )
    model_cfg = MultiBranchModelConfig()
    training_cfg = TrainingConfig(epochs=5, log_interval=5, evaluate_on_test=True)
    optimizer_cfg = OptimizerConfig(lr=1e-4)
    scheduler_cfg = SchedulerConfig(use_cosine=True, min_lr=1e-6)

datamodule = ASVspoofDataModule(data_cfg)
model = MultiBranchAttentionModel(model_cfg)
trainer = Trainer(model, training_cfg, optimizer_cfg, scheduler_cfg)

print("Pipeline đã sẵn sàng. Gọi trainer.fit(datamodule) để bắt đầu huấn luyện.")
print("Cảnh báo: việc huấn luyện đầy đủ cần nhiều thời gian và tài nguyên.")
# history = trainer.fit(datamodule)
# if training_cfg.evaluate_on_test:
#     metrics = trainer.evaluate(datamodule)
#     print("Test metrics:", metrics)

---
**Gợi ý:** Sau khi chạy, bạn có thể lưu checkpoint tốt nhất từ thư mục `checkpoints/` về Kaggle Output bằng `shutil.copy` nếu cần nộp kết quả.