# ðŸŽ¯ AASIST Model Evaluation on ASVspoof 2019 LA

Evaluates the AASIST-inspired anti-spoofing model trained on MLAAD.

**Run Cell 1, then RESTART RUNTIME, then continue from Cell 2.**

In [None]:
# CELL 1: Install dependencies
!pip uninstall datasets -y -q 2>/dev/null
!pip install datasets==2.14.7 soundfile librosa -q
print("âœ… Done! RESTART RUNTIME: Runtime â†’ Restart runtime")

In [None]:
# CELL 2: Upload model weights
from google.colab import files
print("ðŸ“¤ Upload aasist_best.pth:")
uploaded = files.upload()
MODEL_PATH = list(uploaded.keys())[0]
print(f"âœ… Using: {MODEL_PATH}")

In [None]:
# CELL 3: Imports and config
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio.transforms as T
import numpy as np
from sklearn.metrics import roc_curve
from datasets import load_dataset
from tqdm.auto import tqdm
import math
import gc
import warnings
warnings.filterwarnings('ignore')

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAMPLE_RATE = 16000
MAX_LEN = SAMPLE_RATE * 4
SINC_CHANNELS = 70
ENCODER_DIM = 128

print(f"Device: {DEVICE}")

In [None]:
# CELL 4: AASIST Model Definition

class SincConv(nn.Module):
    def __init__(self, out_channels=70, kernel_size=251, sample_rate=16000, 
                 min_low_hz=50, min_band_hz=50):
        super().__init__()
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.sample_rate = sample_rate
        self.min_low_hz = min_low_hz
        self.min_band_hz = min_band_hz
        
        low_hz = 30
        high_hz = sample_rate / 2 - (min_low_hz + min_band_hz)
        mel_low = 2595 * np.log10(1 + low_hz / 700)
        mel_high = 2595 * np.log10(1 + high_hz / 700)
        mel_points = np.linspace(mel_low, mel_high, out_channels + 1)
        hz_points = 700 * (10 ** (mel_points / 2595) - 1)
        
        self.low_hz_ = nn.Parameter(torch.Tensor(hz_points[:-1]).view(-1, 1))
        self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz_points)).view(-1, 1))
        
        n_lin = torch.linspace(0, (kernel_size / 2) - 1, steps=kernel_size // 2)
        self.window_ = 0.54 - 0.46 * torch.cos(2 * math.pi * n_lin / kernel_size)
        
        n = (kernel_size - 1) / 2.0
        self.n_ = 2 * math.pi * torch.arange(-n, 0).view(1, -1) / sample_rate
        
    def forward(self, x):
        self.n_ = self.n_.to(x.device)
        self.window_ = self.window_.to(x.device)
        
        low = self.min_low_hz + torch.abs(self.low_hz_)
        high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_), 
                          self.min_low_hz, self.sample_rate / 2)
        band = (high - low)[:, 0]
        
        f_low = low / self.sample_rate
        f_high = high / self.sample_rate
        
        band_pass_left = ((torch.sin(f_high * self.n_) - torch.sin(f_low * self.n_)) / 
                         (self.n_ / 2)) * self.window_
        band_pass_center = 2 * band.view(-1, 1)
        band_pass_right = torch.flip(band_pass_left, dims=[1])
        
        band_pass = torch.cat([band_pass_left, band_pass_center, band_pass_right], dim=1)
        band_pass = band_pass / (2 * band[:, None])
        filters = band_pass.view(self.out_channels, 1, self.kernel_size)
        
        return F.conv1d(x, filters, stride=1, padding=self.kernel_size // 2)

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)
    def forward(self, x):
        y = x.mean(dim=2) if x.dim() == 3 else x.mean(dim=[2, 3])
        y = F.relu(self.fc1(y))
        y = torch.sigmoid(self.fc2(y))
        return x * y.unsqueeze(2) if x.dim() == 3 else x * y.unsqueeze(2).unsqueeze(3)

class Res2NetBlock(nn.Module):
    def __init__(self, in_ch, out_ch, scale=4, stride=1):
        super().__init__()
        width = out_ch // scale
        self.scale = scale
        self.width = width
        self.conv1 = nn.Conv1d(in_ch, width * scale, 1, bias=False)
        self.bn1 = nn.BatchNorm1d(width * scale)
        self.convs = nn.ModuleList([nn.Conv1d(width, width, 3, stride=stride, padding=1, bias=False) for _ in range(scale - 1)])
        self.bns = nn.ModuleList([nn.BatchNorm1d(width) for _ in range(scale - 1)])
        self.conv3 = nn.Conv1d(width * scale, out_ch, 1, bias=False)
        self.bn3 = nn.BatchNorm1d(out_ch)
        self.se = SEBlock(out_ch)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Sequential(nn.Conv1d(in_ch, out_ch, 1, stride=stride, bias=False), nn.BatchNorm1d(out_ch))
    def forward(self, x):
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        spx = torch.split(out, self.width, dim=1)
        sp = []
        for i in range(self.scale):
            if i == 0: sp.append(spx[i])
            elif i == 1: sp.append(F.relu(self.bns[i-1](self.convs[i-1](spx[i]))))
            else: sp.append(F.relu(self.bns[i-1](self.convs[i-1](spx[i] + sp[-1]))))
        out = torch.cat(sp, dim=1)
        out = self.bn3(self.conv3(out))
        out = self.se(out)
        return F.relu(out + self.shortcut(identity))

class GraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = out_dim // num_heads
        self.q = nn.Linear(in_dim, out_dim)
        self.k = nn.Linear(in_dim, out_dim)
        self.v = nn.Linear(in_dim, out_dim)
        self.out = nn.Linear(out_dim, out_dim)
    def forward(self, x):
        b, n, _ = x.size()
        q = self.q(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
        attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(b, n, -1)
        return self.out(out)

class AASISTEncoder(nn.Module):
    def __init__(self, sinc_channels=70, encoder_dim=128):
        super().__init__()
        self.sinc = SincConv(out_channels=sinc_channels, kernel_size=251)
        self.sinc_bn = nn.BatchNorm1d(sinc_channels)
        self.sinc_pool = nn.MaxPool1d(3)
        self.res2net = nn.Sequential(
            Res2NetBlock(sinc_channels, 64, scale=4), nn.MaxPool1d(3),
            Res2NetBlock(64, 128, scale=4), nn.MaxPool1d(3),
            Res2NetBlock(128, 256, scale=4), nn.MaxPool1d(3),
            Res2NetBlock(256, encoder_dim, scale=4))
        self.gat = GraphAttentionLayer(encoder_dim, encoder_dim, num_heads=4)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(encoder_dim, 2)
    def forward(self, x):
        x = self.sinc(x)
        x = F.relu(self.sinc_bn(x))
        x = self.sinc_pool(x)
        x = self.res2net(x)
        x = x.transpose(1, 2)
        x = x + self.gat(x)
        x = x.transpose(1, 2)
        x = self.pool(x).squeeze(-1)
        return x, self.fc(x)

print("âœ… Model defined")

In [None]:
# CELL 5: Load model
model = AASISTEncoder(sinc_channels=SINC_CHANNELS, encoder_dim=ENCODER_DIM).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True))
model.eval()
print("âœ… Model loaded")

In [None]:
# CELL 6: Load ASVspoof dataset
print("ðŸ“¥ Loading ASVspoof 2019 LA eval...")
hf_ds = load_dataset("Bisher/ASVspoof_2019_LA", split="eval", streaming=False)
print(f"âœ… {len(hf_ds)} samples")

In [None]:
# CELL 7: Dataset wrapper for raw waveform
class HFDataset(Dataset):
    def __init__(self, hf_ds):
        self.ds = hf_ds
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        item = self.ds[idx]
        audio = torch.FloatTensor(item['audio']['array'])
        sr = item['audio']['sampling_rate']
        if sr != SAMPLE_RATE:
            audio = T.Resample(sr, SAMPLE_RATE)(audio)
        if audio.shape[0] > MAX_LEN:
            audio = audio[:MAX_LEN]
        else:
            audio = torch.cat([audio, torch.zeros(MAX_LEN - audio.shape[0])])
        label = int(item['key'])  # 0=bonafide, 1=spoof
        return audio.unsqueeze(0), label  # (1, T)

dataset = HFDataset(hf_ds)
test_audio, test_label = dataset[0]
print(f"âœ… Dataset works! Shape: {test_audio.shape}, Label: {test_label}")

In [None]:
# CELL 8: Run evaluation
def calc_eer(labels, scores):
    fpr, tpr, _ = roc_curve(labels, scores, pos_label=1)
    fnr = 1 - tpr
    idx = np.nanargmin(np.abs(fpr - fnr))
    return (fpr[idx] + fnr[idx]) / 2

print(f"ðŸš€ Evaluating {len(dataset)} samples...")

all_labels = []
all_scores = []
BATCH = 32

with torch.no_grad():
    for start in tqdm(range(0, len(dataset), BATCH)):
        batch_audio = []
        batch_labels = []
        for i in range(start, min(start + BATCH, len(dataset))):
            audio, label = dataset[i]
            batch_audio.append(audio)
            batch_labels.append(label)
        
        batch_tensor = torch.stack(batch_audio).to(DEVICE)
        _, logits = model(batch_tensor)
        probs = F.softmax(logits, dim=1)[:, 1].cpu().numpy()
        
        all_labels.extend(batch_labels)
        all_scores.extend(probs)
        
        if start % 5000 == 0:
            gc.collect()
            torch.cuda.empty_cache()

print(f"\nâœ… Processed {len(all_scores)} samples")
print(f"Bonafide: {all_labels.count(0)}, Spoof: {all_labels.count(1)}")

In [None]:
# CELL 9: Calculate EER
eer = calc_eer(all_labels, all_scores)

print("\n" + "="*50)
print(f"ðŸŽ¯ ASVspoof 2019 LA Eval EER: {eer * 100:.2f}%")
print("="*50)