In [None]:
import os
import csv
import pandas as pd
import librosa
import numpy as np
from pathlib import Path
from tqdm import tqdm
import gc
import re
import glob
import pickle

import soundfile as sf
from musicnn.extractor import extractor as musicnn_extractor

import torch
import resampy
from transformers import Wav2Vec2Processor, Wav2Vec2Model

# Add CLAP import
from laion_clap import CLAP_Module

import warnings
warnings.filterwarnings('ignore')

In [None]:
class GTZANFeatureExtractor:
    def __init__(self, data_path):
        """
        Initialize the feature extractor for GTZAN dataset
        
        Args:
            data_path (str): Path to GTZAN data directory (e.g., "E:/Oxford/Extra/ICASSP/Draft_1/GTZAN-data")
        """
        self.data_path = Path(data_path)
        self.audio_base_dir = self.data_path / "genres_original"
        self.metadata_df = None
        self.mapping_validation = None
        
        # GTZAN genre list
        self.genres = ['blues', 'classical', 'country', 'disco', 'hiphop', 
                      'jazz', 'metal', 'pop', 'reggae', 'rock']
        
    def load_metadata(self):
        """
        Create metadata DataFrame from GTZAN folder structure
        Each genre folder contains 100 files: genre.00000.wav to genre.00099.wav
        """
        print(f"Creating metadata from GTZAN folder structure: {self.audio_base_dir}")
        
        if not self.audio_base_dir.exists():
            raise FileNotFoundError(f"GTZAN audio directory not found: {self.audio_base_dir}")
        
        metadata_rows = []
        
        for genre in self.genres:
            genre_dir = self.audio_base_dir / genre
            if not genre_dir.exists():
                print(f"Warning: Genre directory not found: {genre_dir}")
                continue
            
            # Find all .wav files in this genre directory
            wav_files = list(genre_dir.glob("*.wav"))
            print(f"Found {len(wav_files)} files in {genre} directory")
            
            for wav_file in wav_files:
                # Extract track number from filename (e.g., blues.00042.wav -> 42)
                filename = wav_file.stem  # e.g., "blues.00042"
                parts = filename.split('.')
                if len(parts) >= 2:
                    track_num = parts[1]  # "00042"
                else:
                    track_num = "unknown"
                
                # Create unique track ID
                track_id = f"{genre}_{track_num}"
                
                # Create metadata row
                metadata_row = {
                    'TRACK_ID': track_id,
                    'PATH': f"{genre}/{wav_file.name}",  # e.g., "blues/blues.00042.wav"
                    'GENRE': genre,
                    'TRACK_NUMBER': track_num,
                    'FILENAME': wav_file.name,
                    'FULL_PATH': str(wav_file)
                }
                metadata_rows.append(metadata_row)
        
        # Create DataFrame
        self.metadata_df = pd.DataFrame(metadata_rows)
        print(f"\nCreated metadata for {len(self.metadata_df)} tracks")
        print(f"Genres found: {sorted(self.metadata_df['GENRE'].unique())}")
        print(f"Tracks per genre: {self.metadata_df['GENRE'].value_counts().to_dict()}")
        print(f"Columns: {list(self.metadata_df.columns)}\n")
        print(self.metadata_df.head(3))
        return self.metadata_df
    
    def validate_audio_directory(self):
        """Validate the GTZAN audio directory structure"""
        if not self.audio_base_dir.exists():
            raise FileNotFoundError(f"Audio base directory not found: {self.audio_base_dir}")
        
        print(f"Found GTZAN audio base directory: {self.audio_base_dir}")
        
        # Check for genre folders
        found_genres = []
        total_wav_count = 0
        
        for genre in self.genres:
            genre_dir = self.audio_base_dir / genre
            if genre_dir.exists() and genre_dir.is_dir():
                found_genres.append(genre)
                wav_files = list(genre_dir.glob("*.wav"))
                total_wav_count += len(wav_files)
                print(f"  Genre '{genre}': {len(wav_files)} WAV files")
            else:
                print(f"  Genre '{genre}': MISSING")
        
        print(f"\nFound {len(found_genres)}/{len(self.genres)} genre directories")
        print(f"Total WAV files: {total_wav_count}")
        
        if len(found_genres) == 0:
            raise FileNotFoundError("No genre directories found in GTZAN directory")
        
        return True
    
    def get_audio_path(self, track_id):
        """Get audio file path from track ID"""
        row = self.metadata_df.loc[self.metadata_df['TRACK_ID'] == track_id]
        if row.empty:
            raise KeyError(f"No metadata found for track {track_id}")
        relpath = row.iloc[0]['PATH']     # e.g. "blues/blues.00042.wav"
        fullpath = self.audio_base_dir / relpath
        if not fullpath.exists():
            raise FileNotFoundError(f"Audio file not found at {fullpath}")
        return fullpath
    
    def validate_mapping(self, sample_size=100):
        """
        Validate mapping between metadata and actual audio files
        
        Args:
            sample_size (int): Number of random samples to check
        
        Returns:
            dict: Validation results with statistics
        """
        print(f"\n{'='*50}")
        print("VALIDATING AUDIO-METADATA MAPPING")
        print(f"{'='*50}")
        
        if self.metadata_df is None:
            self.load_metadata()
        
        # Sample tracks for validation
        sample_df = self.metadata_df.sample(min(sample_size, len(self.metadata_df)), random_state=42)
        
        validation_results = {
            'total_checked': len(sample_df),
            'found': 0,
            'missing': 0,
            'found_tracks': [],
            'missing_tracks': [],
            'sample_paths': []
        }
        
        print(f"Checking {len(sample_df)} random tracks from metadata...")
        
        for idx, row in sample_df.iterrows():
            track_id = row['TRACK_ID']
            audio_path = self.get_audio_path(track_id)
            
            if audio_path.exists():
                validation_results['found'] += 1
                validation_results['found_tracks'].append({
                    'track_id': track_id,
                    'path': str(audio_path),
                    'size_mb': audio_path.stat().st_size / (1024*1024),
                    'genre': row['GENRE']
                })
                # Store first 5 found paths as samples
                if len(validation_results['sample_paths']) < 5:
                    validation_results['sample_paths'].append(str(audio_path))
            else:
                validation_results['missing'] += 1
                validation_results['missing_tracks'].append({
                    'track_id': track_id,
                    'expected_path': str(audio_path)
                })
        
        # Calculate statistics
        found_percentage = (validation_results['found'] / validation_results['total_checked']) * 100
        
        print(f"\nVALIDATION RESULTS:")
        print(f"  Total checked: {validation_results['total_checked']}")
        print(f"  Found: {validation_results['found']} ({found_percentage:.1f}%)")
        print(f"  Missing: {validation_results['missing']} ({100-found_percentage:.1f}%)")
        
        if validation_results['sample_paths']:
            print(f"\nSample found audio paths:")
            for path in validation_results['sample_paths']:
                print(f"  {path}")
        
        if validation_results['missing_tracks'] and len(validation_results['missing_tracks']) <= 5:
            print(f"\nMissing tracks (showing first 5):")
            for track in validation_results['missing_tracks'][:5]:
                print(f"  Track {track['track_id']}: Expected at {track['expected_path']}")
        
        # Check genre distribution for found tracks
        if validation_results['found_tracks']:
            genre_distribution = {}
            for track in validation_results['found_tracks']:
                genre = track['genre']
                genre_distribution[genre] = genre_distribution.get(genre, 0) + 1
            
            print(f"\nGenre distribution of found tracks:")
            for genre, count in sorted(genre_distribution.items()):
                print(f"  {genre}: {count} tracks")
        
        self.mapping_validation = validation_results
        
        if found_percentage < 50:
            print(f"\n⚠️  WARNING: Only {found_percentage:.1f}% of tracks found!")
            print("This might indicate an issue with the folder structure or file naming.")
        else:
            print(f"\n✅ Validation successful: {found_percentage:.1f}% of tracks found")
        
        return validation_results
    
    def load_audio_file(self, file_path, sr=22050, duration=30.0):
        """
        Load an audio file using librosa
        
        Args:
            file_path (Path): Path to the audio file
            sr (int): Sample rate (default: 22050 Hz)
            duration (float): Duration to load in seconds (default: 30.0 for 30-second clips)
        
        Returns:
            tuple: (audio_data, sample_rate) or (None, None) if loading fails
        """
        try:
            # Load audio file
            y, sr_actual = librosa.load(file_path, sr=sr, duration=duration)
            return y, sr_actual
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
            return None, None
    
    def process_dataset(self,
                        feature_extractor_func,
                        start_index: int = 0,
                        end_index: int = None,
                        save_interval: int = 100,
                        validate_first: bool = True,
                        output_dir_name: str = "clap_features") -> pd.DataFrame:
        """
        Main processing function to extract features from all audio files
        
        Args:
            feature_extractor_func: Your feature extraction function
            start_index (int): Starting index for processing
            end_index (int): Ending index for processing (None for all)
            save_interval (int): Save intermediate results every N files
            validate_first (bool): Whether to validate mapping before processing
            output_dir_name (str): Name of output directory for intermediate saves
        
        Returns:
            pd.DataFrame: DataFrame with extracted features
        """
        # Load metadata and validate audio directory
        self.load_metadata()
        self.validate_audio_directory()
        
        # Run validation check first
        if validate_first:
            validation_results = self.validate_mapping(sample_size=50)
            
            # Ask user if they want to continue if validation shows issues
            if validation_results['found'] / validation_results['total_checked'] < 0.5:
                print("\n⚠️  Low success rate in validation. Please check the issues above.")
                response = input("Do you want to continue anyway? (y/n): ")
                if response.lower() != 'y':
                    print("Processing cancelled.")
                    return None
        
        # Slice the DataFrame to only the desired segment
        total = len(self.metadata_df)
        end_index = end_index or total
        df_segment = self.metadata_df.iloc[start_index:end_index]

        print(f"Processing tracks {start_index} to {end_index} (total {len(df_segment)})")
        
        # Initialize processing variables
        results = []
        processed_count = start_index
        failed_count = 0
        last_pkl = None
        
        print(f"Processing starting at index {start_index} (track #{processed_count+1})")
        print(f"Tracks left: {len(df_segment)}")
        
        # Create output directory if it doesn't exist
        output_dir = self.data_path / output_dir_name
        output_dir.mkdir(exist_ok=True)
        
        # Process each track
        for idx, row in tqdm(df_segment.iterrows(), total=len(df_segment), desc="Processing audio"):
            track_id = row['TRACK_ID']
            
            # Get audio file path
            audio_path = self.get_audio_path(track_id)
            
            if not audio_path.exists():
                failed_count += 1
                continue
            
            # Load audio
            audio_data, sample_rate = self.load_audio_file(audio_path)
            
            if audio_data is None:
                failed_count += 1
                continue
            
            # Extract features using your feature extractor
            try:
                features = feature_extractor_func(audio_data, sample_rate)
                
                # Combine metadata with features
                result_row = {
                    'track_id': track_id,
                    'audio_path': str(audio_path),
                    'duration': len(audio_data) / sample_rate,
                    'sample_rate': sample_rate,
                    **features  # Unpack feature dictionary
                }
                
                # Add original metadata columns
                for col in row.index:
                    if col not in result_row:
                        result_row[f'meta_{col}'] = row[col]
                
                results.append(result_row)
                processed_count += 1
                
                # Save intermediate results
                if processed_count % save_interval == 0:
                    temp_df = pd.DataFrame(results)
                    
                    # Save CSV checkpoint
                    csv_path = output_dir / f"gtzan_features_intermediate_{processed_count}.csv"
                    temp_df.to_csv(csv_path, index=False)  

                    # Save Pickle checkpoint
                    pkl_path = output_dir / f"gtzan_features_intermediate_{processed_count}.pkl"
                    temp_df.to_pickle(pkl_path) 

                    # Delete previous snapshot if any
                    if last_pkl and os.path.exists(last_pkl):
                        os.remove(last_pkl)
                    last_pkl = str(pkl_path) 

                    # Optional: free memory if needed
                    del temp_df  
                    gc.collect()

                    print(f"Saved intermediate results at {processed_count} files →")
                    print(f"  • CSV:  {csv_path}")
                    print(f"  • Pickle: {pkl_path}")
                
            except Exception as e:
                print(f"Feature extraction failed for {track_id}: {e}")
                failed_count += 1
        
        # Create final DataFrame
        features_df = pd.DataFrame(results)
        
        print(f"\nProcessing complete!")
        print(f"Successfully processed: {processed_count} files")
        print(f"Failed: {failed_count} files")
        if len(features_df) > 0:
            print(f"Feature DataFrame shape: {features_df.shape}")
        
        return features_df

In [None]:
def quick_validation_check():
    """
    Quick function to just validate the mapping without doing feature extraction
    Run this first to make sure everything is set up correctly
    """
    data_path = "E:/Oxford/Extra/ICASSP/Draft_1/GTZAN-data"
    extractor = GTZANFeatureExtractor(data_path)
    
    print("Loading metadata and validating audio directory structure...")
    extractor.load_metadata()
    extractor.validate_audio_directory()
    
    print("\nValidating audio-metadata mapping...")
    validation_results = extractor.validate_mapping(sample_size=100)
    
    return validation_results

def musicnn_feature_extractor(audio_data: np.ndarray,
                              sample_rate: int,
                              model: str = 'MTT_musicnn') -> dict:
    """
    Given raw audio (numpy) + sample rate, save to temp WAV,
    run musicnn.extractor to get intermediate representations.
    Returns a dict with keys like 'timbral', 'temporal', 'cnn1', ...,
    'penultimate' (for musicnn models).
    """
    tmp_path = 'temp_clip.wav'
    sf.write(tmp_path, audio_data, sample_rate)  
    
    # extractor returns (taggram, tags, features_dict)
    _, _, feats = musicnn_extractor(tmp_path,
                             model=model,
                             extract_features=True)
    return feats


# Load wav2vec once at module scope:
processor = Wav2Vec2Processor.from_pretrained('facebook/wav2vec2-base-960h')
model     = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')
model.eval()

def wav2vec_feature_extractor(audio_data: np.ndarray,
                              sample_rate: int) -> dict:
    """
    Resample to 16 kHz, tokenize with Wav2Vec2Processor,
    run Wav2Vec2Model, and average last_hidden_state to
    get one fixed-length embedding per clip.
    """
    # Resample if needed
    if sample_rate != 16000:
        audio_data = resampy.resample(audio_data, sample_rate, 16000)
    
    # Tokenize & run model
    inputs = processor(audio_data,
                       sampling_rate=16000,
                       return_tensors='pt',
                       padding=True)
    with torch.no_grad():
        outputs = model(**inputs)  # outputs.last_hidden_state shape (1, T, 768)
    
    # Mean‑pool over the time dimension → (768,)
    embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
    return {'wav2vec2_embedding': embedding}


# ============================================================================
# CLAP Feature Extractor
# ============================================================================

# Initialize CLAP model with local checkpoint
print("Loading CLAP model...")
ckpt_path = "D:/Models/630k-audioset-best.pt"
clap_model = CLAP_Module(enable_fusion=False)
clap_model.load_ckpt(ckpt_path)  # loads from your local copy
clap_model.eval()
print("CLAP model loaded successfully!")

def clap_feature_extractor(audio_data: np.ndarray, sample_rate: int) -> dict:
    """
    Extract CLAP embeddings from raw audio.
    Resamples to 48 kHz, reshapes to (1, T), and pools the output.
    
    Args:
        audio_data (np.ndarray): Raw audio data
        sample_rate (int): Sample rate of the audio
        
    Returns:
        dict: Dictionary containing 'clap_embedding' key with embedding vector
    """
    # Resample to 48 kHz if needed (CLAP requires 48kHz)
    if sample_rate != 48000:
        audio_data = resampy.resample(audio_data, sample_rate, 48000)
    
    # Shape to (batch, time) - CLAP expects batch dimension
    audio_data = audio_data.reshape(1, -1)
    
    # Get numpy embeddings from CLAP
    audio_embed = clap_model.get_audio_embedding_from_data(x=audio_data, use_tensor=False)
    
    # Squeeze to 1D array
    embedding = np.squeeze(audio_embed, axis=0)
    
    return {"clap_embedding": embedding}

In [None]:
data_path = "E:/Oxford/Extra/ICASSP/Draft_1/GTZAN-data"

DATA_ROOT = Path("E:/Oxford/Extra/ICASSP/Draft_1/GTZAN-data")

In [None]:
# Quick validation check to ensure everything is set up correctly
validation_results = quick_validation_check()

In [None]:
# # 1) musicnn embeddings
# extractor = GTZANFeatureExtractor(data_path)
# df_musicnn = extractor.process_dataset(
#     feature_extractor_func=musicnn_feature_extractor,
#     start_index=0,
#     end_index=None,
#     save_interval=100,
#     validate_first=False
# )
# df_musicnn.to_csv(DATA_ROOT / 'gtzan_musicnn_embeddings.csv',
#                   index=False)
# df_musicnn.to_pickle(DATA_ROOT / 'gtzan_musicnn_embeddings.pkl')

In [None]:
# # 2) wav2vec2 embeddings
# extractor = GTZANFeatureExtractor(data_path)
# df_wav2vec = extractor.process_dataset(
#     feature_extractor_func=wav2vec_feature_extractor,
#     start_index=0,
#     end_index=None,
#     save_interval=100,
#     validate_first=False
# )
# df_wav2vec.to_csv(DATA_ROOT / 'gtzan_wav2vec2_embeddings.csv',
#                   index=False)
# df_wav2vec.to_pickle(DATA_ROOT / 'gtzan_wav2vec2_embeddings.pkl')

In [None]:
# # Convert wav2vec embeddings to separate columns for CSV compatibility
# emb = df_wav2vec.pop('wav2vec2_embedding').tolist()
# df_wav2vec[[f'w2v_{i}' for i in range(len(emb[0]))]] = pd.DataFrame(emb)

# df_wav2vec.to_csv(DATA_ROOT / 'gtzan_wav2vec2_embeddings_v2.csv',
#                   index=False)
# df_wav2vec.to_pickle(DATA_ROOT / 'gtzan_wav2vec2_embeddings_v2.pkl')

In [None]:
# 3) CLAP embeddings
extractor = GTZANFeatureExtractor(data_path)
df_clap = extractor.process_dataset(
    feature_extractor_func=clap_feature_extractor,
    start_index=0,
    end_index=None,
    save_interval=300,
    validate_first=False
)
df_clap.to_csv(DATA_ROOT / 'gtzan_clap_embeddings.csv',
                  index=False)
df_clap.to_pickle(DATA_ROOT / 'gtzan_clap_embeddings.pkl')

In [None]:
# Convert CLAP embeddings to separate columns for CSV compatibility
emb = df_clap.pop('clap_embedding').tolist()
df_clap[[f'clap_{i}' for i in range(len(emb[0]))]] = pd.DataFrame(emb)

df_clap.to_csv(DATA_ROOT / 'gtzan_clap_embeddings_v2.csv',
                  index=False)
df_clap.to_pickle(DATA_ROOT / 'gtzan_clap_embeddings_v2.pkl')