In [1]:
import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import librosa
from tqdm import tqdm
from sklearn.metrics import roc_curve

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

SAMPLE_RATE = 16000
SEGMENT_LEN = 64000
FRAME_SIZE = 160

BATCH_SIZE = 16
LR = 1e-4
EPOCHS = 10

os.makedirs("checkpoints", exist_ok=True)


Using device: cuda


In [2]:
def circular_shift(arr, shift):
    return np.roll(arr, shift)

def least_number(a, b):
    g = math.gcd(a, b)
    return a // g, b // g
def generate_matrix(N, M, P, Q):
    row = np.concatenate([np.ones(M), np.zeros(N - M)])
    rows = []
    for _ in range(M):
        rows.append(row)
        row = circular_shift(row, M)
    return np.array(rows)
def apply_cs(audio, sensing_matrix, frame_size):
    n_frames = len(audio) // frame_size
    out = []
    for i in range(n_frames):
        frame = audio[i*frame_size:(i+1)*frame_size]
        out.append(sensing_matrix @ frame)
    return np.concatenate(out)


In [3]:
class ASVspoofCSDataset(Dataset):
    def __init__(self, audio_dir, protocol_path, apply_cs=True):
        self.audio_dir = audio_dir
        self.apply_cs = apply_cs

        M = N = FRAME_SIZE
        P, Q = least_number(M, N)
        self.sensing_matrix = generate_matrix(N, M, P, Q)

        available_files = set(
            os.path.splitext(f)[0]
            for f in os.listdir(audio_dir)
            if f.endswith(".flac")
        )

        self.items = []
        skipped = 0

        with open(protocol_path) as f:
            for line in f:
                parts = line.strip().split()
                utt = parts[1]                      # IMPORTANT
                label = 1 if parts[-1] == "spoof" else 0

                if utt not in available_files:
                    skipped += 1
                    continue

                self.items.append((utt, label))

        print(f"[Dataset] Loaded {len(self.items)} | Skipped {skipped}")

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        utt, label = self.items[idx]
        audio_path = os.path.join(self.audio_dir, utt + ".flac")

        audio, _ = librosa.load(audio_path, sr=SAMPLE_RATE)

        if len(audio) < SEGMENT_LEN:
            audio = np.pad(audio, (0, SEGMENT_LEN - len(audio)))
        else:
            audio = audio[:SEGMENT_LEN]

        if self.apply_cs:
            audio = apply_cs(audio, self.sensing_matrix, FRAME_SIZE)

        return torch.tensor(audio, dtype=torch.float32), torch.tensor(label)


In [None]:
class TFAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.fc1 = nn.Linear(channels, channels // 4)
        self.fc2 = nn.Linear(channels // 4, channels)

    def forward(self, x):
        # x: (B, C, T)
        w = x.mean(dim=2)
        w = F.relu(self.fc1(w))
        w = torch.sigmoid(self.fc2(w)).unsqueeze(-1)
        return x * w
class ResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()

        self.conv1 = nn.Conv1d(in_ch, out_ch, 3, padding=1)
        self.bn1 = nn.BatchNorm1d(out_ch)

        self.conv2 = nn.Conv1d(out_ch, out_ch, 3, padding=1)
        self.bn2 = nn.BatchNorm1d(out_ch)

        self.att = TFAttention(out_ch)
        self.act = nn.LeakyReLU(0.3)

        self.skip = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
        self.pool = nn.MaxPool1d(3)

    def forward(self, x):
        identity = self.skip(x)
        x = self.act(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x = self.att(x)
        x = self.act(x + identity)
        return self.pool(x)
class RawTFNet(nn.Module):
    def __init__(self, base_channels=32):
        super().__init__()

        self.frontend = nn.Sequential(
            nn.Conv1d(1, base_channels, 251, padding=125),
            nn.BatchNorm1d(base_channels),
            nn.LeakyReLU(0.3),
            nn.MaxPool1d(3)
        )

        self.block1 = ResidualBlock(base_channels, base_channels * 2)
        self.block2 = ResidualBlock(base_channels * 2, base_channels * 4)
        self.block3 = ResidualBlock(base_channels * 4, base_channels * 8)

        self.gru = nn.GRU(base_channels * 8, 512, batch_first=True)
        self.fc1 = nn.Linear(512, 512)
        self.fc2 = nn.Linear(512, 2)

    def forward(self, x):
        x = x.unsqueeze(1)            # (B, 1, T)
        x = self.frontend(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        x = x.permute(0, 2, 1)
        x, _ = self.gru(x)
        x = F.normalize(self.fc1(x[:, -1]), dim=1)
        return self.fc2(x)
# RawTFNet-16
# model = RawTFNet(base_channels=16).to(DEVICE)

# RawTFNet-32 (later)
model = RawTFNet(base_channels=32).to(DEVICE)
def compute_eer(scores, labels):
    fpr, tpr, _ = roc_curve(labels, scores, pos_label=1)
    fnr = 1 - tpr
    return 100 * fpr[np.nanargmin(np.abs(fnr - fpr))]
def compute_min_tDCF(scores, labels):
    Ptar, Cmiss, Cfa = 0.01, 1, 1
    fpr, tpr, _ = roc_curve(labels, scores, pos_label=1)
    fnr = 1 - tpr
    return np.min(Cmiss * Ptar * fnr + Cfa * (1 - Ptar) * fpr)


In [5]:
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for x, y in tqdm(loader):
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    scores, labels = [], []
    for x, y in loader:
        probs = torch.softmax(model(x.to(DEVICE)), dim=1)
        scores.extend(probs[:, 1].cpu().numpy())
        labels.extend(y.numpy())
    return np.array(scores), np.array(labels)


In [6]:
train_set = ASVspoofCSDataset(
    "/kaggle/input/asvspoof-19/LA/ASVspoof2019_LA_train/flac",
    "/kaggle/input/asvspoof-19/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt",
    apply_cs=True
)

dev_set = ASVspoofCSDataset(
    "/kaggle/input/asvspoof-19/LA/ASVspoof2019_LA_dev/flac",
    "/kaggle/input/asvspoof-19/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt",
    apply_cs=True
)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
dev_loader = DataLoader(dev_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

best_eer = 100


[Dataset] Loaded 25380 | Skipped 0
[Dataset] Loaded 24844 | Skipped 0


In [7]:
for epoch in range(EPOCHS):
    loss = train_epoch(model, train_loader, optimizer, criterion)
    scores, labels = evaluate(model, dev_loader)

    eer = compute_eer(scores, labels)
    tdcf = compute_min_tDCF(scores, labels)

    print(f"Epoch {epoch+1} | Loss {loss:.4f} | EER {eer:.2f}% | min-tDCF {tdcf:.4f}")

    if eer < best_eer:
        best_eer = eer
        torch.save(
            {"model": model.state_dict(), "eer": eer, "tdcf": tdcf},
            "/kaggle/working/checkpoints/rawtfnet_cs_best.pth"
        )


100%|██████████| 1587/1587 [05:48<00:00,  4.55it/s]


Epoch 1 | Loss 0.3450 | EER 47.29% | min-tDCF 0.0100


100%|██████████| 1587/1587 [05:28<00:00,  4.83it/s]


Epoch 2 | Loss 0.3289 | EER 45.29% | min-tDCF 0.0100


100%|██████████| 1587/1587 [05:30<00:00,  4.80it/s]


Epoch 3 | Loss 0.3257 | EER 38.74% | min-tDCF 0.0099


100%|██████████| 1587/1587 [05:30<00:00,  4.80it/s]


Epoch 4 | Loss 0.2561 | EER 13.38% | min-tDCF 0.0089


100%|██████████| 1587/1587 [05:31<00:00,  4.79it/s]


Epoch 5 | Loss 0.2050 | EER 13.30% | min-tDCF 0.0064


100%|██████████| 1587/1587 [05:31<00:00,  4.79it/s]


Epoch 6 | Loss 0.1761 | EER 9.73% | min-tDCF 0.0061


100%|██████████| 1587/1587 [05:31<00:00,  4.79it/s]


Epoch 7 | Loss 0.1522 | EER 10.44% | min-tDCF 0.0078


100%|██████████| 1587/1587 [05:34<00:00,  4.75it/s]


Epoch 8 | Loss 0.1273 | EER 11.11% | min-tDCF 0.0097


100%|██████████| 1587/1587 [05:42<00:00,  4.64it/s]


Epoch 9 | Loss 0.1109 | EER 7.81% | min-tDCF 0.0040


100%|██████████| 1587/1587 [05:31<00:00,  4.78it/s]


Epoch 10 | Loss 0.0987 | EER 7.30% | min-tDCF 0.0058


In [8]:
!zip -r checkpoints.zip /kaggle/working/checkpoints

  adding: kaggle/working/checkpoints/ (stored 0%)
  adding: kaggle/working/checkpoints/rawtfnet_cs_best.pth (deflated 8%)
