In [1]:
import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import librosa
from tqdm import tqdm
from sklearn.metrics import roc_curve


In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

SAMPLE_RATE = 16000
SEGMENT_LEN = 32000
FRAME_SIZE = 160

BATCH_SIZE = 1
LR = 1e-4
EPOCHS = 30

os.makedirs("checkpoints", exist_ok=True)
DEVICE

Using device: cuda


'cuda'

In [3]:
def circular_shift(arr, shift):
    return np.roll(arr, shift)

def least_number(a, b):
    g = math.gcd(a, b)
    return a // g, b // g
def generate_matrix(N, M, P, Q):
    row = np.concatenate([np.ones(M), np.zeros(N - M)])
    rows = []
    for _ in range(M):
        rows.append(row)
        row = circular_shift(row, M)
    return np.array(rows)
def apply_cs(audio, sensing_matrix, frame_size):
    n_frames = len(audio) // frame_size
    out = []
    for i in range(n_frames):
        frame = audio[i*frame_size:(i+1)*frame_size]
        out.append(sensing_matrix @ frame)
    return np.concatenate(out)


In [4]:
class ASVspoofCSDataset(Dataset):
    def __init__(self, audio_dir, protocol_path, apply_cs=True):
        self.audio_dir = audio_dir
        self.apply_cs = apply_cs

        M = N = FRAME_SIZE
        P, Q = least_number(M, N)
        self.sensing_matrix = generate_matrix(N, M, P, Q)

        available_files = set(
            os.path.splitext(f)[0]
            for f in os.listdir(audio_dir)
            if f.endswith(".flac")
        )

        self.items = []
        skipped = 0

        with open(protocol_path) as f:
            for line in f:
                parts = line.strip().split()
                utt = parts[1]                         # CRITICAL
                label = 1 if parts[-1] == "spoof" else 0

                if utt not in available_files:
                    skipped += 1
                    continue

                self.items.append((utt, label))

        print(f"[Dataset] Loaded {len(self.items)} | Skipped {skipped}")

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        utt, label = self.items[idx]
        audio_path = os.path.join(self.audio_dir, utt + ".flac")

        audio, _ = librosa.load(audio_path, sr=SAMPLE_RATE)

        if len(audio) < SEGMENT_LEN:
            audio = np.pad(audio, (0, SEGMENT_LEN - len(audio)))
        else:
            audio = audio[:SEGMENT_LEN]

        if self.apply_cs:
            audio = apply_cs(audio, self.sensing_matrix, FRAME_SIZE)

        return torch.tensor(audio, dtype=torch.float32), torch.tensor(label)


In [5]:
class GraphAttention(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.fc = nn.Linear(in_dim, out_dim)
        self.att = nn.Linear(2 * out_dim, 1)

    def forward(self, x):
        # x: (B, T, C)
        h = self.fc(x)
        B, T, C = h.shape

        a_i = h.unsqueeze(2).expand(B, T, T, C)
        a_j = h.unsqueeze(1).expand(B, T, T, C)
        e = F.leaky_relu(self.att(torch.cat([a_i, a_j], dim=-1))).squeeze(-1)

        alpha = F.softmax(e, dim=-1)
        return torch.matmul(alpha, h)


In [6]:
class STBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv1d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm1d(out_ch),
            nn.LeakyReLU(0.3),

            # ðŸ”¥ extra downsampling BEFORE attention
            nn.MaxPool1d(4)
        )

        self.attn = TemporalSelfAttention(out_ch, heads=2)

    def forward(self, x):
        x = self.conv(x)          # (B, C, Tâ†“)
        x = x.permute(0, 2, 1)    # (B, Tâ†“, C)
        x = self.attn(x)
        return x.permute(0, 2, 1)


class TemporalSelfAttention(nn.Module):
    def __init__(self, dim, heads=4):
        super().__init__()
        self.attn = nn.MultiheadAttention(
            embed_dim=dim,
            num_heads=heads,
            batch_first=True
        )

    def forward(self, x):
        # x: (B, T, C)
        out, _ = self.attn(x, x, x)
        return out


In [7]:
class AASIST(nn.Module):
    def __init__(self):
        super().__init__()

        self.frontend = nn.Sequential(
            nn.Conv1d(1, 64, 251, padding=125),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.3),
            nn.MaxPool1d(3)
        )

        self.block1 = STBlock(64, 128)
        self.block2 = STBlock(128, 256)
        self.block3 = STBlock(256, 512)

        self.gru = nn.GRU(512, 512, batch_first=True)
        self.fc1 = nn.Linear(512, 512)
        self.fc2 = nn.Linear(512, 2)

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.frontend(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        x = x.permute(0, 2, 1)
        x, _ = self.gru(x)

        x = F.normalize(self.fc1(x[:, -1]), dim=1)
        return self.fc2(x)


In [8]:
def compute_eer(scores, labels):
    fpr, tpr, _ = roc_curve(labels, scores, pos_label=1)
    fnr = 1 - tpr
    return 100 * fpr[np.nanargmin(np.abs(fnr - fpr))]
def compute_min_tDCF(scores, labels):
    Ptar, Cmiss, Cfa = 0.01, 1, 1
    fpr, tpr, _ = roc_curve(labels, scores, pos_label=1)
    fnr = 1 - tpr
    return np.min(Cmiss * Ptar * fnr + Cfa * (1 - Ptar) * fpr)


In [9]:
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for x, y in tqdm(loader):
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    scores, labels = [], []
    for x, y in loader:
        probs = torch.softmax(model(x.to(DEVICE)), dim=1)
        scores.extend(probs[:, 1].cpu().numpy())
        labels.extend(y.numpy())
    return np.array(scores), np.array(labels)


In [10]:
train_set = ASVspoofCSDataset(
    "/kaggle/input/la-dataset/LA/ASVspoof2019_LA_train/flac",
    "/kaggle/input/la-dataset/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt",
    apply_cs=True
)

dev_set = ASVspoofCSDataset(
    "/kaggle/input/la-dataset/LA/ASVspoof2019_LA_dev/flac",
    "/kaggle/input/la-dataset/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt",
    apply_cs=True
)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
dev_loader = DataLoader(dev_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

model = AASIST().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

best_eer = 100


[Dataset] Loaded 25380 | Skipped 0
[Dataset] Loaded 24844 | Skipped 0


In [11]:
for epoch in range(EPOCHS):
    loss = train_epoch(model, train_loader, optimizer, criterion)
    scores, labels = evaluate(model, dev_loader)

    eer = compute_eer(scores, labels)
    tdcf = compute_min_tDCF(scores, labels)

    print(f"Epoch {epoch+1} | Loss {loss:.4f} | EER {eer:.2f}% | min-tDCF {tdcf:.4f}")

    if eer < best_eer:
        best_eer = eer
        torch.save(
            {"model": model.state_dict(), "eer": eer, "tdcf": tdcf},
            "/kaggle/working/checkpoints/aasist_cs_best.pth"
        )



100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:57<00:00, 42.48it/s]


Epoch 1 | Loss 0.3308 | EER 52.04% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:49<00:00, 43.04it/s]


Epoch 2 | Loss 0.3292 | EER 75.55% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:49<00:00, 43.04it/s]


Epoch 3 | Loss 0.3291 | EER 45.33% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:48<00:00, 43.12it/s]


Epoch 4 | Loss 0.3292 | EER 85.83% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:49<00:00, 43.07it/s]


Epoch 5 | Loss 0.3290 | EER 61.73% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:48<00:00, 43.10it/s]


Epoch 6 | Loss 0.3290 | EER 44.78% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:48<00:00, 43.11it/s]


Epoch 7 | Loss 0.3290 | EER 42.54% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:48<00:00, 43.10it/s]


Epoch 8 | Loss 0.3292 | EER 12.87% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:48<00:00, 43.10it/s]


Epoch 9 | Loss 0.3290 | EER 46.39% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:49<00:00, 43.04it/s]


Epoch 10 | Loss 0.3290 | EER 57.42% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:49<00:00, 43.04it/s]


Epoch 11 | Loss 0.3291 | EER 60.56% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:50<00:00, 43.01it/s]


Epoch 12 | Loss 0.3291 | EER 48.39% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:48<00:00, 43.12it/s]


Epoch 13 | Loss 0.3289 | EER 35.52% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:47<00:00, 43.18it/s]


Epoch 14 | Loss 0.3291 | EER 80.93% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:48<00:00, 43.13it/s]


Epoch 15 | Loss 0.3291 | EER 69.23% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:48<00:00, 43.14it/s]


Epoch 16 | Loss 0.3290 | EER 16.01% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:46<00:00, 43.25it/s]


Epoch 17 | Loss 0.3292 | EER 65.54% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:30<00:00, 44.46it/s]


Epoch 18 | Loss 0.3291 | EER 39.64% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:28<00:00, 44.64it/s]


Epoch 19 | Loss 0.3289 | EER 55.69% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:28<00:00, 44.62it/s]


Epoch 20 | Loss 0.3291 | EER 53.49% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:28<00:00, 44.65it/s]


Epoch 21 | Loss 0.3291 | EER 60.99% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:29<00:00, 44.53it/s]


Epoch 22 | Loss 0.3290 | EER 60.36% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:28<00:00, 44.63it/s]


Epoch 23 | Loss 0.3291 | EER 75.59% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:29<00:00, 44.58it/s]


Epoch 24 | Loss 0.3291 | EER 54.40% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:29<00:00, 44.60it/s]


Epoch 25 | Loss 0.3290 | EER 70.49% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:29<00:00, 44.58it/s]


Epoch 26 | Loss 0.3289 | EER 61.30% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:28<00:00, 44.61it/s]


Epoch 27 | Loss 0.3291 | EER 56.75% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:29<00:00, 44.56it/s]


Epoch 28 | Loss 0.3291 | EER 69.11% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:29<00:00, 44.58it/s]


Epoch 29 | Loss 0.3290 | EER 60.64% | min-tDCF 0.0100


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 25380/25380 [09:28<00:00, 44.64it/s]


Epoch 30 | Loss 0.3291 | EER 61.42% | min-tDCF 0.0100


In [12]:
import zipfile
import os

zip_path = "/kaggle/working/checkpointsd.zip"
folder_path = "/kaggle/working/checkpoints"

with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for root, _, files in os.walk(folder_path):
        for file in files:
            full_path = os.path.join(root, file)
            arcname = os.path.relpath(full_path, folder_path)
            zipf.write(full_path, arcname)

print("Zipped to:", zip_path)



Zipped to: /kaggle/working/checkpointsd.zip
