# Convert audio signals to pre-trained CNN embeddings

The audio signals are being converted to pre-trained CNN embeddings, which will be used for classification tasks. This process involves extracting high-level features from the audio data using pre-trained convolutional neural network models.

Three popular pre-trained CNN models for classifying audio events have been selected:
- VGGish: https://github.com/tensorflow/models/tree/master/research/audioset/vggish
- YAMNet: https://github.com/tensorflow/models/tree/master/research/audioset/yamnet
- PANNs: https://github.com/qiuqiangkong/audioset_tagging_cnn

### **Code Functionality Breakdown**

#### **1. Purpose of the Code**
This code is a pipeline for **extracting audio embeddings** using pre-trained models (YAMNet, VGGish, and PANNs) and calculating **acoustic features** for your **speech emotion recognition project**. It processes `.pkl` files containing preprocessed audio data, extracts embeddings, and saves them for further use in machine learning models.

---

### **2. Key Components**

#### **`AudioEmbeddingExtractor` Class**
This class encapsulates the entire workflow for:
- Extracting embeddings from pre-trained models.
- Calculating acoustic features like MFCCs and spectral indices.
- Saving the extracted features.

##### **Attributes**
1. **`csv_path`**: Path to the metadata CSV containing `.pkl` file paths.
2. **`output_dir`**: Directory where extracted embeddings will be saved.
3. **`sample_rate`**: The sampling rate of the audio (default: 16 kHz).
4. **`target_length`**: Length to which all audio signals are padded or truncated.
5. **`batch_size`**: Batch size for processing PANN embeddings.
6. **`use_gpu`**: Flag to enable GPU for PANN inference.

---

#### **2.1 Initialization and Pre-trained Model Loading**
- **`_load_models`**: Loads the pre-trained models:
  1. **YAMNet**: Extracts embeddings for audio classification tasks.
  2. **VGGish**: Generates audio embeddings.
  3. **PANNs**: Extracts clip-level and frame-level audio embeddings.
  
If a model fails to load, the process logs an error and halts.

---

#### **2.2 Feature Calculation**
- **`calculate_acoustic_features`**: Extracts basic acoustic features like:
  - RMS energy, zero-crossing rate, spectral centroid, rolloff, and bandwidth.
  - MFCCs (mean and standard deviation of 13 coefficients).
  - Spectral contrast and flatness.

These features provide insights into the audio signal’s characteristics.

---

#### **2.3 Audio Preprocessing**
- **`load_and_preprocess_audio`**:
  1. Loads the `.pkl` file containing preprocessed audio (`y`).
  2. Pads or truncates the audio to the `target_length`.
  3. Calculates acoustic features if not already present.

---

#### **2.4 Embedding Extraction**
- **`extract_yamnet_vggish_features`**:
  - Extracts embeddings using YAMNet and VGGish models.
  - Processes **MPS features** if available in the input data.
  
- **`process_panns_embeddings`**:
  - Resamples the audio to 32 kHz (required for PANNs).
  - Processes audio in batches for efficiency.
  - Extracts clip-level and frame-level embeddings using PANNs.

---

#### **2.5 Dataset Processing**
- **`process_dataset`**:
  1. Loads the dataset metadata from the provided CSV file.
  2. Iterates through `.pkl` file paths to:
     - Preprocess audio.
     - Extract embeddings using YAMNet, VGGish, and PANNs.
     - Collect features like MPS and acoustic indices.
  3. Periodically runs garbage collection to manage memory.
  4. Saves all features (embeddings and indices) to the output directory.

- **`_save_features`**:
  - Saves extracted features as `.npy` files (for embeddings) and `.csv` files (for acoustic indices).

---

#### **3. Main Function**
The `main()` function defines the workflow:
1. Sets the paths for the dataset metadata and output directory.
2. Initializes the `AudioEmbeddingExtractor`.
3. Runs the `process_dataset()` method to extract embeddings and save them.

---

### **3. Role in Your Project**
This code plays a **critical role** in your speech emotion recognition project by:
- **Generating Input Features**:
  - Extracts embeddings from pre-trained models (YAMNet, VGGish, PANNs) and calculates acoustic features.
  - These features form the **input for your deep learning models** (CNN or DNN).
  
- **Standardizing Audio Data**:
  - Pads or truncates audio signals to a consistent length for uniform model input.
  
- **Handling Large Data**:
  - Processes data in batches and uses memory management techniques to handle large datasets efficiently.

---

### **4. Workflow Overview**
1. **Input Data**: `.pkl` files containing preprocessed audio signals.
2. **Processing Steps**:
   - Load and preprocess audio.
   - Extract embeddings (YAMNet, VGGish, PANNs).
   - Calculate acoustic features.
3. **Output Data**:
   - `.npy` files for embeddings (YAMNet, VGGish, PANNs).
   - `.csv` file for acoustic features.

Let me know if you’d like help customizing or optimizing specific parts of this code!

In [1]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub
import librosa
import pickle
import time
from panns_inference import AudioTagging
from typing import Dict, List, Any, Optional
from pathlib import Path
import logging
from tqdm import tqdm
import gc
from typing import Dict, List, Any, Optional, Tuple


# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('audio_embeddings.log'),
        logging.StreamHandler()
    ]
)

class AudioEmbeddingExtractor:
    """Class to handle extraction of audio embeddings using multiple pre-trained models"""
    
    def __init__(self, 
                 csv_path: str,
                 output_dir: str,
                 sample_rate: int = 16000,
                 target_length: int = 160000,
                 batch_size: int = 100,
                 use_gpu: bool = False):
        """
        Initialize the embedding extractor
        
        Args:
            csv_path: Path to CSV containing audio file paths
            output_dir: Directory to save embeddings
            sample_rate: Original sample rate of audio
            target_length: Target length for padding/truncating
            batch_size: Batch size for PANN processing
            use_gpu: Whether to use GPU for PANN inference
        """
        self.csv_path = Path(csv_path)
        self.output_dir = Path(output_dir)
        self.sample_rate = sample_rate
        self.target_length = target_length
        self.batch_size = batch_size
        self.device = 'cuda' if use_gpu else 'cpu'
        
        # Create output directory
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # Initialize models
        self._load_models()
        
    def _load_models(self) -> None:
        """Load all pre-trained models"""
        try:
            logging.info("Loading pre-trained models...")
            self.yamnet_model = hub.load('https://tfhub.dev/google/yamnet/1')
            self.vggish_model = hub.load('https://tfhub.dev/google/vggish/1')
            self.panns_model = AudioTagging(checkpoint_path=None, device=self.device)
            logging.info("Models loaded successfully")
        except Exception as e:
            logging.error(f"Error loading models: {str(e)}")
            raise

    def calculate_acoustic_features(self, y: np.ndarray) -> pd.DataFrame:
        """Calculate comprehensive acoustic features"""
        try:
            # Basic features
            features = {
                'rms': np.sqrt(np.mean(y**2)),
                'zero_crossing_rate': np.mean(librosa.feature.zero_crossing_rate(y)),
                'spectral_centroid': np.mean(librosa.feature.spectral_centroid(y=y, sr=self.sample_rate)[0]),
                'spectral_rolloff': np.mean(librosa.feature.spectral_rolloff(y=y, sr=self.sample_rate)[0]),
                'spectral_bandwidth': np.mean(librosa.feature.spectral_bandwidth(y=y, sr=self.sample_rate)[0])
            }
            
            # MFCC features
            mfccs = librosa.feature.mfcc(y=y, sr=self.sample_rate, n_mfcc=13)
            for i, mfcc in enumerate(mfccs):
                features[f'mfcc_{i+1}_mean'] = np.mean(mfcc)
                features[f'mfcc_{i+1}_std'] = np.std(mfcc)
            
            # Spectral features
            features.update({
                'spectral_contrast': np.mean(librosa.feature.spectral_contrast(y=y, sr=self.sample_rate)[0]),
                'spectral_flatness': np.mean(librosa.feature.spectral_flatness(y=y))
            })
            
            return pd.DataFrame([features])
            
        except Exception as e:
            logging.error(f"Error calculating acoustic features: {str(e)}")
            raise
            
    def load_and_preprocess_audio(self, file_path: str) -> Dict[str, np.ndarray]:
        """Load and preprocess a single audio file"""
        try:
            with open(file_path, 'rb') as file:
                data = pickle.load(file)
                
            # Check for required 'y' field
            if 'y' not in data:
                raise KeyError("Audio data 'y' not found in pickle file")
                
            y = data['y']
            
            # Pad/truncate to target length
            if len(y) < self.target_length:
                pad_length = self.target_length - len(y)
                y = np.pad(y, (0, pad_length), 'constant', constant_values=y.mean())
            else:
                y = y[:self.target_length]
                
            # Calculate acoustic features if not present
            indices = data.get('df_indices', None)
            if indices is None:
                indices = self.calculate_acoustic_features(y)
                
            return {
                'waveform': y,
                'indices': indices,
                'mps': data.get('mps', np.array([])),
                'wt': data.get('wt', np.array([]))
            }
            
        except Exception as e:
            logging.error(f"Error processing {file_path}: {str(e)}")
            raise

    def extract_yamnet_vggish_features(self, audio_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        """Extract YAMNet and VGGish embeddings"""
        try:
            y = audio_data['waveform']
            
            # YAMNet embeddings
            scores, embedding_tensor, _ = self.yamnet_model(y)
            yamnet_embedding = tf.reduce_mean(embedding_tensor, axis=0).numpy()
            
            # VGGish embeddings
            vggish_embedding = tf.reduce_mean(self.vggish_model(y), axis=0).numpy()
            
            # Process MPS features if available
            mps = audio_data.get('mps', np.array([]))
            wt = audio_data.get('wt', np.array([]))
            if mps.size > 0 and wt.size > 0:
                mps = mps[:, wt <= 100].reshape(-1)
                
            return {
                'yamnet': yamnet_embedding,
                'vggish': vggish_embedding,
                'mps': mps,
                'indices': audio_data['indices']
            }
            
        except Exception as e:
            logging.error(f"Error extracting features: {str(e)}")
            raise
            
    def process_panns_embeddings(self, waveforms: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Process PANNs embeddings in batches"""
        try:
            # Resample to 32kHz for PANN
            logging.info("Resampling audio for PANNs processing...")
            waveforms_32k = librosa.resample(
                waveforms, 
                orig_sr=self.sample_rate,
                target_sr=32000,
                axis=1
            )
            
            clipwise_outputs = []
            embeddings = []
            
            # Process in batches
            for i in tqdm(range(0, len(waveforms_32k), self.batch_size), 
                         desc="Processing PANNs batches"):
                batch = waveforms_32k[i:i + self.batch_size]
                clipwise_output, embedding = self.panns_model.inference(batch)
                clipwise_outputs.append(clipwise_output)
                embeddings.append(embedding)
                
            # Clean up to save memory
            del waveforms_32k
            gc.collect()
            
            return (np.concatenate(clipwise_outputs, axis=0),
                   np.concatenate(embeddings, axis=0))
                   
        except Exception as e:
            logging.error(f"Error processing PANNs embeddings: {str(e)}")
            raise

    def process_dataset(self) -> None:
        """Process entire dataset and extract all embeddings"""
        try:
            # Load dataset
            df = pd.read_csv(self.csv_path)
            logging.info(f"Processing {len(df)} audio files")
            
            # Initialize lists
            waveforms = []
            yamnet_embeddings = []
            vggish_embeddings = []
            mps_features = []
            all_indices = []
            failed_files = []
            
            # Process each file
            for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing files"):
                try:
                    # Load and process audio
                    audio_data = self.load_and_preprocess_audio(row['pkl_path'])
                    waveforms.append(audio_data['waveform'])
                    
                    # Extract features
                    features = self.extract_yamnet_vggish_features(audio_data)
                    
                    # Collect features
                    yamnet_embeddings.append(features['yamnet'])
                    vggish_embeddings.append(features['vggish'])
                    if features['mps'].size > 0:
                        mps_features.append(features['mps'])
                    all_indices.append(features['indices'])
                    
                except Exception as e:
                    logging.warning(f"Failed to process file {row['pkl_path']}: {str(e)}")
                    failed_files.append(row['pkl_path'])
                    continue
                
                # Periodic garbage collection
                if idx % 100 == 0:
                    gc.collect()
            
            # Process PANNs embeddings
            waveforms = np.stack(waveforms)
            panns_clip, panns_embedding = self.process_panns_embeddings(waveforms)
            del waveforms
            gc.collect()
            
            # Save all features
            self._save_features({
                'yamnet': np.stack(yamnet_embeddings),
                'vggish': np.stack(vggish_embeddings),
                'panns_clip': panns_clip,
                'panns_embedding': panns_embedding,
                'mps': np.stack(mps_features) if mps_features else np.array([]),
                'indices': pd.concat(all_indices, ignore_index=True)
            })
            
            if failed_files:
                logging.warning(f"Failed to process {len(failed_files)} files")
                
            logging.info("Processing completed successfully")
            
        except Exception as e:
            logging.error(f"Error processing dataset: {str(e)}")
            raise
            
    def _save_features(self, features: Dict[str, np.ndarray]) -> None:
        """Save extracted features to files"""
        try:
            for name, data in features.items():
                if isinstance(data, pd.DataFrame):
                    data.to_csv(self.output_dir / f'{name}_raw.csv', index=False)
                else:
                    np.save(self.output_dir / f'{name}_embedding.npy', data)
                    
            logging.info(f"Features saved to {self.output_dir}")
            
        except Exception as e:
            logging.error(f"Error saving features: {str(e)}")
            raise

def main():
    """Main execution function"""
    try:
        # Define paths
        csv_path = "/Users/huangjuhua/文档文稿/NYU/Time_Series/data/train_val_test_split_EMODB.csv"
        output_dir = "/Users/huangjuhua/文档文稿/NYU/Time_Series/data/processed/embeddings"
        
        # Initialize and run feature extraction
        extractor = AudioEmbeddingExtractor(
            csv_path=csv_path,
            output_dir=output_dir,
            use_gpu=False  # Set to True if GPU available
        )
        
        extractor.process_dataset()
        
    except Exception as e:
        logging.error(f"Processing failed: {str(e)}")

if __name__ == "__main__":
    main()

2024-12-07 16:21:38,205 - INFO - Loading pre-trained models...
2024-12-07 16:21:38,206 - INFO - Using /var/folders/8d/r58q4mcx1sv_vq93k0x2r2dw0000gn/T/tfhub_modules to cache modules.
2024-12-07 16:21:42,223 - INFO - Fingerprint not found. Saved model loading will continue.
2024-12-07 16:21:42,224 - INFO - path_and_singleprint metric could not be logged. Saved model loading will continue.
2024-12-07 16:21:42,481 - INFO - Fingerprint not found. Saved model loading will continue.
2024-12-07 16:21:42,482 - INFO - path_and_singleprint metric could not be logged. Saved model loading will continue.


Checkpoint path: /Users/huangjuhua/panns_data/Cnn14_mAP=0.431.pth


2024-12-07 16:21:43,717 - INFO - Models loaded successfully
2024-12-07 16:21:43,721 - INFO - Processing 535 audio files


Using CPU.


Processing files: 100%|██████████| 535/535 [02:14<00:00,  3.97it/s]
2024-12-07 16:23:58,566 - INFO - Resampling audio for PANNs processing...
Processing PANNs batches: 100%|██████████| 6/6 [00:56<00:00,  9.42s/it]
2024-12-07 16:24:55,965 - INFO - Features saved to /Users/huangjuhua/文档文稿/NYU/Time_Series/data/processed/embeddings
2024-12-07 16:24:55,966 - INFO - Processing completed successfully
