In [2]:
# !pip install librosa

In [3]:
import os
import random
import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import json
from collections import Counter
import warnings
from pathlib import Path
import shutil

In [4]:
# Set parameters
INPUT_DIR = "../dataset/raw data"  # Folder containing genre subdirectories with 30-second audio files
OUTPUT_DIR = "../dataset/melspectrograms"  # Folder where generated images will be saved
METADATA_DIR = "../dataset/metadata"  # Directory for metadata files

SPLIT_RATIOS = {"train": 0.8, "validation": 0.1, "test": 0.1}

# Audio processing parameters
SAMPLE_RATE = 22050       
SEGMENT_DURATION = 3      
NUM_SEGMENTS = 10        
N_MELS = 128
N_FFT = 2048
HOP_LENGTH = 512

In [5]:
# Create metadata directory
os.makedirs(METADATA_DIR, exist_ok=True)

In [6]:
def create_output_dirs():
    """Create train, validation, and test directories for each genre."""
    for split in SPLIT_RATIOS.keys():
        for genre in os.listdir(INPUT_DIR):
            genre_dir = os.path.join(INPUT_DIR, genre)
            if os.path.isdir(genre_dir):
                output_genre_dir = os.path.join(OUTPUT_DIR, split, genre)
                os.makedirs(output_genre_dir, exist_ok=True)



In [7]:
def assign_split():
    """Randomly assign a segment to train, validation, or test based on SPLIT_RATIOS."""
    rnd = random.random()
    cumulative = 0.0
    for split, ratio in SPLIT_RATIOS.items():
        cumulative += ratio
        if rnd < cumulative:
            return split
    return "train"  # fallback



In [8]:
def save_mel_spectrogram(mel_db, sr, output_filepath):
    """Save the Mel spectrogram as a JPEG image without axes."""
    plt.figure(figsize=(3, 3))
    # Display the spectrogram;
    librosa.display.specshow(mel_db, sr=sr, hop_length=HOP_LENGTH, 
                             x_axis='time', y_axis='mel', cmap='viridis')
    plt.axis('off')
    plt.tight_layout(pad=0)
    plt.savefig(output_filepath, bbox_inches='tight', pad_inches=0)
    plt.close()



In [9]:
def process_audio_file(filepath, genre, metadata_dict):
    """Load an audio file, split into segments, generate and save Mel spectrograms."""
    try:
        # Load the full audio file
        y, sr = librosa.load(filepath, sr=SAMPLE_RATE)
        total_samples = len(y)
        samples_per_segment = int(SAMPLE_RATE * SEGMENT_DURATION)
        
        # Ensure the audio file has the expected length
        if total_samples < samples_per_segment * NUM_SEGMENTS:
            warnings.warn(f"Warning: {filepath} is shorter than expected. Skipping.")
            metadata_dict["skipped_files"].append(filepath)
            return
        
        # Process each segment
        for i in range(NUM_SEGMENTS):
            start = i * samples_per_segment
            end = start + samples_per_segment
            segment = y[start:end]
            
            # Compute the Mel spectrogram
            mel_spec = librosa.feature.melspectrogram(y=segment, sr=sr, n_fft=N_FFT,
                                                  hop_length=HOP_LENGTH, n_mels=N_MELS)
            # Convert to decibels for visualization
            mel_db = librosa.power_to_db(mel_spec, ref=np.max)
            
            # Assign the segment to a split
            split = assign_split()
            # Track split assignments
            metadata_dict["split_counts"][split] += 1
            
            # Construct the output filepath. Using the original file's basename (without extension)
            base_filename = os.path.splitext(os.path.basename(filepath))[0]
            output_filename = f"{base_filename}_segment{i+1}.jpg"
            output_path = os.path.join(OUTPUT_DIR, split, genre, output_filename)
            
            # Save the spectrogram as a jpg file
            save_mel_spectrogram(mel_db, sr, output_path)
            
            # Add to metadata
            metadata_dict["files"].append({
                "original_file": filepath,
                "segment": i+1,
                "genre": genre,
                "split": split,
                "spectrogram_path": output_path
            })
            
            # Track genre counts
            metadata_dict["genre_counts"][genre] += 1
            
            # For the first segment of the first few files per genre, save a preview with labels
            if i == 0 and metadata_dict["genre_preview_counts"][genre] < 3:
                save_preview_spectrogram(mel_db, sr, genre, base_filename)
                metadata_dict["genre_preview_counts"][genre] += 1
            
    except Exception as e:
        metadata_dict["errors"].append(f"Error processing {filepath}: {str(e)}")
        warnings.warn(f"Error processing {filepath}: {e}")

In [10]:
def save_preview_spectrogram(mel_db, sr, genre, base_filename):
    """Save a labeled preview spectrogram for visualization purposes."""
    preview_dir = os.path.join(METADATA_DIR, "previews")
    os.makedirs(preview_dir, exist_ok=True)
    
    plt.figure(figsize=(5, 5))
    librosa.display.specshow(mel_db, sr=sr, hop_length=HOP_LENGTH, 
                           x_axis='time', y_axis='mel', cmap='viridis')
    plt.colorbar(format='%+2.0f dB')
    plt.title(f'Mel Spectrogram - {genre}')
    plt.tight_layout()
    plt.savefig(os.path.join(preview_dir, f"{genre}_{base_filename}_preview.jpg"))
    plt.close()

In [11]:
def verify_split_distribution(metadata):
    """Verify that the split distribution matches the expected ratios."""
    total = sum(metadata["split_counts"].values())
    actual_ratios = {split: count/total for split, count in metadata["split_counts"].items()}
    
    print("\nSplit Distribution Verification:")
    print(f"Target ratios: {SPLIT_RATIOS}")
    print(f"Actual ratios: {actual_ratios}")
    
    # Calculate absolute differences
    differences = {split: abs(SPLIT_RATIOS[split] - actual_ratios[split]) 
                  for split in SPLIT_RATIOS.keys()}
    print(f"Differences: {differences}")
    
    # Check if any difference is more than 5% (a reasonable threshold)
    if any(diff > 0.05 for diff in differences.values()):
        print("WARNING: Split distribution differs from target by more than 5%")
    else:
        print("Split distribution is close to target (within 5% tolerance)")

In [12]:
def generate_dataset_statistics(metadata):
    """Generate statistics about the dataset."""
    total_files = len(metadata["files"])
    
    # Create DataFrame for easier analysis
    df = pd.DataFrame(metadata["files"])
    
    print("\n=== Dataset Statistics ===")
    print(f"Total spectrograms generated: {total_files}")
    print(f"Genres: {list(metadata['genre_counts'].keys())}")
    
    # Per genre counts
    print("\nSpectrograms per genre:")
    for genre, count in metadata["genre_counts"].items():
        print(f"  {genre}: {count}")
    
    # Per split counts
    print("\nSpectrograms per split:")
    for split, count in metadata["split_counts"].items():
        print(f"  {split}: {count}")
    
    # Cross-tabulation of genres and splits
    if total_files > 0:
        print("\nDistribution across genres and splits:")
        split_genre_table = pd.crosstab(df['genre'], df['split'])
        print(split_genre_table)
    
    # Error summary
    if metadata["errors"]:
        print(f"\nTotal errors: {len(metadata['errors'])}")
        print("First 5 errors:")
        for error in metadata["errors"][:5]:
            print(f"  {error}")
    else:
        print("\nNo errors occurred during processing.")
    
    if metadata["skipped_files"]:
        print(f"\nTotal skipped files: {len(metadata['skipped_files'])}")
    else:
        print("\nNo files were skipped.")

In [13]:
def save_metadata(metadata):
    """Save metadata to JSON and CSV files."""
    # Save full metadata as JSON
    with open(os.path.join(METADATA_DIR, 'dataset_metadata.json'), 'w') as f:
        # Convert Counter objects to dictionaries for JSON serialization
        json_metadata = metadata.copy()
        json_metadata["genre_counts"] = dict(metadata["genre_counts"])
        json_metadata["split_counts"] = dict(metadata["split_counts"])
        json_metadata["genre_preview_counts"] = dict(metadata["genre_preview_counts"])
        json.dump(json_metadata, f, indent=2)
    
    # Save file listing as CSV for easy loading in ML pipelines
    if metadata["files"]:
        df = pd.DataFrame(metadata["files"])
        df.to_csv(os.path.join(METADATA_DIR, 'dataset_files.csv'), index=False)

In [14]:
def visualize_dataset(metadata):
    """Create visualizations of the dataset distribution."""
    if not metadata["files"]:
        print("No files to visualize.")
        return
    
    # Create a DataFrame from the files metadata
    df = pd.DataFrame(metadata["files"])
    
    # Set up the figure
    plt.figure(figsize=(15, 10))
    
    # Plot 1: Genre distribution
    plt.subplot(2, 2, 1)
    genre_counts = df['genre'].value_counts()
    genre_counts.plot(kind='bar')
    plt.title('Number of Spectrograms per Genre')
    plt.ylabel('Count')
    plt.xticks(rotation=45)
    
    # Plot 2: Split distribution
    plt.subplot(2, 2, 2)
    split_counts = df['split'].value_counts()
    split_counts.plot(kind='pie', autopct='%1.1f%%')
    plt.title('Distribution of Train/Validation/Test Splits')
    
    # Plot 3: Heatmap of genre vs split
    plt.subplot(2, 2, 3)
    cross_tab = pd.crosstab(df['genre'], df['split'])
    plt.imshow(cross_tab, cmap='viridis')
    plt.colorbar(label='Count')
    plt.xticks(range(len(cross_tab.columns)), cross_tab.columns, rotation=45)
    plt.yticks(range(len(cross_tab.index)), cross_tab.index)
    plt.title('Genre vs Split Distribution')
    
    # Plot 4: Segment distribution (should be uniform)
    plt.subplot(2, 2, 4)
    segment_counts = df['segment'].value_counts().sort_index()
    segment_counts.plot(kind='bar')
    plt.title('Number of Spectrograms per Segment Position')
    plt.xlabel('Segment Position')
    plt.ylabel('Count')
    
    plt.tight_layout()
    plt.savefig(os.path.join(METADATA_DIR, 'dataset_visualization.png'))
    plt.close()
    
    print(f"Dataset visualization saved to {os.path.join(METADATA_DIR, 'dataset_visualization.png')}")

In [15]:
def clean_output_directory():
    """Clean the output directory if it exists."""
    if os.path.exists(OUTPUT_DIR):
        print(f"Cleaning output directory: {OUTPUT_DIR}")
        try:
            shutil.rmtree(OUTPUT_DIR)
            print("Output directory cleaned successfully.")
        except Exception as e:
            print(f"Error cleaning output directory: {e}")
            print("Continuing with existing directory.")
    
    # Recreate the main output directory
    os.makedirs(OUTPUT_DIR, exist_ok=True)

In [16]:
def resume_processing(metadata_path=None):
    """Resume processing from a saved state if metadata exists."""
    processed_files = set()
    
    if metadata_path and os.path.exists(metadata_path):
        try:
            with open(metadata_path, 'r') as f:
                saved_metadata = json.load(f)
            
            # Extract already processed files
            for item in saved_metadata["files"]:
                processed_files.add(item["original_file"])
            
            print(f"Resuming from previous run. {len(processed_files)} files already processed.")
            return processed_files
        except Exception as e:
            print(f"Error loading previous metadata: {e}")
            print("Starting from scratch.")
    
    return processed_files


In [17]:
# Configuration options - Modify these variables as needed
# Set these variables to control execution
CLEAN_OUTPUT_DIR = False  # Set to True to clean the output directory before processing
RESUME_PROCESSING = False  # Set to True to resume from a previous run
SAVE_METADATA_INTERVAL = 10  # Save metadata every N files



In [18]:
# Initialize metadata and directories
# Initialize metadata tracking
metadata = {
    "files": [],
    "errors": [],
    "skipped_files": [],
    "genre_counts": Counter(),
    "split_counts": Counter(),
    "genre_preview_counts": Counter()
}

# Clean output directory if requested
if CLEAN_OUTPUT_DIR:
    clean_output_directory()

# Set up for resuming if requested
processed_files = set()
if RESUME_PROCESSING:
    metadata_path = os.path.join(METADATA_DIR, 'dataset_metadata.json')
    processed_files = resume_processing(metadata_path)

# Create the output directory structure
create_output_dirs()

# Set random seed for reproducibility
random.seed(42)

In [None]:
# Main processing loop
# Process all genres and audio files
for genre in os.listdir(INPUT_DIR):
    print(f'Genre: {genre}')
    genre_dir = os.path.join(INPUT_DIR, genre)
    if not os.path.isdir(genre_dir):
        continue  # skip non-directory files
    
    # Get list of audio files
    audio_files = [
        f for f in os.listdir(genre_dir) 
        if f.lower().endswith((".wav", ".mp3", ".au", ".ogg", ".flac"))
    ]
    
    # Process each audio file in the genre directory
    for filename in tqdm(audio_files, desc=f"Processing {genre}", leave=True):
        filepath = os.path.join(genre_dir, filename)
        
        # Skip already processed files if resuming
        if filepath in processed_files:
            print(f"Skipping already processed file: {filepath}")
            continue
            
        process_audio_file(filepath, genre, metadata)
        
        # Save metadata periodically
        if len(metadata["files"]) % SAVE_METADATA_INTERVAL == 0 and metadata["files"]:
            save_metadata(metadata)

Genre: blues


Processing blues: 100%|██████████| 100/100 [01:00<00:00,  1.65it/s]


Genre: classical


Processing classical: 100%|██████████| 100/100 [01:04<00:00,  1.55it/s]


Genre: country


Processing country: 100%|██████████| 100/100 [01:23<00:00,  1.19it/s]


Genre: disco


Processing disco:  73%|███████▎  | 73/100 [02:02<01:41,  3.76s/it]

In [None]:
# Final steps - analysis and reporting
# Verify and analyze the final dataset
verify_split_distribution(metadata)
generate_dataset_statistics(metadata)
save_metadata(metadata)
visualize_dataset(metadata)

print("\nProcessing complete!")
print(f"Generated {len(metadata['files'])} spectrogram images.")
print(f"Metadata saved to {METADATA_DIR}")