In [1]:
pip install -q librosa numpy soundfile torchaudio praat-parselmouth

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m28.3 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.5/207.5 MB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.1/21.1 MB[0m [31m81.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━

In [2]:
import os
import glob
import torch
import torchaudio
import librosa
import numpy as np
from scipy.signal import butter, filtfilt
import parselmouth
from tqdm import tqdm
from typing import Tuple, Optional, Dict, List
import json

class AudioPreprocessor:
    def __init__(self, target_sr: int = 16000):
        self.target_sr = target_sr
        self.min_duration = 0.5  # minimum duration in seconds
        self.max_duration = 10.0  # maximum duration in seconds
        
    def preprocess(self, audio_path: str) -> Tuple[np.ndarray, int]:
        """Enhanced audio preprocessing pipeline"""
        # Load audio
        y, sr = librosa.load(audio_path, sr=None)
        
        # Quality checks
        if len(y) < self.min_duration * sr:
            raise ValueError(f"Audio too short: {len(y)/sr:.2f}s")
        if len(y) > self.max_duration * sr:
            y = y[:int(self.max_duration * sr)]
            
        # Resample if needed
        if sr != self.target_sr:
            y = librosa.resample(y, orig_sr=sr, target_sr=self.target_sr)
            
        # Trim silence
        y, _ = librosa.effects.trim(y, top_db=20, frame_length=2048, hop_length=512)
        
        # Apply pre-emphasis
        y = librosa.effects.preemphasis(y)
        
        # Apply high-pass filter to remove DC offset
        y = self._apply_highpass_filter(y, self.target_sr)
        
        # Normalize using LUFS
        y = self._normalize_lufs(y, self.target_sr)
        
        return y, self.target_sr
    
    def _apply_highpass_filter(self, y: np.ndarray, sr: int, cutoff: float = 20.0) -> np.ndarray:
        """Apply high-pass filter to remove DC offset"""
        try:
            # Method 1: Using scipy's butter filter
            nyquist = sr / 2
            normal_cutoff = cutoff / nyquist
            b, a = butter(4, normal_cutoff, btype='high', analog=False)
            y_filtered = filtfilt(b, a, y)
            return y_filtered
        except Exception as e:
            print(f"Warning: High-pass filter failed, using alternative method: {e}")
            try:
                # Method 2: Using librosa's high-pass filter
                y_filtered = librosa.effects.preemphasis(y, coef=0.97)
                return y_filtered
            except Exception as e:
                print(f"Warning: Alternative high-pass filter failed: {e}")
                # If both methods fail, return original signal
                return y
    
    def _normalize_lufs(self, y: np.ndarray, sr: int, target_lufs: float = -23.0) -> np.ndarray:
        """Normalize audio to target LUFS level"""
        try:
            # Calculate current LUFS
            current_lufs = self._calculate_lufs(y, sr)
            
            # Calculate gain adjustment
            gain_db = target_lufs - current_lufs
            gain_linear = 10 ** (gain_db / 20)
            
            # Apply gain
            y_normalized = y * gain_linear
            
            # Prevent clipping
            if np.max(np.abs(y_normalized)) > 0.99:
                y_normalized = y_normalized * 0.99 / np.max(np.abs(y_normalized))
                
            return y_normalized
        except Exception as e:
            print(f"Warning: LUFS normalization failed: {e}")
            # Fallback to simple normalization
            return librosa.util.normalize(y) * 0.95
    
    def _calculate_lufs(self, y: np.ndarray, sr: int) -> float:
        """Calculate LUFS (Loudness Units Full Scale)"""
        try:
            # Simplified LUFS calculation
            rms = np.sqrt(np.mean(y ** 2))
            lufs = 20 * np.log10(rms) + 0.691
            return lufs
        except Exception as e:
            print(f"Warning: LUFS calculation failed: {e}")
            # Fallback to simple RMS
            return 20 * np.log10(np.sqrt(np.mean(y ** 2)))
            
class FeatureExtractor:
    def __init__(self, 
                 n_mels: int = 80,
                 n_fft: int = 1024,
                 hop_length: int = 256,
                 win_length: int = 1024,
                 fmin: float = 80.0,
                 fmax: float = 7600.0,
                 n_mfcc: int = 13):
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.fmin = fmin
        self.fmax = fmax
        self.n_mfcc = n_mfcc
        
    def extract_features(self, y: np.ndarray, sr: int) -> dict:
        """Extract all required features"""
        features = {}
        
        # 1. Basic features
        features['mel'] = self._extract_mel(y, sr)
        features['prosody'] = self._extract_prosody(y, sr)
        
        # 2. Spectral features
        features['spectral'] = self._extract_spectral_features(y, sr)
        
        # 3. Formant features
        features['formants'] = self._extract_formants(y, sr)
        
        # 4. MFCC features
        features['mfcc'] = self._extract_mfcc(y, sr)
        
        # 5. Voice quality features
        features['voice_quality'] = self._extract_voice_quality(y, sr)
        
        return features
    
    def _extract_mel(self, y: np.ndarray, sr: int) -> np.ndarray:
        """Extract mel spectrogram with enhanced parameters"""
        mel = librosa.feature.melspectrogram(
            y=y,
            sr=sr,
            n_mels=self.n_mels,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            fmin=self.fmin,
            fmax=self.fmax,
            power=1.0
        )
        
        # Convert to log scale
        mel = librosa.power_to_db(mel, ref=1.0)
        
        # Normalize
        mel = (mel - mel.mean()) / (mel.std() + 1e-8)
        
        return mel
    
    def _extract_prosody(self, y: np.ndarray, sr: int) -> np.ndarray:
        """Extract prosody features using librosa.pyin"""
        # Extract F0 with proper voicing detection
        f0, voiced_flag, _ = librosa.pyin(
            y, 
            fmin=80, 
            fmax=400, 
            sr=sr,
            frame_length=self.n_fft, 
            hop_length=self.hop_length
        )
        f0[~voiced_flag] = 0  # Set unvoiced to 0
        
        # Energy with same params as mel
        energy = librosa.feature.rms(
            y=y, 
            frame_length=self.n_fft, 
            hop_length=self.hop_length
        ).squeeze()
        
        # Duration
        duration = np.linspace(0, 1, len(f0))
        
        # Normalize
        f0 = (f0 - np.mean(f0)) / (np.std(f0) + 1e-6)
        energy = (energy - np.mean(energy)) / (np.std(energy) + 1e-6)
        
        return np.stack([f0, energy, duration], axis=1)
    
    def _extract_spectral_features(self, y: np.ndarray, sr: int) -> dict:
        """Extract spectral features"""
        # Spectral Centroid
        centroid = librosa.feature.spectral_centroid(
            y=y, sr=sr, n_fft=self.n_fft, hop_length=self.hop_length
        )[0]
        
        # Spectral Bandwidth
        bandwidth = librosa.feature.spectral_bandwidth(
            y=y, sr=sr, n_fft=self.n_fft, hop_length=self.hop_length
        )[0]
        
        # Spectral Contrast
        contrast = librosa.feature.spectral_contrast(
            y=y, sr=sr, n_fft=self.n_fft, hop_length=self.hop_length
        )
        
        # Spectral Rolloff
        rolloff = librosa.feature.spectral_rolloff(
            y=y, sr=sr, n_fft=self.n_fft, hop_length=self.hop_length
        )[0]
        
        # Normalize features
        centroid = (centroid - np.mean(centroid)) / (np.std(centroid) + 1e-6)
        bandwidth = (bandwidth - np.mean(bandwidth)) / (np.std(bandwidth) + 1e-6)
        contrast = (contrast - np.mean(contrast)) / (np.std(contrast) + 1e-6)
        rolloff = (rolloff - np.mean(rolloff)) / (np.std(rolloff) + 1e-6)
        
        return {
            'centroid': centroid,    # Brightness of sound
            'bandwidth': bandwidth,  # Spread of frequencies
            'contrast': contrast,    # Peak vs. valley differences
            'rolloff': rolloff      # Frequency distribution
        }
    
    def _extract_formants(self, y: np.ndarray, sr: int) -> np.ndarray:
        """Extract formant frequencies (F1, F2, F3)"""
        try:
            # Use Praat through parselmouth
            snd = parselmouth.Sound(y, sr)
            formants = snd.to_formant_burg()
            
            # Get time points for formant extraction
            # Use the same hop_length as other features for consistency
            time_points = np.arange(0, len(y)/sr, self.hop_length/sr)
            
            # Initialize arrays for formants
            f1 = np.zeros(len(time_points))
            f2 = np.zeros(len(time_points))
            f3 = np.zeros(len(time_points))
            
            # Extract formants at each time point
            for i, t in enumerate(time_points):
                try:
                    f1[i] = formants.get_value_at_time(1, t)
                    f2[i] = formants.get_value_at_time(2, t)
                    f3[i] = formants.get_value_at_time(3, t)
                except:
                    # If formant extraction fails at this point, use previous values
                    if i > 0:
                        f1[i] = f1[i-1]
                        f2[i] = f2[i-1]
                        f3[i] = f3[i-1]
                    else:
                        f1[i] = 0
                        f2[i] = 0
                        f3[i] = 0
            
            # Handle NaN values
            f1 = np.nan_to_num(f1, nan=0.0)
            f2 = np.nan_to_num(f2, nan=0.0)
            f3 = np.nan_to_num(f3, nan=0.0)
            
            # Normalize
            f1 = (f1 - np.mean(f1)) / (np.std(f1) + 1e-6)
            f2 = (f2 - np.mean(f2)) / (np.std(f2) + 1e-6)
            f3 = (f3 - np.mean(f3)) / (np.std(f3) + 1e-6)
            
            return np.stack([f1, f2, f3], axis=1)
        except Exception as e:
            print(f"Error extracting formants: {e}")
            # Return zero array if formant extraction fails
            return np.zeros((len(y)//self.hop_length, 3))
    
    def _extract_mfcc(self, y: np.ndarray, sr: int) -> np.ndarray:
        """Extract MFCC features"""
        # Extract base MFCC
        mfcc = librosa.feature.mfcc(
            y=y, 
            sr=sr,
            n_mfcc=self.n_mfcc,
            n_fft=self.n_fft,
            hop_length=self.hop_length
        )
        
        # Add delta and delta-delta features
        delta = librosa.feature.delta(mfcc)
        delta2 = librosa.feature.delta(mfcc, order=2)
        
        # Normalize
        mfcc = (mfcc - np.mean(mfcc)) / (np.std(mfcc) + 1e-6)
        delta = (delta - np.mean(delta)) / (np.std(delta) + 1e-6)
        delta2 = (delta2 - np.mean(delta2)) / (np.std(delta2) + 1e-6)
        
        return np.concatenate([mfcc, delta, delta2], axis=0)
    
    def _extract_voice_quality(self, y: np.ndarray, sr: int) -> dict:
        """Extract voice quality features"""
        # Jitter (pitch period variation)
        jitter = librosa.feature.zero_crossing_rate(y)
        
        # Shimmer (amplitude variation)
        shimmer = np.diff(np.abs(y))
        # Pad shimmer to match other features
        shimmer = np.pad(shimmer, (0, 1), mode='edge')
        
        # Harmonics-to-Noise Ratio (HNR)
        try:
            hnr = librosa.effects.harmonic(y)
            hnr = np.abs(hnr)  # Take magnitude
        except:
            hnr = np.zeros_like(y)
        
        # Resample to match other features
        target_length = len(y) // self.hop_length
        jitter = librosa.resample(jitter, orig_sr=sr/self.hop_length, target_sr=target_length)
        shimmer = librosa.resample(shimmer, orig_sr=sr, target_sr=target_length)
        hnr = librosa.resample(hnr, orig_sr=sr, target_sr=target_length)
        
        # Normalize
        jitter = (jitter - np.mean(jitter)) / (np.std(jitter) + 1e-6)
        shimmer = (shimmer - np.mean(shimmer)) / (np.std(shimmer) + 1e-6)
        hnr = (hnr - np.mean(hnr)) / (np.std(hnr) + 1e-6)
        
        return {
            'jitter': jitter,
            'shimmer': shimmer,
            'hnr': hnr
        }

class FeatureProcessor:
    def __init__(self, 
                 input_root: str,
                 output_root: str,
                 target_sr: int = 16000):
        self.input_root = input_root
        self.output_root = output_root
        self.preprocessor = AudioPreprocessor(target_sr)
        self.feature_extractor = FeatureExtractor()
        
        # Define accents and speakers
        self.accents = {
            "hindi": ["ASI", "RRBI", "SVBI", "TNI"],
            "spanish": ["EBVS", "ERMS", "MBMPS", "NJS"]
        }
        
        # Create output directory
        os.makedirs(output_root, exist_ok=True)
        
    def process_dataset(self):
        """Process all audio files in the dataset"""
        # Create metadata dictionary
        metadata = {
            'accents': self.accents,
            'processed_files': []
        }
        
        # Process each accent
        for accent, speakers in self.accents.items():
            print(f"\nProcessing {accent} accent...")
            
            for speaker in speakers:
                print(f"\nProcessing speaker: {speaker}")
                
                # Create speaker directory
                speaker_dir = os.path.join(self.output_root, speaker)
                os.makedirs(speaker_dir, exist_ok=True)
                
                # Get all wav files for this speaker
                wav_path = os.path.join(self.input_root, speaker, speaker, "wav")
                if not os.path.exists(wav_path):
                    print(f"Warning: No wav directory found for speaker {speaker}")
                    continue
                    
                wav_files = glob.glob(os.path.join(wav_path, "*.wav"))
                
                # Process each wav file
                for wav_file in tqdm(wav_files, desc=f"Processing {speaker}"):
                    try:
                        # Get utterance ID
                        utt_id = os.path.basename(wav_file).replace(".wav", "")
                        
                        # Preprocess audio
                        y, sr = self.preprocessor.preprocess(wav_file)
                        
                        # Extract features
                        features = self.feature_extractor.extract_features(y, sr)
                        
                        # Save features
                        self._save_features(speaker_dir, utt_id, features)
                        
                        # Update metadata
                        metadata['processed_files'].append({
                            'speaker': speaker,
                            'accent': accent,
                            'utt_id': utt_id,
                            'duration': len(y) / sr
                        })
                        
                    except Exception as e:
                        print(f"Error processing {wav_file}: {e}")
                        continue
        
        # Save metadata
        with open(os.path.join(self.output_root, 'metadata.json'), 'w') as f:
            json.dump(metadata, f, indent=2)
            
        print("\nFeature extraction completed!")
        print(f"Processed files: {len(metadata['processed_files'])}")
        
    def _save_features(self, speaker_dir: str, utt_id: str, features: dict):
        """Save extracted features"""
        # Save each feature type
        for feature_name, feature_data in features.items():
            if isinstance(feature_data, dict):
                # Handle nested features (like spectral features)
                for sub_feature_name, sub_feature_data in feature_data.items():
                    np.save(
                        os.path.join(speaker_dir, f"{utt_id}_{feature_name}_{sub_feature_name}.npy"),
                        sub_feature_data
                    )
            else:
                # Handle direct features
                np.save(
                    os.path.join(speaker_dir, f"{utt_id}_{feature_name}.npy"),
                    feature_data
                )



In [3]:
# Usage
if __name__ == "__main__":
    # Set paths
    L2_ARCTIC_ROOT = "/kaggle/input/l2-arctic-data"
    PROCESSED_DIR = "/kaggle/working/processed_features"
    
    # Create processor
    processor = FeatureProcessor(
        input_root=L2_ARCTIC_ROOT,
        output_root=PROCESSED_DIR
    )
    
    # Process dataset
    processor.process_dataset()


Processing hindi accent...

Processing speaker: ASI


Processing ASI: 100%|██████████| 1131/1131 [09:23<00:00,  2.01it/s]



Processing speaker: RRBI


Processing RRBI: 100%|██████████| 1130/1130 [10:38<00:00,  1.77it/s]



Processing speaker: SVBI


Processing SVBI: 100%|██████████| 1132/1132 [09:08<00:00,  2.06it/s]



Processing speaker: TNI


Processing TNI: 100%|██████████| 1131/1131 [09:39<00:00,  1.95it/s]



Processing spanish accent...

Processing speaker: EBVS


Processing EBVS: 100%|██████████| 1007/1007 [10:47<00:00,  1.56it/s]



Processing speaker: ERMS


Processing ERMS: 100%|██████████| 1132/1132 [12:12<00:00,  1.54it/s]



Processing speaker: MBMPS


Processing MBMPS: 100%|██████████| 1132/1132 [12:40<00:00,  1.49it/s]



Processing speaker: NJS


Processing NJS: 100%|██████████| 1131/1131 [09:18<00:00,  2.02it/s]


Feature extraction completed!
Processed files: 8926





In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from typing import Dict, List, Tuple
import os
import json
from tqdm import tqdm
import librosa

class AccentEmbeddingDataset(Dataset):
    def __init__(self, feature_root, accents, target_size=(128, 128)):
        """
        Args:
            feature_root (str): Root directory containing the processed features
            accents (dict): Dictionary mapping accent groups to their codes
            target_size (tuple): Target size for mel spectrograms (time_steps, n_mels)
        """
        self.feature_root = feature_root
        self.accents = accents
        self.target_size = target_size
        
        # Get all feature files and their labels
        self.feature_files = []
        self.labels = []
        
        # Walk through the directory structure
        for accent_group, accent_codes in accents.items():
            for accent_code in accent_codes:
                accent_dir = os.path.join(feature_root, accent_code)
                if os.path.exists(accent_dir):
                    for file in os.listdir(accent_dir):
                        if file.endswith('.npy'):
                            self.feature_files.append(os.path.join(accent_dir, file))
                            self.labels.append(accent_group)
        
        # Create label to index mapping
        self.label_to_idx = {label: idx for idx, label in enumerate(sorted(set(self.labels)))}
        self.idx_to_label = {idx: label for label, idx in self.label_to_idx.items()}
        
        # Validate at least one file exists
        if len(self.feature_files) == 0:
            raise ValueError(f"No .npy files found in the specified directories: {feature_root}")
            
        # Check the format of the first feature file to determine how to load data
        sample_file = self.feature_files[0]
        sample_features = np.load(sample_file, allow_pickle=True)
        
        if isinstance(sample_features, np.ndarray) and sample_features.dtype == np.dtype('object'):
            # Array of arrays
            self.is_array_of_arrays = True
        else:
            # Single array, need to reshape
            self.is_array_of_arrays = False
            # Determine if it's a structured array with named fields
            self.has_structured_fields = hasattr(sample_features, 'dtype') and sample_features.dtype.names is not None
    
    def __len__(self):
        return len(self.feature_files)
    
    def pad_or_truncate(self, tensor, target_size):
        """Pad or truncate tensor to target size"""
        # Convert to tensor if not already
        if not isinstance(tensor, torch.Tensor):
            tensor = torch.tensor(tensor, dtype=torch.float32)
            
        # Ensure tensor is 2D
        if len(tensor.shape) == 1:
            tensor = tensor.unsqueeze(0)  # Add channel dimension
        
        current_size = tensor.shape
        
        # Handle time dimension (first dimension)
        if current_size[0] < target_size[0]:
            # Pad time dimension
            pad_size = target_size[0] - current_size[0]
            tensor = F.pad(tensor, (0, 0, 0, pad_size))
        elif current_size[0] > target_size[0]:
            # Truncate time dimension
            tensor = tensor[:target_size[0], :]
            
        # Handle frequency dimension (second dimension)
        if current_size[1] < target_size[1]:
            # Pad frequency dimension
            pad_size = target_size[1] - current_size[1]
            tensor = F.pad(tensor, (0, pad_size))
        elif current_size[1] > target_size[1]:
            # Truncate frequency dimension
            tensor = tensor[:, :target_size[1]]
            
        return tensor
    
    def __getitem__(self, idx):
        feature_path = self.feature_files[idx]
        label = self.labels[idx]
        
        # Load features
        features = np.load(feature_path, allow_pickle=True)
        
        # Handle different feature formats
        if self.is_array_of_arrays:
            # Features are an array of arrays
            mel_spec = torch.tensor(features[0], dtype=torch.float32)
            prosody = torch.tensor(features[1], dtype=torch.float32)
            mfcc = torch.tensor(features[2], dtype=torch.float32)
            chroma = torch.tensor(features[3], dtype=torch.float32)
            spectral_contrast = torch.tensor(features[4], dtype=torch.float32)
            tonnetz = torch.tensor(features[5], dtype=torch.float32)
        elif self.has_structured_fields:
            # Features are a structured array with named fields
            mel_spec = torch.tensor(features['mel_spec'], dtype=torch.float32)
            prosody = torch.tensor(features['prosody'], dtype=torch.float32)
            mfcc = torch.tensor(features['mfcc'], dtype=torch.float32)
            chroma = torch.tensor(features['chroma'], dtype=torch.float32)
            spectral_contrast = torch.tensor(features['spectral_contrast'], dtype=torch.float32)
            tonnetz = torch.tensor(features['tonnetz'], dtype=torch.float32)
        else:
            # Assume features is a single array that needs to be split
            # This is a simplification - adjust based on your actual data format
            feature_size = features.shape[0] // 6
            mel_spec = torch.tensor(features[:feature_size], dtype=torch.float32)
            prosody = torch.tensor(features[feature_size:2*feature_size], dtype=torch.float32)
            mfcc = torch.tensor(features[2*feature_size:3*feature_size], dtype=torch.float32)
            chroma = torch.tensor(features[3*feature_size:4*feature_size], dtype=torch.float32)
            spectral_contrast = torch.tensor(features[4*feature_size:5*feature_size], dtype=torch.float32)
            tonnetz = torch.tensor(features[5*feature_size:], dtype=torch.float32)
        
        # Process mel spectrogram
        mel_spec = self.pad_or_truncate(mel_spec, self.target_size)
        
        # Process prosody features
        prosody = self.pad_or_truncate(prosody, self.target_size)
        
        # Process spectral features
        spectral = {
            'mfcc': self.pad_or_truncate(mfcc, self.target_size),
            'chroma': self.pad_or_truncate(chroma, self.target_size),
            'spectral_contrast': self.pad_or_truncate(spectral_contrast, self.target_size),
            'tonnetz': self.pad_or_truncate(tonnetz, self.target_size)
        }
        
        # Convert label to index
        label_idx = self.label_to_idx[label]
        
        return {
            'mel': mel_spec,
            'prosody': prosody,
            'spectral': spectral,
            'label': label_idx,
            'accent': label,
            'feature_path': feature_path
        }

def collate_fn(batch):
    """
    Custom collate function to handle the batch creation
    """
    # Stack mel spectrograms
    mel = torch.stack([item['mel'] for item in batch])
    
    # Stack prosody features
    prosody = torch.stack([item['prosody'] for item in batch])
    
    # Stack spectral features
    spectral = {
        'mfcc': torch.stack([item['spectral']['mfcc'] for item in batch]),
        'chroma': torch.stack([item['spectral']['chroma'] for item in batch]),
        'spectral_contrast': torch.stack([item['spectral']['spectral_contrast'] for item in batch]),
        'tonnetz': torch.stack([item['spectral']['tonnetz'] for item in batch])
    }
    
    # Stack labels (numeric indices)
    labels = torch.tensor([item['label'] for item in batch])
    
    # Get other information
    accents = [item['accent'] for item in batch]
    feature_paths = [item['feature_path'] for item in batch]
    
    return {
        'mel': mel,
        'prosody': prosody,
        'spectral': spectral,
        'label': labels,
        'accent': accents,
        'feature_path': feature_paths
    }

class EnhancedAccentEncoder(nn.Module):
    def __init__(self, 
                 mel_dim: int = 128,
                 prosody_dim: int = 128,
                 spectral_dim: int = 128,
                 hidden_dim: int = 256,
                 emb_dim: int = 64,
                 num_heads: int = 8,
                 num_layers: int = 3,
                 dropout: float = 0.1):
        super().__init__()
        
        # Feature-specific encoders with residual connections
        self.mel_encoder = ResidualBlock(mel_dim, hidden_dim)
        self.prosody_encoder = ResidualBlock(prosody_dim, hidden_dim)
        
        # Combined spectral features encoder (mfcc + chroma + spectral_contrast + tonnetz)
        # We'll use a separate encoder for each feature type and then combine them
        self.mfcc_encoder = nn.Conv1d(mel_dim, hidden_dim // 4, kernel_size=3, padding=1)
        self.chroma_encoder = nn.Conv1d(mel_dim, hidden_dim // 4, kernel_size=3, padding=1)
        self.spectral_contrast_encoder = nn.Conv1d(mel_dim, hidden_dim // 4, kernel_size=3, padding=1)
        self.tonnetz_encoder = nn.Conv1d(mel_dim, hidden_dim // 4, kernel_size=3, padding=1)
        
        # Feature combination and residual block for spectral features
        self.spectral_combiner = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)
        self.spectral_residual = ResidualBlock(hidden_dim, hidden_dim)
        
        # Learnable feature weights
        self.feature_weights = nn.Parameter(torch.ones(3))
        
        # Transformer encoder for temporal modeling
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )
        
        # Positional encoding
        self.pos_encoder = PositionalEncoding(hidden_dim)
        
        # Feature fusion with attention
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim * 2),
            nn.LayerNorm(hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, emb_dim)
        )
        
        # Reconstruction heads for multi-task learning
        self.mel_head = nn.Linear(emb_dim, mel_dim)
        self.prosody_head = nn.Linear(emb_dim, prosody_dim)
        self.spectral_head = nn.Linear(emb_dim, spectral_dim)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, mel, prosody, spectral):
        # Encode mel and prosody features with residual connections
        mel_feat = self.mel_encoder(mel)  # [B, hidden_dim, T]
        prosody_feat = self.prosody_encoder(prosody)  # [B, hidden_dim, T]
        
        # Process spectral features
        mfcc_feat = F.relu(self.mfcc_encoder(spectral['mfcc']))
        chroma_feat = F.relu(self.chroma_encoder(spectral['chroma']))
        contrast_feat = F.relu(self.spectral_contrast_encoder(spectral['spectral_contrast']))
        tonnetz_feat = F.relu(self.tonnetz_encoder(spectral['tonnetz']))
        
        # Combine spectral features
        spectral_combined = torch.cat([mfcc_feat, chroma_feat, contrast_feat, tonnetz_feat], dim=1)
        spectral_feat = self.spectral_combiner(spectral_combined)
        spectral_feat = self.spectral_residual(spectral_feat)
        
        # Apply learned feature weights
        mel_feat = mel_feat * self.feature_weights[0]
        prosody_feat = prosody_feat * self.feature_weights[1]
        spectral_feat = spectral_feat * self.feature_weights[2]
        
        # Prepare for transformer
        mel_feat = mel_feat.transpose(1, 2)  # [B, T, hidden_dim]
        prosody_feat = prosody_feat.transpose(1, 2)
        spectral_feat = spectral_feat.transpose(1, 2)
        
        # Add positional encoding
        mel_feat = self.pos_encoder(mel_feat)
        prosody_feat = self.pos_encoder(prosody_feat)
        spectral_feat = self.pos_encoder(spectral_feat)
        
        # Apply transformer
        mel_feat = self.transformer(mel_feat)
        prosody_feat = self.transformer(prosody_feat)
        spectral_feat = self.transformer(spectral_feat)
        
        # Global average pooling
        mel_feat = torch.mean(mel_feat, dim=1)
        prosody_feat = torch.mean(prosody_feat, dim=1)
        spectral_feat = torch.mean(spectral_feat, dim=1)
        
        # Combine features
        combined = torch.cat([mel_feat, prosody_feat, spectral_feat], dim=1)
        
        # Generate embedding
        embedding = self.fusion(combined)
        embedding = F.normalize(embedding, p=2, dim=1)
        
        # Reconstruction
        mel_recon = self.mel_head(embedding)
        prosody_recon = self.prosody_head(embedding)
        spectral_recon = self.spectral_head(embedding)
        
        return embedding, mel_recon, prosody_recon, spectral_recon

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm1d(out_channels)
        
        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, 1),
                nn.BatchNorm1d(out_channels)
            )
    
    def forward(self, x):
        residual = self.shortcut(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        return F.relu(out)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
    
    def forward(self, embeddings, labels):
        """
        Compute contrastive loss from embeddings and labels
        
        Args:
            embeddings: Tensor of shape [batch_size, embedding_dimension]
            labels: Tensor of numeric indices for each embedding
            
        Returns:
            Contrastive loss value
        """
        batch_size = embeddings.size(0)
        
        # Normalize embeddings
        embeddings = F.normalize(embeddings, p=2, dim=1)
        
        # Compute similarity matrix
        similarity_matrix = torch.matmul(embeddings, embeddings.T) / self.temperature
        
        # Create mask for positive pairs - same accent = positive pair
        # Make sure labels is a tensor of integers
        if not isinstance(labels, torch.Tensor):
            labels = torch.tensor(labels, device=embeddings.device)
            
        mask = labels.unsqueeze(0) == labels.unsqueeze(1)
        
        # Remove self-similarity from positive pair mask
        mask.fill_diagonal_(False)
        
        # Create negative mask - different accent = negative pair
        negative_mask = ~mask
        
        # Compute log_prob for positive pairs
        exp_sim = torch.exp(similarity_matrix)
        
        # Sum for normalization constant, excluding self-similarity
        log_prob = similarity_matrix - torch.log(
            (exp_sim * negative_mask.float()).sum(dim=1, keepdim=True) + 1e-8
        )
        
        # Compute mean of positive pairs log_prob
        mean_log_prob = (mask.float() * log_prob).sum(dim=1) / (mask.float().sum(dim=1) + 1e-8)
        
        # Loss is negative mean of log probabilities
        loss = -mean_log_prob.mean()
        
        return loss

def train_accent_encoder(model, train_loader, val_loader, num_epochs, device):
    """Train the accent encoder model"""
    
    # Initialize optimizers and schedulers
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-6)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=1e-3, epochs=num_epochs, steps_per_epoch=len(train_loader)
    )
    
    # Loss functions
    triplet_loss_fn = nn.TripletMarginLoss(margin=1.0)
    reconstruction_loss_fn = nn.MSELoss()
    contrastive_loss_fn = ContrastiveLoss(temperature=0.07)
    
    # Training loop
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        # Training
        model.train()
        total_loss = 0
        train_triplet = 0
        train_recon = 0
        train_contrast = 0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            # Move data to device
            mel = batch['mel'].to(device)
            prosody = batch['prosody'].to(device)
            
            # Handle spectral features
            spectral = {}
            for key, value in batch['spectral'].items():
                spectral[key] = value.to(device)
            
            # Get numeric labels - already tensor from DataLoader
            labels = batch['label'].to(device)
            
            # Forward pass
            embedding, mel_recon, prosody_recon, spectral_recon = model(mel, prosody, spectral)
            
            # Calculate triplet loss
            triplet = torch.tensor(0.0, device=device)
            if embedding.size(0) >= 3:  # Need at least 3 samples for triplet loss
                # Find positive and negative samples for each anchor based on numeric labels
                for i in range(len(embedding)):
                    # Find positive examples (same label)
                    positive_indices = [j for j in range(len(labels)) 
                                      if j != i and labels[j].item() == labels[i].item()]
                    # Find negative examples (different label)
                    negative_indices = [j for j in range(len(labels)) 
                                      if labels[j].item() != labels[i].item()]
                    
                    if positive_indices and negative_indices:  # Make sure we have valid pairs
                        pos_idx = np.random.choice(positive_indices)
                        neg_idx = np.random.choice(negative_indices)
                        
                        anchor = embedding[i].unsqueeze(0)
                        positive = embedding[pos_idx].unsqueeze(0)
                        negative = embedding[neg_idx].unsqueeze(0)
                        
                        triplet += triplet_loss_fn(anchor, positive, negative)
                
                # Average triplet loss
                if len(embedding) > 0:
                    triplet = triplet / len(embedding)
            
            # Reconstruction loss
            recon = reconstruction_loss_fn(mel_recon, mel.mean(dim=2)) + \
                   reconstruction_loss_fn(prosody_recon, prosody.mean(dim=2)) + \
                   reconstruction_loss_fn(spectral_recon, spectral['mfcc'].mean(dim=2))  # Using MFCC as representative
            
            # Contrastive loss - use numeric labels
            contrast = contrastive_loss_fn(embedding, labels)
            
            # Combined loss - weighted sum
            loss = triplet + 0.1 * recon + 0.1 * contrast
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
            # Update metrics
            total_loss += loss.item()
            train_triplet += triplet.item()
            train_recon += recon.item()
            train_contrast += contrast.item()
        
        # Validation
        model.eval()
        val_loss = 0
        val_triplet = 0
        val_recon = 0
        val_contrast = 0
        
        with torch.no_grad():
            for batch in val_loader:
                # Move data to device
                mel = batch['mel'].to(device)
                prosody = batch['prosody'].to(device)
                
                # Handle spectral features
                spectral = {}
                for key, value in batch['spectral'].items():
                    spectral[key] = value.to(device)
                
                # Get numeric labels
                labels = batch['label'].to(device)
                
                # Forward pass
                embedding, mel_recon, prosody_recon, spectral_recon = model(mel, prosody, spectral)
                
                # Calculate same losses as in training
                triplet = torch.tensor(0.0, device=device)
                if embedding.size(0) >= 3:
                    for i in range(len(embedding)):
                        positive_indices = [j for j in range(len(labels)) 
                                          if j != i and labels[j].item() == labels[i].item()]
                        negative_indices = [j for j in range(len(labels)) 
                                          if labels[j].item() != labels[i].item()]
                        
                        if positive_indices and negative_indices:
                            pos_idx = np.random.choice(positive_indices)
                            neg_idx = np.random.choice(negative_indices)
                            
                            anchor = embedding[i].unsqueeze(0)
                            positive = embedding[pos_idx].unsqueeze(0)
                            negative = embedding[neg_idx].unsqueeze(0)
                            
                            triplet += triplet_loss_fn(anchor, positive, negative)
                    
                    if len(embedding) > 0:
                        triplet = triplet / len(embedding)
                
                recon = reconstruction_loss_fn(mel_recon, mel.mean(dim=2)) + \
                       reconstruction_loss_fn(prosody_recon, prosody.mean(dim=2)) + \
                       reconstruction_loss_fn(spectral_recon, spectral['mfcc'].mean(dim=2))
                
                contrast = contrastive_loss_fn(embedding, labels)
                
                loss = triplet + 0.1 * recon + 0.1 * contrast
                
                val_loss += loss.item()
                val_triplet += triplet.item()
                val_recon += recon.item()
                val_contrast += contrast.item()
        
        # Calculate average losses
        avg_train_loss = total_loss / len(train_loader)
        avg_train_triplet = train_triplet / len(train_loader)
        avg_train_recon = train_recon / len(train_loader)
        avg_train_contrast = train_contrast / len(train_loader)
        
        avg_val_loss = val_loss / len(val_loader) if len(val_loader) > 0 else 0
        avg_val_triplet = val_triplet / len(val_loader) if len(val_loader) > 0 else 0
        avg_val_recon = val_recon / len(val_loader) if len(val_loader) > 0 else 0
        avg_val_contrast = val_contrast / len(val_loader) if len(val_loader) > 0 else 0
        
        # Print metrics
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {avg_train_loss:.4f}")
        print(f"Train Triplet: {avg_train_triplet:.4f}")
        print(f"Train Recon: {avg_train_recon:.4f}")
        print(f"Train Contrast: {avg_train_contrast:.4f}")
        print(f"Val Loss: {avg_val_loss:.4f}")
        print(f"Val Triplet: {avg_val_triplet:.4f}")
        print(f"Val Recon: {avg_val_recon:.4f}")
        print(f"Val Contrast: {avg_val_contrast:.4f}")
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_loss': avg_val_loss,
            }, 'best_accent_encoder.pt')
        
        # Save checkpoint every few epochs
        if (epoch + 1) % 5 == 0 or epoch == num_epochs - 1:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_loss': avg_val_loss,
            }, f'accent_encoder_epoch_{epoch+1}.pt')

def extract_embeddings(model, dataloader, device):
    """Extract embeddings for all samples in the dataset"""
    model.eval()
    embeddings = []
    labels = []
    accents = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Extracting embeddings"):
            # Move data to device
            mel = batch['mel'].to(device)
            prosody = batch['prosody'].to(device)
            
            # Handle spectral features
            spectral = {}
            for key, value in batch['spectral'].items():
                spectral[key] = value.to(device)
            
            # Get labels and accents
            batch_labels = batch['label']
            batch_accents = batch['accent']
            
            # Forward pass - only need embeddings
            embedding, _, _, _ = model(mel, prosody, spectral)
            
            # Store embeddings and labels
            embeddings.append(embedding.cpu())
            labels.extend(batch_labels.tolist())
            accents.extend(batch_accents)
    
    # Concatenate embeddings
    embeddings = torch.cat(embeddings, dim=0)
    
    return embeddings, labels, accents

if __name__ == "__main__":
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    feature_root = "/kaggle/working/extracted_data/kaggle/working/processed_features"
    
    accents = {
        "hindi": ["ASI", "RRBI", "SVBI", "TNI"],
        "spanish": ["EBVS", "ERMS", "MBMPS", "NJS"]
    }
    
    try:
        dataset = AccentEmbeddingDataset(
            feature_root=feature_root,
            accents=accents,
            target_size=(128, 128)  
        )
        print(f"Dataset created successfully with {len(dataset)} samples")
        
        first_sample = dataset[0]
        print(f"First sample mel shape: {first_sample['mel'].shape}")
        print(f"First sample prosody shape: {first_sample['prosody'].shape}")
        print(f"First sample accent: {first_sample['accent']}")
        print(f"First sample label index: {first_sample['label']}")
        
        print(f"Label to index mapping: {dataset.label_to_idx}")
        
        train_size = int(0.9 * len(dataset))
        val_size = len(dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
        print(f"Split dataset: {train_size} training samples, {val_size} validation samples")
        
        train_loader = DataLoader(
            train_dataset,
            batch_size=32,
            shuffle=True,
            num_workers=2,  
            pin_memory=True,
            collate_fn=collate_fn
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=32,
            shuffle=False,
            num_workers=2, 
            pin_memory=True,
            collate_fn=collate_fn
        )
        
        model = EnhancedAccentEncoder().to(device)
        print("Model created successfully")
        
        print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
        
        print("Starting training...")
        train_accent_encoder(model, train_loader, val_loader, num_epochs=50, device=device)
        
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()

Using device: cuda
Dataset created successfully with 98186 samples
First sample mel shape: torch.Size([128, 128])
First sample prosody shape: torch.Size([128, 128])
First sample accent: hindi
First sample label index: 0
Label to index mapping: {'hindi': 0, 'spanish': 1}
Split dataset: 88367 training samples, 9819 validation samples
Model created successfully
Model parameters: 4040899
Starting training...


Epoch 1/50: 100%|██████████| 2762/2762 [08:29<00:00,  5.42it/s]



Epoch 1/50
Train Loss: 1.2823
Train Triplet: 0.9998
Train Recon: 0.0145
Train Contrast: 2.8098
Val Loss: 1.2802
Val Triplet: 1.0002
Val Recon: 0.0118
Val Contrast: 2.7876


Epoch 2/50:  22%|██▏       | 594/2762 [01:49<06:39,  5.42it/s]