In [1]:
import multiprocessing
multiprocessing.cpu_count()

56

# Environment Setup

In [2]:
import sys
!{sys.executable} -m pip install torch torchaudio transformers librosa matplotlib numpy scikit-learn pandas seaborn tqdm 

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


# Data Preparation

In [3]:
import os
import torchaudio
import pandas as pd
from sklearn.model_selection import train_test_split

# Define paths - adjust these to your JupyterHub environment
BASE_PATH = "./data/InTheWild"  # Update this path
DATASET_PATH = os.path.join(BASE_PATH, "release_in_the_wild")  # Update this path

def load_InTheWild_metadata(base_path, dataset_path):
    meta_file = os.path.join(base_path, "meta.csv")
    print("Meta file: ", meta_file)
    metadata = pd.read_csv(meta_file)
    metadata['filepath'] = metadata['file'].apply(lambda x: os.path.join(dataset_path, x))
    metadata['label'] = metadata['label'].apply(lambda x: 1 if x == 'spoof' else 0)
    return metadata

def load_ASVspoof_metadata(dataset_path):
    protocol_file = os.path.join(dataset_path, 'protocol.txt')
    metadata = pd.read_csv(protocol_file, sep=' ', header=None, 
                         names=['speaker_id', 'filename', 'unknown1', 'unknown2', 'label'])
    metadata['filepath'] = metadata['filename'].apply(lambda x: os.path.join(dataset_path, 'flac', x + '.flac'))
    metadata['label'] = metadata['label'].apply(lambda x: 1 if x == 'spoof' else 0)
    return metadata

# Load dataset
inthewild_meta = load_InTheWild_metadata(BASE_PATH, DATASET_PATH)
combined_meta = pd.concat([inthewild_meta], ignore_index=True)

# Split data
train_df, test_df = train_test_split(combined_meta, test_size=0.2, random_state=42, stratify=combined_meta['label'])
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42, stratify=train_df['label'])

Meta file:  ./data/InTheWild/meta.csv


## Data Loaders

In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
import librosa
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

class AudioDataset(Dataset):
    def __init__(self, metadata, sample_rate=16000, max_length=64600, name="dataset"):
        self.metadata = metadata
        self.sample_rate = sample_rate
        self.max_length = max_length
        self.name = name
        self._analyze_dataset()

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]
        try:
            waveform, sr = torchaudio.load(row['filepath'])

            if sr != self.sample_rate:
                resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
                waveform = resampler(waveform)

            if waveform.shape[1] < self.max_length:
                pad_length = self.max_length - waveform.shape[1]
                waveform = torch.nn.functional.pad(waveform, (0, pad_length))
            else:
                waveform = waveform[:, :self.max_length]

            return waveform.squeeze(0), torch.tensor(row['label'], dtype=torch.float32)
        except Exception as e:
            print(f"\nError loading {row['filepath']}: {str(e)}")
            return torch.zeros(self.max_length), torch.tensor(-1, dtype=torch.float32)

    def _analyze_dataset(self):
        print(f"\n{'='*50}")
        print(f"Initializing {self.name} dataset")
        print(f"{'='*50}")
        print(f"Total samples: {len(self.metadata)}")
        print(f"Real/Fake ratio: {sum(self.metadata['label']==0)}/{sum(self.metadata['label']==1)}")

# Create datasets and data loaders
train_dataset = AudioDataset(train_df, name="Training")
val_dataset = AudioDataset(val_df, name="Validation")
test_dataset = AudioDataset(test_df, name="Test")

batch_size = 32
num_workers = 0
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory = False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory = False)
test_loader = DataLoader(test_dataset, batch_size=batch_size)


Initializing Training dataset
Total samples: 22880
Real/Fake ratio: 14373/8507

Initializing Validation dataset
Total samples: 2543
Real/Fake ratio: 1597/946

Initializing Test dataset
Total samples: 6356
Real/Fake ratio: 3993/2363


# Model Implementation

In [5]:
import zipfile

def unzip_dataset(zip_path, extract_to):
    """
    Unzips the dataset file if it exists
    
    Parameters:
    - zip_path (str): Path to the zip file
    - extract_to (str): Directory to extract to
    
    Returns:
    - bool: True if unzipped successfully, False otherwise
    """
    try:
        if not os.path.exists(zip_path):
            print(f"Zip file not found at {zip_path}")
            return False
            
        print(f"Unzipping {zip_path} to {extract_to}...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to)
        print("Unzip completed successfully!")
        return True
    except Exception as e:
        print(f"Error unzipping file: {e}")
        return False

zip_file_path = "models/models--microsoft--wavlm-base.zip"  # Change if your zip has different name

if os.path.exists(zip_file_path):
    unzip_success = unzip_dataset(zip_file_path, "models")
else:
    print("No zip file found, assuming dataset is already extracted")

Unzipping models/models--microsoft--wavlm-base.zip to models...
Unzip completed successfully!


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import WavLMModel, WavLMConfig
from typing import Optional, Tuple

class WavLMFeatureExtractor(nn.Module):
    def __init__(self, model_name: str = "models/wavlm-base", freeze: bool = True):
        """
        WavLM feature extractor with optional fine-tuning
        
        Args:
            model_name: Path to local pretrained WavLM model
            freeze: Whether to freeze WavLM parameters
        """
        super().__init__()
        self.config = WavLMConfig.from_pretrained(model_name)
        self.wavlm = WavLMModel.from_pretrained(model_name)

        if freeze:
            for param in self.wavlm.parameters():
                param.requires_grad = False
                
        self.sample_rate = 16000  # WavLM's expected sample rate
        self.output_dim = self.config.hidden_size

    def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
        """
        Args:
            waveforms: Input audio tensor of shape (batch, seq_len) or (batch, 1, seq_len)
        Returns:
            features: Extracted features of shape (batch, seq_len, hidden_size)
        """
        # Input validation and reshaping
        if waveforms.dim() == 1:
            waveforms = waveforms.unsqueeze(0)
        elif waveforms.dim() == 3:
            waveforms = waveforms.squeeze(1)
            
        # Normalize waveform to [-1, 1] if not already
        if waveforms.abs().max() > 1:
            waveforms = waveforms / (waveforms.abs().max() + 1e-8)
            
        outputs = self.wavlm(waveforms)
        return outputs.last_hidden_state

class AASIST(nn.Module):
    def __init__(self, input_dim: int = 768, num_heads: int = 4, dropout: float = 0.3):
        """
        AASIST model for audio spoofing detection
        
        Args:
            input_dim: Dimension of input features
            num_heads: Number of attention heads
            dropout: Dropout probability
        """
        super().__init__()
        
        # Spectro-temporal processing
        self.conv_block = nn.Sequential(
            nn.Conv1d(input_dim, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.MaxPool1d(2),
            
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.MaxPool1d(2),
            
            nn.Conv1d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Dropout(dropout),
        )
        
        # Attention mechanism
        self.attention = nn.MultiheadAttention(
            embed_dim=256,
            num_heads=num_heads,
            dropout=dropout
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        
        self.pool = nn.AdaptiveAvgPool1d(1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input features of shape (batch, seq_len, input_dim)
        Returns:
            predictions: Output scores of shape (batch,)
        """
        # Conv1d expects (batch, channels, seq_len)
        x = x.permute(0, 2, 1)
        x = self.conv_block(x)
        
        # Global average pooling
        x = self.pool(x).squeeze(2)
        
        # Self-attention (expects seq_len, batch, channels)
        x = x.unsqueeze(0)
        attn_output, _ = self.attention(x, x, x)
        x = x + attn_output  # Residual connection
        x = x.mean(dim=0)    # Average over sequence
        
        return self.classifier(x).squeeze(1)

class WavLM_AASIST_Model(nn.Module):
    def __init__(self, wavlm_model: str = "microsoft/wavlm-base", freeze_wavlm: bool = True):
        """
        Combined WavLM + AASIST model for audio deepfake detection
        
        Args:
            wavlm_model: Name of pretrained WavLM model
            freeze_wavlm: Whether to freeze WavLM parameters
        """
        super().__init__()
        self.feature_extractor = WavLMFeatureExtractor(wavlm_model, freeze_wavlm)
        self.aasist = AASIST(input_dim=self.feature_extractor.output_dim)
        
    def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
        """
        Args:
            waveforms: Input audio tensor of shape (batch, seq_len) or (batch, 1, seq_len)
        Returns:
            predictions: Output scores of shape (batch,)
        """
        features = self.feature_extractor(waveforms)
        return self.aasist(features)
    
    def get_feature_dim(self) -> int:
        """Returns the dimension of the extracted features"""
        return self.feature_extractor.output_dim

# Ignore the tqdm warning (optional)
# import warnings
# warnings.filterwarnings("ignore", category=UserWarning, module="tqdm.auto")

# Training

## Training Setup

In [7]:
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

print(f"Initializeing model")
model = WavLM_AASIST_Model(
    wavlm_model="models/wavlm-base",  # Point to your local model directory
    freeze_wavlm=True
).to(device)
print(f"Criterion")
criterion = nn.BCELoss()
print(f"Optimizer")
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
print(f"Scheduler")
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_score is None:
            self.best_score = val_loss
        elif val_loss > self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_loss
            self.counter = 0

early_stopping = EarlyStopping(patience=5)

Using device: cpu
Initializeing model
Criterion
Optimizer
Scheduler


## Training Loop

In [8]:
from tqdm import tqdm
import time


def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # Use position=0 and leave=True for the main bar
    with tqdm(dataloader, desc="Training", position=0, leave=True) as pbar:
        for i, (waveforms, labels) in enumerate(pbar):
            waveforms, labels = waveforms.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(waveforms)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            preds = (outputs > 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            
            # Update the progress bar description with current metrics
            pbar.set_postfix({
                'loss': running_loss/(i+1),
                'acc': correct/total
            })
            
    return running_loss / len(dataloader), correct / total

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        # Use position=0 and leave=True for the main bar
        with tqdm(dataloader, desc="Validating", position=0, leave=True) as pbar:
            for i, (waveforms, labels) in enumerate(pbar):
                waveforms, labels = waveforms.to(device), labels.to(device)
                outputs = model(waveforms)
                loss = criterion(outputs, labels)
                running_loss += loss.item()
                preds = (outputs > 0.5).float()
                correct += (preds == labels).sum().item()
                total += labels.size(0)
                
                # Update the progress bar description with current metrics
                pbar.set_postfix({
                    'loss': running_loss/(i+1),
                    'acc': correct/total
                })
                
    return running_loss / len(dataloader), correct / total

# Training loop
num_epochs = 3
train_losses, val_losses = [], []
train_accs, val_accs = [], []

# start = time.time()
# for i, (waveforms, labels) in enumerate(train_loader):
#     if i == 10:
#         break
# print(f"Time to load 10 batches: {time.time() - start:.2f} sec")

waveforms, labels = next(iter(train_loader))
waveforms, labels = waveforms.to(device), labels.to(device)

start = time.time()
for _ in range(10):
    optimizer.zero_grad()
    outputs = model(waveforms)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
print(f"Time for 10 training steps: {time.time() - start:.2f} sec")

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    scheduler.step(val_loss)
    early_stopping(val_loss)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)
    
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"Train Acc: {train_acc:.2%} | Val Acc: {val_acc:.2%}")
    
    if early_stopping.early_stop:
        print("Early stopping triggered!")
        break

Time for 10 training steps: 78.04 sec

Epoch 1/3


Training:   0%|          | 1/715 [00:42<8:27:54, 42.68s/it, loss=0.691, acc=0.531]


KeyboardInterrupt: 

# Evaluation and Visualization

In [None]:
from sklearn.metrics import roc_auc_score, roc_curve, accuracy_score, precision_score, recall_score, f1_score

def evaluate_model(model, dataloader, device):
    model.eval()
    all_labels = []
    all_outputs = []

    with torch.no_grad():
        for waveforms, labels in dataloader:
            waveforms, labels = waveforms.to(device), labels.to(device)
            outputs = model(waveforms)
            all_labels.extend(labels.cpu().numpy())
            all_outputs.extend(outputs.cpu().numpy())

    preds = [1 if x > 0.5 else 0 for x in all_outputs]
    fpr, tpr, thresholds = roc_curve(all_labels, all_outputs)
    fnr = 1 - tpr
    eer = fpr[np.nanargmin(np.absolute(fnr - fpr))]
    
    return {
        'accuracy': accuracy_score(all_labels, preds),
        'precision': precision_score(all_labels, preds),
        'recall': recall_score(all_labels, preds),
        'f1': f1_score(all_labels, preds),
        'auc': roc_auc_score(all_labels, all_outputs),
        'eer': eer
    }

# Evaluate
test_metrics = evaluate_model(model, test_loader, device)
print("\nTest Metrics:")
for k, v in test_metrics.items():
    print(f"{k}: {v:.4f}")

# Plot training curves
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train')
plt.plot(val_losses, label='Validation')
plt.title('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train')
plt.plot(val_accs, label='Validation')
plt.title('Accuracy')
plt.legend()
plt.show()

# Save Model

In [None]:
torch.save(model.state_dict(), 'audio_deepfake_model.pth')

# To load later:
# model = WavLM_AASIST_Model().to(device)
# model.load_state_dict(torch.load('audio_deepfake_model.pth'))