In [1]:
import torch

print("CUDA available:", torch.cuda.is_available())
print("GPU count:", torch.cuda.device_count())

for i in range(torch.cuda.device_count()):
    print(f"[{i}] {torch.cuda.get_device_name(i)}")


CUDA available: True
GPU count: 1
[0] NVIDIA GeForce RTX 4090


In [2]:
import os
import math
import random
from collections import OrderedDict
from typing import Any, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Optional: for dataset loading
import torchaudio

In [3]:
import torch
torch.cuda.is_available()

True

In [4]:
# Optional (nice for shapes):
try:
    from torchinfo import summary
    HAS_TORCHINFO = True
except Exception:
    HAS_TORCHINFO = False


# ------------------ Simple config you can edit ------------------ #
# import os
import math
import random
from collections import OrderedDict
from typing import Any, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Optional: for dataset loading
import torchaudio

# Optional (nice for shapes):
try:
    from torchinfo import summary
    HAS_TORCHINFO = True
except Exception:
    HAS_TORCHINFO = False


# ------------------ Configurable Paths ------------------ #
class SysConfig:
    """
    Folder-based dataset structure.
    Each split (train/dev/test) contains two subfolders: bonafide and spoof.
    """
    path_train =r"G:\INTERSPEECH_26\LA\ASV19\train"
    path_dev   = r"G:\INTERSPEECH_26\LA\ASV19\dev"
    path_test  =r"G:\INTERSPEECH_26\LA\ASV19\dev"



# ------------------ Experiment Hyperparameters ------------------ #
class ExpConfig:
    # Audio processing
    sample_rate = 16000
    pre_emphasis = 0.97
    train_duration_sec = 4
    test_duration_sec = 4

    # Model
    transformer_hidden = 660

    # Training hyperparameters
    batch_size = 32
    lr = 8*1e-4
    epochs = 50  # increase as needed


# ------------------ Device Selection ------------------ #
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"‚úÖ Using device: {DEVICE}")



‚úÖ Using device: cuda


In [5]:
from torch.utils.data import Dataset
import torchaudio
import os
import torch
import random

class ASVspoofFolderDataset(Dataset):
    def __init__(self, root_dir, sample_rate=16000, duration_sec=4, augment=False):
        self.root_dir = root_dir
        self.sample_rate = sample_rate
        self.duration_sec = duration_sec
        self.augment = augment

        self.audio_paths = []
        self.labels = []

        # 1 = bonafide, 0 = spoof
        for label_name, label_value in [("bonafide", 1), ("spoof", 0)]:
            class_dir = os.path.join(root_dir, label_name)
            if os.path.exists(class_dir):
                for file in os.listdir(class_dir):
                    if file.endswith(".flac") or file.endswith(".wav"):
                        self.audio_paths.append(os.path.join(class_dir, file))
                        self.labels.append(label_value)

        print(f"Loaded {len(self.audio_paths)} files from {root_dir}")

    def __len__(self):
        return len(self.audio_paths)

    def __getitem__(self, idx):
        path = self.audio_paths[idx]
        label = torch.tensor([self.labels[idx]], dtype=torch.float32)

        wav, sr = torchaudio.load(path)
        if sr != self.sample_rate:
            wav = torchaudio.functional.resample(wav, sr, self.sample_rate)

        # Random cropping or padding to fixed duration
        num_samples = self.sample_rate * self.duration_sec
        if wav.size(1) > num_samples:
            start = random.randint(0, wav.size(1) - num_samples)
            wav = wav[:, start:start + num_samples]
        elif wav.size(1) < num_samples:
            wav = torch.nn.functional.pad(wav, (0, num_samples - wav.size(1)))

        return wav.squeeze(0), label.squeeze(0)


In [6]:
# --- Self-contained augmentation cell (safe if torch_audiomentations isn't installed) ---
import torch
import torch.nn as nn

try:
    from torch_audiomentations import (
        Compose, AddColoredNoise, HighPassFilter, LowPassFilter, Gain
    )
    HAS_TA = True
except Exception:
    HAS_TA = False
    Compose = AddColoredNoise = HighPassFilter = LowPassFilter = Gain = None

class WaveformAugmentation(nn.Module):
    def __init__(self, aug_list=('ACN', 'HPF', 'LPF', 'GAN'), sr=16000):
        super().__init__()
        self.sr = sr
        if HAS_TA:
            transforms = []
            if 'ACN' in aug_list:
                transforms.append(AddColoredNoise(10, 40, -2.0, 2.0, p=0.5))
            if 'HPF' in aug_list:
                transforms.append(HighPassFilter(20.0, 2400.0, p=0.5))
            if 'LPF' in aug_list:
                transforms.append(LowPassFilter(150.0, 7500.0, p=0.5))
            if 'GAN' in aug_list:
                transforms.append(Gain(-15.0, 5.0, p=0.5))
            self.apply_augmentation = Compose(transforms) if transforms else None
        else:
            # No-op if torch_audiomentations isn't available
            self.apply_augmentation = None

    def forward(self, wav: torch.Tensor) -> torch.Tensor:
        # wav: (B, T)
        if self.apply_augmentation is None:
            return wav
        return self.apply_augmentation(wav.unsqueeze(1), self.sr).squeeze(1)


In [7]:
class PreEmphasis(nn.Module):
    def __init__(self, pre_emphasis: float = 0.97):
        super().__init__()
        # Conv1D filter shape: (out_channels=1, in_channels=1, kernel_size=2)
        filt = torch.tensor([[-pre_emphasis, 1.0]], dtype=torch.float32).unsqueeze(0)
        self.register_buffer("filter", filt)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T)
        x = x.unsqueeze(1)  # (B,1,T)
        x = F.pad(x, (1, 0), mode="reflect")
        x = F.conv1d(x, self.filter)
        return x.squeeze(1)  # (B,T)


In [8]:
class SincConv(nn.Module):
    """
    Adapted from AASIST. One input channel only.
    """
    @staticmethod
    def to_mel(hz): return 2595 * np.log10(1 + hz / 700)
    @staticmethod
    def to_hz(mel): return 700 * (10**(mel / 2595) - 1)

    def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1, stride=1, padding=0, dilation=1):
        super().__init__()
        if in_channels != 1:
            raise ValueError("SincConv supports only one input channel.")
        self.out_channels = out_channels
        self.sample_rate = sample_rate
        self.kernel_size = kernel_size + (kernel_size % 2 == 0)

        self.stride = stride
        self.padding = padding
        self.dilation = dilation

        NFFT = 512
        f = int(sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
        fmel = self.to_mel(f)
        filbandwidthsmel = np.linspace(fmel.min(), fmel.max(), out_channels + 1)
        filbandwidthsf = self.to_hz(filbandwidthsmel)

        self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
                                  (self.kernel_size - 1) / 2 + 1)

        band_pass = torch.zeros(out_channels, self.kernel_size)
        for i in range(out_channels):
            fmin, fmax = filbandwidthsf[i], filbandwidthsf[i + 1]
            hHigh = (2 * fmax / sample_rate) * np.sinc(2 * fmax * self.hsupp / sample_rate)
            hLow  = (2 * fmin / sample_rate) * np.sinc(2 * fmin * self.hsupp / sample_rate)
            hideal = hHigh - hLow
            band_pass[i, :] = torch.tensor(np.hamming(self.kernel_size)) * torch.tensor(hideal)
        self.register_buffer("band_pass", band_pass)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B,1,T)
        filt = self.band_pass.to(x.device).view(self.out_channels, 1, self.kernel_size)
        return F.conv1d(x, filt, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)


class Conv2DBlock_S(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, is_first_block: bool=False):
        super().__init__()
        self.normalizer = None if is_first_block else nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.SELU(inplace=True)
        )
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(2,5), padding=(1,2), stride=1),
            nn.BatchNorm2d(out_channels),
            nn.SELU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=(2,3), padding=(0,1), stride=1),
        )
        self.downsampler = None
        if in_channels != out_channels:
            self.downsampler = nn.Conv2d(in_channels, out_channels, kernel_size=(1,3), padding=(0,1), stride=1)
        self.pooling = nn.MaxPool2d(kernel_size=(1,6))

    def forward(self, x):
        identity = x if self.downsampler is None else self.downsampler(x)
        if self.normalizer is not None:
            x = self.normalizer(x)
        x = self.layers(x)
        x = x + identity
        return self.pooling(x)


class Conv2DBlock_L(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, is_first_block: bool=False):
        super().__init__()
        self.normalizer = None if is_first_block else nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.SELU(inplace=True)
        )
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(2,3), padding=(1,1), stride=1),
            nn.BatchNorm2d(out_channels),
            nn.SELU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=(2,3), padding=(0,1), stride=1),
        )
        self.downsampler = None
        if in_channels != out_channels:
            self.downsampler = nn.Conv2d(in_channels, out_channels, kernel_size=(1,3), padding=(0,1), stride=1)
        self.pooling = nn.MaxPool2d(kernel_size=(1,3))

    def forward(self, x):
        identity = x if self.downsampler is None else self.downsampler(x)
        if self.normalizer is not None:
            x = self.normalizer(x)
        x = self.layers(x)
        x = x + identity
        return self.pooling(x)


class SELayer(nn.Module):
    def __init__(self, channels, channel_reduction=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // channel_reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channels // channel_reduction, channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


class Conv2DBlock_SE(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, scale: int = 8, channel_reduction: int = 8):
        super().__init__()
        self.scale = scale
        self.sub_channels = out_channels // scale
        self.hidden_channels = self.sub_channels * scale
        relu = nn.ReLU(inplace=True)

        self.normalizer = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.SELU(inplace=True),
        )

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, self.hidden_channels, kernel_size=(1,7), padding=(0,3)),
            nn.BatchNorm2d(self.hidden_channels),
            relu
        )

        self.conv2 = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(self.sub_channels, self.sub_channels, kernel_size=(3,9), padding=(1,4)),
                nn.BatchNorm2d(self.sub_channels),
                relu
            ) for _ in range(1, scale)
        ])

        self.conv3 = nn.Sequential(
            nn.Conv2d(self.hidden_channels, out_channels, kernel_size=(1,7), padding=(0,3)),
            nn.BatchNorm2d(out_channels),
            relu
        )

        self.se_module = SELayer(out_channels, channel_reduction)

        self.downsampler = None
        if in_channels != out_channels:
            self.downsampler = nn.Conv2d(in_channels, out_channels, kernel_size=(1,7), padding=(0,3), stride=1)

        self.pooling = nn.MaxPool2d(kernel_size=(1,6))

    def forward(self, x):
        identity = x if self.downsampler is None else self.downsampler(x)
        x = self.normalizer(x)
        x = self.conv1(x)

        x_split = torch.split(x, self.sub_channels, dim=1)
        y_list = [x_split[0]]
        for i in range(1, self.scale):
            inp = x_split[i] if i == 1 else x_split[i] + y_list[i-1]
            y_list.append(self.conv2[i-1](inp))

        y = torch.cat(y_list, dim=1)
        y = self.conv3(y)
        y = self.se_module(y)
        y = y + identity
        return self.pooling(y)


class Frontend_S(nn.Module):
    def __init__(self, sinc_kernel_size=128, sample_rate=16000):
        super().__init__()
        self.sinc = SincConv(70, sinc_kernel_size, sample_rate)
        self.bn = nn.BatchNorm2d(1)
        self.act = nn.SELU(inplace=True)
        self.blocks = nn.Sequential(
            Conv2DBlock_S(1, 32, is_first_block=True),
            Conv2DBlock_S(32, 32),
            Conv2DBlock_S(32, 64),
            Conv2DBlock_S(64, 64),
        )

    def forward(self, x):
        # x: (B,T) -> (B,1,T) -> sinc -> (B,70,T') -> (B,1,70,T')
        x = x.unsqueeze(1)
        x = self.sinc(x)
        x = x.unsqueeze(1)
        x = F.max_pool2d(torch.abs(x), (3,3))
        x = self.bn(x)
        lfm = self.act(x)
        return self.blocks(lfm)   # HFM: (B,C=64, f‚âà23, t‚âà16)


class Frontend_L(nn.Module):
    def __init__(self, sinc_kernel_size=128, sample_rate=16000):
        super().__init__()
        self.sinc = SincConv(70, sinc_kernel_size, sample_rate)
        self.bn = nn.BatchNorm2d(1)
        self.act = nn.SELU(inplace=True)
        self.blocks = nn.Sequential(
            Conv2DBlock_L(1, 32, is_first_block=True),
            Conv2DBlock_L(32, 32),
            Conv2DBlock_L(32, 64),
            Conv2DBlock_L(64, 64),
            Conv2DBlock_L(64, 64),
            Conv2DBlock_L(64, 64),
        )

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.sinc(x)
        x = x.unsqueeze(1)
        x = F.max_pool2d(torch.abs(x), (3,3))
        x = self.bn(x)
        lfm = self.act(x)
        return self.blocks(lfm)   # (B,64, f‚âà23, t‚âà29 for 4s @16kHz)


In [9]:
# --- Master imports cell (run first) ---
import os, math, random
from typing import Optional, Tuple, Any

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Nice-to-have (optional)
try:
    from torchinfo import summary
    HAS_TORCHINFO = True
except Exception:
    HAS_TORCHINFO = False

class ScaledDotProductAttention(nn.Module):
    """
    Expects Q,K,V: (B, H, S, D). Optional mask: (B,1,1,S) or broadcastable.
    """
    def __init__(self):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, Q, K, V, mask: Optional[torch.Tensor] = None):
        assert Q.dim() == K.dim() == V.dim() == 4  # (B,H,S,D)
        d_k = K.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)  # (B,H,S_q,S_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))
        attn = self.softmax(scores)
        out = torch.matmul(attn, V)  # (B,H,S_q,D)
        return out


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_head: int):
        super().__init__()
        assert d_model % n_head == 0, "d_model must be divisible by n_head"
        self.n_head = n_head
        self.d_head = d_model // n_head

        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.attn = ScaledDotProductAttention()
        self.W_out = nn.Linear(d_model, d_model)

    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B,S,D) -> (B,H,S,Dh)
        B, S, D = x.size()
        x = x.view(B, S, self.n_head, self.d_head).permute(0, 2, 1, 3)
        return x

    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B,H,S,Dh) -> (B,S,D)
        B, H, S, Dh = x.size()
        return x.permute(0, 2, 1, 3).contiguous().view(B, S, H * Dh)

    def forward(self, Q, K, V, mask: Optional[torch.Tensor] = None):
        assert Q.dim() == K.dim() == V.dim() == 3  # (B,S,D)
        q = self._split_heads(self.W_Q(Q))
        k = self._split_heads(self.W_K(K))
        v = self._split_heads(self.W_V(V))
        if mask is not None:
            # make mask broadcastable to (B,H,S_q,S_k)
            mask = mask.unsqueeze(1)
        context = self.attn(q, k, v, mask=mask)
        context = self._merge_heads(context)
        return self.W_out(context)


class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, unbiased=False, keepdim=True)
        xhat = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * xhat + self.beta


class FFN(nn.Module):
    def __init__(self, d_model, ffn_hidden, drop_prob=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, ffn_hidden),
            nn.ReLU(),
            nn.Dropout(drop_prob),
            nn.Linear(ffn_hidden, d_model),
        )

    def forward(self, x):
        return self.net(x)


class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model=64, n_head=8, ffn_hidden=2048, drop_prob=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_head)
        self.dropout1 = nn.Dropout(drop_prob)
        self.norm1 = LayerNorm(d_model)
        self.ffn = FFN(d_model, ffn_hidden, drop_prob)
        self.dropout2 = nn.Dropout(drop_prob)
        self.norm2 = LayerNorm(d_model)

    def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
        # x: (B,S,D)
        residual = x
        x = self.attn(x, x, x, mask=attn_mask)
        x = self.dropout1(x)
        x = self.norm1(x + residual)

        residual = x
        x = self.ffn(x)
        x = self.dropout2(x)
        x = self.norm2(x + residual)
        return x


In [10]:
import torch
import torch.nn as nn


class PositionalAggregator1D(nn.Module):
    
    def __init__(self, max_C:int, max_ft:int, device):
        """_summary_

        Args:
            max_channels (int): for HFM, max size of C
            max_ft (int): for HFM, max size of f*t
        """
        super(PositionalAggregator1D, self).__init__()
        
        self.flattener = nn.Flatten(start_dim=-2, end_dim=-1)
        
        # ------------------ positional encoding -------------------- #        
        x = torch.arange(1, max_ft-1, device=device).float()
        x = x.float().unsqueeze(1)
        _2i = torch.arange(0, max_C, step=2, device=device).float().unsqueeze(0)
        
        self.encoding = torch.zeros(max_ft, max_C, device=device, requires_grad=False)
        self.encoding[1:-1, 0::2] = torch.sin(x / (10000 ** (_2i / max_C)))
        self.encoding[1:-1, 1::2] = torch.cos(x / (10000 ** (_2i / max_C)))
        
    def forward(self, HFM):
        batch, C, f, t = HFM.shape
        out = self.flattener(HFM).transpose(1, 2)# [batch, f*t, C]
        out = out + self.encoding[:f*t, :C].unsqueeze(0)
        return out
    

In [11]:
class PositionalAggregator2D(nn.Module):
    """
    Adds 2D sinusoidal PE over (f, t) before flattening to sequence.
    Input HFM: (B, C, F, T). Output: (B, F*T, C)
    """
    def __init__(self, max_C: int, max_F: int, max_T: int):
        super().__init__()
        self.max_C = max_C
        self.max_F = max_F
        self.max_T = max_T

        # Create 1D sin/cos for F and T, then combine
        pe_f = torch.zeros(max_F, max_C)
        pe_t = torch.zeros(max_T, max_C)

        pos_f = torch.arange(0, max_F, dtype=torch.float).unsqueeze(1)
        pos_t = torch.arange(0, max_T, dtype=torch.float).unsqueeze(1)

        div_term = torch.exp(torch.arange(0, max_C, 2).float() * (-math.log(10000.0) / max_C))
        pe_f[:, 0::2] = torch.sin(pos_f * div_term)
        pe_f[:, 1::2] = torch.cos(pos_f * div_term)

        pe_t[:, 0::2] = torch.sin(pos_t * div_term)
        pe_t[:, 1::2] = torch.cos(pos_t * div_term)

        # Combine as a simple sum of broadcasted (f and t) encodings
        # resulting PE shape for each (f,t): (C,)
        pe_2d = pe_f.unsqueeze(1) + pe_t.unsqueeze(0)  # (F,1,C) + (1,T,C) -> (F,T,C)
        self.register_buffer("pe_2d", pe_2d)  # (F,T,C)

    def forward(self, HFM: torch.Tensor) -> torch.Tensor:
        # HFM: (B, C, F, T)
        B, C, F, T = HFM.shape
        assert F <= self.max_F and T <= self.max_T and C <= self.max_C

        # permute to (B,F,T,C)
        x = HFM.permute(0, 2, 3, 1).contiguous()
        # add PE (F,T,C) -> broadcast to (B,F,T,C)
        x = x + self.pe_2d[:F, :T, :C]
        # flatten to (B, F*T, C)
        x = x.view(B, F * T, C)
        return x


In [12]:
class Conv2DBlock_S(nn.Module):
    """__summary__
    This is Conv2DBlock of Rawformer-S.\\
    This block is same as ResNet block of AASIST with some different parameters.
    (https://github.com/clovaai/aasist/blob/a04c9863f63d44471dde8a6abcb3b082b07cd1d1/models/AASIST.py#L413)
    """
    
    def __init__(self, in_channels: int, out_channels: int, is_first_block: bool=False):
        """_summary_

        Args:
            in_channels (int): num of input channels
            out_channels (int): num of output channels
            se_reduction (int, optional): reduction factor for squeeze and excitation of channels. Defaults to 8.
            is_first_block (bool, optional): if this is the first block must be True. Defaults to False.
        """
        
        super(Conv2DBlock_S, self).__init__()
        
        self.normalizer = None
        if not is_first_block:
            self.normalizer = nn.Sequential(
                nn.BatchNorm2d(num_features=in_channels),
                nn.SELU(inplace=True)
            )
        
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(2, 5), padding=(1, 2), stride=1),
            nn.BatchNorm2d(num_features=out_channels),
            nn.SELU(inplace=True),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(2, 3), padding=(0, 1), stride=1),
        )        
        
        self.downsampler = None
        if in_channels != out_channels:
            self.downsampler = nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, padding=(0, 1), kernel_size=(1, 3), stride=1)
            )
        
        self.pooling = nn.MaxPool2d(kernel_size=(1, 6))
        
    def forward(self, x):
        
        identity = x
        if self.downsampler is not None:
            identity = self.downsampler(identity)
        
        if self.normalizer is not None:
            x = self.normalizer(x)
            
        x = self.layers(x)
        x = x + identity
        
        x = self.pooling(x)
        return x

class Conv2DBlock_L(nn.Module):
    """__summary__
    This is Conv2DBlock of Rawformer-L.\\
    This block is same as ResNet block of AASIST.
    (https://github.com/clovaai/aasist/blob/a04c9863f63d44471dde8a6abcb3b082b07cd1d1/models/AASIST.py#L413)
    """
    
    def __init__(self, in_channels: int, out_channels: int, is_first_block: bool=False):
        """_summary_

        Args:
            in_channels (int): num of input channels
            out_channels (int): num of output channels
            se_reduction (int, optional): reduction factor for squeeze and excitation of channels. Defaults to 8.
            is_first_block (bool, optional): if this is the first block must be True. Defaults to False.
        """
        
        super(Conv2DBlock_L, self).__init__()
        
        self.normalizer = None
        if not is_first_block:
            self.normalizer = nn.Sequential(
                nn.BatchNorm2d(num_features=in_channels),
                nn.SELU(inplace=True)
            )
        
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(2, 3), padding=(1, 1), stride=1),
            nn.BatchNorm2d(num_features=out_channels),
            nn.SELU(inplace=True),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(2, 3), padding=(0, 1), stride=1),
        )        
        
        self.downsampler = None
        if in_channels != out_channels:
            self.downsampler = nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, padding=(0, 1), kernel_size=(1, 3), stride=1)
            )
        
        self.pooling = nn.MaxPool2d(kernel_size=(1, 3))
        
    def forward(self, x):
        
        identity = x
        if self.downsampler is not None:
            identity = self.downsampler(identity)
        
        if self.normalizer is not None:
            x = self.normalizer(x)
            
        x = self.layers(x)
        x = x + identity
        
        x = self.pooling(x)
        return x
    
class SELayer(nn.Module):
    def __init__(self, channels, channel_reduction=8):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
                nn.Linear(channels, channels // channel_reduction),
                nn.ReLU(inplace=True),
                nn.Linear(channels // channel_reduction, channels),
                nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y
        
class Conv2DBlock_SE(nn.Module):
    
    def __init__(self, in_channels: int, out_channels: int, scale:int = 8, channel_reduction:int=8):
        super(Conv2DBlock_SE, self).__init__()
        
        self.scale = scale
        self.sub_channels = out_channels // scale
        self.hidden_channels = self.sub_channels * scale
        relu = nn.ReLU(inplace=True)
        
        
        self.normalizer = nn.Sequential(
            nn.BatchNorm2d(num_features=in_channels),
            nn.SELU(inplace=True)
        )
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=self.hidden_channels, kernel_size=(1, 7), padding=(0, 3)),
            nn.BatchNorm2d(num_features=self.hidden_channels),
            relu
        )
        
        self.conv2 = []
        for i in range(2, scale+1):
            self.conv2.append(nn.Sequential(
                nn.Conv2d(in_channels=self.sub_channels, out_channels=self.sub_channels, kernel_size=(3, 9), padding=(1, 4)),
                nn.BatchNorm2d(num_features=self.sub_channels),
                relu
            ))
        self.conv2 = nn.ModuleList(self.conv2)
            
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=self.hidden_channels, out_channels=out_channels, kernel_size=(1, 7), padding=(0, 3)),
            nn.BatchNorm2d(num_features=out_channels),
            relu
        )
        
        self.se_module = SELayer(channels=out_channels, channel_reduction=channel_reduction)
        
        self.downsampler = None
        if in_channels != out_channels:
            self.downsampler = nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, padding=(0, 3), kernel_size=(1, 7), stride=1)
            )
            
        self.pooling = nn.MaxPool2d(kernel_size=(1, 6))
        #self.pooling = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(1, 6), padding=(0, 0), )
        
        
    def forward(self, x):
        
        identity = x
        if self.downsampler is not None:
            identity = self.downsampler(identity)
            
        x = self.normalizer(x)
        
        x = self.conv1(x)
        
        x_sub = torch.split(x, split_size_or_sections=self.sub_channels, dim = 1)
        y_sub = [x_sub[0]]
        
        for i in range(1, self.scale):
            y_i = None
            if i == 1:
                y_i = self.conv2[i - 1](x_sub[i])
            else:
                y_i = self.conv2[i - 1](x_sub[i] + y_sub[i-1])
                
            y_sub.append(y_i)
        
        y = torch.cat(y_sub, dim = 1)
        y = self.conv3(y)
        y = self.se_module(y)
        
        y = y + identity
        y = self.pooling(y)
        
        return y
        
        
        
    
class Frontend_S(nn.Module):
    """_summary_
    This is frontend of Rawformer-S
    """
    
    def __init__(self, device, sinc_kernel_size=128, sample_rate=16000):
        """_summary_
        frontend of Rawformer-S\\
            
        N: number of conv2D-based blocks\\
        N is fixed to 4.
        
        C: output channel of front-end\\
        C is fixed to 64
        
        f: frequency \\
        f is fixed to 23
        
        t: number of temporal bins\\
        for 4 sec, t is 16. for 10 sec, t is 73\\
        
        Args:
            sinc_kernel_size (int, optional): kernel size of sinc layer. Defaults to 128.
            sample_rate (int, optional): _description_. Defaults to 16000.
        """
        super(Frontend_S, self).__init__()
        
        self.sinc_layer = SincConv(in_channels=1, out_channels=70, kernel_size=sinc_kernel_size, sample_rate=sample_rate)
        self.bn = nn.BatchNorm2d(num_features=1) 
        self.selu = nn.SELU(inplace=True)
        
        self.conv_blocks = nn.Sequential(
            Conv2DBlock_S(in_channels=1, out_channels=32, is_first_block=True),
            Conv2DBlock_S(in_channels=32, out_channels=32),
            Conv2DBlock_S(in_channels=32, out_channels=64),
            Conv2DBlock_S(in_channels=64, out_channels=64),            
        )
    
    def forward(self, x):
        
        x = x.unsqueeze(dim=1)
        x = self.sinc_layer(x)
        x = x.unsqueeze(dim=1)
        x = F.max_pool2d(torch.abs(x), (3, 3))
        x = self.bn(x)
        LFM = self.selu(x)
        
        HFM = self.conv_blocks(LFM)
        
        return HFM

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ----------------------------------------------------------------------
# 1. Frontend_S
# ----------------------------------------------------------------------
class Frontend_S(nn.Module):
    def __init__(self, device, sinc_kernel_size=128, sample_rate=16000):
        super().__init__()

        # ---- Sinc layer (no parameters ‚Üí safe on any device) ----
        self.sinc_layer = SincConv(
            in_channels=1,
            out_channels=70,
            kernel_size=sinc_kernel_size,
            sample_rate=sample_rate,
        )

        # ---- BatchNorm that must live on the target device ----
        self.bn = nn.BatchNorm2d(num_features=1).to(device)

        self.selu = nn.SELU(inplace=True)

        # ---- Conv blocks (they also contain BatchNorms) ----
        self.conv_blocks = nn.Sequential(
            Conv2DBlock_S(in_channels=1,  out_channels=32, is_first_block=True),
            Conv2DBlock_S(in_channels=32, out_channels=32),
            Conv2DBlock_S(in_channels=32, out_channels=64),
            Conv2DBlock_S(in_channels=64, out_channels=64),
        ).to(device)                     # <-- move the whole Sequential

    def forward(self, x):
        # x : [B, T]  (raw waveform)
        x = x.unsqueeze(1)                     # [B,1,T]
        x = self.sinc_layer(x)                 # [B,70,T']
        x = x.unsqueeze(1)                     # [B,1,70,T']
        x = F.max_pool2d(torch.abs(x), (3, 3)) # [B,1,F,T]
        x = self.bn(x)
        LFM = self.selu(x)

        HFM = self.conv_blocks(LFM)            # [B,64,f,t]
        return HFM


# ----------------------------------------------------------------------
# 2. Conv2DBlock_S
# ----------------------------------------------------------------------
class Conv2DBlock_S(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, is_first_block: bool = False):
        super().__init__()

        # ---- optional normaliser (BN+SELU) ----
        self.normalizer = None
        if not is_first_block:
            self.normalizer = nn.Sequential(
                nn.BatchNorm2d(in_channels),
                nn.SELU(inplace=True),
            )

        # ---- two conv layers + BN+SELU ----
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(2, 5), padding=(1, 2)),
            nn.BatchNorm2d(out_channels),
            nn.SELU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=(2, 3), padding=(0, 1)),
        )

        # ---- residual connection when channel count changes ----
        self.downsampler = None
        if in_channels != out_channels:
            self.downsampler = nn.Conv2d(in_channels, out_channels,
                                        kernel_size=(1, 3), padding=(0, 1))

        self.pooling = nn.MaxPool2d(kernel_size=(1, 6))

    def forward(self, x):
        identity = x
        if self.downsampler is not None:
            identity = self.downsampler(identity)

        if self.normalizer is not None:
            x = self.normalizer(x)

        x = self.layers(x) + identity
        x = self.pooling(x)
        return x


# ----------------------------------------------------------------------
# 3. PositionalAggregator1D
# ----------------------------------------------------------------------
class PositionalAggregator1D(nn.Module):
    def __init__(self, max_C: int, max_ft: int, device):
        super().__init__()

        self.flattener = nn.Flatten(start_dim=-2, end_dim=-1)

        # ----- sinusoidal positional encoding (no trainable params) -----
        pos = torch.arange(1, max_ft - 1, device=device).float().unsqueeze(1)   # (L-2,1)
        dim = torch.arange(0, max_C, step=2, device=device).float().unsqueeze(0)  # (1,D/2)

        enc = torch.zeros(max_ft, max_C, device=device)
        enc[1:-1, 0::2] = torch.sin(pos / (10000 ** (dim / max_C)))
        enc[1:-1, 1::2] = torch.cos(pos / (10000 ** (dim / max_C)))
        self.register_buffer('encoding', enc)   # stored on the correct device automatically

    def forward(self, HFM):
        """
        HFM : [B, C, f, t]
        out : [B, f*t, C]  with added positional encoding
        """
        B, C, f, t = HFM.shape
        ft = f * t
        out = self.flattener(HFM).transpose(1, 2)               # [B, f*t, C]
        out = out + self.encoding[:ft, :C]                      # broadcast
        return out

In [14]:
# class Rawformer_S(nn.Module):
#     """
    
#     """
    
#     def __init__(self, device, transformer_hidden=64, sample_rate: int = 16000):
#         super(Rawformer_S, self).__init__()
#         self.front_end = Frontend_S(sinc_kernel_size=128, sample_rate=sample_rate)
        
#         self.positional_embedding = PositionalAggregator1D(max_C = 64, max_ft=23*16, device=device)# this max_ft is for input of 4-sec and 16000 sample-rate
        
#         self.classifier = RawformerClassifier(C = 64, n_encoder = 2, transformer_hidden=transformer_hidden)# output: [batch, C]
        
#     def forward(self, x):        
#         x = self.front_end(x)
#         x = self.positional_embedding(x)        
#         x = self.classifier(x)        
#         return x
class Rawformer_S(nn.Module):
    def __init__(self, device, transformer_hidden=64, sample_rate: int = 16000):
        super().__init__()
        # ---- 1. give the front-end the device ----
        self.front_end = Frontend_S(sinc_kernel_size=128,
                                    sample_rate=sample_rate,
                                    device=device)          # <-- add this

        self.positional_embedding = PositionalAggregator1D(
            max_C=64, max_ft=23*16, device=device)

        self.classifier = RawformerClassifier(C=64, n_encoder=2, transformer_hidden=transformer_hidden)

        # ---- 2. move *everything* to the target device in one go ----
        self.to(device)                     # <-- important!

    def forward(self, x):
        x = self.front_end(x)               # now on correct device
        x = self.positional_embedding(x)
        x = self.classifier(x)
        return x

In [15]:
class SequencePooling(nn.Module):
    """
    Attention-style weighted pooling over sequence.
    Input: (B,S,C) -> Output: (B,C)
    """
    def __init__(self, d_model):
        super().__init__()
        self.linear = nn.Linear(d_model, 1)

    def forward(self, x):
        # x: (B,S,C)
        w = self.linear(x)               # (B,S,1)
        w = F.softmax(w.transpose(1, 2), dim=-1)  # (B,1,S)
        out = torch.matmul(w, x)         # (B,1,C)
        return out.squeeze(1)            # (B,C)


class RawformerClassifier(nn.Module):
    """
    Encoders (N layers) + SeqPool + Linear + Sigmoid
    Input: sequence (B,S,C)  Output: (B,) score in [0,1]
    """
    def __init__(self, C: int, n_encoder: int, transformer_hidden: int):
        super().__init__()
        self.encoders = nn.Sequential(OrderedDict([
            (f"encoder{i}", TransformerEncoderLayer(d_model=C, n_head=8, ffn_hidden=transformer_hidden))
            for i in range(n_encoder)
        ]))
        self.seq_pool = SequencePooling(d_model=C)
        self.fc = nn.Linear(C, 1)

    def forward(self, x):
        # x: (B,S,C)
        x = self.encoders(x)
        x = self.seq_pool(x)
        x = self.fc(x)
        return torch.sigmoid(x).squeeze(-1)   # (B,)


In [16]:
class Rawformer_S_2DPE(nn.Module):
    def __init__(self, transformer_hidden=64, sample_rate: int = 16000):
        super().__init__()
        self.front_end = Frontend_S(sinc_kernel_size=128, sample_rate=sample_rate)
        # For 4s @ 16kHz, Frontend_S -> approx F=23, T=16 (as in your notes)
        self.pos_emb = PositionalAggregator2D(max_C=64, max_F=23, max_T=16)
        self.classifier = RawformerClassifier(C=64, n_encoder=2, transformer_hidden=transformer_hidden)

    def forward(self, x):
        # x: (B,T)
        hfm = self.front_end(x)         # (B,64,F,T)
        seq = self.pos_emb(hfm)         # (B,F*T,64)
        return self.classifier(seq)     # (B,)


class Rawformer_L_2DPE(nn.Module):
    def __init__(self, transformer_hidden=80, sample_rate: int = 16000):
        super().__init__()
        self.front_end = Frontend_L(sinc_kernel_size=128, sample_rate=sample_rate)
        # For 4s @ 16kHz, Frontend_L -> approx F=23, T=29
        self.pos_emb = PositionalAggregator2D(max_C=64, max_F=23, max_T=29)
        self.classifier = RawformerClassifier(C=64, n_encoder=3, transformer_hidden=transformer_hidden)

    def forward(self, x):
        hfm = self.front_end(x)         # (B,64,F,T)
        seq = self.pos_emb(hfm)         # (B,F*T,64)
        return self.classifier(seq)     # (B,)


In [17]:
class Rawformer_S_2DPE(nn.Module):
    def __init__(self, transformer_hidden=64, sample_rate: int = 16000):
        super().__init__()
        self.front_end = Frontend_S(sinc_kernel_size=128, sample_rate=sample_rate)
        # For 4s @ 16kHz, Frontend_S -> approx F=23, T=16 (as in your notes)
        self.pos_emb = PositionalAggregator2D(max_C=64, max_F=23, max_T=16)
        self.classifier = RawformerClassifier(C=64, n_encoder=2, transformer_hidden=transformer_hidden)

    def forward(self, x):
        # x: (B,T)
        hfm = self.front_end(x)         # (B,64,F,T)
        seq = self.pos_emb(hfm)         # (B,F*T,64)
        return self.classifier(seq)     # (B,)


class Rawformer_L_2DPE(nn.Module):
    def __init__(self, transformer_hidden=80, sample_rate: int = 16000):
        super().__init__()
        self.front_end = Frontend_L(sinc_kernel_size=128, sample_rate=sample_rate)
        # For 4s @ 16kHz, Frontend_L -> approx F=23, T=29
        self.pos_emb = PositionalAggregator2D(max_C=64, max_F=23, max_T=29)
        self.classifier = RawformerClassifier(C=64, n_encoder=3, transformer_hidden=transformer_hidden)

    def forward(self, x):
        hfm = self.front_end(x)         # (B,64,F,T)
        seq = self.pos_emb(hfm)         # (B,F*T,64)
        return self.classifier(seq)     # (B,)


In [18]:
class ASVspoof2021LA_eval(torch.utils.data.Dataset):
    def __init__(self, sys_config=SysConfig(), exp_config=ExpConfig()):
        super().__init__()
        self.duration = int(exp_config.test_duration_sec * exp_config.sample_rate)
        path_label = sys_config.path_label_asv_spoof_2021_la_eval
        path_eval  = sys_config.path_asv_spoof_2021_la_eval

        self.data_list = []
        if os.path.exists(path_label):
            for line in open(path_label, "r").readlines():
                parts = line.strip().split()
                if len(parts) < 8 or parts[7] != "eval":
                    continue
                file_id, attack_type = parts[1], parts[4]
                label = 0 if attack_type == "bonafide" else 1
                wav_path = os.path.join(path_eval, f"{file_id}.flac")
                self.data_list.append((wav_path, attack_type, label))
        else:
            print(f"[WARN] Label file missing: {path_label}")

    def __len__(self): return len(self.data_list)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        wav_path, _, label = self.data_list[idx]
        wav, _ = torchaudio.load(wav_path)  # (C,T)
        wav = self._fix_duration(wav)
        return wav, label

    def _fix_duration(self, x: torch.Tensor) -> torch.Tensor:
        x = x.squeeze(0) if x.dim() == 2 else x
        need = self.duration
        if x.numel() < need:
            reps = need // x.numel()
            rem  = need %  x.numel()
            x = torch.cat([x] * reps + ([x[:rem]] if rem > 0 else []))
        return x[:need]


class ASVspoof2019LA(torch.utils.data.Dataset):
    def __init__(self, sys_config=SysConfig(), exp_config=ExpConfig(), augment: bool = False):
        super().__init__()
        self.duration = int(exp_config.train_duration_sec * exp_config.sample_rate)
        self.aug = WaveformAugmentation(sr=exp_config.sample_rate) if augment else None

        train_label = sys_config.path_label_asv_spoof_2019_la_train
        dev_label   = sys_config.path_label_asv_spoof_2019_la_dev
        path_train  = sys_config.path_asv_spoof_2019_la_train
        path_dev    = sys_config.path_asv_spoof_2019_la_dev

        self.data_list = []
        if os.path.exists(train_label):
            for line in open(train_label, "r").readlines():
                parts = line.strip().split()
                file_id, attack_type = parts[1], parts[3]
                label = 0 if parts[4] == "bonafide" else 1
                wav_path = os.path.join(path_train, f"{file_id}.flac")
                self.data_list.append((wav_path, attack_type, label))
        if os.path.exists(dev_label):
            for line in open(dev_label, "r").readlines():
                parts = line.strip().split()
                file_id, attack_type = parts[1], parts[3]
                label = 0 if parts[4] == "bonafide" else 1
                wav_path = os.path.join(path_dev, f"{file_id}.flac")
                self.data_list.append((wav_path, attack_type, label))
        if not self.data_list:
            print("[WARN] No 2019 LA data found. Check your paths in SysConfig.")

    def __len__(self): return len(self.data_list)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        wav_path, _, label = self.data_list[idx]
        wav, _ = torchaudio.load(wav_path)
        wav = self._random_duration(wav)
        if self.aug is not None:
            wav = self.aug(wav)
        return wav, label

    def _random_duration(self, x: torch.Tensor) -> torch.Tensor:
        x = x.squeeze(0) if x.dim() == 2 else x
        L = x.numel()
        need = self.duration
        if L < need:
            reps = need // L
            rem  = need %  L
            x = torch.cat([x] * reps + ([x[:rem]] if rem > 0 else []))
            L = x.numel()
        start = random.randint(0, L - need)
        return x[start:start+need]


In [19]:
def collate_pad(batch):
    # Here all items are same length already; just stack.
    wavs, labels = zip(*batch)
    wavs = torch.stack(wavs, dim=0)
    labels = torch.tensor(labels, dtype=torch.float32)
    return wavs, labels


def train_one_epoch(model, loader, optimizer, criterion, preemph=None):
    model.train()
    total_loss = 0.0
    for wav, label in loader:
        wav = wav.to(DEVICE)
        label = label.to(DEVICE)

        if preemph is not None:
            wav = preemph(wav)

        optimizer.zero_grad()
        pred = model(wav)              # (B,)
        loss = criterion(pred, label)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * wav.size(0)
    return total_loss / len(loader.dataset)


@torch.no_grad()
def evaluate(model, loader, criterion, preemph=None):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    for wav, label in loader:
        wav = wav.to(DEVICE)
        label = label.to(DEVICE)
        if preemph is not None:
            wav = preemph(wav)
        pred = model(wav)
        loss = criterion(pred, label)
        total_loss += loss.item() * wav.size(0)
        total_correct += ((pred > 0.5).float() == label).sum().item()
    avg_loss = total_loss / len(loader.dataset)
    acc = total_correct / len(loader.dataset)
    return avg_loss, acc


In [20]:
# Build model and run a forward pass with dummy audio
exp_cfg = ExpConfig()
model = Rawformer_S(device=DEVICE, transformer_hidden=exp_cfg.transformer_hidden,
                          sample_rate=exp_cfg.sample_rate)

B = 2
dummy_audio = torch.randn(B, exp_cfg.sample_rate * exp_cfg.train_duration_sec).to(DEVICE)
with torch.no_grad():
    out = model(dummy_audio)
print("Model output shape:", out.shape, "| values ~", (out.min().item(), out.max().item()))

if HAS_TORCHINFO:
    try:
        summary(model, input_size=(B, exp_cfg.sample_rate * exp_cfg.train_duration_sec))
    except Exception as e:
        print("torchinfo summary error (safe to ignore):", e)


Model output shape: torch.Size([2]) | values ~ (0.49742913246154785, 0.49765050411224365)


In [21]:
import torch
print(torch.__version__)


2.5.1+cu121


In [22]:
# import torchaudio
# torchaudio.set_audio_backend("ffmpeg")


In [23]:
# pip install soundfile

In [24]:
from tqdm.auto import tqdm
from sklearn.metrics import roc_curve
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchaudio
import os
import random
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from sklearn import metrics

# ============================================================
# CONFIGURATION
# ============================================================

class SysConfig:
    path_train = r"G:\INTERSPEECH_26\LA\ASV19\train"
    path_dev   = r"G:\INTERSPEECH_26\LA\ASV19\dev"
    path_test  =r"G:\INTERSPEECH_26\LA\ASV19\dev"

class ExpConfig:
    # Audio processing
    sample_rate = 16000
    pre_emphasis = 0.97
    train_duration_sec = 4
    test_duration_sec = 4

    # Model
    transformer_hidden = 660

    # Training hyperparameters
    batch_size = 32
    lr = 8*1e-4
    epochs = 50  # increase as needed

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"‚úÖ Using device: {DEVICE}")

# ============================================================
# SIMPLE DATASET
# ============================================================

import soundfile as sf

class ASVspoofFolderDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, sample_rate=16000, duration_sec=4):
        self.root_dir = root_dir
        self.sample_rate = sample_rate
        self.duration_sec = duration_sec
        self.audio_paths = []
        self.labels = []

        for label_name, label_value in [("bonafide", 1), ("spoof", 0)]:
            class_dir = os.path.join(root_dir, label_name)
            if os.path.exists(class_dir):
                for file in os.listdir(class_dir):
                    if file.endswith(".flac") or file.endswith(".wav"):
                        self.audio_paths.append(os.path.join(class_dir, file))
                        self.labels.append(label_value)

        print(f"üìÅ Loaded {len(self.audio_paths)} files from {root_dir}")

    def __len__(self):
        return len(self.audio_paths)

    def __getitem__(self, idx):
        path = self.audio_paths[idx]
        label = torch.tensor(self.labels[idx], dtype=torch.float32)

        # --- Use soundfile for FLAC ---
        if path.lower().endswith(".flac"):
            wav_np, sr = sf.read(path)         # numpy array (T,) or (T, C)
            if wav_np.ndim > 1:                # convert stereo ‚Üí mono
                wav_np = wav_np.mean(axis=1)
            wav = torch.tensor(wav_np, dtype=torch.float32).unsqueeze(0)  # [1, T]

        # --- Use torchaudio for WAV ---
        else:
            wav, sr = torchaudio.load(path)

        # --- Resample if needed ---
        if sr != self.sample_rate:
            wav = torchaudio.functional.resample(wav, sr, self.sample_rate)

        # --- Crop/pad ---
        num_samples = int(self.sample_rate * self.duration_sec)

        if wav.size(1) > num_samples:
            start = random.randint(0, wav.size(1) - num_samples)
            wav = wav[:, start:start + num_samples]
        elif wav.size(1) < num_samples:
            wav = F.pad(wav, (0, num_samples - wav.size(1)))

        return wav.squeeze(0), label

# ============================================================
# EER FUNCTION (FOR CM SYSTEM)
# ============================================================

def calculate_EER(labels, scores):
    """Equal Error Rate for Countermeasure system (bonafide=1, spoof=0)."""
    fpr, tpr, _ = metrics.roc_curve(labels, scores, pos_label=1)
    eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer


# ============================================================
# t-DCF FUNCTION (CM-only version using reference ASV parameters)
# ============================================================

def compute_tDCF(bonafide_score_cm, spoof_score_cm, Pfa_asv, Pmiss_asv, Pfa_spoof_asv, cost_model):
    # 1. Compute CM miss/false-alarm rates for thresholds
    cm_scores = np.concatenate([bonafide_score_cm, spoof_score_cm])
    labels = np.concatenate([np.ones_like(bonafide_score_cm), np.zeros_like(spoof_score_cm)])
    sorted_idx = np.argsort(cm_scores)[::-1]
    sorted_labels = labels[sorted_idx]

    tar = np.sum(sorted_labels)
    non = len(sorted_labels) - tar

    cm_miss = np.cumsum(sorted_labels == 1) / tar
    cm_fa = np.cumsum(sorted_labels == 0) / non

    # 2. Compute t-DCF per threshold
    Cmiss, Cfa, Cfa_spoof = cost_model['Cmiss'], cost_model['Cfa'], cost_model['Cfa_spoof']
    Ptar, Pnon, Pspoof = cost_model['Ptar'], cost_model['Pnon'], cost_model['Pspoof']

    tDCF = (Cmiss * Ptar * Pmiss_asv * (1 - cm_miss) +
            Cfa * Pnon * Pfa_asv * cm_fa +
            Cfa_spoof * Pspoof * Pfa_spoof_asv * (1 - cm_miss)) / (
            Cmiss * Ptar * Pmiss_asv + Cfa * Pnon * Pfa_asv)

    tDCF_norm = tDCF / np.min(tDCF)
    thresholds = cm_scores[sorted_idx]

    return tDCF_norm, thresholds


# ============================================================
# TRAIN + VALIDATE + TEST LOOP
# ============================================================

sys_cfg = SysConfig()
exp_cfg = ExpConfig()

train_ds = ASVspoofFolderDataset(sys_cfg.path_train, exp_cfg.sample_rate, exp_cfg.train_duration_sec)
val_ds   = ASVspoofFolderDataset(sys_cfg.path_dev, exp_cfg.sample_rate, exp_cfg.test_duration_sec)
test_ds  = ASVspoofFolderDataset(sys_cfg.path_test, exp_cfg.sample_rate, exp_cfg.test_duration_sec)

train_loader = DataLoader(train_ds, batch_size=exp_cfg.batch_size, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_ds, batch_size=exp_cfg.batch_size, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_ds, batch_size=exp_cfg.batch_size, shuffle=False, num_workers=0)

# --- your modules ---
pre = PreEmphasis(exp_cfg.pre_emphasis).to(DEVICE)
model = Rawformer_S(device=DEVICE, transformer_hidden=exp_cfg.transformer_hidden, sample_rate=exp_cfg.sample_rate).to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=exp_cfg.lr)
criterion = nn.BCELoss()

best_val_eer = 1.0  # initialize high value
save_path = "best_cm_model.pth"

print("üöÄ Starting training...\n")

# Reference ASV parameters (official ASVspoof setup)
Pfa_asv = 0.0005
Pmiss_asv = 0.05
Pmiss_spoof_asv = 0.95
Pfa_spoof_asv = 1.0 - Pmiss_spoof_asv
cost_model = {
    'Ptar': 0.9801,
    'Pnon': 0.0099,
    'Pspoof': 0.0100,
    'Cmiss': 1,
    'Cfa': 10,
    'Cfa_spoof': 10
}

for epoch in range(1, exp_cfg.epochs + 1):
    # === TRAIN ===
    model.train()
    total_loss, total_samples = 0.0, 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{exp_cfg.epochs} [Train]", leave=True)

    for wav, label in pbar:
        wav, label = wav.to(DEVICE), label.to(DEVICE)
        wav = pre(wav)

        opt.zero_grad()
        pred = model(wav).squeeze(-1)
        loss = criterion(pred, label)
        loss.backward()
        opt.step()

        bs = wav.size(0)
        total_loss += loss.item() * bs
        total_samples += bs
        pbar.set_postfix(loss=f"{total_loss / total_samples:.4f}")

    avg_train_loss = total_loss / total_samples

    # === VALIDATE ===
    model.eval()
    val_loss, val_samples = 0.0, 0
    all_scores, all_labels = [], []

    with torch.no_grad():
        pbar = tqdm(val_loader, desc=f"Epoch {epoch}/{exp_cfg.epochs} [Val]", leave=True)
        for wav, label in pbar:
            wav, label = wav.to(DEVICE), label.to(DEVICE)
            wav = pre(wav)
            pred = model(wav).squeeze(-1)
            loss = criterion(pred, label)

            bs = wav.size(0)
            val_loss += loss.item() * bs
            val_samples += bs

            all_scores.extend(pred.cpu().numpy())
            all_labels.extend(label.cpu().numpy())

    avg_val_loss = val_loss / val_samples
    eer = calculate_EER(all_labels, all_scores)

    # --- Compute t-DCF ---
    bona_cm = np.array(all_scores)[np.array(all_labels) == 1]
    spoof_cm = np.array(all_scores)[np.array(all_labels) == 0]
    tDCF_curve, thr = compute_tDCF(bona_cm, spoof_cm, Pfa_asv, Pmiss_asv, Pfa_spoof_asv, cost_model)
    min_tDCF = np.min(tDCF_curve)

    print(f"üßæ Epoch {epoch} Summary:")
    print(f"   Train Loss: {avg_train_loss:.4f}")
    print(f"   Val Loss:   {avg_val_loss:.4f}")
    print(f"   Val EER:    {eer * 100:.2f}%")
    print(f"   min-tDCF:   {min_tDCF:.4f}")

    # === SAVE BEST MODEL ===
    if eer < best_val_eer:
        best_val_eer = eer
        torch.save(model, save_path)
        print(f"üíæ Saved new best model (EER={eer*100:.2f}%) to {save_path}")

    print("-" * 60)


# ============================================================
# FINAL TEST PHASE
# ============================================================

print("\n" + "=" * 60)
print("üèÅ Starting final testing...")
print("=" * 60)

# Load best model
best_model = torch.load(save_path, map_location=DEVICE)
best_model.eval()

test_scores, test_labels = [], []

with torch.no_grad():
    pbar = tqdm(test_loader, desc="Testing", leave=True)
    for wav, label in pbar:
        wav, label = wav.to(DEVICE), label.to(DEVICE)
        wav = pre(wav)
        pred = best_model(wav).squeeze(-1)
        test_scores.extend(pred.cpu().numpy())
        test_labels.extend(label.cpu().numpy())

test_eer = calculate_EER(test_labels, test_scores)

bona_cm = np.array(test_scores)[np.array(test_labels) == 1]
spoof_cm = np.array(test_scores)[np.array(test_labels) == 0]
tDCF_curve, thr = compute_tDCF(bona_cm, spoof_cm, Pfa_asv, Pmiss_asv, Pfa_spoof_asv, cost_model)
min_tDCF = np.min(tDCF_curve)

print(f"üéØ Final Test EER:  {test_eer * 100:.2f}%")
print(f"üìä Final min-tDCF: {min_tDCF:.4f}")

‚úÖ Using device: cuda
üìÅ Loaded 25380 files from G:\INTERSPEECH_26\LA\ASV19\train
üìÅ Loaded 24844 files from G:\INTERSPEECH_26\LA\ASV19\dev
üìÅ Loaded 24844 files from G:\INTERSPEECH_26\LA\ASV19\dev
üöÄ Starting training...



Epoch 1/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 1/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 1 Summary:
   Train Loss: 0.2469
   Val Loss:   0.1918
   Val EER:    13.47%
   min-tDCF:   1.0000
üíæ Saved new best model (EER=13.47%) to best_cm_model.pth
------------------------------------------------------------


Epoch 2/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 2/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 2 Summary:
   Train Loss: 0.1579
   Val Loss:   0.0964
   Val EER:    6.24%
   min-tDCF:   1.0000
üíæ Saved new best model (EER=6.24%) to best_cm_model.pth
------------------------------------------------------------


Epoch 3/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 3/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 3 Summary:
   Train Loss: 0.0852
   Val Loss:   0.0784
   Val EER:    5.27%
   min-tDCF:   1.0000
üíæ Saved new best model (EER=5.27%) to best_cm_model.pth
------------------------------------------------------------


Epoch 4/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 4/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 4 Summary:
   Train Loss: 0.0557
   Val Loss:   0.0320
   Val EER:    2.28%
   min-tDCF:   1.0000
üíæ Saved new best model (EER=2.28%) to best_cm_model.pth
------------------------------------------------------------


Epoch 5/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 5/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 5 Summary:
   Train Loss: 0.0374
   Val Loss:   0.0353
   Val EER:    2.35%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 6/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 6/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 6 Summary:
   Train Loss: 0.0373
   Val Loss:   0.0266
   Val EER:    1.97%
   min-tDCF:   1.0000
üíæ Saved new best model (EER=1.97%) to best_cm_model.pth
------------------------------------------------------------


Epoch 7/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 7/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 7 Summary:
   Train Loss: 0.0355
   Val Loss:   0.0394
   Val EER:    2.43%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 8/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 8/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 8 Summary:
   Train Loss: 0.0433
   Val Loss:   0.0645
   Val EER:    2.55%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 9/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 9/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 9 Summary:
   Train Loss: 0.0284
   Val Loss:   0.0716
   Val EER:    2.16%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 10/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 10/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 10 Summary:
   Train Loss: 0.0354
   Val Loss:   0.0523
   Val EER:    2.02%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 11/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 11/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 11 Summary:
   Train Loss: 0.0279
   Val Loss:   0.0206
   Val EER:    1.37%
   min-tDCF:   1.0000
üíæ Saved new best model (EER=1.37%) to best_cm_model.pth
------------------------------------------------------------


Epoch 12/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 12/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 12 Summary:
   Train Loss: 0.0248
   Val Loss:   0.0482
   Val EER:    2.24%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 13/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 13/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 13 Summary:
   Train Loss: 0.0312
   Val Loss:   0.0613
   Val EER:    2.35%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 14/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 14/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 14 Summary:
   Train Loss: 0.0327
   Val Loss:   0.0537
   Val EER:    1.26%
   min-tDCF:   1.0000
üíæ Saved new best model (EER=1.26%) to best_cm_model.pth
------------------------------------------------------------


Epoch 15/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 15/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 15 Summary:
   Train Loss: 0.0245
   Val Loss:   0.0182
   Val EER:    1.02%
   min-tDCF:   1.0000
üíæ Saved new best model (EER=1.02%) to best_cm_model.pth
------------------------------------------------------------


Epoch 16/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 16/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 16 Summary:
   Train Loss: 0.0209
   Val Loss:   0.0225
   Val EER:    1.30%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 17/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 17/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 17 Summary:
   Train Loss: 0.0376
   Val Loss:   0.0205
   Val EER:    1.22%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 18/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 18/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 18 Summary:
   Train Loss: 0.0398
   Val Loss:   0.0558
   Val EER:    1.54%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 19/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 19/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 19 Summary:
   Train Loss: 0.0292
   Val Loss:   0.5847
   Val EER:    8.47%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 20/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 20/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 20 Summary:
   Train Loss: 0.0319
   Val Loss:   0.0322
   Val EER:    1.45%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 21/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 21/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 21 Summary:
   Train Loss: 0.0206
   Val Loss:   0.0146
   Val EER:    1.06%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 22/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 22/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 22 Summary:
   Train Loss: 0.0193
   Val Loss:   0.0222
   Val EER:    1.14%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 23/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 23/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 23 Summary:
   Train Loss: 0.0225
   Val Loss:   0.0275
   Val EER:    1.84%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 24/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 24/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 24 Summary:
   Train Loss: 0.0282
   Val Loss:   0.0382
   Val EER:    1.69%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 25/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 25/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 25 Summary:
   Train Loss: 0.0239
   Val Loss:   0.0192
   Val EER:    1.02%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 26/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 26/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 26 Summary:
   Train Loss: 0.0237
   Val Loss:   0.0166
   Val EER:    0.99%
   min-tDCF:   1.0000
üíæ Saved new best model (EER=0.99%) to best_cm_model.pth
------------------------------------------------------------


Epoch 27/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 27/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 27 Summary:
   Train Loss: 0.0229
   Val Loss:   0.0188
   Val EER:    1.30%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 28/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 28/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 28 Summary:
   Train Loss: 0.0166
   Val Loss:   0.0143
   Val EER:    1.02%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 29/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 29/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 29 Summary:
   Train Loss: 0.0191
   Val Loss:   0.0208
   Val EER:    1.53%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 30/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 30/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 30 Summary:
   Train Loss: 0.0253
   Val Loss:   0.0247
   Val EER:    1.66%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 31/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 31/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 31 Summary:
   Train Loss: 0.0195
   Val Loss:   0.0205
   Val EER:    1.06%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 32/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 32/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 32 Summary:
   Train Loss: 0.0184
   Val Loss:   0.0115
   Val EER:    0.78%
   min-tDCF:   1.0000
üíæ Saved new best model (EER=0.78%) to best_cm_model.pth
------------------------------------------------------------


Epoch 33/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 33/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 33 Summary:
   Train Loss: 0.0141
   Val Loss:   0.0284
   Val EER:    2.00%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 34/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 34/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 34 Summary:
   Train Loss: 0.0175
   Val Loss:   0.0275
   Val EER:    1.20%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 35/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 35/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 35 Summary:
   Train Loss: 0.0190
   Val Loss:   0.0204
   Val EER:    1.02%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 36/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 36/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 36 Summary:
   Train Loss: 0.0163
   Val Loss:   0.0234
   Val EER:    0.90%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 37/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 37/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 37 Summary:
   Train Loss: 0.0236
   Val Loss:   0.0180
   Val EER:    1.26%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 38/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 38/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 38 Summary:
   Train Loss: 0.0172
   Val Loss:   0.0163
   Val EER:    1.02%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 39/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 39/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 39 Summary:
   Train Loss: 0.0139
   Val Loss:   0.0194
   Val EER:    0.97%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 40/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 40/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 40 Summary:
   Train Loss: 0.0119
   Val Loss:   0.0122
   Val EER:    0.67%
   min-tDCF:   1.0000
üíæ Saved new best model (EER=0.67%) to best_cm_model.pth
------------------------------------------------------------


Epoch 41/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 41/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 41 Summary:
   Train Loss: 0.0133
   Val Loss:   0.0184
   Val EER:    1.33%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 42/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 42/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 42 Summary:
   Train Loss: 0.0165
   Val Loss:   0.0221
   Val EER:    1.31%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 43/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 43/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 43 Summary:
   Train Loss: 0.0154
   Val Loss:   0.0111
   Val EER:    0.90%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 44/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 44/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 44 Summary:
   Train Loss: 0.0213
   Val Loss:   0.0907
   Val EER:    5.58%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 45/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 45/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 45 Summary:
   Train Loss: 0.0296
   Val Loss:   0.0168
   Val EER:    1.14%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 46/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 46/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 46 Summary:
   Train Loss: 0.0207
   Val Loss:   0.0906
   Val EER:    6.27%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 47/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 47/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 47 Summary:
   Train Loss: 0.0230
   Val Loss:   0.0188
   Val EER:    1.10%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 48/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 48/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 48 Summary:
   Train Loss: 0.0185
   Val Loss:   0.0340
   Val EER:    1.39%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 49/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 49/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 49 Summary:
   Train Loss: 0.0224
   Val Loss:   0.0212
   Val EER:    1.33%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 50/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 50/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 50 Summary:
   Train Loss: 0.0200
   Val Loss:   0.0217
   Val EER:    1.26%
   min-tDCF:   1.0000
------------------------------------------------------------

üèÅ Starting final testing...


  best_model = torch.load(save_path, map_location=DEVICE)


Testing:   0%|          | 0/777 [00:00<?, ?it/s]

üéØ Final Test EER:  0.70%
üìä Final min-tDCF: 1.0000


In [25]:
!pip install torchinfo fvcore

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Collecting fvcore
  Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting yacs>=0.1.6 (from fvcore)
  Downloading yacs-0.1.8-py3-none-any.whl.metadata (639 bytes)
Collecting termcolor>=1.1 (from fvcore)
  Downloading termcolor-3.2.0-py3-none-any.whl.metadata (6.4 kB)
Collecting iopath>=0.1.7 (from fvcore)
  Downloading iopath-0.1.10.tar.gz (42 kB)
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting portalocker (from iopath>=0.1.7->fvcore)
  Downloading portalocker-3.2.0-py3-none-any.whl.metadata (8.7 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Downloading termcolor-3.2.0-py3-none-any.whl (7.7 kB)
Downloading yacs-0.1.8-py3-none-any.whl (14 kB)
Down

In [26]:
# --------------------------------------------------------------
#  Model size & FLOPs (place this right after model creation)
# --------------------------------------------------------------
import torch
from torchinfo import summary
from fvcore.nn import FlopCountAnalysis, parameter_count

# --------------------------------------------------------------
# 1. Parameter count (trainable + non-trainable) + size in MiB
# --------------------------------------------------------------
def print_model_params(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total     = sum(p.numel() for p in model.parameters())
    size_mb   = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**2)

    print("\n" + "="*60)
    print("MODEL PARAMETER SUMMARY")
    print("="*60)
    print(f"{'Trainable params':<25}: {trainable:,}")
    print(f"{'Total params'    :<25}: {total:,}")
    print(f"{'Model size (MiB)':<25}: {size_mb:.2f}")
    print("="*60 + "\n")

print_model_params(model)

# --------------------------------------------------------------
# 2. FLOPs / MACs
# --------------------------------------------------------------
# We need a dummy waveform that matches the shape expected by the model.
#   - Rawformer_S expects raw audio: (batch, time)
#   - Use the maximum length defined in the config (or a typical 4-second clip)
max_len_sec = getattr(exp_cfg, "max_len_sec", 4.0)          # fallback 4 s
max_samples = int(exp_cfg.sample_rate * max_len_sec)

dummy_wav = torch.randn(1, max_samples, device=DEVICE)     # (B, T)

# Apply pre-emphasis if it is used in training/validation
if pre is not None:
    dummy_wav = pre(dummy_wav)

# ---- fvcore (very accurate) ----
flops = FlopCountAnalysis(model, dummy_wav)
macs  = flops.total()                # MACs = multiply-adds
flops_2 = macs * 2                   # FLOPs = 2 √ó MACs (standard convention)

print("\n" + "="*60)
print("FLOPs / MACs (per forward pass)")
print("="*60)
print(f"{'Input shape'   :<25}: {list(dummy_wav.shape)}")
print(f"{'MACs'          :<25}: {macs/1e9:.3f} G")
print(f"{'FLOPs'         :<25}: {flops_2/1e9:.3f} G")
print("="*60 + "\n")

# ---- torchinfo (nice table, optional) ----
print("Detailed layer-wise breakdown (torchinfo):")
summary(model,
        input_data=dummy_wav,
        col_names=["input_size", "output_size", "num_params", "mult_adds"],
        depth=4,
        verbose=0)

Unsupported operator aten::abs encountered 1 time(s)
Unsupported operator aten::max_pool2d encountered 5 time(s)
Unsupported operator aten::selu_ encountered 8 time(s)
Unsupported operator aten::add encountered 17 time(s)
Unsupported operator aten::mul encountered 7 time(s)
Unsupported operator aten::div encountered 6 time(s)
Unsupported operator aten::softmax encountered 3 time(s)
Unsupported operator aten::mean encountered 4 time(s)
Unsupported operator aten::var encountered 4 time(s)
Unsupported operator aten::sub encountered 4 time(s)
Unsupported operator aten::sqrt encountered 4 time(s)
Unsupported operator aten::sigmoid encountered 1 time(s)



MODEL PARAMETER SUMMARY
Trainable params         : 345,132
Total params             : 345,132
Model size (MiB)         : 1.32


FLOPs / MACs (per forward pass)
Input shape              : [1, 64000]
MACs                     : 6.186 G
FLOPs                    : 12.371 G

Detailed layer-wise breakdown (torchinfo):


Layer (type:depth-idx)                                       Input Shape               Output Shape              Param #                   Mult-Adds
Rawformer_S                                                  [1, 64000]                [1]                       --                        --
‚îú‚îÄFrontend_S: 1-1                                            [1, 64000]                [1, 64, 23, 16]           --                        --
‚îÇ    ‚îî‚îÄSincConv: 2-1                                         [1, 1, 64000]             [1, 70, 63872]            --                        --
‚îÇ    ‚îî‚îÄBatchNorm2d: 2-2                                      [1, 1, 23, 21290]         [1, 1, 23, 21290]         2                         2
‚îÇ    ‚îî‚îÄSELU: 2-3                                             [1, 1, 23, 21290]         [1, 1, 23, 21290]         --                        --
‚îÇ    ‚îî‚îÄSequential: 2-4                                       [1, 1, 23, 21290]         [1, 64, 23, 16]           -

In [27]:
seconds = max_samples / exp_cfg.sample_rate
print(f"GFLOPs per second of audio : {flops_2/1e9/seconds:.3f}")

GFLOPs per second of audio : 3.093
