In [14]:
# Install required packages (using latest versions)
!pip install torchaudio librosa tensorboardX matplotlib soundfile tqdm pyyaml

Collecting tensorboardX
  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl.metadata (5.8 kB)
Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.7/101.7 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensorboardX
Successfully installed tensorboardX-2.6.2.2


In [15]:
# Import necessary libraries
import os
import torch
import torchaudio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import yaml
import random
import time
import librosa
import soundfile as sf
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
import torch.nn.functional as F
from tqdm.notebook import tqdm
from sklearn.metrics import roc_auc_score, roc_curve, accuracy_score
from tensorboardX import SummaryWriter
from collections import OrderedDict
import psutil
import gc
import glob

In [16]:
# Check GPU availability and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [17]:
# Set random seed for reproducibility
def setup_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    # Keep deterministic mode for reproducibility of results
    torch.backends.cudnn.deterministic = True
    # Enable benchmark mode for optimized performance with fixed input sizes
    torch.backends.cudnn.benchmark = True

setup_seed()

In [18]:
# Memory monitoring function
def print_memory_stats():
    """Print current memory usage stats"""
    # RAM usage
    process = psutil.Process(os.getpid())
    ram_usage = process.memory_info().rss / (1024 * 1024)  # in MB
    
    # GPU memory if available
    gpu_usage = 0
    if torch.cuda.is_available():
        gpu_usage = torch.cuda.memory_allocated() / (1024 * 1024)  # in MB
        max_gpu = torch.cuda.get_device_properties(0).total_memory / (1024 * 1024)
        
        print(f"RAM Usage: {ram_usage:.2f} MB")
        print(f"GPU Memory Usage: {gpu_usage:.2f} MB / {max_gpu:.2f} MB ({gpu_usage/max_gpu*100:.1f}%)")
    else:
        print(f"RAM Usage: {ram_usage:.2f} MB")

In [19]:
# LRU Cache implementation for memory-efficient caching
class LRUCache:
    """
    LRU (Least Recently Used) Cache for efficient memory usage
    """
    def __init__(self, capacity):
        self.cache = OrderedDict()
        self.capacity = capacity
    
    def get(self, key):
        if key not in self.cache:
            return None
        # Move the accessed item to the end (most recently used)
        self.cache.move_to_end(key)
        return self.cache[key]
    
    def put(self, key, value):
        if key in self.cache:
            # Move the updated item to the end
            self.cache.move_to_end(key)
        elif len(self.cache) >= self.capacity:
            # Remove the least recently used item
            self.cache.popitem(last=False)
        self.cache[key] = value
    
    def __len__(self):
        return len(self.cache)
        
    def popitem(self, last=False):
        """Remove and return an item from the cache"""
        return self.cache.popitem(last=last)

In [20]:
class GraphAttentionLayer(nn.Module):
    """
    Graph Attention Layer (GAT) implementation with batch processing support
    """
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, h, adj):
        """
        Forward pass with support for batch processing
        h: Node features of shape (batch_size, N, in_features)
        adj: Adjacency matrices of shape (batch_size, N, N)
        """
        batch_size, N, _ = h.size()
        
        # Apply feature transformation to all nodes in all graphs - (batch_size, N, out_features)
        Wh = torch.matmul(h, self.W)
        
        # Prepare inputs for attention mechanism
        # Repeat first dimension for comparisons
        Wh1 = Wh.unsqueeze(2).repeat(1, 1, N, 1)  # (batch_size, N, N, out_features)
        # Repeat second dimension for comparisons
        Wh2 = Wh.unsqueeze(1).repeat(1, N, 1, 1)  # (batch_size, N, N, out_features)
        
        # Concatenate for attention calculation
        a_input = torch.cat([Wh1, Wh2], dim=-1)  # (batch_size, N, N, 2*out_features)
        
        # Apply attention mechanism
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(-1))  # (batch_size, N, N)
        
        # Mask attention scores using adjacency matrix
        zero_vec = -9e15 * torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        
        # Apply softmax and dropout
        attention = F.softmax(attention, dim=2)  # Normalize on the third dimension (N)
        attention = F.dropout(attention, self.dropout, training=self.training)
        
        # Apply attention to transform features
        h_prime = torch.matmul(attention, Wh)  # (batch_size, N, out_features)
        
        # Apply non-linearity if needed
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

In [21]:
class GAT(nn.Module):
    """
    Graph Attention Network with multiple attention heads and batch processing support
    """
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
        super(GAT, self).__init__()
        self.dropout = dropout
        self.nheads = nheads

        # Multi-head attention layers
        self.attentions = nn.ModuleList([
            GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) 
            for _ in range(nheads)
        ])
        
        self.out_att = GraphAttentionLayer(
            nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False
        )

    def forward(self, x, adj):
        """
        Forward pass with batch processing support
        x: Node features of shape (batch_size, N, in_features)
        adj: Adjacency matrices of shape (batch_size, N, N)
        """
        x = F.dropout(x, self.dropout, training=self.training)
        
        # Apply each attention head and concatenate results
        x = torch.cat([att(x, adj) for att in self.attentions], dim=2)
        
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(self.out_att(x, adj))
        
        return x

In [22]:
class SincConv(nn.Module):
    """Sinc-based convolution for raw waveform processing"""
    def __init__(self, out_channels, kernel_size, sample_rate=16000, 
                 in_channels=1, stride=1, padding=0, dilation=1, bias=False):
        super(SincConv, self).__init__()
        
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.sample_rate = sample_rate
        
        # Initialize filterbanks with mel scale
        self.freq_low = 50 / (sample_rate/2)  # Normalized low frequency
        self.freq_high = 7500 / (sample_rate/2)  # Normalized high frequency
        
        # Compute the mel scale frequencies
        self.mel_low = 2595 * np.log10(1 + self.freq_low * (sample_rate/2) / 700)
        self.mel_high = 2595 * np.log10(1 + self.freq_high * (sample_rate/2) / 700)
        
        # Equally spaced in mel scale
        mel_points = torch.linspace(self.mel_low, self.mel_high, out_channels + 1)
        
        # Convert back to frequency domain
        f_pts = 700 * (10 ** (mel_points / 2595) - 1)
        
        # Normalize to [0, 1]
        self.freq_bands = f_pts / (sample_rate/2)
        
        # Ensure proper dimensions for the filters and band_widths
        self.filters = nn.Parameter(self.freq_bands[:-1].unsqueeze(-1))      # (out_channels, 1)
        self.band_widths = nn.Parameter((self.freq_bands[1:] - self.freq_bands[:-1]).unsqueeze(-1))  # (out_channels, 1)
        
        # Non trainable
        n = torch.arange(-(kernel_size-1)/2, (kernel_size-1)/2 + 1)
        self.n = nn.Parameter(n.float(), requires_grad=False)  # (kernel_size,)
        
        # Window
        window = 0.54 - 0.46 * torch.cos(2 * np.pi * torch.arange(kernel_size) / kernel_size)
        self.window = nn.Parameter(window, requires_grad=False)  # (kernel_size,)

    def forward(self, x):
        # Get input dimensions
        batch_size, channels, signal_length = x.shape
        
        # Reshape n for proper broadcasting
        n = self.n.view(1, -1, 1)  # (1, kernel_size, 1)
        
        # Get filter center frequencies and bandwidths
        filters = self.filters.view(-1, 1, 1)  # (out_channels, 1, 1)
        band_widths = self.band_widths.view(-1, 1, 1)  # (out_channels, 1, 1)
        
        # Compute sinc filters
        low_pass1 = 2 * filters * torch.sinc(2 * filters * n)  # (out_channels, kernel_size, 1)
        low_pass2 = 2 * (filters + band_widths) * torch.sinc(2 * (filters + band_widths) * n)  # (out_channels, kernel_size, 1)
        band_pass = low_pass2 - low_pass1  # (out_channels, kernel_size, 1)
        
        # Apply window function
        band_pass = band_pass * self.window.view(1, -1, 1)  # (out_channels, kernel_size, 1)
        
        # Normalize
        band_pass = band_pass / (torch.norm(band_pass, p=2, dim=1, keepdim=True) + 1e-8)  # (out_channels, kernel_size, 1)
        
        # Reshape for convolution (out_channels, in_channels, kernel_size)
        filters = band_pass.squeeze(-1).view(self.out_channels, 1, self.kernel_size)  # (out_channels, 1, kernel_size)
        
        # Convolve 
        return F.conv1d(x, filters, stride=1, padding=(self.kernel_size-1)//2)

In [23]:
class ResBlock(nn.Module):
    """Residual block for feature extraction"""
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResBlock, self).__init__()
        
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [24]:
class SpectralGraphTransform(nn.Module):
    """Transform raw features into graph features"""
    def __init__(self, in_features, hidden_features, out_features):
        super(SpectralGraphTransform, self).__init__()
        
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.fc2 = nn.Linear(hidden_features, out_features)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [25]:
class AASIST(nn.Module):
    """
    AASIST: Audio Anti-Spoofing using Integrated Spectro-Temporal Graph Attention Networks
    Optimized for batch processing
    """
    def __init__(self, 
                 sinc_out_channels=70, 
                 sinc_kernel_size=1024, 
                 sample_rate=16000,
                 res_channels=[32, 64, 128],
                 gat_nfeat=128,
                 gat_nhid=64,
                 gat_nclass=32,
                 gat_nheads=8,
                 gat_alpha=0.2,
                 gat_dropout=0.3,
                 n_frame_node=64):
        super(AASIST, self).__init__()
        
        # Raw waveform feature extraction using SincConv
        self.sinc_conv = SincConv(
            out_channels=sinc_out_channels,
            kernel_size=sinc_kernel_size,
            sample_rate=sample_rate
        )
        
        # Residual blocks for feature processing
        self.res_block1 = ResBlock(sinc_out_channels, res_channels[0], stride=2)
        self.res_block2 = ResBlock(res_channels[0], res_channels[1], stride=2)
        self.res_block3 = ResBlock(res_channels[1], res_channels[2], stride=2)
        
        # Feature transformers for graph
        self.frame_transform = SpectralGraphTransform(
            res_channels[2], gat_nfeat*2, gat_nfeat
        )
        
        # Number of frame-level nodes for the adjacency matrix
        self.n_frame_node = n_frame_node
        
        # Graph attention network
        self.gat = GAT(
            nfeat=gat_nfeat,
            nhid=gat_nhid,
            nclass=gat_nclass,
            dropout=gat_dropout,
            alpha=gat_alpha,
            nheads=gat_nheads
        )
        
        # Output layer
        self.output = nn.Linear(gat_nclass, 1)
        
    def forward(self, x):
        """
        Forward pass with efficient batch processing
        x: Input audio of shape (batch_size, 1, signal_length)
        """
        batch_size = x.size(0)
        
        # Extract features using SincConv
        x = self.sinc_conv(x)
        x = F.relu(x)
        
        # Process through residual blocks
        x = self.res_block1(x)
        x = self.res_block2(x)
        x = self.res_block3(x)
        
        # Adaptive pooling to ensure fixed number of frames
        x = F.adaptive_avg_pool1d(x, self.n_frame_node)
        
        # Prepare for graph processing (batch_size, n_frames, features)
        x = x.transpose(1, 2)  # (batch_size, n_frame_node, res_channels[2])
        
        # Transform features for graph processing
        frame_feat = self.frame_transform(x)  # (batch_size, n_frame_node, gat_nfeat)
        
        # Create adjacency matrices (fully connected) - (batch_size, n_frame_node, n_frame_node)
        adj = torch.ones(batch_size, self.n_frame_node, self.n_frame_node, device=x.device)
        
        # Process all graphs in batch at once using the batch-optimized GAT
        gat_out = self.gat(frame_feat, adj)  # (batch_size, n_frame_node, gat_nclass)
        
        # Global pooling of graph features
        gat_out = torch.mean(gat_out, dim=1)  # (batch_size, gat_nclass)
        
        # Final classification
        output = torch.sigmoid(self.output(gat_out)).squeeze(-1)  # (batch_size,)
        
        return output

In [26]:
class AASIST_Fixed(AASIST):
    """Modified AASIST model without final sigmoid for use with BCEWithLogitsLoss"""
    def forward(self, x):
        # Keep all the processing the same as the original model
        batch_size = x.size(0)
        
        # Extract features using SincConv
        x = self.sinc_conv(x)
        x = F.relu(x)
        
        # Process through residual blocks
        x = self.res_block1(x)
        x = self.res_block2(x)
        x = self.res_block3(x)
        
        # Adaptive pooling to ensure fixed number of frames
        x = F.adaptive_avg_pool1d(x, self.n_frame_node)
        
        # Prepare for graph processing (batch_size, n_frames, features)
        x = x.transpose(1, 2)  # (batch_size, n_frame_node, res_channels[2])
        
        # Transform features for graph processing
        frame_feat = self.frame_transform(x)  # (batch_size, n_frame_node, gat_nfeat)
        
        # Create adjacency matrices (fully connected) - (batch_size, n_frame_node, n_frame_node)
        adj = torch.ones(batch_size, self.n_frame_node, self.n_frame_node, device=x.device)
        
        # Process all graphs in batch at once using the batch-optimized GAT
        gat_out = self.gat(frame_feat, adj)  # (batch_size, n_frame_node, gat_nclass)
        
        # Global pooling of graph features
        gat_out = torch.mean(gat_out, dim=1)  # (batch_size, gat_nclass)
        
        # Final classification - DIFFERENCE: No sigmoid here
        output = self.output(gat_out).squeeze(-1)  # (batch_size,)
        
        return output

In [27]:
def count_parameters(model):
    """Count the number of trainable parameters in a model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [28]:
# Test the SincConv layer
def test_sincconv():
    print("Testing SincConv layer...")
    # Create SincConv instance
    sinc_conv = SincConv(
        out_channels=70,
        kernel_size=1024,
        sample_rate=16000
    ).to(device)
    
    # Create random input
    x = torch.randn(2, 1, 16000).to(device)  # (batch_size, channels, signal_length)
    
    # Forward pass
    try:
        out = sinc_conv(x)
        print(f"SincConv test passed! Input shape: {x.shape}, Output shape: {out.shape}")
        return True
    except Exception as e:
        print(f"SincConv test failed: {e}")
        return False

In [29]:
class AudioAugmentation:
    """Enhanced audio augmentation techniques with GPU-compatible implementation"""
    def __init__(self, sample_rate=16000):
        self.sample_rate = sample_rate
    
    def add_noise(self, audio, noise_level=0.005):
        """Add Gaussian noise to audio"""
        # Generate noise on same device as audio
        noise = torch.randn_like(audio) * noise_level
        return audio + noise
    
    def time_shift(self, audio, shift_range=0.1):
        """Random time shift"""
        shift = int(audio.shape[1] * shift_range)
        if shift > 0:
            direction = random.choice([1, -1])
            shift = direction * shift
            if shift > 0:
                # Shift right
                audio = torch.cat([torch.zeros_like(audio[:, :shift]), audio[:, :-shift]], dim=1)
            else:
                # Shift left
                audio = torch.cat([audio[:, -shift:], torch.zeros_like(audio[:, :shift])], dim=1)
        return audio
    
    def speed_change(self, audio, speed_range=0.3):
        """Change playback speed without affecting pitch"""
        # Get device of input tensor
        device = audio.device
        
        # Convert to numpy after moving to CPU
        audio_np = audio.cpu().squeeze().numpy()
        
        # Random speed factor
        speed_factor = random.uniform(1 - speed_range, 1 + speed_range)
        
        # Speed up or slow down
        indices = np.round(np.arange(0, len(audio_np), speed_factor)).astype(int)
        indices = indices[indices < len(audio_np)]
        
        # Get modified audio
        modified_audio = audio_np[indices]
        
        # Convert back to tensor and move to the original device
        modified_tensor = torch.from_numpy(modified_audio).float().unsqueeze(0).to(device)
        
        # Ensure shape matches original
        if modified_tensor.shape[1] < audio.shape[1]:
            # Pad
            padding = audio.shape[1] - modified_tensor.shape[1]
            modified_tensor = torch.nn.functional.pad(modified_tensor, (0, padding))
        elif modified_tensor.shape[1] > audio.shape[1]:
            # Truncate
            modified_tensor = modified_tensor[:, :audio.shape[1]]
        
        return modified_tensor
    
    def apply_batch_augmentation(self, audio_batch):
        """Apply augmentation to a batch of audio efficiently"""
        batch_size = audio_batch.size(0)
        # Clone to avoid modifying original tensor
        augmented_batch = audio_batch.clone()
        
        for i in range(batch_size):
            # Randomly select augmentation type with higher probability (80%)
            if random.random() < 0.8:
                aug_type = random.choice(['noise', 'time_shift', 'speed_change'])
                
                if aug_type == 'noise':
                    # Randomize noise level
                    noise_level = random.uniform(0.001, 0.01)
                    augmented_batch[i:i+1] = self.add_noise(audio_batch[i:i+1], noise_level)
                
                elif aug_type == 'time_shift':
                    # Randomize shift range
                    shift_range = random.uniform(0.05, 0.2)
                    augmented_batch[i:i+1] = self.time_shift(audio_batch[i:i+1], shift_range)
                
                elif aug_type == 'speed_change':
                    # Randomize speed range
                    speed_range = random.uniform(0.1, 0.3)  # Reduced range to avoid extreme changes
                    augmented_batch[i:i+1] = self.speed_change(audio_batch[i:i+1], speed_range)
        
        return augmented_batch

In [30]:
class ASVspoofDataset(Dataset):
    """
    Memory-efficient dataset for ASVspoof 2019 Logical Access
    """
    def __init__(self, 
                 path, 
                 protocol_file, 
                 sample_rate=16000, 
                 max_frames=64000,  # 4 seconds at 16kHz
                 augment=False,
                 cache_capacity=200,  # Reduced capacity for lower memory usage
                 max_cache_size_mb=500):  # Maximum cache size in MB
        
        self.path = path
        self.sample_rate = sample_rate
        self.max_frames = max_frames
        self.augment = augment
        self.max_cache_size_mb = max_cache_size_mb
        
        # Print the current working directory for debugging
        print(f"Current working directory: {os.getcwd()}")
        print(f"Base path: {path}")
        
        # Determine the protocols directory
        if os.path.basename(path) == 'LA':
            # We're at the LA directory level
            protocols_dir = os.path.join(path, 'ASVspoof2019_LA_cm_protocols')
        else:
            # We might be at a higher level
            protocols_dir = os.path.join(path, 'LA', 'ASVspoof2019_LA_cm_protocols')
            if not os.path.exists(protocols_dir):
                # Try another common structure
                protocols_dir = os.path.join(path, 'ASVspoof2019_LA_cm_protocols')
        
        # Get the protocol file path
        if os.path.isabs(protocol_file):
            protocol_path = protocol_file
        else:
            protocol_path = os.path.join(protocols_dir, os.path.basename(protocol_file))
            if not os.path.exists(protocol_path):
                # Try other locations
                potential_paths = [
                    os.path.join(path, protocol_file),
                    os.path.join(path, 'LA', protocol_file),
                    os.path.join(path, 'LA', 'ASVspoof2019_LA_cm_protocols', protocol_file),
                    os.path.join(protocols_dir, protocol_file)
                ]
                for p in potential_paths:
                    if os.path.exists(p):
                        protocol_path = p
                        break
        
        print(f"Using protocol file: {protocol_path}")
        
        # Determine audio directory based on protocol file
        protocol_basename = os.path.basename(protocol_path)
        if 'train' in protocol_basename:
            if os.path.basename(path) == 'LA':
                # We're at the LA directory level
                self.audio_dir = os.path.join(path, 'ASVspoof2019_LA_train', 'flac')
            else:
                # Try different possible paths
                potential_audio_dirs = [
                    os.path.join(path, 'ASVspoof2019_LA_train', 'flac'),
                    os.path.join(path, 'LA', 'ASVspoof2019_LA_train', 'flac')
                ]
                for d in potential_audio_dirs:
                    if os.path.exists(d):
                        self.audio_dir = d
                        break
        elif 'dev' in protocol_basename:
            if os.path.basename(path) == 'LA':
                self.audio_dir = os.path.join(path, 'ASVspoof2019_LA_dev', 'flac')
            else:
                potential_audio_dirs = [
                    os.path.join(path, 'ASVspoof2019_LA_dev', 'flac'),
                    os.path.join(path, 'LA', 'ASVspoof2019_LA_dev', 'flac')
                ]
                for d in potential_audio_dirs:
                    if os.path.exists(d):
                        self.audio_dir = d
                        break
        elif 'eval' in protocol_basename:
            if os.path.basename(path) == 'LA':
                self.audio_dir = os.path.join(path, 'ASVspoof2019_LA_eval', 'flac')
            else:
                potential_audio_dirs = [
                    os.path.join(path, 'ASVspoof2019_LA_eval', 'flac'),
                    os.path.join(path, 'LA', 'ASVspoof2019_LA_eval', 'flac')
                ]
                for d in potential_audio_dirs:
                    if os.path.exists(d):
                        self.audio_dir = d
                        break
        else:
            # Default case - try to find a suitable directory
            potential_dirs = [
                os.path.join(path, 'flac'),
                os.path.join(path, 'LA', 'flac')
            ]
            for d in potential_dirs:
                if os.path.exists(d):
                    self.audio_dir = d
                    break
            if not hasattr(self, 'audio_dir'):
                self.audio_dir = os.path.join(path, 'flac')  # Default
        
        # Check if audio directory exists
        if not os.path.exists(self.audio_dir):
            print(f"Warning: Audio directory {self.audio_dir} not found!")
            print("Available directories:")
            
            # Try to list directories at the base path and one level up
            paths_to_list = [path]
            if 'LA' in path:
                parent_path = os.path.dirname(path)
                paths_to_list.append(parent_path)
            
            for p in paths_to_list:
                try:
                    print(f"\nDirectories in {p}:")
                    for d in os.listdir(p):
                        if os.path.isdir(os.path.join(p, d)):
                            print(f"  - {d}")
                            try:
                                # List subdirectories one level down
                                subdirs = os.listdir(os.path.join(p, d))
                                print(f"    Subdirectories: {subdirs[:5]}{'...' if len(subdirs) > 5 else ''}")
                            except Exception:
                                pass
                except Exception as e:
                    print(f"Error listing {p}: {e}")
            
            # As a last resort, try to locate flac files using find
            print("\nSearching for flac files...")
            try:
                import glob
                flac_files = glob.glob(os.path.join(path, '**', '*.flac'), recursive=True)
                if flac_files:
                    print(f"Found {len(flac_files)} flac files")
                    print(f"Sample file paths: {flac_files[:3]}")
                    # Try to determine the correct audio directory
                    sample_path = os.path.dirname(flac_files[0])
                    self.audio_dir = sample_path
                    print(f"Setting audio directory to: {self.audio_dir}")
            except Exception as e:
                print(f"Error searching for flac files: {e}")
        
        print(f"Using audio directory: {self.audio_dir}")
        
        # Load protocol file
        try:
            self.protocol_df = pd.read_csv(protocol_path, sep=' ', header=None)
            
            # Check number of columns to determine format
            if self.protocol_df.shape[1] >= 5:  # Has at least 5 columns
                self.protocol_df.columns = ['speaker_id', 'file_name', 'environment', 'attack_id', 'spoofing_type'] 
            else:  # Likely the evaluation format
                if self.protocol_df.shape[1] == 3:  # Simplified format
                    self.protocol_df.columns = ['speaker_id', 'file_name', 'spoofing_type']
                else:  # Standard format with missing columns
                    self.protocol_df.columns = ['speaker_id', 'file_name', '-', '-', 'spoofing_type']
            
            # Create file list and labels
            self.file_list = list(self.protocol_df['file_name'])
            self.labels = [(1 if str(x).lower() == 'bonafide' else 0) for x in self.protocol_df['spoofing_type']]
            
            # Print some information about the protocol
            print(f"Protocol file has {self.protocol_df.shape[1]} columns")
            print(f"Protocol columns: {self.protocol_df.columns.tolist()}")
            print(f"First few rows of protocol file:")
            print(self.protocol_df.head(3))
            
        except Exception as e:
            print(f"Error loading protocol file: {e}")
            # Create empty lists as fallback
            self.file_list = []
            self.labels = []
        
        # Create augmenter
        self.augmenter = AudioAugmentation(sample_rate=sample_rate) if augment else None
        
        # Setup LRU cache for efficient memory usage
        self.cache = LRUCache(cache_capacity)
        self.current_cache_size = 0  # Track cache size in bytes
        
        print(f"Loaded {self.__len__()} files from {protocol_path}")
        # Print distribution of classes
        bonafide_count = self.labels.count(1)
        spoof_count = self.labels.count(0)
        print(f"Bonafide: {bonafide_count} ({bonafide_count/len(self.labels)*100:.2f}% of total)")
        print(f"Spoofed: {spoof_count} ({spoof_count/len(self.labels)*100:.2f}% of total)")
        
        # Print a few file paths to verify
        if len(self.file_list) > 0:
            print("Sample expected file paths:")
            for i in range(min(3, len(self.file_list))):
                file_path = os.path.join(self.audio_dir, f"{self.file_list[i]}.flac")
                print(f"  {file_path} (exists: {os.path.exists(file_path)})")
    
    def __len__(self):
        return len(self.file_list)
    
    def pad_or_truncate(self, audio):
        """Pad or truncate audio to max_frames"""
        if audio.shape[1] < self.max_frames:
            # Pad audio
            padding = self.max_frames - audio.shape[1]
            audio = torch.nn.functional.pad(audio, (0, padding))
        elif audio.shape[1] > self.max_frames:
            # Truncate audio - randomize starting point for better diversity if augmenting
            if self.augment:
                max_start = audio.shape[1] - self.max_frames
                start = random.randint(0, max_start)
                audio = audio[:, start:start+self.max_frames]
            else:
                audio = audio[:, :self.max_frames]
        return audio
    
    def check_cache_size(self):
        """Check if cache is too large and clear if necessary"""
        max_bytes = self.max_cache_size_mb * 1024 * 1024  # Convert MB to bytes
        if self.current_cache_size > max_bytes:
            # Clear half the cache to avoid frequent clearing
            items_to_remove = len(self.cache) // 2
            for _ in range(items_to_remove):
                # LRU cache removes least recently used items
                self.cache.popitem(last=False)
            
            # Force garbage collection
            gc.collect()
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
            
            # Reset cache size counter (approximation)
            self.current_cache_size = max_bytes // 2
    
    def __getitem__(self, idx):
        file_name = self.file_list[idx]
        
        # Check if audio is in cache
        audio = self.cache.get(file_name)
        
        if audio is None:
            file_path = os.path.join(self.audio_dir, f"{file_name}.flac")
            
            try:
                # Check if file exists before trying to load it
                if not os.path.exists(file_path):
                    print(f"Warning: Audio file does not exist: {file_path}")
                    # Try different capitalization or path formats
                    alternate_paths = [
                        os.path.join(self.audio_dir, f"{file_name.upper()}.flac"),
                        os.path.join(self.audio_dir, f"{file_name.lower()}.flac"),
                        os.path.join(os.path.dirname(self.audio_dir), f"{file_name}.flac"),
                        os.path.join(os.path.dirname(os.path.dirname(self.audio_dir)), f"{file_name}.flac")
                    ]
                    
                    for alt_path in alternate_paths:
                        if os.path.exists(alt_path):
                            print(f"Found alternate path: {alt_path}")
                            file_path = alt_path
                            break
                
                if not os.path.exists(file_path):
                    # If file still doesn't exist, raise error to trigger the exception handler
                    raise FileNotFoundError(f"Could not find audio file: {file_path}")
                
                # Load audio file
                audio, sr = torchaudio.load(file_path)
                
                # Resample if needed
                if sr != self.sample_rate:
                    resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
                    audio = resampler(audio)
                
                # Ensure mono audio
                if audio.shape[0] > 1:
                    audio = torch.mean(audio, dim=0, keepdim=True)
                
                # Normalize audio
                audio = audio / (torch.max(torch.abs(audio)) + 1e-8)
                
                # Add to cache if there's space
                audio_size_bytes = audio.element_size() * audio.nelement()
                if self.current_cache_size + audio_size_bytes < self.max_cache_size_mb * 1024 * 1024:
                    self.cache.put(file_name, audio.clone())
                    self.current_cache_size += audio_size_bytes
                    
                    # Check if cache is too large
                    self.check_cache_size()
                
            except Exception as e:
                print(f"Error loading {file_path}: {e}")
                # Return a zero tensor and the label
                audio = torch.zeros(1, self.max_frames)
        
        # Pad or truncate to fixed length
        audio = self.pad_or_truncate(audio)
        
        # Apply augmentation if enabled
        # Note: We don't apply augmentation here for individual samples
        # Instead, we'll apply batch augmentation in the training loop
        
        label = self.labels[idx]
        
        return audio, label

In [31]:
def get_dataloaders(path, batch_size=32, sample_rate=16000, max_frames=64000, 
                   num_workers=4, augment_train=True, pin_memory=True,
                   persistent_workers=True):
    """
    Create DataLoaders for ASVspoof 2019 LA with memory-efficient loading
    """
    # Check and fix path to match the Kaggle dataset structure
    print(f"Original path: {path}")
    
    # Check if the path exists
    if not os.path.exists(path):
        print(f"Warning: Path {path} does not exist!")
        
        # Try to find the correct path
        potential_paths = [
            '/kaggle/input/asvspoof2019-la',
            '/kaggle/input/asvspoof2019-la/LA',
            '/kaggle/input/asvspoof-2019',
            '/kaggle/input/asvspoof2019',
        ]
        
        for p in potential_paths:
            if os.path.exists(p):
                print(f"Found alternate path: {p}")
                path = p
                break
    
    # Print available directories for debugging
    print("Available directories in dataset path:")
    try:
        for d in os.listdir(path):
            if os.path.isdir(os.path.join(path, d)):
                print(f"  - {d}")
    except Exception as e:
        print(f"Error listing directories: {e}")
    
    # Check for the 'LA' directory
    la_path = os.path.join(path, 'LA')
    if os.path.exists(la_path):
        # If there's an LA subdirectory, use it as the base path
        print(f"Found LA directory: {la_path}")
        path = la_path
    
    # Find protocol files
    # First look in the standard location
    protocols_dir = os.path.join(path, 'ASVspoof2019_LA_cm_protocols')
    
    # Define protocol file paths
    train_protocol_file = 'ASVspoof2019.LA.cm.train.trn.txt'
    dev_protocol_file = 'ASVspoof2019.LA.cm.dev.trl.txt'
    eval_protocol_file = 'ASVspoof2019.LA.cm.eval.trl.txt'
    
    # Get full paths to protocol files
    if os.path.exists(protocols_dir):
        train_protocol = os.path.join(protocols_dir, train_protocol_file)
        dev_protocol = os.path.join(protocols_dir, dev_protocol_file)
        eval_protocol = os.path.join(protocols_dir, eval_protocol_file)
    else:
        # Try directly in the path
        train_protocol = os.path.join(path, train_protocol_file)
        dev_protocol = os.path.join(path, dev_protocol_file)
        eval_protocol = os.path.join(path, eval_protocol_file)
    
    # Verify protocol files exist
    print(f"Checking train protocol: {train_protocol} (exists: {os.path.exists(train_protocol)})")
    print(f"Checking dev protocol: {dev_protocol} (exists: {os.path.exists(dev_protocol)})")
    print(f"Checking eval protocol: {eval_protocol} (exists: {os.path.exists(eval_protocol)})")
    
    # If protocol files don't exist, search for them
    if not os.path.exists(train_protocol):
        print("Searching for protocol files...")
        import glob
        protocol_files = glob.glob(os.path.join(path, '**', train_protocol_file), recursive=True)
        if protocol_files:
            print(f"Found train protocol at: {protocol_files[0]}")
            train_protocol = protocol_files[0]
            
            # Try to find related protocols in the same directory
            protocols_dir = os.path.dirname(train_protocol)
            dev_protocol = os.path.join(protocols_dir, dev_protocol_file)
            eval_protocol = os.path.join(protocols_dir, eval_protocol_file)
    
    # Create datasets with error handling
    try:
        print("\nCreating training dataset...")
        train_dataset = ASVspoofDataset(
            path=path,
            protocol_file=train_protocol,
            sample_rate=sample_rate,
            max_frames=max_frames,
            augment=False  # We'll handle augmentation in the training loop for better efficiency
        )
    except Exception as e:
        print(f"Error creating train dataset: {e}")
        # Create an empty dataset as a fallback
        train_dataset = ASVspoofDataset(
            path=path,
            protocol_file="",
            sample_rate=sample_rate,
            max_frames=max_frames,
            augment=False
        )
    
    try:
        print("\nCreating development dataset...")
        dev_dataset = ASVspoofDataset(
            path=path,
            protocol_file=dev_protocol,
            sample_rate=sample_rate,
            max_frames=max_frames,
            augment=False
        )
    except Exception as e:
        print(f"Error creating dev dataset: {e}")
        # Use a copy of the training dataset as fallback
        print("Using training dataset as fallback for development dataset")
        dev_dataset = train_dataset
    
    try:
        print("\nCreating evaluation dataset...")
        eval_dataset = ASVspoofDataset(
            path=path,
            protocol_file=eval_protocol,
            sample_rate=sample_rate,
            max_frames=max_frames,
            augment=False
        )
    except Exception as e:
        print(f"Error creating eval dataset: {e}")
        # Use a copy of the dev dataset as fallback
        print("Using development dataset as fallback for evaluation dataset")
        eval_dataset = dev_dataset
    
    # Create DataLoaders with error handling
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=num_workers,
        pin_memory=pin_memory,
        prefetch_factor=2 if num_workers > 0 else None,
        persistent_workers=persistent_workers if num_workers > 0 else False
    )
    
    dev_loader = DataLoader(
        dev_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=persistent_workers if num_workers > 0 else False
    )
    
    eval_loader = DataLoader(
        eval_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=max(1, num_workers // 2),  # Use fewer workers for eval
        pin_memory=pin_memory
    )
    
    return train_loader, dev_loader, eval_loader

In [32]:
def mixup_data(x, y, alpha=0.2):
    """
    Mixup data augmentation that properly handles device placement
    """
    # Get the device of input data
    device = x.device
    
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(device)
    
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    
    return mixed_x, y_a, y_b, lam




In [33]:
def compute_tdcf(bonafide_scores, spoof_scores, cost_model=None):
    """
    Compute tandem detection cost function (t-DCF)
    Simplified implementation for ASVspoof 2019
    """
    if cost_model is None:
        # Default cost model as per ASVspoof 2019
        cost_model = {
            'Pspoof': 0.05,  # Prior probability of spoofing attack
            'Cmiss_asv': 1,  # Cost of ASV system rejecting target speaker
            'Cfa_asv': 10,   # Cost of ASV system accepting impostor
            'Cmiss_cm': 1,   # Cost of CM system rejecting genuine speech
            'Cfa_cm': 10,    # Cost of CM system accepting spoofed speech
            'beta_cm': 0.5   # Weight for countermeasure errors
        }
    
    # Sort scores
    bonafide_scores = sorted(bonafide_scores)
    spoof_scores = sorted(spoof_scores)
    
    # Calculate FRR (genuine) and FAR (spoof) for different thresholds
    thresholds = sorted(list(set(bonafide_scores + spoof_scores)))
    
    frr_list = []
    far_list = []
    
    for threshold in thresholds:
        frr = sum(score > threshold for score in bonafide_scores) / len(bonafide_scores)
        far = sum(score <= threshold for score in spoof_scores) / len(spoof_scores)
        frr_list.append(frr)
        far_list.append(far)
    
    # Calculate t-DCF for each threshold
    tdcf_list = []
    for frr, far in zip(frr_list, far_list):
        # Simplified t-DCF calculation
        tdcf = (cost_model['Cmiss_cm'] * frr * (1 - cost_model['Pspoof'])) + \
               (cost_model['Cfa_cm'] * far * cost_model['Pspoof'])
        tdcf_list.append(tdcf)
    
    # Find minimum t-DCF
    min_tdcf = min(tdcf_list)
    min_tdcf_threshold = thresholds[tdcf_list.index(min_tdcf)]
    
    return min_tdcf, min_tdcf_threshold

In [34]:
# Configuration for faster training and lower memory usage
config = {
    # Training parameters
    'batch_size': 32,      # Smaller batch size to prevent OOM errors
    'epochs': 20,          # Reduced epochs 
    'lr': 0.0003,          # Learning rate
    'weight_decay': 0.0001,
    'mixup_alpha': 0.2,    # Mixup augmentation parameter
    'amp': True,           # Enable Automatic Mixed Precision for faster training
    
    # Model parameters
    'model': {
        'sinc_out_channels': 70,
        'sinc_kernel_size': 1024,
        'sample_rate': 16000,
        'res_channels': [32, 64, 128],
        'gat_nfeat': 128,
        'gat_nhid': 64,
        'gat_nclass': 32,
        'gat_nheads': 8,
        'gat_alpha': 0.2,
        'gat_dropout': 0.3,  # Reduced dropout for faster convergence
        'n_frame_node': 64
    },
    
    # Data parameters
    'data': {
        'sample_rate': 16000,
        'max_frames': 64000,  # 4 seconds at 16kHz
        'num_workers': 2,     # Reduced for stability
        'augment_train': True,
        'cache_capacity': 100,  # Number of files to cache
        'max_cache_size_mb': 300  # Maximum cache size in MB
    },
    
    # Optimization parameters
    'optimizer': 'AdamW',  # Options: 'Adam', 'AdamW', 'SGD'
    'scheduler': 'OneCycleLR',  # Options: 'CosineAnnealingWarmRestarts', 'OneCycleLR', 'ReduceLROnPlateau'
    
    # Paths - will be updated automatically
    'data_path': '/kaggle/input/asvspoof2019-la/LA',
    'output_dir': './outputs'
}

# Create output directory
os.makedirs(config['output_dir'], exist_ok=True)

# Detect and configure ASVspoof dataset path
print("Detecting ASVspoof dataset location on Kaggle...")

# List all available datasets in the input directory
print("Available datasets in Kaggle input directory:")
!ls -la /kaggle/input/

# Check for ASVspoof in the directory names
asvspoof_dirs = []
for root, dirs, files in os.walk('/kaggle/input'):
    for d in dirs:
        if 'asvspoof' in d.lower():
            asvspoof_dirs.append(os.path.join(root, d))

if asvspoof_dirs:
    print(f"Found potential ASVspoof directories: {asvspoof_dirs}")
    
    # Find the most likely candidate (with LA or logical_access in name)
    best_dir = None
    for d in asvspoof_dirs:
        # Check if this directory has LA structure
        if os.path.exists(os.path.join(d, 'LA')):
            best_dir = os.path.join(d, 'LA')
            break
        elif 'LA' in d or 'logical' in d.lower() or 'logical_access' in d.lower():
            best_dir = d
            break
    
    # If no obvious best directory, just use the first one
    if best_dir is None and asvspoof_dirs:
        best_dir = asvspoof_dirs[0]
    
    if best_dir:
        print(f"Setting dataset path to: {best_dir}")
        config['data_path'] = best_dir
        
        # Check for specific paths known to work
        if 'asvspoof2019-la' in best_dir:
            print("Detected asvspoof2019-la dataset, using known structure")
            
            # If path is to asvspoof2019-la, add LA subdirectory 
            if not best_dir.endswith('/LA'):
                config['data_path'] = os.path.join(best_dir, 'LA')
                print(f"Updated path to include LA directory: {config['data_path']}")
else:
    print("No ASVspoof directories found. Using default path.")

# Update configuration based on the structure
print(f"\nFinal dataset path: {config['data_path']}")

# Update model and loss function for mixed precision compatibility
print("Updating model and loss function for mixed precision compatibility...")
config['use_fixed_model'] = True  # Use the fixed model version
config['use_bce_with_logits'] = True  # Use BCEWithLogitsLoss instead of BCELoss

Detecting ASVspoof dataset location on Kaggle...
Available datasets in Kaggle input directory:
total 8
drwxr-xr-x 3 root   root    4096 Apr  5 14:06 .
drwxr-xr-x 5 root   root    4096 Apr  5 14:06 ..
drwxr-xr-x 3 nobody nogroup    0 Apr  3 14:46 asvspoof2019-la
Found potential ASVspoof directories: ['/kaggle/input/asvspoof2019-la', '/kaggle/input/asvspoof2019-la/LA/ASVspoof2019_LA_dev', '/kaggle/input/asvspoof2019-la/LA/ASVspoof2019_LA_train', '/kaggle/input/asvspoof2019-la/LA/ASVspoof2019_LA_cm_protocols', '/kaggle/input/asvspoof2019-la/LA/ASVspoof2019_LA_eval', '/kaggle/input/asvspoof2019-la/LA/ASVspoof2019_LA_asv_protocols', '/kaggle/input/asvspoof2019-la/LA/ASVspoof2019_LA_asv_scores']
Setting dataset path to: /kaggle/input/asvspoof2019-la/LA
Detected asvspoof2019-la dataset, using known structure

Final dataset path: /kaggle/input/asvspoof2019-la/LA
Updating model and loss function for mixed precision compatibility...


In [35]:
# Update AMP functions to use the newer syntax
def get_amp_autocast():
    """Get the appropriate autocast context manager"""
    if torch.cuda.is_available():
        return torch.amp.autocast(device_type='cuda')
    else:
        return torch.amp.autocast(device_type='cpu')


In [36]:
def get_amp_scaler():
    """Get the appropriate GradScaler"""
    if torch.cuda.is_available():
        return torch.amp.GradScaler()
    else:
        return None

In [37]:
def main():
    # Set random seed
    setup_seed(42)
    
    # Print initial memory stats
    print("Initial memory usage:")
    print_memory_stats()
    
    # Test SincConv before proceeding
    print("\nTesting SincConv implementation:")
    test_sincconv()
    
    # Create data loaders
    try:
        train_loader, dev_loader, eval_loader = get_dataloaders(
            path=config['data_path'],
            batch_size=config['batch_size'],
            sample_rate=config['data']['sample_rate'],
            max_frames=config['data']['max_frames'],
            num_workers=config['data']['num_workers'],
            augment_train=config['data']['augment_train']
        )
    except Exception as e:
        print(f"Error creating dataloaders: {e}")
        # Create empty dataloaders
        train_loader = []
        dev_loader = []
        eval_loader = []
    
    # Create model based on configuration
    if config.get('use_fixed_model', False):
        print("Using fixed model without sigmoid for BCEWithLogitsLoss")
        model = AASIST_Fixed(
            sinc_out_channels=config['model']['sinc_out_channels'],
            sinc_kernel_size=config['model']['sinc_kernel_size'],
            sample_rate=config['model']['sample_rate'],
            res_channels=config['model']['res_channels'],
            gat_nfeat=config['model']['gat_nfeat'],
            gat_nhid=config['model']['gat_nhid'],
            gat_nclass=config['model']['gat_nclass'],
            gat_nheads=config['model']['gat_nheads'],
            gat_alpha=config['model']['gat_alpha'],
            gat_dropout=config['model']['gat_dropout'],
            n_frame_node=config['model']['n_frame_node']
        ).to(device)
    else:
        print("Using original model with sigmoid and BCELoss")
        model = AASIST(
            sinc_out_channels=config['model']['sinc_out_channels'],
            sinc_kernel_size=config['model']['sinc_kernel_size'],
            sample_rate=config['model']['sample_rate'],
            res_channels=config['model']['res_channels'],
            gat_nfeat=config['model']['gat_nfeat'],
            gat_nhid=config['model']['gat_nhid'],
            gat_nclass=config['model']['gat_nclass'],
            gat_nheads=config['model']['gat_nheads'],
            gat_alpha=config['model']['gat_alpha'],
            gat_dropout=config['model']['gat_dropout'],
            n_frame_node=config['model']['n_frame_node']
        ).to(device)
    
    print(f"Model has {count_parameters(model):,} trainable parameters")
    
    # Create optimizer
    if config['optimizer'] == 'Adam':
        optimizer = optim.Adam(
            model.parameters(), 
            lr=config['lr'], 
            weight_decay=config['weight_decay']
        )
    elif config['optimizer'] == 'AdamW':
        optimizer = optim.AdamW(
            model.parameters(), 
            lr=config['lr'], 
            weight_decay=config['weight_decay']
        )
    elif config['optimizer'] == 'SGD':
        optimizer = optim.SGD(
            model.parameters(),
            lr=config['lr'],
            momentum=0.9,
            weight_decay=config['weight_decay'],
            nesterov=True
        )
    
    # Check if data loaders are empty
    if not train_loader or len(train_loader) == 0:
        print("Error: Empty train loader. Cannot continue training.")
        return None, {}
    
    # Create learning rate scheduler
    if config['scheduler'] == 'CosineAnnealingWarmRestarts':
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=10,  # Restart every 10 epochs
            T_mult=1,
            eta_min=config['lr'] / 10
        )
    elif config['scheduler'] == 'OneCycleLR':
        scheduler = optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=config['lr'],
            epochs=config['epochs'],
            steps_per_epoch=len(train_loader),
            pct_start=0.2,  # Warm up for 20% of training
            div_factor=25,  # Initial lr = max_lr/25
            final_div_factor=1000  # Final lr = max_lr/1000
        )
    elif config['scheduler'] == 'ReduceLROnPlateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode='min', 
            factor=0.5, 
            patience=3, 
            verbose=True
        )
    
    # Create loss function
    if config.get('use_bce_with_logits', False):
        print("Using BCEWithLogitsLoss for mixed precision compatibility")
        criterion = nn.BCEWithLogitsLoss()
    else:
        print("Using standard BCELoss (not compatible with mixed precision)")
        criterion = nn.BCELoss()
    
    # Create tensorboard writer
    writer = SummaryWriter(os.path.join(config['output_dir'], 'logs'))
    
    # Create augmenter
    augmenter = AudioAugmentation(sample_rate=config['data']['sample_rate'])
    
    # Initialize tracking variables
    best_eer = float('inf')
    best_epoch = 0
    
    # Print start message
    print(f"Starting training for {config['epochs']} epochs")
    
    # Initialize Automatic Mixed Precision (AMP) components
    use_amp = config['amp']
    scaler = get_amp_scaler() if use_amp else None
    
    # Training loop
    for epoch in range(config['epochs']):
        print(f"\nEpoch {epoch+1}/{config['epochs']}")
        
        # Train
        model.train()
        running_loss = 0.0
        pred_scores = []
        true_labels = []
        
        # Training loop with fixed autocast
        pbar = tqdm(train_loader, desc="Training")
        for batch_idx, (audio, labels) in enumerate(pbar):
            try:
                # Move data to device
                audio = audio.to(device)
                labels = labels.float().to(device)
                
                # Apply batch augmentation if enabled
                if augmenter is not None:
                    try:
                        audio = augmenter.apply_batch_augmentation(audio)
                    except Exception as e:
                        print(f"Warning: Augmentation failed: {e}. Continuing with original audio.")
                
                # Apply mixup
                mixup_applied = False
                if config.get('mixup_alpha', 0) > 0:
                    try:
                        audio, labels_a, labels_b, lam = mixup_data(audio, labels, config['mixup_alpha'])
                        mixup_applied = True
                    except Exception as e:
                        print(f"Warning: Mixup failed: {e}. Continuing without mixup.")
                
                # Forward pass with fixed mixed precision
                optimizer.zero_grad()
                
                if use_amp and scaler is not None:
                    with torch.amp.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'):
                        outputs = model(audio)
                        
                        if mixup_applied:
                            loss = lam * criterion(outputs, labels_a) + (1 - lam) * criterion(outputs, labels_b)
                        else:
                            loss = criterion(outputs, labels)
                    
                    # Backward and optimize with gradient scaling
                    scaler.scale(loss).backward()
                    
                    # Clip gradients to prevent exploding gradients
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0)
                    
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    # Standard forward pass without mixed precision
                    outputs = model(audio)
                    
                    if mixup_applied:
                        loss = lam * criterion(outputs, labels_a) + (1 - lam) * criterion(outputs, labels_b)
                    else:
                        loss = criterion(outputs, labels)
                    
                    # Backward and optimize
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0)
                    optimizer.step()
                
                # Step scheduler if batch-based
                if scheduler is not None and isinstance(scheduler, (optim.lr_scheduler.CyclicLR, 
                                                                  optim.lr_scheduler.OneCycleLR)):
                    scheduler.step()
                
                # For metrics, use original labels if not using mixup
                if not mixup_applied:
                    # Get scores appropriate to the model type
                    if isinstance(model, AASIST_Fixed):
                        scores = torch.sigmoid(outputs).detach().cpu().numpy()
                    else:
                        scores = outputs.detach().cpu().numpy()
                        
                    pred_scores.extend(scores)
                    true_labels.extend(labels.detach().cpu().numpy())
                
                # Update progress bar
                running_loss += loss.item()
                avg_loss = running_loss / (batch_idx + 1)
                pbar.set_postfix({'loss': avg_loss})
                
                # Free up memory
                del audio, labels, outputs, loss
                if mixup_applied:
                    del labels_a, labels_b
                torch.cuda.empty_cache() if torch.cuda.is_available() else None
                
            except Exception as e:
                print(f"Error processing batch {batch_idx}: {e}")
                # Try to free memory
                torch.cuda.empty_cache() if torch.cuda.is_available() else None
                continue
        
        # Calculate metrics
        epoch_loss = running_loss / len(train_loader) if len(train_loader) > 0 else float('inf')
        
        # Skip metrics calculation if no predictions were collected
        if len(pred_scores) > 0:
            pred_labels = [1 if score >= 0.5 else 0 for score in pred_scores]
            accuracy = accuracy_score(true_labels, pred_labels)
            auc = roc_auc_score(true_labels, pred_scores) if len(np.unique(true_labels)) > 1 else 0.5
        else:
            accuracy = 0
            auc = 0
            
        # Print training metrics
        print(f"Train Loss: {epoch_loss:.4f}, Acc: {accuracy:.4f}, AUC: {auc:.4f}")
        
        # Validate
        model.eval()
        val_running_loss = 0.0
        val_pred_scores = []
        val_true_labels = []
        
        with torch.no_grad():
            for audio, labels in tqdm(dev_loader, desc="Validation"):
                try:
                    # Move data to device
                    audio = audio.to(device)
                    labels = labels.float().to(device)
                    
                    # Forward pass with fixed mixed precision
                    if use_amp and torch.cuda.is_available():
                        with torch.amp.autocast(device_type='cuda'):
                            outputs = model(audio)
                            loss = criterion(outputs, labels)
                    else:
                        outputs = model(audio)
                        loss = criterion(outputs, labels)
                    
                    # Collect statistics
                    val_running_loss += loss.item()
                    
                    # Get scores appropriate to the model type
                    if isinstance(model, AASIST_Fixed):
                        scores = torch.sigmoid(outputs).detach().cpu().numpy()
                    else:
                        scores = outputs.detach().cpu().numpy()
                        
                    val_pred_scores.extend(scores)
                    val_true_labels.extend(labels.detach().cpu().numpy())
                    
                except Exception as e:
                    print(f"Error during validation: {e}")
                    continue
        
        # Calculate validation metrics
        val_loss = val_running_loss / len(dev_loader) if len(dev_loader) > 0 else float('inf')
        
        if len(val_pred_scores) > 0:
            val_pred_labels = [1 if score >= 0.5 else 0 for score in val_pred_scores]
            val_accuracy = accuracy_score(val_true_labels, val_pred_labels)
            val_auc = roc_auc_score(val_true_labels, val_pred_scores) if len(np.unique(val_true_labels)) > 1 else 0.5
            
            # Calculate EER
            fpr, tpr, thresholds = roc_curve(val_true_labels, val_pred_scores)
            fnr = 1 - tpr
            eer_idx = np.nanargmin(np.absolute(fnr - fpr))
            val_eer = fpr[eer_idx]
            eer_threshold = thresholds[eer_idx]
        else:
            val_accuracy = 0
            val_auc = 0.5
            val_eer = 0.5
            eer_threshold = 0.5
        
        # Update learning rate for epoch-based schedulers
        if config['scheduler'] == 'CosineAnnealingWarmRestarts':
            scheduler.step()
        elif config['scheduler'] == 'ReduceLROnPlateau':
            scheduler.step(val_loss)
        
        # Save logs
        writer.add_scalar('Loss/train', epoch_loss, epoch)
        writer.add_scalar('Loss/val', val_loss, epoch)
        writer.add_scalar('Accuracy/train', accuracy, epoch)
        writer.add_scalar('Accuracy/val', val_accuracy, epoch)
        writer.add_scalar('AUC/train', auc, epoch)
        writer.add_scalar('AUC/val', val_auc, epoch)
        writer.add_scalar('EER/val', val_eer, epoch)
        writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)
        
        # Print validation metrics
        print(f"Val Loss: {val_loss:.4f}, Acc: {val_accuracy:.4f}, AUC: {val_auc:.4f}, EER: {val_eer:.4f}")
        
        # Memory check after each epoch
        print(f"Memory usage after epoch {epoch+1}:")
        print_memory_stats()
        
        # Save checkpoint (only save every few epochs to save disk space)
        if (epoch + 1) % 5 == 0 or epoch == config['epochs'] - 1:
            checkpoint_path = os.path.join(config['output_dir'], f'checkpoint_epoch_{epoch+1}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'val_eer': val_eer,
                'best_eer': best_eer,
                'is_fixed_model': isinstance(model, AASIST_Fixed)
            }, checkpoint_path)
        
        # Save best model
        if val_eer < best_eer:
            best_eer = val_eer
            best_epoch = epoch
            best_model_path = os.path.join(config['output_dir'], 'best_model.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'val_eer': val_eer,
                'best_eer': best_eer,
                'is_fixed_model': isinstance(model, AASIST_Fixed)
            }, best_model_path)
            print(f"New best model saved with EER: {val_eer:.4f}")
        
        # Clean up to prevent memory leaks
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Final evaluation on evaluation set
    print("\nTraining completed!")
    print(f"Best validation EER: {best_eer:.4f} at epoch {best_epoch+1}")
    print("\nEvaluating on evaluation set...")
    
    # Load best model
    best_model_path = os.path.join(config['output_dir'], 'best_model.pth')
    if os.path.exists(best_model_path):
        checkpoint = torch.load(best_model_path)
        
        # Check if it's a fixed model
        is_fixed_model = checkpoint.get('is_fixed_model', False)
        if is_fixed_model and not isinstance(model, AASIST_Fixed):
            print("Loading best model (fixed version without sigmoid)")
            model = AASIST_Fixed(
                sinc_out_channels=config['model']['sinc_out_channels'],
                sinc_kernel_size=config['model']['sinc_kernel_size'],
                sample_rate=config['model']['sample_rate'],
                res_channels=config['model']['res_channels'],
                gat_nfeat=config['model']['gat_nfeat'],
                gat_nhid=config['model']['gat_nhid'],
                gat_nclass=config['model']['gat_nclass'],
                gat_nheads=config['model']['gat_nheads'],
                gat_alpha=config['model']['gat_alpha'],
                gat_dropout=config['model']['gat_dropout'],
                n_frame_node=config['model']['n_frame_node']
            ).to(device)
        elif not is_fixed_model and isinstance(model, AASIST_Fixed):
            print("Loading best model (original version with sigmoid)")
            model = AASIST(
                sinc_out_channels=config['model']['sinc_out_channels'],
                sinc_kernel_size=config['model']['sinc_kernel_size'],
                sample_rate=config['model']['sample_rate'],
                res_channels=config['model']['res_channels'],
                gat_nfeat=config['model']['gat_nfeat'],
                gat_nhid=config['model']['gat_nhid'],
                gat_nclass=config['model']['gat_nclass'],
                gat_nheads=config['model']['gat_nheads'],
                gat_alpha=config['model']['gat_alpha'],
                gat_dropout=config['model']['gat_dropout'],
                n_frame_node=config['model']['n_frame_node']
            ).to(device)
            
        model.load_state_dict(checkpoint['model_state_dict'])
    
    # Evaluate final model
    model.eval()
    test_pred_scores = []
    test_true_labels = []
    
    with torch.no_grad():
        for audio, labels in tqdm(eval_loader, desc="Evaluation"):
            try:
                # Move data to device
                audio = audio.to(device)
                
                # Forward pass
                if use_amp and torch.cuda.is_available():
                    with torch.amp.autocast(device_type='cuda'):
                        outputs = model(audio)
                else:
                    outputs = model(audio)
                
                # Get scores appropriate to the model type
                if isinstance(model, AASIST_Fixed):
                    scores = torch.sigmoid(outputs).detach().cpu().numpy()
                else:
                    scores = outputs.detach().cpu().numpy()
                    
                test_pred_scores.extend(scores)
                test_true_labels.extend(labels.numpy())
                
            except Exception as e:
                print(f"Error during evaluation: {e}")
                continue
    
    # Calculate final evaluation metrics
    if len(test_pred_scores) > 0:
        test_pred_labels = [1 if score >= 0.5 else 0 for score in test_pred_scores]
        test_accuracy = accuracy_score(test_true_labels, test_pred_labels)
        test_auc = roc_auc_score(test_true_labels, test_pred_scores) if len(np.unique(test_true_labels)) > 1 else 0.5
        
        # Calculate EER
        fpr, tpr, thresholds = roc_curve(test_true_labels, test_pred_scores)
        fnr = 1 - tpr
        eer_idx = np.nanargmin(np.absolute(fnr - fpr))
        test_eer = fpr[eer_idx]
        test_eer_threshold = thresholds[eer_idx]
        
        # Calculate t-DCF
        bonafide_scores = [test_pred_scores[i] for i in range(len(test_pred_scores)) if test_true_labels[i] == 1]
        spoof_scores = [test_pred_scores[i] for i in range(len(test_pred_scores)) if test_true_labels[i] == 0]
        
        if bonafide_scores and spoof_scores:
            test_min_tdcf, test_tdcf_threshold = compute_tdcf(bonafide_scores, spoof_scores)
        else:
            test_min_tdcf, test_tdcf_threshold = 0.5, 0.5
    else:
        test_accuracy = 0
        test_auc = 0.5
        test_eer = 0.5
        test_eer_threshold = 0.5
        test_min_tdcf = 0.5
        test_tdcf_threshold = 0.5
    
    # Report results
    print("\nEvaluation Results:")
    print(f"Accuracy: {test_accuracy:.4f}")
    print(f"AUC: {test_auc:.4f}")
    print(f"EER: {test_eer:.4f}")
    print(f"min-tDCF: {test_min_tdcf:.4f}")
    print(f"EER Threshold: {test_eer_threshold:.4f}")
    
    # Save final results
    results = {
        'test_accuracy': test_accuracy,
        'test_auc': test_auc,
        'test_eer': test_eer,
        'test_tdcf': test_min_tdcf,
        'eer_threshold': test_eer_threshold,
        'tdcf_threshold': test_tdcf_threshold,
        'best_epoch': best_epoch + 1,
        'is_fixed_model': isinstance(model, AASIST_Fixed)
    }
    
    with open(os.path.join(config['output_dir'], 'results.yaml'), 'w') as f:
        yaml.dump(results, f)
    
    # Print final memory usage
    print("\nFinal memory usage:")
    print_memory_stats()
    
    return model, results

# Run the training function
if __name__ == "__main__":
    model, results = main()

Initial memory usage:
RAM Usage: 739.97 MB
GPU Memory Usage: 0.00 MB / 16269.25 MB (0.0%)

Testing SincConv implementation:
Testing SincConv layer...
SincConv test passed! Input shape: torch.Size([2, 1, 16000]), Output shape: torch.Size([2, 70, 15999])
Original path: /kaggle/input/asvspoof2019-la/LA
Available directories in dataset path:
  - ASVspoof2019_LA_dev
  - ASVspoof2019_LA_train
  - ASVspoof2019_LA_cm_protocols
  - ASVspoof2019_LA_eval
  - ASVspoof2019_LA_asv_protocols
  - ASVspoof2019_LA_asv_scores
Checking train protocol: /kaggle/input/asvspoof2019-la/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt (exists: True)
Checking dev protocol: /kaggle/input/asvspoof2019-la/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt (exists: True)
Checking eval protocol: /kaggle/input/asvspoof2019-la/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.eval.trl.txt (exists: True)

Creating training dataset...
Current working directory: /kaggle/working
Base path: /kag

Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.3662, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.3278, Acc: 0.8974, AUC: 0.5947, EER: 0.4368
Memory usage after epoch 1:
RAM Usage: 1689.70 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)
New best model saved with EER: 0.4368

Epoch 2/20


Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.3216, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.2917, Acc: 0.8974, AUC: 0.7639, EER: 0.3115
Memory usage after epoch 2:
RAM Usage: 1691.84 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)
New best model saved with EER: 0.3115

Epoch 3/20


Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.3023, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.3993, Acc: 0.8987, AUC: 0.8227, EER: 0.2497
Memory usage after epoch 3:
RAM Usage: 1692.34 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)
New best model saved with EER: 0.2497

Epoch 4/20


Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.2561, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.1409, Acc: 0.9457, AUC: 0.9615, EER: 0.1067
Memory usage after epoch 4:
RAM Usage: 1692.34 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)
New best model saved with EER: 0.1067

Epoch 5/20


Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.1937, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.0799, Acc: 0.9733, AUC: 0.9920, EER: 0.0501
Memory usage after epoch 5:
RAM Usage: 1692.34 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)
New best model saved with EER: 0.0501

Epoch 6/20


Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.1802, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.0837, Acc: 0.9702, AUC: 0.9944, EER: 0.0330
Memory usage after epoch 6:
RAM Usage: 1692.41 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)
New best model saved with EER: 0.0330

Epoch 7/20


Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.1684, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.0233, Acc: 0.9928, AUC: 0.9996, EER: 0.0099
Memory usage after epoch 7:
RAM Usage: 1692.71 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)
New best model saved with EER: 0.0099

Epoch 8/20


Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.1522, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.0174, Acc: 0.9958, AUC: 0.9998, EER: 0.0087
Memory usage after epoch 8:
RAM Usage: 1692.71 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)
New best model saved with EER: 0.0087

Epoch 9/20


Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.1549, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.0257, Acc: 0.9933, AUC: 0.9998, EER: 0.0067
Memory usage after epoch 9:
RAM Usage: 1692.84 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)
New best model saved with EER: 0.0067

Epoch 10/20


Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.1537, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.0132, Acc: 0.9965, AUC: 0.9999, EER: 0.0047
Memory usage after epoch 10:
RAM Usage: 1692.96 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)
New best model saved with EER: 0.0047

Epoch 11/20


Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.1541, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.0080, Acc: 0.9977, AUC: 0.9999, EER: 0.0035
Memory usage after epoch 11:
RAM Usage: 1692.96 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)
New best model saved with EER: 0.0035

Epoch 12/20


Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.1510, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.0120, Acc: 0.9967, AUC: 0.9999, EER: 0.0047
Memory usage after epoch 12:
RAM Usage: 1693.09 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)

Epoch 13/20


Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.1472, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.0113, Acc: 0.9976, AUC: 0.9999, EER: 0.0049
Memory usage after epoch 13:
RAM Usage: 1693.09 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)

Epoch 14/20


Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.1425, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.0089, Acc: 0.9977, AUC: 0.9999, EER: 0.0048
Memory usage after epoch 14:
RAM Usage: 1693.34 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)

Epoch 15/20


Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.1411, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.0060, Acc: 0.9986, AUC: 1.0000, EER: 0.0028
Memory usage after epoch 15:
RAM Usage: 1693.34 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)
New best model saved with EER: 0.0028

Epoch 16/20


Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.1362, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.0081, Acc: 0.9984, AUC: 1.0000, EER: 0.0020
Memory usage after epoch 16:
RAM Usage: 1693.59 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)
New best model saved with EER: 0.0020

Epoch 17/20


Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.1398, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.0065, Acc: 0.9983, AUC: 1.0000, EER: 0.0019
Memory usage after epoch 17:
RAM Usage: 1693.71 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)
New best model saved with EER: 0.0019

Epoch 18/20


Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.1447, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.0108, Acc: 0.9974, AUC: 1.0000, EER: 0.0019
Memory usage after epoch 18:
RAM Usage: 1693.71 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)

Epoch 19/20


Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.1365, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.0065, Acc: 0.9982, AUC: 1.0000, EER: 0.0018
Memory usage after epoch 19:
RAM Usage: 1693.84 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)
New best model saved with EER: 0.0018

Epoch 20/20


Training:   0%|          | 0/794 [00:00<?, ?it/s]

Train Loss: 0.1330, Acc: 0.0000, AUC: 0.0000


Validation:   0%|          | 0/777 [00:00<?, ?it/s]

Val Loss: 0.0051, Acc: 0.9986, AUC: 1.0000, EER: 0.0020
Memory usage after epoch 20:
RAM Usage: 1693.96 MB
GPU Memory Usage: 23.42 MB / 16269.25 MB (0.1%)

Training completed!
Best validation EER: 0.0018 at epoch 19

Evaluating on evaluation set...


  checkpoint = torch.load(best_model_path)


Evaluation:   0%|          | 0/2227 [00:00<?, ?it/s]


Evaluation Results:
Accuracy: 0.6281
AUC: 0.9285
EER: 0.1562
min-tDCF: 0.5000
EER Threshold: 0.9854

Final memory usage:
RAM Usage: 1705.58 MB
GPU Memory Usage: 24.78 MB / 16269.25 MB (0.2%)


In [54]:
def predict_audio(model, audio_path, device, sample_rate=16000, max_frames=64000, amp=True):
    """Predict whether an audio file is real or fake"""
    # Load audio
    audio, sr = torchaudio.load(audio_path)
    
    # Resample if needed
    if sr != sample_rate:
        resampler = torchaudio.transforms.Resample(sr, sample_rate)
        audio = resampler(audio)
    
    # Ensure mono
    if audio.shape[0] > 1:
        audio = torch.mean(audio, dim=0, keepdim=True)
    
    # Normalize
    audio = audio / (torch.max(torch.abs(audio)) + 1e-8)
    
    # Pad or truncate
    if audio.shape[1] < max_frames:
        padding = max_frames - audio.shape[1]
        audio = torch.nn.functional.pad(audio, (0, padding))
    elif audio.shape[1] > max_frames:
        audio = audio[:, :max_frames]
    
    # Add batch dimension - THIS LINE FIXES THE ISSUE
    audio = audio.unsqueeze(0)
    
    # Move to device
    audio = audio.to(device)
    
    # Set model to eval mode
    model.eval()
    
    # Check if it's the fixed model type (without sigmoid in forward)
    is_fixed_model = isinstance(model, AASIST_Fixed)
    
    # Get prediction
    with torch.no_grad():
        if amp and torch.cuda.is_available():
            with torch.amp.autocast(device_type='cuda'):
                output = model(audio)
        else:
            output = model(audio)
    
    # Apply sigmoid if using the fixed model
    if is_fixed_model:
        score = torch.sigmoid(output).item()
    else:
        score = output.item()
    
    # Return score
    return score

In [55]:
def visualize_waveform_and_spectrogram(audio_path, score, sample_rate=16000, threshold=0.5):
    """Visualize audio waveform and spectrogram with prediction"""
    # Load audio
    audio, sr = torchaudio.load(audio_path)
    
    # Resample if needed
    if sr != sample_rate:
        resampler = torchaudio.transforms.Resample(sr, sample_rate)
        audio = resampler(audio)
    
    # Ensure mono
    if audio.shape[0] > 1:
        audio = torch.mean(audio, dim=0, keepdim=True)
    
    # Convert to numpy
    audio_np = audio.squeeze().numpy()
    
    # Create figure
    plt.figure(figsize=(12, 10))
    
    # Plot waveform
    plt.subplot(3, 1, 1)
    time_axis = np.linspace(0, len(audio_np) / sample_rate, len(audio_np))
    plt.plot(time_axis, audio_np)
    plt.title('Audio Waveform')
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude')
    plt.grid(True)
    
    # Plot spectrogram
    plt.subplot(3, 1, 2)
    D = librosa.stft(audio_np, n_fft=2048, hop_length=512)
    S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
    img = librosa.display.specshow(S_db, sr=sample_rate, hop_length=512, 
                          x_axis='time', y_axis='log', cmap='viridis')
    plt.colorbar(img, format='%+2.0f dB')
    plt.title('Spectrogram')
    
    # Add prediction visualization
    plt.subplot(3, 1, 3)
    
    # Create a simple gauge chart for prediction
    prediction = "REAL" if score > threshold else "FAKE"
    confidence = score if score > threshold else 1 - score
    
    # Use different colors for real vs fake
    color = 'green' if prediction == "REAL" else 'red'
    
    plt.barh(['Prediction'], [confidence], color=color)
    plt.xlim(0, 1)
    plt.xticks([0, 0.25, 0.5, 0.75, 1], ['0%', '25%', '50%', '75%', '100%'])
    plt.title(f'Prediction: {prediction} (Score: {score:.4f}, Confidence: {confidence:.2%})')
    
    plt.tight_layout()
    plt.show()

In [56]:
def batch_evaluate_files(model, file_dir, device, sample_rate=16000, max_frames=64000, 
                         threshold=0.5, amp=True, ext='.flac'):
    """
    Evaluate multiple audio files in a directory
    """
    results = []
    
    # Find all audio files
    files = [f for f in os.listdir(file_dir) if f.endswith(ext)]
    print(f"Found {len(files)} files to evaluate")
    
    # Check if it's the fixed model type (without sigmoid in forward)
    is_fixed_model = isinstance(model, AASIST_Fixed)
    
    # Process files
    for file in tqdm(files, desc="Evaluating files"):
        file_path = os.path.join(file_dir, file)
        
        try:
            # Load and preprocess audio
            audio, sr = torchaudio.load(file_path)
            
            # Resample if needed
            if sr != sample_rate:
                resampler = torchaudio.transforms.Resample(sr, sample_rate)
                audio = resampler(audio)
            
            # Ensure mono
            if audio.shape[0] > 1:
                audio = torch.mean(audio, dim=0, keepdim=True)
            
            # Normalize
            audio = audio / (torch.max(torch.abs(audio)) + 1e-8)
            
            # Pad or truncate
            if audio.shape[1] < max_frames:
                padding = max_frames - audio.shape[1]
                audio = torch.nn.functional.pad(audio, (0, padding))
            elif audio.shape[1] > max_frames:
                audio = audio[:, :max_frames]
            
            # Move to device
            audio = audio.to(device)
            
            # Get prediction
            model.eval()
            with torch.no_grad():
                if amp and torch.cuda.is_available():
                    with torch.amp.autocast(device_type='cuda'):
                        output = model(audio)
                else:
                    output = model(audio)
            
            # Apply sigmoid if using the fixed model
            if is_fixed_model:
                score = torch.sigmoid(output).item()
            else:
                score = output.item()
            
            # Determine prediction
            prediction = "REAL" if score > threshold else "FAKE"
            confidence = score if score > threshold else 1 - score
            
            # Store result
            results.append({
                'file': file,
                'score': score,
                'prediction': prediction,
                'confidence': confidence
            })
        except Exception as e:
            print(f"Error processing {file}: {e}")
    
    # Convert to dataframe
    results_df = pd.DataFrame(results)
    
    # Display summary
    print("\nEvaluation Summary:")
    print(f"Total files: {len(results_df)}")
    print(f"Predicted as real: {len(results_df[results_df['prediction'] == 'REAL'])}")
    print(f"Predicted as fake: {len(results_df[results_df['prediction'] == 'FAKE'])}")
    
    # Save results
    results_df.to_csv(os.path.join(file_dir, 'evaluation_results.csv'), index=False)
    
    return results_df

In [57]:
# Example of loading and using a saved model
def load_and_use_model(model_path, device):
    """Load a saved model and prepare it for inference"""
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    
    # Check if it's a fixed model
    is_fixed_model = checkpoint.get('is_fixed_model', False)
    
    # Create appropriate model
    if is_fixed_model:
        print("Loading fixed model without sigmoid")
        model = AASIST_Fixed(
            sinc_out_channels=config['model']['sinc_out_channels'],
            sinc_kernel_size=config['model']['sinc_kernel_size'],
            sample_rate=config['model']['sample_rate'],
            res_channels=config['model']['res_channels'],
            gat_nfeat=config['model']['gat_nfeat'],
            gat_nhid=config['model']['gat_nhid'],
            gat_nclass=config['model']['gat_nclass'],
            gat_nheads=config['model']['gat_nheads'],
            gat_alpha=config['model']['gat_alpha'],
            gat_dropout=config['model']['gat_dropout'],
            n_frame_node=config['model']['n_frame_node']
        ).to(device)
    else:
        print("Loading original model with sigmoid")
        model = AASIST(
            sinc_out_channels=config['model']['sinc_out_channels'],
            sinc_kernel_size=config['model']['sinc_kernel_size'],
            sample_rate=config['model']['sample_rate'],
            res_channels=config['model']['res_channels'],
            gat_nfeat=config['model']['gat_nfeat'],
            gat_nhid=config['model']['gat_nhid'],
            gat_nclass=config['model']['gat_nclass'],
            gat_nheads=config['model']['gat_nheads'],
            gat_alpha=config['model']['gat_alpha'],
            gat_dropout=config['model']['gat_dropout'],
            n_frame_node=config['model']['n_frame_node']
        ).to(device)
    
    # Load state dict
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Set to evaluation mode
    model.eval()
    
    return model

In [None]:
# Load a trained model
model_path = os.path.join(config['output_dir'], 'best_model.pth')
model = load_and_use_model(model_path, device)

# Predict on an audio file from your dataset
# Using a real sample file from your dataset
audio_path = "/kaggle/input/asvspoof2019-la/LA/ASVspoof2019_LA_eval/flac/LA_E_1000147.flac"
print(f"Analyzing file: {audio_path}")
print(f"File exists: {os.path.exists(audio_path)}")

# Get prediction score
score = predict_audio(model, audio_path, device)
print(f"Prediction score: {score:.4f} (Higher = more likely to be real)")

# Visualize the results
visualize_waveform_and_spectrogram(audio_path, score)

  checkpoint = torch.load(model_path, map_location=device)


Loading fixed model without sigmoid
Analyzing file: /kaggle/input/asvspoof2019-la/LA/ASVspoof2019_LA_eval/flac/LA_E_1000147.flac
File exists: True
Prediction score: 0.9702 (Higher = more likely to be real)


In [52]:
# Example usage - uncomment to run
# # Load a trained model
# model_path = os.path.join(config['output_dir'], 'best_model.pth')
# model = load_and_use_model(model_path, device)
# 
# # Predict on an audio file
# audio_path = "/path/to/your/audio/file.flac"
# score = predict_audio(model, audio_path, device)
# visualize_waveform_and_spectrogram(audio_path, score)

# Conclusion and Further Improvements

### The AASIST implementation offers state-of-the-art performance for detecting AI-generated human speech. Here are some possible improvements for future work:

### 1. **Model Optimizations**:
    - Model quantization for faster inference
    - Knowledge distillation for smaller models
    - Adaptive frame selection for variable-length audio
 
### 2. **Training Enhancements**:
    - Adversarial training for improved robustness
    - Transfer learning from larger datasets
    - Multi-task learning (combining spoofing detection with other tasks)
 
### 3. **Real-world Applications**:
    - Integration with web services and APIs
    - Browser extensions for real-time detection
    - Mobile applications for on-device detection