In [None]:
import os
import pandas as pd
import torch
import torchaudio
from torch.utils.data import Dataset

class ASVspoofDataset(Dataset):
    """
    PyTorch Dataset class for ASVspoof2019 LA dataset.
    Loads audio and corresponding labels for training.
    """
    def __init__(self, protocol_file, audio_dir, sample_rate=16000, duration=4):
        self.df = pd.read_csv(protocol_file, sep=' ', header=None, names=['file_id', '1', '2', 'label'])
        self.df = self.df[['file_id', 'label']]
        self.audio_dir = audio_dir
        self.sample_rate = sample_rate
        self.max_len = sample_rate * duration  # Target length in samples

        # Convert labels to integers
        label_map = {'bonafide': 0, 'spoof': 1}
        self.df['label'] = self.df['label'].map(label_map)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        file_id = row['file_id']
        label = row['label']
        file_path = os.path.join(self.audio_dir, file_id + ".flac")

        waveform, sr = torchaudio.load(file_path)
        waveform = waveform.mean(dim=0)  # Convert to mono

        # Pad or trim the waveform
        if waveform.size(0) < self.max_len:
            waveform = torch.nn.functional.pad(waveform, (0, self.max_len - waveform.size(0)))
        else:
            waveform = waveform[:self.max_len]

        return waveform.unsqueeze(0), torch.tensor(label, dtype=torch.long)
