### Install Dependencies

In [None]:
pip install torch torchaudio numpy transformers librosa

### Connecting Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

### Importing Libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Wav2Vec2Model, HubertModel, WavLMModel
import numpy as np
from sklearn.metrics import roc_curve
from torch.utils.data import Dataset, DataLoader
import torchaudio
import os
import tarfile
from sklearn.model_selection import train_test_split
import torch.optim as optim

### Loading Data

In [None]:
#Paths
protocol_path = '/content/drive/MyDrive/ASVspoof5_protocols/ASVspoof5.train.tsv'
audio_dir = '/content/drive/MyDrive/flac_T_aa'

In [None]:
# Step 3: Custom Dataset
class ASVspoofDataset(Dataset):
    def __init__(self, audio_dir, label_map, transform=None):
        self.audio_dir = audio_dir
        self.label_map = label_map
        self.files = list(label_map.keys())
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        filename = self.files[idx]
        filepath = os.path.join(self.audio_dir, filename)
        waveform, sample_rate = torchaudio.load(filepath)

        # Optional: Add transformations (e.g., MFCC, MelSpectrogram)
        if self.transform:
            waveform = self.transform(waveform)

        label = self.label_map[filename]
        return waveform, label


In [None]:
#Parse the protocol file and filter available files
def parse_and_filter_protocol(protocol_path, audio_dir):
    label_map = {}
    available_files = set(os.listdir(audio_dir))

    with open(protocol_path, 'r') as file:
        for line in file:
            parts = line.strip().split('\t')
            utt_id, label = parts[0], parts[2]
            audio_file = utt_id + '.flac'
            if audio_file in available_files:
                label_map[audio_file] = 0 if label == 'bonafide' else 1
    return label_map



In [None]:
# Load label map and split into train/test
label_map = parse_and_filter_protocol(protocol_path, audio_dir)

train_keys, test_keys = train_test_split(list(label_map.keys()), test_size=0.2, random_state=42)
train_map = {k: label_map[k] for k in train_keys}
test_map = {k: label_map[k] for k in test_keys}


In [None]:
# Initialize Datasets and DataLoaders
train_dataset = ASVspoofDataset(audio_dir, train_map)
test_dataset = ASVspoofDataset(audio_dir, test_map)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)


### Traning Architecture

In [None]:
# --- Residual Block ---
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers=2):
        super(ResidualBlock, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.SELU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.SELU(),
        )

        # Ensure residual connection has the same dimensions
        self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        return self.layers(x) + self.residual(x)


# --- Forgery Detection Model ---
class ForgeryDetectionModel(nn.Module):
    def __init__(self, ssl_model_name):
        super(ForgeryDetectionModel, self).__init__()

        # Load SSL model dynamically
        self.ssl_model = get_ssl_model(ssl_model_name)
        self.ssl_out_dim = self.ssl_model.config.hidden_size  # Typically 768

        # Dimensionality Reduction
        self.fc1 = nn.Linear(self.ssl_out_dim, 128)
        self.pool = nn.MaxPool2d(kernel_size=3, stride=3)
        self.bn_selu = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.SELU()
        )

        # Residual Blocks
        self.res_block1 = ResidualBlock(1, 32, num_layers=2)
        self.res_block2 = ResidualBlock(32, 64, num_layers=4)

        # Global Average Pooling & Classification
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc2 = nn.Linear(64, 2)

    def forward(self, x):
        # SSL frontend
        x = self.ssl_model(x).last_hidden_state  # (B, T, 768)

        # Dimensionality Reduction
        x = self.fc1(x)  # (B, T, 128)
        x = x.unsqueeze(1)  # Add channel dimension for CNN (B, 1, T, 128)
        x = self.pool(x)  # (B, 1, 67, 42)
        x = self.bn_selu(x)  # Apply BN & SeLU

        # Residual Blocks
        x = self.res_block1(x)  # (B, 32, 67, 42)
        x = self.res_block2(x)  # (B, 64, 67, 42)

        # Global Average Pooling & Classification
        x = self.global_avg_pool(x)  # (B, 64, 1, 1)
        x = x.view(x.size(0), -1)  # Flatten (B, 64)
        x = self.fc2(x)  # (B, 2)
        return x


### Evaluation Implementation

In [None]:
# --- Function to Compute EER ---
def compute_eer(y_true, y_scores):
    fpr, tpr, thresholds = roc_curve(y_true, y_scores)
    fnr = 1 - tpr
    eer_threshold = thresholds[np.nanargmin(np.abs(fnr - fpr))]
    eer = fpr[np.nanargmin(np.abs(fnr - fpr))]
    return eer * 100, eer_threshold

# --- Function to Compute t-DCF ---
def compute_tdcf(y_true, y_scores, P_miss=1, P_fa=1):
    fpr, tpr, _ = roc_curve(y_true, y_scores)
    fnr = 1 - tpr
    tdcf = P_miss * fnr + P_fa * fpr
    return min(tdcf)

### Traning and Evaluation

In [None]:
# --- Function to Select SSL Model ---
def get_ssl_model(model_name):
    if "wav2vec2" in model_name:
        return Wav2Vec2Model.from_pretrained(model_name)
    elif "hubert" in model_name:
        return HubertModel.from_pretrained(model_name)
    elif "wavlm" in model_name:
        return WavLMModel.from_pretrained(model_name)
    else:
        raise ValueError(f"Unknown SSL Model: {model_name}")


# --- Evaluate Model ---
def evaluate_model(model, test_loader, device):
    model.eval()
    y_true, y_scores = [], []

    with torch.no_grad():
        for batch in test_loader:
            audio, labels = batch
            audio = audio.to(device)
            labels = labels.to(device)

            logits = model(audio)
            probs = F.softmax(logits, dim=1)[:, 1].cpu().numpy()
            y_scores.extend(probs)
            y_true.extend(labels.cpu().numpy())

    eer, eer_threshold = compute_eer(y_true, y_scores)
    tdcf = compute_tdcf(y_true, y_scores)

    print(f"EER: {eer:.2f}%")
    print(f"t-DCF: {tdcf:.4f}")

    return eer, tdcf

# --- Train Model ---
def train_model(model, train_loader, device, optimizer, criterion, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0
        correct_preds = 0
        total_preds = 0

        for batch in train_loader:
            audio, labels = batch
            audio = audio.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            # Forward pass
            logits = model(audio)
            loss = criterion(logits, labels)

            # Backward pass
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

            # Metrics for classification accuracy
            _, predicted = torch.max(logits, 1)
            correct_preds += (predicted == labels).sum().item()
            total_preds += labels.size(0)

        epoch_accuracy = 100 * correct_preds / total_preds
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%")



In [None]:
# --- Run Model for Different SSL Frontends ---
ssl_models = ["facebook/wav2vec2-base", "facebook/hubert-base-ls960", "microsoft/wavlm-base-plus"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for ssl_model in ssl_models:
    print(f"\nTraining and Evaluating Model with {ssl_model}...")

    model = ForgeryDetectionModel(ssl_model).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-6)
    criterion = nn.CrossEntropyLoss()

    train_model(model, train_loader, device, optimizer, criterion, num_epochs=10)
    evaluate_model(model, test_loader, device)