In [1]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import wavfile
from scipy.interpolate import CubicSpline
from multiprocessing import Pool, cpu_count
from tqdm import tqdm

In [2]:
# csv_file_path = r"../Data/audio_features.csv"
# audio_directory = r"../Data/audios"
# imf_directory = r"../IMFS/IMFS_EMD"
# residual_directory = r"../Residual/Residual_EMD"
# reconstructed_directory = r"../Reconstructed_Signal/RS_EMD"
# df = pd.read_csv(csv_file_path)

In [3]:
def count_zero_crossings(signal):
    """Count zero-crossings in a signal."""
    return np.sum(np.diff(np.sign(signal)) != 0)

In [4]:
def sift(residual, max_iter=128, tol=1e-4):
    """Extract one IMF with zero-crossing stopping criterion."""
    h = residual
    for i in range(max_iter):
        maxima = np.where((h[:-2] < h[1:-1]) & (h[1:-1] > h[2:]))[0] + 1
        minima = np.where((h[:-2] > h[1:-1]) & (h[1:-1] < h[2:]))[0] + 1

        if len(maxima) < 2 or len(minima) < 2:
            break

        upper_env = CubicSpline(maxima, h[maxima])(np.arange(len(h)))
        lower_env = CubicSpline(minima, h[minima])(np.arange(len(h)))
        mean_env = (upper_env + lower_env) / 2
        new_h = h - mean_env

        zero_crossings = count_zero_crossings(new_h)
        extrema_count = len(maxima) + len(minima)
        
        if abs(zero_crossings - extrema_count) <= 1:
            print(f"IMF found with {i + 1} iterations")
            return new_h
        h = new_h

    return h

In [5]:
def emd(signal, max_imfs=16):
    """Perform Empirical Mode Decomposition (EMD) on the signal."""
    imfs = []
    residual = signal
    for i in range(max_imfs):
        imf = sift(residual)
        imfs.append(imf)
        residual -= imf
        if np.all(np.abs(residual) < 1e-6):
            break
    return imfs, residual

In [6]:
def save_audio(file_path, sample_rate, data):
    """Save audio data to a file, normalized to int16."""
    data = (data / np.max(np.abs(data))) * 32767
    wavfile.write(file_path, sample_rate, data.astype(np.int16))

In [7]:
# For a single data point

audio_path = '../../Dataset/neurovoz_v3/data/audios/HC_A1_0034.wav'
sample_rate, data = wavfile.read(audio_path)
data = np.array(data)

data = data.astype(np.float32)
imfs, residual = emd(signal=data)

IMF found with 32 iterations
IMF found with 101 iterations
IMF found with 37 iterations
IMF found with 6 iterations
IMF found with 1 iterations
IMF found with 7 iterations
IMF found with 2 iterations
IMF found with 2 iterations
IMF found with 2 iterations


In [None]:
def process_audio(row):
    """Process a single audio file using EMD and save IMFs, residual, and reconstructed signal."""
    relative_path = row['AudioPath'].strip().replace('../data/audios/', '')
    file_path = os.path.join(audio_directory, relative_path)
    
    try:
        # Read and normalize audio
        sample_rate, data = wavfile.read(file_path)
        data = data[:, 0] if len(data.shape) == 2 else data
        data = data / np.max(np.abs(data))
        print(f"Processing file: {file_path}")

        # Apply EMD
        imfs, residual = emd(data, max_imfs=10)

        # Save each IMF and residual
        for i, imf in enumerate(imfs):
            output_file_path = os.path.join(imf_directory, f"{os.path.splitext(os.path.basename(file_path))[0]}_imf_{i+1}.wav")
            save_audio(output_file_path, sample_rate, imf)
            print(f"IMF {i+1} saved: {output_file_path}")

        # Save residual
        residual_path = os.path.join(residual_directory, f"{os.path.splitext(os.path.basename(file_path))[0]}_residual.wav")
        save_audio(residual_path, sample_rate, residual)
        print(f"Residual saved: {residual_path}")

        # Save reconstructed signal (sum of IMFs)
        reconstructed = np.sum(imfs, axis=0)
        reconstructed_path = os.path.join(reconstructed_directory, f"{os.path.splitext(os.path.basename(file_path))[0]}_reconstructed.wav")
        save_audio(reconstructed_path, sample_rate, reconstructed)
        print(f"Reconstructed signal saved: {reconstructed_path}")

    except FileNotFoundError as e:
        print(f"File not found: {file_path}. Error: {e}")
    except Exception as e:
        print(f"Error processing {file_path}: {e}")

In [None]:
print("Starting EMD processing for all audio files sequentially.")
for _, row in tqdm(df.iterrows(), total=len(df)):
    process_audio(row)
print("Processing completed.")

In [None]:
def plot_emd_results(file_base_name, sample_rate=16000):
    """
    Plots the IMFs, residual, and reconstructed signal for a specific audio file.
    Assumes files are saved in the specified directories.
    """
    # Load IMFs
    imf_files = sorted([f for f in os.listdir(imf_directory) if f.startswith(file_base_name) and '_imf_' in f])
    imfs = [wavfile.read(os.path.join(imf_directory, f))[1] for f in imf_files]

    # Load residual
    residual_path = os.path.join(residual_directory, f"{file_base_name}_residual.wav")
    _, residual = wavfile.read(residual_path)

    # Load reconstructed signal
    reconstructed_path = os.path.join(reconstructed_directory, f"{file_base_name}_reconstructed.wav")
    _, reconstructed = wavfile.read(reconstructed_path)

    # Plot IMFs, residual, and reconstructed signal
    num_imfs = len(imfs)
    plt.figure(figsize=(10, 2 * (num_imfs + 2)))  # Adjust height based on number of plots

    # Plot each IMF
    for i, imf in enumerate(imfs, 1):
        plt.subplot(num_imfs + 2, 1, i)
        plt.plot(imf)
        plt.title(f"IMF {i}")
        plt.xlabel("Sample")
        plt.ylabel("Amplitude")

    # Plot residual
    plt.subplot(num_imfs + 2, 1, num_imfs + 1)
    plt.plot(residual, color="orange")
    plt.title("Residual")
    plt.xlabel("Sample")
    plt.ylabel("Amplitude")

    # Plot reconstructed signal
    plt.subplot(num_imfs + 2, 1, num_imfs + 2)
    plt.plot(reconstructed, color="green")
    plt.title("Reconstructed Signal")
    plt.xlabel("Sample")
    plt.ylabel("Amplitude")

    plt.tight_layout()
    plt.show()

# Usage example:
file_base_name = "PD_A2_0047"  # Replace with your actual file's base name
plot_emd_results(file_base_name)
