In [11]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
from transformers import AutoProcessor
import os
from sklearn.model_selection import train_test_split


class AudioDataset(Dataset):
    def __init__(self, data_samples, target_length=16000*3):  # 3 seconds at 16kHz
        """
        Args:
            data_samples (list): List of tuples (audio_path, emotion_label)
            target_length (int): Target length in samples (default 3 seconds)
        """
        self.data_samples = data_samples
        self.audio_processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
        self.target_length = target_length

    def __len__(self):
        return len(self.data_samples)

    def __getitem__(self, idx):
        audio_path, emotion_label = self.data_samples[idx]
        
        # Extract metadata from filename
        audio_file = os.path.basename(audio_path)
        parts = audio_file.split("-")
        
        # Parse all metadata
        modality = int(parts[0])
        vocal_channel = int(parts[1])
        emotion = int(parts[2])
        intensity = int(parts[3])
        statement = int(parts[4])
        repetition = int(parts[5])
        actor = int(parts[6].split(".")[0])
        gender = 0 if actor % 2 == 1 else 1

        # Load and process audio
        waveform, sample_rate = torchaudio.load(audio_path, format="wav")
        
        # Resample to 16kHz if needed
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
            waveform = resampler(waveform)
        
        # Pad or truncate to target length
        if waveform.shape[1] < self.target_length:
            # Pad with zeros
            pad_amount = self.target_length - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, pad_amount))
        elif waveform.shape[1] > self.target_length:
            # Random crop
            start = torch.randint(0, waveform.shape[1] - self.target_length, (1,)).item()
            waveform = waveform[:, start:start+self.target_length]
        
        # Process with wav2vec processor
        audio_input = self.audio_processor(
            waveform.squeeze(0), 
            sampling_rate=16000, 
            return_tensors="pt"
        )

        label = {
            "modality": modality,
            "vocal_channel": vocal_channel,
            "emotion": emotion,
            "intensity": intensity,
            "statement": statement,
            "repetition": repetition,
            "gender": gender,
        }

        return audio_input, label['emotion'], label


def collate_fn(batch):
    """
    Custom collate function to handle wav2vec2 processor outputs.
    Pads the audio inputs to the longest in the batch.
    """
    audio_inputs, emotions, labels = zip(*batch)
    
    # Get max length in this batch
    max_length = max([item.input_values.shape[1] for item in audio_inputs])
    
    # Pad all inputs to max length
    padded_inputs = []
    for item in audio_inputs:
        pad_amount = max_length - item.input_values.shape[1]
        padded = torch.nn.functional.pad(
            item.input_values, 
            (0, pad_amount), 
            value=0
        )
        padded_inputs.append(padded)
    
    # Stack all padded tensors
    audio_input_values = torch.stack(padded_inputs).squeeze(1)
    
    # Stack attention masks if they exist
    if hasattr(audio_inputs[0], 'attention_mask'):
        padded_masks = []
        for item in audio_inputs:
            pad_amount = max_length - item.attention_mask.shape[1]
            padded = torch.nn.functional.pad(
                item.attention_mask,
                (0, pad_amount),
                value=0
            )
            padded_masks.append(padded)
        attention_mask = torch.stack(padded_masks).squeeze(1)
    else:
        attention_mask = None
    
    # Stack other elements
    emotions = torch.tensor(emotions)
    
    return {
        'input_values': audio_input_values,
        'attention_mask': attention_mask,
        'emotions': emotions,
        'metadata': labels
    }


def load_ravdess_dataset(base_dir, test_size=0.2, random_state=42):
    """
    Load the RAVDESS dataset and split it into training and testing datasets.

    Args:
        base_dir (str): Path to the RAVDESS dataset containing actor folders.
        test_size (float): Proportion of the dataset to include in the test split.
        random_state (int): Random seed for reproducibility.

    Returns:
        train_dataset (AudioDataset): Training dataset
        test_dataset (AudioDataset): Testing dataset
    """
    data = []
    for actor_folder in os.listdir(base_dir):
        actor_path = os.path.join(base_dir, actor_folder)
        if os.path.isdir(actor_path):
            for audio_file in os.listdir(actor_path):
                if audio_file.endswith(".wav"):
                    parts = audio_file.split("-")
                    emotion = int(parts[2])
                    audio_path = os.path.join(actor_path, audio_file)
                    data.append((audio_path, emotion))

    train_data, test_data = train_test_split(data, test_size=test_size, random_state=random_state)

    train_dataset = AudioDataset(train_data)
    test_dataset = AudioDataset(test_data)

    return train_dataset, test_dataset


def get_data_loaders(base_dir, batch_size=8, test_size=0.2, random_state=42):
    train_dataset, test_dataset = load_ravdess_dataset(base_dir, test_size, random_state)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn  # Use our custom collate function
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )
    
    return train_loader, test_loader


# Example Usage
if __name__ == "__main__":
    base_dir = r"C:\Users\gnith\Desktop\multi_modal_audio_speech\ravdess_dataset"
    
    # Get data loaders
    train_loader, test_loader = get_data_loaders(base_dir, batch_size=2)
    
    print(f"Number of training batches: {len(train_loader)}")
    print(f"Number of testing batches: {len(test_loader)}")
    
    # Example: Get a batch from the training loader
    for batch in train_loader:
        print(batch['metadata'])
        break

Number of training batches: 405
Number of testing batches: 102
({'modality': 3, 'vocal_channel': 2, 'emotion': 4, 'intensity': 1, 'statement': 2, 'repetition': 2, 'gender': 1}, {'modality': 3, 'vocal_channel': 2, 'emotion': 6, 'intensity': 2, 'statement': 1, 'repetition': 1, 'gender': 1})


In [1]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from transformers import Wav2Vec2Model
from audio_dataset import get_data_loaders

audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

  from .autonotebook import tqdm as notebook_tqdm
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [2]:
from audio_dataset import get_data_loaders

audio_dir = r"C:\Users\gnith\Desktop\multi_modal_audio_speech\ravdess_dataset"

    # Get train and validation dataloaders
train_loader, val_loader = get_data_loaders(audio_dir, batch_size=2)

In [3]:
for a in train_loader:
    print(a)
    break

{'input_values': tensor([[-7.0262e-03, -8.2524e-03, -1.0697e-02,  ..., -7.0473e-03,
         -1.2591e-01, -2.1741e-01],
        [-3.1370e-06, -1.3217e-05, -9.6774e-06,  ..., -1.0263e-05,
         -1.1941e-05, -2.2046e-06]]), 'attention_mask': None, 'emotions': tensor([2, 6]), 'metadata': ({'modality': 3, 'vocal_channel': 2, 'emotion': 2, 'intensity': 2, 'statement': 1, 'repetition': 1, 'gender': 0}, {'modality': 3, 'vocal_channel': 1, 'emotion': 6, 'intensity': 2, 'statement': 2, 'repetition': 2, 'gender': 1})}


In [9]:
audio_encoder(a['input_values']).last_hidden_state.shape

torch.Size([2, 149, 768])

In [13]:
audio_encoder(a['input_values']).extract_features.mean(dim=1).shape

torch.Size([2, 512])

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.transforms as T
from transformers import AutoProcessor
import os
from sklearn.model_selection import train_test_split
import numpy as np


class AudioDataset(Dataset):
    def __init__(self, data_samples, target_length=16000*3, augment=False, n_mels=128, n_fft=1024, hop_length=512):
        """
        Args:
            data_samples (list): List of tuples (audio_path, emotion_label)
            target_length (int): Target length in samples (default 3 seconds)
            augment (bool): Whether to apply audio augmentation
            n_mels (int): Number of Mel bins
            n_fft (int): FFT window size
            hop_length (int): Hop length between STFT windows
        """
        self.data_samples = data_samples
        self.target_length = target_length
        self.augment = augment
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        
        # Spectrogram transforms
        self.mel_spectrogram = T.MelSpectrogram(
            sample_rate=16000,
            n_mels=n_mels,
            n_fft=n_fft,
            hop_length=hop_length,
            power=2.0  # Use power spectrum (magnitude squared)
        )
        
        # For log compression
        self.amplitude_to_db = T.AmplitudeToDB()

    def __len__(self):
        return len(self.data_samples)

    def __getitem__(self, idx):
        audio_path, emotion_label = self.data_samples[idx]
        
        # Extract metadata from filename
        audio_file = os.path.basename(audio_path)
        parts = audio_file.split("-")
        
        # Parse all metadata
        modality = int(parts[0])
        vocal_channel = int(parts[1])
        emotion = int(parts[2]) - 1
        intensity = int(parts[3])
        statement = int(parts[4])
        repetition = int(parts[5])
        actor = int(parts[6].split(".")[0])
        gender = 0 if actor % 2 == 1 else 1

        # Load and process audio
        waveform, sample_rate = torchaudio.load(audio_path, format="wav")
        
        # Apply augmentation
        if self.augment:
            waveform = self.apply_augmentation(waveform, sample_rate)

        # Normalize the waveform (zero mean, unit variance)
        waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-7)
        
        # Resample to 16kHz if needed
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
            waveform = resampler(waveform)
        
        # Pad or truncate to target length
        if waveform.shape[1] < self.target_length:
            # Pad with zeros
            pad_amount = self.target_length - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, pad_amount))
        elif waveform.shape[1] > self.target_length:
            # Random crop
            start = torch.randint(0, waveform.shape[1] - self.target_length, (1,)).item()
            waveform = waveform[:, start:start+self.target_length]
        
        # Convert to mono if needed
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        # Generate Mel spectrogram
        mel_spec = self.mel_spectrogram(waveform)
        
        # Apply log compression
        log_mel_spec = self.amplitude_to_db(mel_spec)
        
        # Normalize spectrogram to [0, 1] range
        log_mel_spec = (log_mel_spec - log_mel_spec.min()) / (log_mel_spec.max() - log_mel_spec.min())
        
        # Resize spectrogram to 224x224
        resized_spec = torch.nn.functional.interpolate(
            log_mel_spec.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False
        ).squeeze(0)

        # Convert to 3-channel "image" by repeating the spectrogram
        # (ViT typically expects 3-channel input)
        spectrogram_img = resized_spec.repeat(3, 1, 1)
        
        label = {
            "modality": modality,
            "vocal_channel": vocal_channel,
            "emotion": emotion,
            "intensity": intensity,
            "statement": statement,
            "repetition": repetition,
            "gender": gender,
        }

        return spectrogram_img, label['emotion'], label

    def apply_augmentation(self, waveform, sample_rate):
        """
        Apply audio augmentations suitable for spectrogram generation.
        """
        # Resample to 16kHz
        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)

        # Apply time masking
        waveform = torchaudio.transforms.TimeMasking(time_mask_param=100)(waveform)

        # Apply frequency masking
        waveform = torchaudio.transforms.FrequencyMasking(freq_mask_param=15)(waveform)

        # Adjust volume
        waveform = waveform * 0.5  # Equivalent to T.Vol(0.5)

        return waveform


def collate_fn(batch):
    """
    Custom collate function for spectrogram data.
    """
    try:
        spectrograms, emotions, labels = zip(*batch)
        
        # Stack spectrograms
        spectrograms = torch.stack(spectrograms)
        
        # Stack emotions
        emotions = torch.tensor(emotions)
        
        return {
            'spectrograms': spectrograms,
            'emotions': emotions,
            'metadata': labels
        }
    except:
        return None


def load_ravdess_dataset(base_dir, test_size=0.2, random_state=42):
    """
    Load the RAVDESS dataset and split it into training and testing datasets.
    """
    data = []
    for actor_folder in os.listdir(base_dir):
        actor_path = os.path.join(base_dir, actor_folder)
        if os.path.isdir(actor_path):
            for audio_file in os.listdir(actor_path):
                if audio_file.endswith(".wav"):
                    parts = audio_file.split("-")
                    emotion = int(parts[2])
                    audio_path = os.path.join(actor_path, audio_file)
                    data.append((audio_path, emotion))

    train_data, test_data = train_test_split(data, test_size=test_size, random_state=random_state)

    train_dataset = AudioDataset(train_data)
    test_dataset = AudioDataset(test_data)

    return train_dataset, test_dataset


def get_data_loaders(base_dir, batch_size=8, test_size=0.2, random_state=42, augment=False):
    train_dataset, test_dataset = load_ravdess_dataset(base_dir, test_size, random_state)

    train_dataset.augment = augment

    def safe_collate(batch):
        batch = [item for item in batch if item is not None]
        return collate_fn(batch)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=safe_collate
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=safe_collate
    )
    
    return train_loader, test_loader

In [15]:
train_loader, test_loader = get_data_loaders("ravdess_dataset", batch_size=2, augment=False)

In [3]:
import timm

vit = timm.create_model(
            'vit_base_patch16_224',
            pretrained=True,
            num_classes=0,  # We'll add our own head
            in_chans=3  # Our spectrograms have 3 channels
        )

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
for a,b in vit.named_parameters():
    print(a)

cls_token
pos_embed
patch_embed.proj.weight
patch_embed.proj.bias
blocks.0.norm1.weight
blocks.0.norm1.bias
blocks.0.attn.qkv.weight
blocks.0.attn.qkv.bias
blocks.0.attn.proj.weight
blocks.0.attn.proj.bias
blocks.0.norm2.weight
blocks.0.norm2.bias
blocks.0.mlp.fc1.weight
blocks.0.mlp.fc1.bias
blocks.0.mlp.fc2.weight
blocks.0.mlp.fc2.bias
blocks.1.norm1.weight
blocks.1.norm1.bias
blocks.1.attn.qkv.weight
blocks.1.attn.qkv.bias
blocks.1.attn.proj.weight
blocks.1.attn.proj.bias
blocks.1.norm2.weight
blocks.1.norm2.bias
blocks.1.mlp.fc1.weight
blocks.1.mlp.fc1.bias
blocks.1.mlp.fc2.weight
blocks.1.mlp.fc2.bias
blocks.2.norm1.weight
blocks.2.norm1.bias
blocks.2.attn.qkv.weight
blocks.2.attn.qkv.bias
blocks.2.attn.proj.weight
blocks.2.attn.proj.bias
blocks.2.norm2.weight
blocks.2.norm2.bias
blocks.2.mlp.fc1.weight
blocks.2.mlp.fc1.bias
blocks.2.mlp.fc2.weight
blocks.2.mlp.fc2.bias
blocks.3.norm1.weight
blocks.3.norm1.bias
blocks.3.attn.qkv.weight
blocks.3.attn.qkv.bias
blocks.3.attn.proj.wei