In [1]:
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchaudio

In [2]:
# Define a custom dataset class for audio processing
class CustomAudioDataset(Dataset):
    def __init__(self, csv_file, transformations=None, target_sample_rate=16000):
        """
        Initialize the dataset.
        
        Args:
            csv_file (str): Path to the CSV file with annotations.
            transformations (list): List of torchaudio transformations to apply.
            target_sample_rate (int): Target sample rate for resampling.
        """
        self.data = pd.read_csv(csv_file)
        self.transformations = transformations if transformations is not None else []
        self.target_sample_rate = target_sample_rate

    def __len__(self):
        """Return the total number of samples in the dataset."""
        return len(self.data)

    def __getitem__(self, idx):
        """
        Fetch a sample from the dataset.
        
        Args:
            idx (int): Index of the sample to fetch.
        
        Returns:
            tuple: Audio waveform and metadata dictionary.
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # Load and process audio
        audio_file = self.data.iloc[idx]['audio_filepath']
        audio_waveform, sample_rate = torchaudio.load(audio_file)
        audio_waveform = self.resample(audio_waveform, sample_rate)

        # Convert stereo to mono if necessary
        if audio_waveform.size(0) > 1:
            audio_waveform = torch.mean(audio_waveform, dim=0, keepdim=True)

        # Apply transformations
        for transform in self.transformations:
            audio_waveform = transform(audio_waveform)

        # Fetch metadata
        metadata = {
            'duration': self.data.iloc[idx]['duration'],
            'transcription': self.data.iloc[idx]['text'],
            'gender': self.data.iloc[idx]['gender'],
            'age_group': self.data.iloc[idx]['age-group'],
            'primary_language': self.data.iloc[idx]['primary_language'],
            'job_category': self.data.iloc[idx]['job_category']
        }

        return audio_waveform, metadata

    def resample(self, audio_waveform, sample_rate):
        """
        Resample the audio waveform if necessary.
        
        Args:
            audio_waveform (Tensor): The original audio waveform.
            sample_rate (int): The sample rate of the loaded audio.
        
        Returns:
            Tensor: The resampled audio waveform.
        """
        if sample_rate != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(
                orig_freq=sample_rate, new_freq=self.target_sample_rate
            )
            return resampler(audio_waveform)
        return audio_waveform

    # Additional utility methods
    def get_audio_sample_path(self, idx):
        """Get the audio file path for a given index."""
        return self.data.iloc[idx]['audio_filepath']

    def get_audio_sample_duration(self, idx):
        """Get the duration of the audio sample for a given index."""
        return self.data.iloc[idx]['duration']

    def get_transcription(self, idx):
        """Get the transcription text for a given index."""
        return self.data.iloc[idx]['text']

    def get_metadata(self, idx):
        """Get additional metadata for a given index."""
        return {
            'gender': self.data.iloc[idx]['gender'],
            'age_group': self.data.iloc[idx]['age-group'],
            'primary_language': self.data.iloc[idx]['primary_language'],
            'job_category': self.data.iloc[idx]['job_category']
        }


In [3]:
# Collate function for batching
def collate_fn(batch):
    """
    Collate function to pad the time dimension for batching.
    
    Args:
        batch (list): List of samples from the dataset.
    
    Returns:
        tuple: Padded audio waveforms and metadata.
    """
    audio_waveforms, metadata = zip(*batch)
    max_length = max(waveform.size(-1) for waveform in audio_waveforms)
    
    padded_waveforms = []
    for waveform in audio_waveforms:
        padding = max_length - waveform.size(-1)
        padded_waveform = torch.nn.functional.pad(waveform, (0, padding), 'constant', 0)
        padded_waveforms.append(padded_waveform)
    
    return torch.stack(padded_waveforms), metadata


In [4]:
# Define mel spectrogram transformation
mel_spectrogram = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=64)

# Initialize dataset with transformations
dataset = CustomAudioDataset("../meta_speaker_stats.csv", transformations=[mel_spectrogram])


FileNotFoundError: [Errno 2] No such file or directory: '../meta_speaker_stats.csv'

In [5]:
# Example of accessing dataset methods
sample_idx = 325  # Replace with any valid index

# Accessing __getitem__
sample, metadata = dataset[sample_idx]
print("Sample shape:", sample.shape)
print("Metadata:", metadata)

# Accessing other methods
audio_path = dataset.get_audio_sample_path(sample_idx)
duration = dataset.get_audio_sample_duration(sample_idx)
transcription = dataset.get_transcription(sample_idx)
metadata_details = dataset.get_metadata(sample_idx)

print("Audio File Path:", audio_path)
print("Audio Duration:", duration)
print("Transcription:", transcription)
print("Metadata Details:", metadata_details)


NameError: name 'dataset' is not defined

In [6]:
# Use DataLoader to iterate over the dataset with batching and collation
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

# Iterate through batches and print details of one batch
for batch_waveforms, batch_metadata in dataloader:
    print("Batch shape:", batch_waveforms.shape)
    print("Batch metadata:", batch_metadata)
    break  # Print only the first batch details


NameError: name 'dataset' is not defined