In [1]:
import torch

print("CUDA available:", torch.cuda.is_available())
print("GPU count:", torch.cuda.device_count())

for i in range(torch.cuda.device_count()):
    print(f"[{i}] {torch.cuda.get_device_name(i)}")

print(torch.__version__)

CUDA available: True
GPU count: 1
[0] NVIDIA GeForce RTX 4090
2.5.1+cu121


In [2]:
import os
import math
import random
from collections import OrderedDict
from typing import Any, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchaudio
from tqdm.auto import tqdm
from sklearn.metrics import roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from sklearn import metrics
from torch.utils.data import DataLoader

In [3]:
# Optional (nice for shapes):
try:
    from torchinfo import summary
    HAS_TORCHINFO = True
except Exception:
    HAS_TORCHINFO = False

save_path = r"C:\Users\Admin\Desktop\Test Folder Arth Shah\LearnableGammatone_best_cm_model.pth"
# ------------------ Configurable Paths ------------------ #
class SysConfig:
    """
    Folder-based dataset structure.
    Each split (train/dev/test) contains two subfolders: bonafide and spoof.
    """
    path_train =r"G:\INTERSPEECH_26\LA\ASV19\train"
    path_dev   = r"G:\INTERSPEECH_26\LA\ASV19\dev"
    path_test  =r"G:\INTERSPEECH_26\LA\ASV19\dev"


# ------------------ Experiment Hyperparameters ------------------ #
class ExpConfig:
    # Audio processing
    sample_rate = 16000
    pre_emphasis = 0.97
    train_duration_sec = 4
    test_duration_sec = 4

    # Model
    transformer_hidden = 660

    # Training hyperparameters
    batch_size = 32
    lr = 8*1e-4
    epochs = 50  # increase as needed

In [4]:
try:
    from torch_audiomentations import (
        Compose, AddColoredNoise, HighPassFilter, LowPassFilter, Gain
    )
    HAS_TA = True
except Exception:
    HAS_TA = False
    Compose = AddColoredNoise = HighPassFilter = LowPassFilter = Gain = None

class WaveformAugmentation(nn.Module):
    def __init__(self, aug_list=('ACN', 'HPF', 'LPF', 'GAN'), sr=16000):
        super().__init__()
        self.sr = sr
        if HAS_TA:
            transforms = []
            if 'ACN' in aug_list:
                transforms.append(AddColoredNoise(10, 40, -2.0, 2.0, p=0.5))
            if 'HPF' in aug_list:
                transforms.append(HighPassFilter(20.0, 2400.0, p=0.5))
            if 'LPF' in aug_list:
                transforms.append(LowPassFilter(150.0, 7500.0, p=0.5))
            if 'GAN' in aug_list:
                transforms.append(Gain(-15.0, 5.0, p=0.5))
            self.apply_augmentation = Compose(transforms) if transforms else None
        else:
            # No-op if torch_audiomentations isn't available
            self.apply_augmentation = None

    def forward(self, wav: torch.Tensor) -> torch.Tensor:
        # wav: (B, T)
        if self.apply_augmentation is None:
            return wav
        return self.apply_augmentation(wav.unsqueeze(1), self.sr).squeeze(1)


In [5]:
class PreEmphasis(nn.Module):
    def __init__(self, pre_emphasis: float = 0.97):
        super().__init__()
        # Conv1D filter shape: (out_channels=1, in_channels=1, kernel_size=2)
        filt = torch.tensor([[-pre_emphasis, 1.0]], dtype=torch.float32).unsqueeze(0)
        self.register_buffer("filter", filt)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T)
        x = x.unsqueeze(1)  # (B,1,T)
        x = F.pad(x, (1, 0), mode="reflect")
        x = F.conv1d(x, self.filter)
        return x.squeeze(1)  # (B,T)


In [6]:
class SincConv(nn.Module):
    """
    Adapted from AASIST. One input channel only.
    """
    @staticmethod
    def to_mel(hz): return 2595 * np.log10(1 + hz / 700)
    @staticmethod
    def to_hz(mel): return 700 * (10**(mel / 2595) - 1)

    def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1, stride=1, padding=0, dilation=1):
        super().__init__()
        if in_channels != 1:
            raise ValueError("SincConv supports only one input channel.")
        self.out_channels = out_channels
        self.sample_rate = sample_rate
        self.kernel_size = kernel_size + (kernel_size % 2 == 0)

        self.stride = stride
        self.padding = padding
        self.dilation = dilation

        NFFT = 512
        f = int(sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
        fmel = self.to_mel(f)
        filbandwidthsmel = np.linspace(fmel.min(), fmel.max(), out_channels + 1)
        filbandwidthsf = self.to_hz(filbandwidthsmel)

        self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
                                  (self.kernel_size - 1) / 2 + 1)

        band_pass = torch.zeros(out_channels, self.kernel_size)
        for i in range(out_channels):
            fmin, fmax = filbandwidthsf[i], filbandwidthsf[i + 1]
            hHigh = (2 * fmax / sample_rate) * np.sinc(2 * fmax * self.hsupp / sample_rate)
            hLow  = (2 * fmin / sample_rate) * np.sinc(2 * fmin * self.hsupp / sample_rate)
            hideal = hHigh - hLow
            band_pass[i, :] = torch.tensor(np.hamming(self.kernel_size)) * torch.tensor(hideal)
        self.register_buffer("band_pass", band_pass)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B,1,T)
        filt = self.band_pass.to(x.device).view(self.out_channels, 1, self.kernel_size)
        return F.conv1d(x, filt, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)

class LearnableSincConv(nn.Module):
    @staticmethod
    def to_mel(hz):
        return 2595 * np.log10(1 + hz / 700)

    @staticmethod
    def to_hz(mel):
        return 700 * (10**(mel / 2595) - 1)

    def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1,
                 stride=1, padding=0, dilation=1, bias=False, min_low_hz=50, min_band_hz=50):
        super().__init__()
        if in_channels != 1:
            raise ValueError(f"SincConv only supports one input channel, got {in_channels}")
        if kernel_size % 2 == 0:
            kernel_size += 1

        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.sample_rate = sample_rate
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.min_low_hz = min_low_hz
        self.min_band_hz = min_band_hz

        # Mel-scale initialization
        NFFT = 512
        f = np.linspace(0, sample_rate / 2, int(NFFT / 2) + 1)
        fmel = self.to_mel(f)
        mel_points = np.linspace(fmel.min(), fmel.max(), out_channels + 1)
        hz_points = self.to_hz(mel_points)

        # Initialize learnable parameters for low cutoff and bandwidth
        low_hz = hz_points[:-1]
        band_hz = np.diff(hz_points)

        self.low_hz_ = nn.Parameter(torch.tensor(low_hz, dtype=torch.float32))
        self.band_hz_ = nn.Parameter(torch.tensor(band_hz, dtype=torch.float32))

        # Time axis for filter generation
        n = torch.arange(-(kernel_size - 1) / 2, (kernel_size - 1) / 2 + 1)
        self.register_buffer('n', n)

    def forward(self, x):
        device = x.device
        n = self.n.to(device)

        # Enforce positive frequency constraints
        low = self.min_low_hz + torch.abs(self.low_hz_)
        high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_), self.min_low_hz, self.sample_rate / 2 - 1)

        band = (high - low)[:, None]
        f_times_t_low = 2 * np.pi * low[:, None] * n / self.sample_rate
        f_times_t_high = 2 * np.pi * high[:, None] * n / self.sample_rate

        # Compute filters using sinc functions
        sinc_high = torch.sin(f_times_t_high) / (n / self.sample_rate + 1e-8)
        sinc_low = torch.sin(f_times_t_low) / (n / self.sample_rate + 1e-8)
        filters = sinc_high - sinc_low

        # Apply window (Hamming)
        window = 0.54 - 0.46 * torch.cos(2 * np.pi * (torch.arange(self.kernel_size).to(device)) / self.kernel_size)
        filters = filters * window

        # Normalize
        filters = filters / (2 * band)

        filters = filters.view(self.out_channels, 1, self.kernel_size)
        return F.conv1d(x, filters, stride=self.stride, padding=self.padding)

class AdaptiveGaborConv(nn.Module):
    def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1):
        super().__init__()
        if in_channels != 1:
            raise ValueError("GaborConv only supports 1 input channel")
        if kernel_size % 2 == 0:
            kernel_size += 1  # ensure odd kernel size

        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.sample_rate = sample_rate

        # Time support
        self.t = torch.arange(-(kernel_size - 1) / 2, (kernel_size - 1) / 2 + 1, dtype=torch.float32) / sample_rate  # [kernel_size]

        # Mel-scale initialization for center frequencies
        NFFT = 512
        f = np.linspace(0, sample_rate / 2, int(NFFT / 2) + 1)
        mel = 2595 * np.log10(1 + f / 700)
        mel_centers = np.linspace(mel.min(), mel.max(), out_channels)
        hz_centers = 700 * (10 ** (mel_centers / 2595) - 1)
        eta_init = hz_centers / sample_rate  # normalized center freqs [0‚Äì0.5]
        self.eta = nn.Parameter(torch.tensor(eta_init, dtype=torch.float32))

        # Adaptive bandwidths inversely proportional to frequency
        base_sigma = (kernel_size / sample_rate) / 4  # base scale
        sigma_init = base_sigma / (self.eta + 1e-4)   # inverse proportionality
        sigma_init = torch.clamp(torch.tensor(sigma_init, dtype=torch.float32), 1e-4, 0.05)
        self.sigma_scale = nn.Parameter(torch.ones(out_channels))  # learnable global scaling
        self.register_buffer("sigma_init", sigma_init)

    def _create_filters(self, device):
        t = self.t.to(device)                         # [kernel_size]
        eta = torch.clamp(self.eta, 1e-4, 0.5)        # [out_channels]
        sigma = (self.sigma_init.to(device) * self.sigma_scale).unsqueeze(1)  # [out_channels, 1]
        eta = eta.unsqueeze(1)                        # [out_channels, 1]

        # Gaussian window ‚Äî broadcast over time
        gaussian = torch.exp(-t[None, :]**2 / (2 * sigma**2)) / (np.sqrt(2 * np.pi) * sigma)

        # Cosine/sine modulations
        cos_component = torch.cos(2 * np.pi * eta * t[None, :])
        sin_component = torch.sin(2 * np.pi * eta * t[None, :])

        filters_real = gaussian * cos_component
        filters_imag = gaussian * sin_component

        filters = torch.cat([filters_real, filters_imag], dim=0)  # [2*out_channels, kernel_size]
        filters = filters / (filters.abs().max(dim=1, keepdim=True)[0] + 1e-8)
        return filters

    def forward(self, x):
        filters = self._create_filters(x.device)
        filters = filters.view(2 * self.out_channels, 1, self.kernel_size)
        padding = self.kernel_size // 2
        out = F.conv1d(x, filters, stride=1, padding=padding)
        real, imag = out[:, :self.out_channels, :], out[:, self.out_channels:, :]
        magnitude = torch.sqrt(real**2 + imag**2 + 1e-8)
        return magnitude


In [7]:
# Nice-to-have (optional)
try:
    from torchinfo import summary
    HAS_TORCHINFO = True
except Exception:
    HAS_TORCHINFO = False

class ScaledDotProductAttention(nn.Module):
    """
    Expects Q,K,V: (B, H, S, D). Optional mask: (B,1,1,S) or broadcastable.
    """
    def __init__(self):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, Q, K, V, mask: Optional[torch.Tensor] = None):
        assert Q.dim() == K.dim() == V.dim() == 4  # (B,H,S,D)
        d_k = K.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)  # (B,H,S_q,S_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))
        attn = self.softmax(scores)
        out = torch.matmul(attn, V)  # (B,H,S_q,D)
        return out


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_head: int):
        super().__init__()
        assert d_model % n_head == 0, "d_model must be divisible by n_head"
        self.n_head = n_head
        self.d_head = d_model // n_head

        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.attn = ScaledDotProductAttention()
        self.W_out = nn.Linear(d_model, d_model)

    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B,S,D) -> (B,H,S,Dh)
        B, S, D = x.size()
        x = x.view(B, S, self.n_head, self.d_head).permute(0, 2, 1, 3)
        return x

    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B,H,S,Dh) -> (B,S,D)
        B, H, S, Dh = x.size()
        return x.permute(0, 2, 1, 3).contiguous().view(B, S, H * Dh)

    def forward(self, Q, K, V, mask: Optional[torch.Tensor] = None):
        assert Q.dim() == K.dim() == V.dim() == 3  # (B,S,D)
        q = self._split_heads(self.W_Q(Q))
        k = self._split_heads(self.W_K(K))
        v = self._split_heads(self.W_V(V))
        if mask is not None:
            # make mask broadcastable to (B,H,S_q,S_k)
            mask = mask.unsqueeze(1)
        context = self.attn(q, k, v, mask=mask)
        context = self._merge_heads(context)
        return self.W_out(context)


class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, unbiased=False, keepdim=True)
        xhat = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * xhat + self.beta


class FFN(nn.Module):
    def __init__(self, d_model, ffn_hidden, drop_prob=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, ffn_hidden),
            nn.ReLU(),
            nn.Dropout(drop_prob),
            nn.Linear(ffn_hidden, d_model),
        )

    def forward(self, x):
        return self.net(x)


class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model=64, n_head=8, ffn_hidden=2048, drop_prob=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_head)
        self.dropout1 = nn.Dropout(drop_prob)
        self.norm1 = LayerNorm(d_model)
        self.ffn = FFN(d_model, ffn_hidden, drop_prob)
        self.dropout2 = nn.Dropout(drop_prob)
        self.norm2 = LayerNorm(d_model)

    def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
        # x: (B,S,D)
        residual = x
        x = self.attn(x, x, x, mask=attn_mask)
        x = self.dropout1(x)
        x = self.norm1(x + residual)

        residual = x
        x = self.ffn(x)
        x = self.dropout2(x)
        x = self.norm2(x + residual)
        return x


In [8]:
"""
TD (Time-Domain) Filterbank and Learnable Gammatone Filterbank
PyTorch implementations intended as drop-in replacements for the SincConv layer.

Both modules accept (B,1,T) raw waveform and return (B, out_channels, T').
They are made fully drop-in compatible with a SincConv-style constructor
(i.e. accept `in_channels=1`, `out_channels=...`, `kernel_size=...`, `sample_rate=...`).
"""

from typing import Optional
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


# -------------------- helpers --------------------

def hz_to_mel(hz: np.ndarray) -> np.ndarray:
    return 2595.0 * np.log10(1.0 + hz / 700.0)


def mel_to_hz(mel: np.ndarray) -> np.ndarray:
    return 700.0 * (10 ** (mel / 2595.0) - 1.0)


def windowed_sinc_impulse(kernel_size: int, sr: int, fmin: float, fmax: float) -> np.ndarray:
    """Return a single band-pass windowed-sinc impulse response (numpy array).
    kernel_size must be odd.
    """
    assert kernel_size % 2 == 1, "kernel_size should be odd"
    t = np.arange(-(kernel_size - 1) / 2.0, (kernel_size - 1) / 2.0 + 1.0)
    h_high = (2 * fmax / sr) * np.sinc(2 * fmax * t / sr)
    h_low = (2 * fmin / sr) * np.sinc(2 * fmin * t / sr)
    hideal = h_high - h_low
    w = np.hamming(kernel_size)
    return hideal * w


# -------------------- TDFilterbank --------------------

class TDFilterbank(nn.Module):
    """Time-Domain Filterbank (drop-in replacement for SincConv).

    Accepts `in_channels` argument for API compatibility but only supports mono input.
    """

    def __init__(self,
                 in_channels: int = 1,          # dummy for compatibility (ignored)
                 out_channels: int = 70,
                 kernel_size: int = 129,
                 sample_rate: int = 16000,
                 learnable: bool = True,
                 learnable_f: bool = True,
                 min_low_hz: float = 30.0,
                 min_band_hz: float = 50.0,
                 **kwargs):
        super().__init__()
        if kernel_size % 2 == 0:
            kernel_size += 1
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.sample_rate = sample_rate
        self.learnable = learnable
        self.learnable_f = learnable_f

        # mel-spaced boundaries (numpy)
        NFFT = 512
        f = int(sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
        fmel = hz_to_mel(f)
        mel_bins = np.linspace(fmel.min(), fmel.max(), out_channels + 1)
        hz_bins = mel_to_hz(mel_bins)

        fmins = hz_bins[:-1].copy()
        fmaxs = hz_bins[1:].copy()

        # ensure minimum band
        fmins = np.maximum(fmins, min_low_hz)
        fmaxs = np.maximum(fmaxs, fmins + min_band_hz)

        # init kernels (numpy) shape (out_channels, kernel_size)
        init_kernels = np.zeros((out_channels, kernel_size), dtype=np.float32)
        for i in range(out_channels):
            init_kernels[i, :] = windowed_sinc_impulse(kernel_size, sample_rate, fmins[i], fmaxs[i])

        # normalize
        init_kernels /= np.maximum(np.abs(init_kernels).sum(axis=1, keepdims=True), 1e-8)

        # store initial kernels or params depending on mode
        if not learnable:
            # fixed kernel bank (register buffer for zero-parameter behavior)
            self.register_buffer('kernels', torch.tensor(init_kernels, dtype=torch.float32).unsqueeze(1))
            return

        if learnable_f:
            centres = (fmins + fmaxs) / 2.0
            bws = (fmaxs - fmins)

            self.log_centres = nn.Parameter(torch.log(torch.tensor(centres + 1.0, dtype=torch.float32)))
            self.log_bws = nn.Parameter(torch.log(torch.tensor(bws + 1.0, dtype=torch.float32)))
            self.register_buffer('kernel_window_ref', torch.tensor(init_kernels, dtype=torch.float32))
        else:
            # learn full kernels (Conv1d-like)
            self.kernels_param = nn.Parameter(torch.tensor(init_kernels, dtype=torch.float32).unsqueeze(1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B,1,T)
        returns: (B, out_channels, T')
        """
        if not self.learnable:
            return F.conv1d(x, self.kernels.to(x.device), stride=1, padding=(self.kernel_size - 1) // 2)

        if self.learnable_f:
            device = x.device
            centres = torch.exp(self.log_centres).to(device) - 1.0
            bws = torch.exp(self.log_bws).to(device) - 1.0

            fmin = centres - 0.5 * bws
            fmax = centres + 0.5 * bws

            # clamp ranges safely using tensor ops (avoid mixing tensor/scalar in positional args)
            max_freq_tensor = (self.sample_rate / 2.0) * torch.ones_like(fmax, device=device)
            fmin = torch.clamp(fmin, min=1.0)
            fmin = torch.min(fmin, max_freq_tensor - 2.0)
            fmax = torch.max(fmax, fmin + 1.0)
            fmax = torch.min(fmax, max_freq_tensor)

            # time vector (device)
            t = torch.linspace(
                -(self.kernel_size - 1) / 2.0,
                (self.kernel_size - 1) / 2.0,
                steps=self.kernel_size,
                device=device,
                dtype=torch.float32
            )

            kernels = []
            # build each kernel on-device
            for i in range(self.out_channels):
                hi = (2.0 * fmax[i] / self.sample_rate) * torch.sinc(2.0 * fmax[i] * t / self.sample_rate)
                lo = (2.0 * fmin[i] / self.sample_rate) * torch.sinc(2.0 * fmin[i] * t / self.sample_rate)

                h = hi - lo
                h = h * torch.hamming_window(self.kernel_size, periodic=False, device=device, dtype=torch.float32)
                h = h / (h.abs().sum() + 1e-8)
                kernels.append(h)

            kernels = torch.stack(kernels, dim=0).unsqueeze(1)  # (out,1,k)
            return F.conv1d(x, kernels, stride=1, padding=(self.kernel_size - 1) // 2)

        # learn full kernels branch
        return F.conv1d(x, self.kernels_param.to(x.device), stride=1, padding=(self.kernel_size - 1) // 2)


# -------------------- Learnable Gammatone Filterbank --------------------

class LearnableGammatone(nn.Module):
    """Learnable Gammatone Filterbank (drop-in compatible)."""

    def __init__(self,
                 in_channels: int = 1,         # dummy for compatibility (ignored)
                 out_channels: int = 70,
                 kernel_size: int = 129,
                 sample_rate: int = 16000,
                 n: int = 4,
                 min_freq: float = 30.0,
                 max_freq: Optional[float] = None,
                 **kwargs):
        super().__init__()

        if kernel_size % 2 == 0:
            kernel_size += 1

        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.sample_rate = sample_rate
        self.n = n
        self.max_freq = max_freq or sample_rate / 2.0

        # mel spaced centres (numpy)
        NFFT = 512
        f = int(sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
        mel = hz_to_mel(f)
        mel_bins = np.linspace(mel.min(), mel.max(), out_channels)
        centres = mel_to_hz(mel_bins)
        centres = np.clip(centres, min_freq, self.max_freq - 10.0)

        # ERB approx
        erb = 24.7 + 0.108 * centres

        self.log_centres = nn.Parameter(torch.log(torch.tensor(centres + 1.0, dtype=torch.float32)))
        self.log_band = nn.Parameter(torch.log(torch.tensor(erb + 1.0, dtype=torch.float32)))

        self.log_amp = nn.Parameter(torch.zeros(out_channels, dtype=torch.float32))

        t = np.arange(kernel_size, dtype=np.float32) - (kernel_size - 1) / 2.0
        self.register_buffer("t", torch.tensor(t, dtype=torch.float32))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B,1,T)
        returns: (B, out_channels, T')
        """
        device = x.device

        centres = torch.exp(self.log_centres).to(device) - 1.0
        band = torch.exp(self.log_band).to(device) - 1.0
        amp = torch.exp(self.log_amp).to(device)

        centres = torch.clamp(centres, 20.0, self.max_freq - 1.0)
        band = torch.clamp(band, 1.0, self.sample_rate / 4.0)

        t = self.t.to(device)
        kernels = []

        for i in range(self.out_channels):
            fc = centres[i]
            b = band[i]
            a = amp[i]

            tp = t  # centered time vector

            env = (tp.abs() / self.sample_rate) ** (self.n - 1)
            env = env * torch.exp(-2.0 * math.pi * b * tp.abs() / self.sample_rate)
            carrier = torch.cos(2.0 * math.pi * fc * tp / self.sample_rate)

            g = a * env * carrier
            g = g / (g.abs().sum() + 1e-8)

            kernels.append(g)

        kernels = torch.stack(kernels, dim=0).unsqueeze(1)
        return F.conv1d(x, kernels.to(device), stride=1, padding=(self.kernel_size - 1) // 2)



In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ----------------------------------------------------------------------
# 1. Frontend_S
# ----------------------------------------------------------------------
class Frontend_S(nn.Module):
    def __init__(self, device, sinc_kernel_size=128, sample_rate=16000):
        super().__init__()

        # ---- Sinc layer (no parameters ‚Üí safe on any device) ----
        self.sinc_layer = LearnableGammatone(
            in_channels=1,
            out_channels=70,
            kernel_size=sinc_kernel_size,
            sample_rate=sample_rate,
        )

        # ---- BatchNorm that must live on the target device ----
        self.bn = nn.BatchNorm2d(num_features=1).to(device)

        self.selu = nn.SELU(inplace=True)

        # ---- Conv blocks (they also contain BatchNorms) ----
        self.conv_blocks = nn.Sequential(
            Conv2DBlock_S(in_channels=1,  out_channels=32, is_first_block=True),
            Conv2DBlock_S(in_channels=32, out_channels=32),
            Conv2DBlock_S(in_channels=32, out_channels=64),
            Conv2DBlock_S(in_channels=64, out_channels=64),
        ).to(device)                     # <-- move the whole Sequential

    def forward(self, x):
        # x : [B, T]  (raw waveform)
        x = x.unsqueeze(1)                     # [B,1,T]
        x = self.sinc_layer(x)                 # [B,70,T']
        x = x.unsqueeze(1)                     # [B,1,70,T']
        x = F.max_pool2d(torch.abs(x), (3, 3)) # [B,1,F,T]
        x = self.bn(x)
        LFM = self.selu(x)

        HFM = self.conv_blocks(LFM)            # [B,64,f,t]
        return HFM


# ----------------------------------------------------------------------
# 2. Conv2DBlock_S
# ----------------------------------------------------------------------
class Conv2DBlock_S(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, is_first_block: bool = False):
        super().__init__()

        # ---- optional normaliser (BN+SELU) ----
        self.normalizer = None
        if not is_first_block:
            self.normalizer = nn.Sequential(
                nn.BatchNorm2d(in_channels),
                nn.SELU(inplace=True),
            )

        # ---- two conv layers + BN+SELU ----
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(2, 5), padding=(1, 2)),
            nn.BatchNorm2d(out_channels),
            nn.SELU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=(2, 3), padding=(0, 1)),
        )

        # ---- residual connection when channel count changes ----
        self.downsampler = None
        if in_channels != out_channels:
            self.downsampler = nn.Conv2d(in_channels, out_channels,
                                        kernel_size=(1, 3), padding=(0, 1))

        self.pooling = nn.MaxPool2d(kernel_size=(1, 6))

    def forward(self, x):
        identity = x
        if self.downsampler is not None:
            identity = self.downsampler(identity)

        if self.normalizer is not None:
            x = self.normalizer(x)

        x = self.layers(x) + identity
        x = self.pooling(x)
        return x


# ----------------------------------------------------------------------
# 3. PositionalAggregator1D
# ----------------------------------------------------------------------
class PositionalAggregator1D(nn.Module):
    def __init__(self, max_C: int, max_ft: int, device):
        super().__init__()

        self.flattener = nn.Flatten(start_dim=-2, end_dim=-1)

        # ----- sinusoidal positional encoding (no trainable params) -----
        pos = torch.arange(1, max_ft - 1, device=device).float().unsqueeze(1)   # (L-2,1)
        dim = torch.arange(0, max_C, step=2, device=device).float().unsqueeze(0)  # (1,D/2)

        enc = torch.zeros(max_ft, max_C, device=device)
        enc[1:-1, 0::2] = torch.sin(pos / (10000 ** (dim / max_C)))
        enc[1:-1, 1::2] = torch.cos(pos / (10000 ** (dim / max_C)))
        self.register_buffer('encoding', enc)   # stored on the correct device automatically

    def forward(self, HFM):
        """
        HFM : [B, C, f, t]
        out : [B, f*t, C]  with added positional encoding
        """
        B, C, f, t = HFM.shape
        ft = f * t
        out = self.flattener(HFM).transpose(1, 2)               # [B, f*t, C]
        out = out + self.encoding[:ft, :C]                      # broadcast
        return out

In [10]:
class Rawformer_S(nn.Module):
    def __init__(self, device, transformer_hidden=64, sample_rate: int = 16000):
        super().__init__()
        # ---- 1. give the front-end the device ----
        self.front_end = Frontend_S(sinc_kernel_size=128,
                                    sample_rate=sample_rate,
                                    device=device)          # <-- add this

        self.positional_embedding = PositionalAggregator1D(
            max_C=64, max_ft=23*16, device=device)

        self.classifier = RawformerClassifier(C=64, n_encoder=2, transformer_hidden=transformer_hidden)

        # ---- 2. move *everything* to the target device in one go ----
        self.to(device)                     # <-- important!

    def forward(self, x):
        x = self.front_end(x)               # now on correct device
        x = self.positional_embedding(x)
        x = self.classifier(x)
        return x

In [11]:
class SequencePooling(nn.Module):
    """
    Attention-style weighted pooling over sequence.
    Input: (B,S,C) -> Output: (B,C)
    """
    def __init__(self, d_model):
        super().__init__()
        self.linear = nn.Linear(d_model, 1)

    def forward(self, x):
        # x: (B,S,C)
        w = self.linear(x)               # (B,S,1)
        w = F.softmax(w.transpose(1, 2), dim=-1)  # (B,1,S)
        out = torch.matmul(w, x)         # (B,1,C)
        return out.squeeze(1)            # (B,C)


class RawformerClassifier(nn.Module):
    """
    Encoders (N layers) + SeqPool + Linear + Sigmoid
    Input: sequence (B,S,C)  Output: (B,) score in [0,1]
    """
    def __init__(self, C: int, n_encoder: int, transformer_hidden: int):
        super().__init__()
        self.encoders = nn.Sequential(OrderedDict([
            (f"encoder{i}", TransformerEncoderLayer(d_model=C, n_head=8, ffn_hidden=transformer_hidden))
            for i in range(n_encoder)
        ]))
        self.seq_pool = SequencePooling(d_model=C)
        self.fc = nn.Linear(C, 1)

    def forward(self, x):
        # x: (B,S,C)
        x = self.encoders(x)
        x = self.seq_pool(x)
        x = self.fc(x)
        return torch.sigmoid(x).squeeze(-1)   # (B,)


In [12]:
def collate_pad(batch):
    # Here all items are same length already; just stack.
    wavs, labels = zip(*batch)
    wavs = torch.stack(wavs, dim=0)
    labels = torch.tensor(labels, dtype=torch.float32)
    return wavs, labels


def train_one_epoch(model, loader, optimizer, criterion, preemph=None):
    model.train()
    total_loss = 0.0
    for wav, label in loader:
        wav = wav.to(DEVICE)
        label = label.to(DEVICE)

        if preemph is not None:
            wav = preemph(wav)

        optimizer.zero_grad()
        pred = model(wav)              # (B,)
        loss = criterion(pred, label)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * wav.size(0)
    return total_loss / len(loader.dataset)


@torch.no_grad()
def evaluate(model, loader, criterion, preemph=None):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    for wav, label in loader:
        wav = wav.to(DEVICE)
        label = label.to(DEVICE)
        if preemph is not None:
            wav = preemph(wav)
        pred = model(wav)
        loss = criterion(pred, label)
        total_loss += loss.item() * wav.size(0)
        total_correct += ((pred > 0.5).float() == label).sum().item()
    avg_loss = total_loss / len(loader.dataset)
    acc = total_correct / len(loader.dataset)
    return avg_loss, acc


In [13]:
# Build model and run a forward pass with dummy audio
exp_cfg = ExpConfig()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = Rawformer_S(device=DEVICE, transformer_hidden=exp_cfg.transformer_hidden,
                          sample_rate=exp_cfg.sample_rate)

B = 2
dummy_audio = torch.randn(B, exp_cfg.sample_rate * exp_cfg.train_duration_sec).to(DEVICE)
with torch.no_grad():
    out = model(dummy_audio)
print("Model output shape:", out.shape, "| values ~", (out.min().item(), out.max().item()))

if HAS_TORCHINFO:
    try:
        summary(model, input_size=(B, exp_cfg.sample_rate * exp_cfg.train_duration_sec))
    except Exception as e:
        print("torchinfo summary error (safe to ignore):", e)


Model output shape: torch.Size([2]) | values ~ (0.5365177392959595, 0.5377536416053772)


In [14]:
# torchaudio.set_audio_backend("ffmpeg")

In [15]:
# pip install soundfile

In [16]:
# ============================================================
# SIMPLE DATASET
# ============================================================

import soundfile as sf

class ASVspoofFolderDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, sample_rate=16000, duration_sec=4):
        self.root_dir = root_dir
        self.sample_rate = sample_rate
        self.duration_sec = duration_sec
        self.audio_paths = []
        self.labels = []

        for label_name, label_value in [("bonafide", 1), ("spoof", 0)]:
            class_dir = os.path.join(root_dir, label_name)
            if os.path.exists(class_dir):
                for file in os.listdir(class_dir):
                    if file.endswith(".flac") or file.endswith(".wav"):
                        self.audio_paths.append(os.path.join(class_dir, file))
                        self.labels.append(label_value)

        print(f"üìÅ Loaded {len(self.audio_paths)} files from {root_dir}")

    def __len__(self):
        return len(self.audio_paths)

    def __getitem__(self, idx):
        path = self.audio_paths[idx]
        label = torch.tensor(self.labels[idx], dtype=torch.float32)

        # --- Use soundfile for FLAC ---
        if path.lower().endswith(".flac"):
            wav_np, sr = sf.read(path)         # numpy array (T,) or (T, C)
            if wav_np.ndim > 1:                # convert stereo ‚Üí mono
                wav_np = wav_np.mean(axis=1)
            wav = torch.tensor(wav_np, dtype=torch.float32).unsqueeze(0)  # [1, T]

        # --- Use torchaudio for WAV ---
        else:
            wav, sr = torchaudio.load(path)

        # --- Resample if needed ---
        if sr != self.sample_rate:
            wav = torchaudio.functional.resample(wav, sr, self.sample_rate)

        # --- Crop/pad ---
        num_samples = int(self.sample_rate * self.duration_sec)

        if wav.size(1) > num_samples:
            start = random.randint(0, wav.size(1) - num_samples)
            wav = wav[:, start:start + num_samples]
        elif wav.size(1) < num_samples:
            wav = F.pad(wav, (0, num_samples - wav.size(1)))

        return wav.squeeze(0), label

# ============================================================
# EER FUNCTION (FOR CM SYSTEM)
# ============================================================

def calculate_EER(labels, scores):
    """Equal Error Rate for Countermeasure system (bonafide=1, spoof=0)."""
    fpr, tpr, _ = metrics.roc_curve(labels, scores, pos_label=1)
    eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer


# ============================================================
# t-DCF FUNCTION (CM-only version using reference ASV parameters)
# ============================================================

def compute_tDCF(bonafide_score_cm, spoof_score_cm, Pfa_asv, Pmiss_asv, Pfa_spoof_asv, cost_model):
    # 1. Compute CM miss/false-alarm rates for thresholds
    cm_scores = np.concatenate([bonafide_score_cm, spoof_score_cm])
    labels = np.concatenate([np.ones_like(bonafide_score_cm), np.zeros_like(spoof_score_cm)])
    sorted_idx = np.argsort(cm_scores)[::-1]
    sorted_labels = labels[sorted_idx]

    tar = np.sum(sorted_labels)
    non = len(sorted_labels) - tar

    cm_miss = np.cumsum(sorted_labels == 1) / tar
    cm_fa = np.cumsum(sorted_labels == 0) / non

    # 2. Compute t-DCF per threshold
    Cmiss, Cfa, Cfa_spoof = cost_model['Cmiss'], cost_model['Cfa'], cost_model['Cfa_spoof']
    Ptar, Pnon, Pspoof = cost_model['Ptar'], cost_model['Pnon'], cost_model['Pspoof']

    tDCF = (Cmiss * Ptar * Pmiss_asv * (1 - cm_miss) +
            Cfa * Pnon * Pfa_asv * cm_fa +
            Cfa_spoof * Pspoof * Pfa_spoof_asv * (1 - cm_miss)) / (
            Cmiss * Ptar * Pmiss_asv + Cfa * Pnon * Pfa_asv)

    tDCF_norm = tDCF / np.min(tDCF)
    thresholds = cm_scores[sorted_idx]

    return tDCF_norm, thresholds


# ============================================================
# TRAIN + VALIDATE + TEST LOOP
# ============================================================

sys_cfg = SysConfig()
exp_cfg = ExpConfig()

train_ds = ASVspoofFolderDataset(sys_cfg.path_train, exp_cfg.sample_rate, exp_cfg.train_duration_sec)
val_ds   = ASVspoofFolderDataset(sys_cfg.path_dev, exp_cfg.sample_rate, exp_cfg.test_duration_sec)
test_ds  = ASVspoofFolderDataset(sys_cfg.path_test, exp_cfg.sample_rate, exp_cfg.test_duration_sec)

train_loader = DataLoader(train_ds, batch_size=exp_cfg.batch_size, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_ds, batch_size=exp_cfg.batch_size, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_ds, batch_size=exp_cfg.batch_size, shuffle=False, num_workers=0)

# --- your modules ---
pre = PreEmphasis(exp_cfg.pre_emphasis).to(DEVICE)
model = Rawformer_S(device=DEVICE, transformer_hidden=exp_cfg.transformer_hidden, sample_rate=exp_cfg.sample_rate).to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=exp_cfg.lr)
criterion = nn.BCELoss()

best_val_eer = 1.0  # initialize high value

print("üöÄ Starting training...\n")

# Reference ASV parameters (official ASVspoof setup)
Pfa_asv = 0.0005
Pmiss_asv = 0.05
Pmiss_spoof_asv = 0.95
Pfa_spoof_asv = 1.0 - Pmiss_spoof_asv
cost_model = {
    'Ptar': 0.9801,
    'Pnon': 0.0099,
    'Pspoof': 0.0100,
    'Cmiss': 1,
    'Cfa': 10,
    'Cfa_spoof': 10
}

for epoch in range(1, exp_cfg.epochs + 1):
    # === TRAIN ===
    model.train()
    total_loss, total_samples = 0.0, 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{exp_cfg.epochs} [Train]", leave=True)

    for wav, label in pbar:
        wav, label = wav.to(DEVICE), label.to(DEVICE)
        wav = pre(wav)

        opt.zero_grad()
        pred = model(wav).squeeze(-1)
        loss = criterion(pred, label)
        loss.backward()
        opt.step()

        bs = wav.size(0)
        total_loss += loss.item() * bs
        total_samples += bs
        pbar.set_postfix(loss=f"{total_loss / total_samples:.4f}")

    avg_train_loss = total_loss / total_samples

    # === VALIDATE ===
    model.eval()
    val_loss, val_samples = 0.0, 0
    all_scores, all_labels = [], []

    with torch.no_grad():
        pbar = tqdm(val_loader, desc=f"Epoch {epoch}/{exp_cfg.epochs} [Val]", leave=True)
        for wav, label in pbar:
            wav, label = wav.to(DEVICE), label.to(DEVICE)
            wav = pre(wav)
            pred = model(wav).squeeze(-1)
            loss = criterion(pred, label)

            bs = wav.size(0)
            val_loss += loss.item() * bs
            val_samples += bs

            all_scores.extend(pred.cpu().numpy())
            all_labels.extend(label.cpu().numpy())

    avg_val_loss = val_loss / val_samples
    eer = calculate_EER(all_labels, all_scores)

    # --- Compute t-DCF ---
    bona_cm = np.array(all_scores)[np.array(all_labels) == 1]
    spoof_cm = np.array(all_scores)[np.array(all_labels) == 0]
    tDCF_curve, thr = compute_tDCF(bona_cm, spoof_cm, Pfa_asv, Pmiss_asv, Pfa_spoof_asv, cost_model)
    min_tDCF = np.min(tDCF_curve)

    print(f"üßæ Epoch {epoch} Summary:")
    print(f"   Train Loss: {avg_train_loss:.4f}")
    print(f"   Val Loss:   {avg_val_loss:.4f}")
    print(f"   Val EER:    {eer * 100:.2f}%")
    print(f"   min-tDCF:   {min_tDCF:.4f}")

    # === SAVE BEST MODEL ===
    if eer < best_val_eer:
        best_val_eer = eer
        torch.save(model, save_path)
        print(f"üíæ Saved new best model (EER={eer*100:.2f}%) to {save_path}")

    print("-" * 60)


# ============================================================
# FINAL TEST PHASE
# ============================================================

print("\n" + "=" * 60)
print("üèÅ Starting final testing...")
print("=" * 60)

# Load best model
best_model = torch.load(save_path, map_location=DEVICE)
best_model.eval()

test_scores, test_labels = [], []

with torch.no_grad():
    pbar = tqdm(test_loader, desc="Testing", leave=True)
    for wav, label in pbar:
        wav, label = wav.to(DEVICE), label.to(DEVICE)
        wav = pre(wav)
        pred = best_model(wav).squeeze(-1)
        test_scores.extend(pred.cpu().numpy())
        test_labels.extend(label.cpu().numpy())

test_eer = calculate_EER(test_labels, test_scores)

bona_cm = np.array(test_scores)[np.array(test_labels) == 1]
spoof_cm = np.array(test_scores)[np.array(test_labels) == 0]
tDCF_curve, thr = compute_tDCF(bona_cm, spoof_cm, Pfa_asv, Pmiss_asv, Pfa_spoof_asv, cost_model)
min_tDCF = np.min(tDCF_curve)

print(f"üéØ Final Test EER:  {test_eer * 100:.2f}%")
print(f"üìä Final min-tDCF: {min_tDCF:.4f}")

üìÅ Loaded 25380 files from G:\INTERSPEECH_26\LA\ASV19\train
üìÅ Loaded 24844 files from G:\INTERSPEECH_26\LA\ASV19\dev
üìÅ Loaded 24844 files from G:\INTERSPEECH_26\LA\ASV19\dev
üöÄ Starting training...



Epoch 1/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 1/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 1 Summary:
   Train Loss: 0.2490
   Val Loss:   0.3054
   Val EER:    17.46%
   min-tDCF:   1.0000
üíæ Saved new best model (EER=17.46%) to C:\Users\Admin\Desktop\Test Folder Arth Shah\LearnableGammatone_best_cm_model.pth
------------------------------------------------------------


Epoch 2/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 2/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 2 Summary:
   Train Loss: 0.0938
   Val Loss:   0.0867
   Val EER:    5.22%
   min-tDCF:   1.0000
üíæ Saved new best model (EER=5.22%) to C:\Users\Admin\Desktop\Test Folder Arth Shah\LearnableGammatone_best_cm_model.pth
------------------------------------------------------------


Epoch 3/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 3/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 3 Summary:
   Train Loss: 0.0630
   Val Loss:   0.1821
   Val EER:    5.69%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 4/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 4/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 4 Summary:
   Train Loss: 0.0541
   Val Loss:   0.3511
   Val EER:    4.75%
   min-tDCF:   1.0000
üíæ Saved new best model (EER=4.75%) to C:\Users\Admin\Desktop\Test Folder Arth Shah\LearnableGammatone_best_cm_model.pth
------------------------------------------------------------


Epoch 5/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 5/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 5 Summary:
   Train Loss: 0.0461
   Val Loss:   0.2787
   Val EER:    8.66%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 6/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 6/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 6 Summary:
   Train Loss: 0.0460
   Val Loss:   0.0957
   Val EER:    3.26%
   min-tDCF:   1.0000
üíæ Saved new best model (EER=3.26%) to C:\Users\Admin\Desktop\Test Folder Arth Shah\LearnableGammatone_best_cm_model.pth
------------------------------------------------------------


Epoch 7/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 7/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 7 Summary:
   Train Loss: 0.0412
   Val Loss:   0.6872
   Val EER:    16.37%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 8/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 8/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 8 Summary:
   Train Loss: 0.0432
   Val Loss:   0.0941
   Val EER:    3.06%
   min-tDCF:   1.0000
üíæ Saved new best model (EER=3.06%) to C:\Users\Admin\Desktop\Test Folder Arth Shah\LearnableGammatone_best_cm_model.pth
------------------------------------------------------------


Epoch 9/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 9/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 9 Summary:
   Train Loss: 0.0423
   Val Loss:   0.0358
   Val EER:    2.33%
   min-tDCF:   1.0000
üíæ Saved new best model (EER=2.33%) to C:\Users\Admin\Desktop\Test Folder Arth Shah\LearnableGammatone_best_cm_model.pth
------------------------------------------------------------


Epoch 10/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 10/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 10 Summary:
   Train Loss: 0.0316
   Val Loss:   0.0690
   Val EER:    2.83%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 11/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 11/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 11 Summary:
   Train Loss: 0.0336
   Val Loss:   0.1027
   Val EER:    2.67%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 12/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 12/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 12 Summary:
   Train Loss: 0.0267
   Val Loss:   0.0523
   Val EER:    2.86%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 13/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 13/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 13 Summary:
   Train Loss: 0.0427
   Val Loss:   0.0448
   Val EER:    1.73%
   min-tDCF:   1.0000
üíæ Saved new best model (EER=1.73%) to C:\Users\Admin\Desktop\Test Folder Arth Shah\LearnableGammatone_best_cm_model.pth
------------------------------------------------------------


Epoch 14/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 14/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 14 Summary:
   Train Loss: 0.0308
   Val Loss:   0.0405
   Val EER:    1.88%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 15/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 15/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 15 Summary:
   Train Loss: 0.0341
   Val Loss:   0.0324
   Val EER:    2.27%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 16/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 16/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 16 Summary:
   Train Loss: 0.0329
   Val Loss:   0.0376
   Val EER:    2.12%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 17/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 17/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 17 Summary:
   Train Loss: 0.0301
   Val Loss:   0.0435
   Val EER:    2.39%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 18/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 18/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 18 Summary:
   Train Loss: 0.0461
   Val Loss:   0.0426
   Val EER:    2.16%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 19/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 19/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 19 Summary:
   Train Loss: 0.0313
   Val Loss:   0.0775
   Val EER:    3.31%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 20/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 20/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 20 Summary:
   Train Loss: 0.0493
   Val Loss:   0.0471
   Val EER:    2.59%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 21/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 21/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 21 Summary:
   Train Loss: 0.0326
   Val Loss:   0.0411
   Val EER:    2.83%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 22/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 22/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 22 Summary:
   Train Loss: 0.0285
   Val Loss:   0.0494
   Val EER:    2.47%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 23/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 23/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 23 Summary:
   Train Loss: 0.0312
   Val Loss:   0.0303
   Val EER:    2.24%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 24/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 24/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 24 Summary:
   Train Loss: 0.0334
   Val Loss:   0.0334
   Val EER:    1.88%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 25/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 25/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 25 Summary:
   Train Loss: 0.0297
   Val Loss:   0.0402
   Val EER:    2.51%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 26/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 26/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 26 Summary:
   Train Loss: 0.0302
   Val Loss:   0.0337
   Val EER:    2.12%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 27/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 27/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 27 Summary:
   Train Loss: 0.0332
   Val Loss:   0.1361
   Val EER:    2.75%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 28/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 28/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 28 Summary:
   Train Loss: 0.0250
   Val Loss:   0.0272
   Val EER:    1.85%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 29/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 29/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 29 Summary:
   Train Loss: 0.0367
   Val Loss:   0.0352
   Val EER:    2.12%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 30/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 30/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 30 Summary:
   Train Loss: 0.0254
   Val Loss:   0.0261
   Val EER:    1.83%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 31/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 31/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 31 Summary:
   Train Loss: 0.0240
   Val Loss:   0.0464
   Val EER:    2.79%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 32/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 32/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 32 Summary:
   Train Loss: 0.0310
   Val Loss:   0.1346
   Val EER:    2.39%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 33/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 33/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 33 Summary:
   Train Loss: 0.0355
   Val Loss:   0.0317
   Val EER:    2.28%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 34/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 34/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 34 Summary:
   Train Loss: 0.0233
   Val Loss:   0.0299
   Val EER:    1.79%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 35/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 35/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 35 Summary:
   Train Loss: 0.0218
   Val Loss:   0.0335
   Val EER:    2.16%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 36/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 36/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 36 Summary:
   Train Loss: 0.0298
   Val Loss:   0.1101
   Val EER:    2.79%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 37/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 37/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 37 Summary:
   Train Loss: 0.0557
   Val Loss:   0.0572
   Val EER:    2.91%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 38/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 38/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 38 Summary:
   Train Loss: 0.0327
   Val Loss:   0.0354
   Val EER:    2.16%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 39/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 39/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 39 Summary:
   Train Loss: 0.0222
   Val Loss:   0.0367
   Val EER:    1.73%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 40/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 40/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 40 Summary:
   Train Loss: 0.0274
   Val Loss:   0.0332
   Val EER:    1.83%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 41/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 41/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 41 Summary:
   Train Loss: 0.0331
   Val Loss:   0.0385
   Val EER:    2.22%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 42/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 42/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 42 Summary:
   Train Loss: 0.0265
   Val Loss:   0.0464
   Val EER:    3.41%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 43/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 43/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 43 Summary:
   Train Loss: 0.0316
   Val Loss:   0.0292
   Val EER:    1.93%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 44/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 44/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 44 Summary:
   Train Loss: 0.0308
   Val Loss:   0.0480
   Val EER:    3.22%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 45/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 45/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 45 Summary:
   Train Loss: 0.0287
   Val Loss:   0.0362
   Val EER:    2.12%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 46/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 46/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 46 Summary:
   Train Loss: 0.0321
   Val Loss:   0.0724
   Val EER:    4.11%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 47/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 47/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 47 Summary:
   Train Loss: 0.0295
   Val Loss:   0.0273
   Val EER:    1.73%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 48/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 48/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 48 Summary:
   Train Loss: 0.0194
   Val Loss:   0.0321
   Val EER:    2.00%
   min-tDCF:   1.0000
------------------------------------------------------------


Epoch 49/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 49/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 49 Summary:
   Train Loss: 0.0229
   Val Loss:   0.0361
   Val EER:    1.60%
   min-tDCF:   1.0000
üíæ Saved new best model (EER=1.60%) to C:\Users\Admin\Desktop\Test Folder Arth Shah\LearnableGammatone_best_cm_model.pth
------------------------------------------------------------


Epoch 50/50 [Train]:   0%|          | 0/794 [00:00<?, ?it/s]

Epoch 50/50 [Val]:   0%|          | 0/777 [00:00<?, ?it/s]

üßæ Epoch 50 Summary:
   Train Loss: 0.0239
   Val Loss:   0.0303
   Val EER:    1.81%
   min-tDCF:   1.0000
------------------------------------------------------------

üèÅ Starting final testing...


  best_model = torch.load(save_path, map_location=DEVICE)


Testing:   0%|          | 0/777 [00:00<?, ?it/s]

üéØ Final Test EER:  1.52%
üìä Final min-tDCF: 1.0000


In [17]:
# --------------------------------------------------------------
#  Model size & FLOPs (place this right after model creation)
# --------------------------------------------------------------
import torch
from torchinfo import summary
from fvcore.nn import FlopCountAnalysis, parameter_count

# --------------------------------------------------------------
# 1. Parameter count (trainable + non-trainable) + size in MiB
# --------------------------------------------------------------
def print_model_params(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total     = sum(p.numel() for p in model.parameters())
    size_mb   = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**2)

    print("\n" + "="*60)
    print("MODEL PARAMETER SUMMARY")
    print("="*60)
    print(f"{'Trainable params':<25}: {trainable:,}")
    print(f"{'Total params'    :<25}: {total:,}")
    print(f"{'Model size (MiB)':<25}: {size_mb:.2f}")
    print("="*60 + "\n")

print_model_params(model)

# --------------------------------------------------------------
# 2. FLOPs / MACs
# --------------------------------------------------------------
# We need a dummy waveform that matches the shape expected by the model.
#   - Rawformer_S expects raw audio: (batch, time)
#   - Use the maximum length defined in the config (or a typical 4-second clip)
max_len_sec = getattr(exp_cfg, "max_len_sec", 4.0)          # fallback 4 s
max_samples = int(exp_cfg.sample_rate * max_len_sec)

dummy_wav = torch.randn(1, max_samples, device=DEVICE)     # (B, T)

# Apply pre-emphasis if it is used in training/validation
if pre is not None:
    dummy_wav = pre(dummy_wav)

# ---- fvcore (very accurate) ----
flops = FlopCountAnalysis(model, dummy_wav)
macs  = flops.total()                # MACs = multiply-adds
flops_2 = macs * 2                   # FLOPs = 2 √ó MACs (standard convention)

print("\n" + "="*60)
print("FLOPs / MACs (per forward pass)")
print("="*60)
print(f"{'Input shape'   :<25}: {list(dummy_wav.shape)}")
print(f"{'MACs'          :<25}: {macs/1e9:.3f} G")
print(f"{'FLOPs'         :<25}: {flops_2/1e9:.3f} G")
print("="*60 + "\n")

# ---- torchinfo (nice table, optional) ----
print("Detailed layer-wise breakdown (torchinfo):")
summary(model,
        input_data=dummy_wav,
        col_names=["input_size", "output_size", "num_params", "mult_adds"],
        depth=4,
        verbose=0)


MODEL PARAMETER SUMMARY
Trainable params         : 345,342
Total params             : 345,342
Model size (MiB)         : 1.32



Unsupported operator aten::exp encountered 73 time(s)
Unsupported operator aten::sub encountered 6 time(s)
Unsupported operator aten::abs encountered 211 time(s)
Unsupported operator aten::div encountered 286 time(s)
Unsupported operator aten::pow encountered 70 time(s)
Unsupported operator aten::mul encountered 497 time(s)
Unsupported operator aten::cos encountered 70 time(s)
Unsupported operator aten::sum encountered 70 time(s)
Unsupported operator aten::add encountered 87 time(s)
Unsupported operator aten::max_pool2d encountered 5 time(s)
Unsupported operator aten::selu_ encountered 8 time(s)
Unsupported operator aten::softmax encountered 3 time(s)
Unsupported operator aten::mean encountered 4 time(s)
Unsupported operator aten::var encountered 4 time(s)
Unsupported operator aten::sqrt encountered 4 time(s)
Unsupported operator aten::sigmoid encountered 1 time(s)



FLOPs / MACs (per forward pass)
Input shape              : [1, 64000]
MACs                     : 6.197 G
FLOPs                    : 12.394 G

Detailed layer-wise breakdown (torchinfo):


Layer (type:depth-idx)                                       Input Shape               Output Shape              Param #                   Mult-Adds
Rawformer_S                                                  [1, 64000]                [1]                       --                        --
‚îú‚îÄFrontend_S: 1-1                                            [1, 64000]                [1, 64, 23, 16]           --                        --
‚îÇ    ‚îî‚îÄLearnableGammatone: 2-1                               [1, 1, 64000]             [1, 70, 64000]            210                       --
‚îÇ    ‚îî‚îÄBatchNorm2d: 2-2                                      [1, 1, 23, 21333]         [1, 1, 23, 21333]         2                         2
‚îÇ    ‚îî‚îÄSELU: 2-3                                             [1, 1, 23, 21333]         [1, 1, 23, 21333]         --                        --
‚îÇ    ‚îî‚îÄSequential: 2-4                                       [1, 1, 23, 21333]         [1, 64, 23, 16]           -

In [18]:
seconds = max_samples / exp_cfg.sample_rate
print(f"GFLOPs per second of audio : {flops_2/1e9/seconds:.3f}")

GFLOPs per second of audio : 3.099
