In [1]:
import os 
import torch
import numpy as np 
from scipy.io import wavfile
from scipy.interpolate import CubicSpline
from tqdm import tqdm


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

In [3]:
def count_zero_crossings(signal: torch.tensor):
    return torch.sum(torch.diff(torch.sign(signal)) != 0)

In [4]:
def sift(residual: torch.tensor, max_iter: int = 128, tol: float = 1e-4, device='cpu'):

    h = residual

    for i in range(max_iter):
        maxima = torch.where((h[:-2] < h[1:-1]) & (h[1:-1] > h[2:]))[0] + 1
        minima = torch.where((h[:-2] > h[1:-1]) & (h[1:-1] < h[2:]))[0] + 1

        if len(maxima) < 2 or len(minima) < 2:
            break

        h = h.cpu().numpy()
        maxima = maxima.cpu().numpy()
        minima = minima.cpu().numpy()

        upper_env = CubicSpline(maxima, h[maxima])(np.arange(len(h)))
        lower_env = CubicSpline(minima, h[minima])(np.arange(len(h)))

        h = torch.from_numpy(h).to(device)
        upper_env = torch.from_numpy(upper_env).to(device)
        lower_env = torch.from_numpy(lower_env).to(device)

        mean_env = (upper_env + lower_env) / 2
        new_h = h - mean_env

        zero_crossings = count_zero_crossings(new_h)
        extrema_count = len(maxima) + len(minima)

        if abs(zero_crossings - extrema_count) <= 1:
            return new_h

        h = new_h

    return h

In [5]:
def EmpiricalModeDecomposition(signal: torch.tensor, max_imfs: int = 16, device: torch.device = 'cpu'):

    imfs = []
    residual = signal

    for _ in range(max_imfs):
        imf = sift(residual, device=device)
        imfs.append(imf)
        residual -= imf

        if torch.all(torch.abs(residual) < 1e-6):
            break

    return imfs, residual


In [15]:
def EnsembleEmpiricalModeDecomposition(signal: torch.tensor, ensemble_size: int = 5, noise_std: float = 0.1, max_imfs: int = 16):
    """
    Perform EEMD on the input signal.

    Parameters:
        signal (torch.tensor): The input signal.
        ensemble_size (int): Number of noisy signals to generate.
        noise_std (float): Standard deviation of the added noise.
        max_imfs (int): Maximum number of IMFs to extract.

    Returns:
        list: Averaged IMFs.
        torch.tensor: Residual.
    """
    signal = signal.to(device)
    imf_list = []
    max_len_imfs = 0

    for _ in range(ensemble_size):
        # Generate Gaussian white noise and add to the signal
        noise = torch.randn_like(signal) * noise_std
        noisy_signal = signal + noise

        # Decompose noisy signal using EMD
        imfs, residual = EmpiricalModeDecomposition(noisy_signal, max_imfs=max_imfs, device=device)

        # Pad the IMFs to ensure consistent length
        while len(imfs) < max_imfs:
            imfs.append(torch.zeros_like(signal))

        if len(imf_list) == 0:
            imf_list = [imf.clone().unsqueeze(0) for imf in imfs]
        else:
            for idx, imf in enumerate(imfs):
                imf_list[idx] = torch.cat([imf_list[idx], imf.unsqueeze(0)], dim=0)

    # Average IMFs across all ensembles
    averaged_imfs = [torch.mean(imf_stack, dim=0) for imf_stack in imf_list]

    return averaged_imfs, residual


In [19]:
def save_audio(data: torch.tensor, file_path: str, sample_rate: int, device: torch.device = 'cpu'):
    """
    Save a tensor as an audio file in WAV format.

    Parameters:
        data (torch.tensor): Audio signal tensor.
        file_path (str): Path to save the WAV file.
        sample_rate (int): Sampling rate of the audio.
        device (torch.device): Device where the tensor is located.
    """
    # Normalize audio data
    data = (data / torch.max(torch.abs(data))) * 32767

    # Ensure tensor is on CPU before converting to numpy
    if data.is_cuda:
        data = data.cpu()
    data = data.numpy()

    # Replace NaNs or infinities with zeros
    if np.isnan(data).any() or np.isinf(data).any():
        data = np.nan_to_num(data)

    data = data.astype(np.int16)

    # Save to WAV file
    wavfile.write(file_path, sample_rate, data)


In [24]:
import matplotlib.pyplot as plt

def plot_imfs(imfs, residual, save_path, signal_name, sample_rate):
    """
    Plot all IMFs and the residual of a signal.

    Parameters:
        imfs (list of torch.tensor): List of Intrinsic Mode Functions.
        residual (torch.tensor): Residual signal.
        save_path (str): Path to save the plot.
        signal_name (str): Name of the signal being processed.
        sample_rate (int): Sample rate of the audio signal.
    """
    num_imfs = len(imfs)
    num_subplots = num_imfs + 1  # IMFs + Residual

    # Set up the figure size dynamically
    fig_height = max(3 * num_subplots, 10)  # Adjust height for clarity
    plt.figure(figsize=(15, fig_height))

    time_axis = torch.arange(len(imfs[0])) / sample_rate  # Time axis for plotting

    # Plot each IMF
    for idx, imf in enumerate(imfs):
        plt.subplot(num_subplots, 1, idx + 1)
        plt.plot(time_axis, imf.cpu().numpy(), color='blue')
        plt.title(f'IMF {idx + 1}')
        plt.ylabel('Amplitude')
        plt.xlabel('Time (s)')

    # Plot the residual
    plt.subplot(num_subplots, 1, num_subplots)
    plt.plot(time_axis, residual.cpu().numpy(), color='red')
    plt.title('Residual')
    plt.ylabel('Amplitude')
    plt.xlabel('Time (s)')

    # Adjust layout and save the figure
    plt.tight_layout()
    os.makedirs(save_path, exist_ok=True)
    plot_file = os.path.join(save_path, f'{signal_name}_imfs_plot.png')
    plt.savefig(plot_file, dpi=300)
    plt.close()


In [27]:
def process_audio_eemd(audio_path: str, save_path: str, min_val: int = None, max_val: int = None, 
                       ensemble_size: int = 5, noise_std: float = 0.1, max_imfs: int = 16):
    """
    Process a single audio file, extract EEMD IMFs, save them as audio, and plot them.

    Parameters:
        audio_path (str): Path to the input audio file.
        save_path (str): Path to save the outputs.
        min_val (int, optional): Minimum value for normalization.
        max_val (int, optional): Maximum value for normalization.
        ensemble_size (int): Number of noisy signals for EEMD.
        noise_std (float): Standard deviation of the added noise.
        max_imfs (int): Maximum number of IMFs.
    """
    sample_rate, data = wavfile.read(audio_path)
    data = torch.tensor(data).type(torch.float32).to(device)
    
    if min_val and max_val:
        data = (data - min_val) / (max_val - min_val)

    # Extract IMFs and residual using EEMD
    imfs, residual = EnsembleEmpiricalModeDecomposition(signal=data, 
                                                        ensemble_size=10, 
                                                        noise_std=noise_std, 
                                                        max_imfs=max_imfs)

    os.makedirs(save_path, exist_ok=True)

    # Save IMFs as .wav files
    for i, imf in enumerate(imfs):
        imf_path = os.path.join(save_path, f'eemd_imf_{i}.wav')
        save_audio(data=imf, file_path=imf_path, sample_rate=sample_rate, device=device)

    # Save residual as a .wav file
    residual_path = os.path.join(save_path, 'residual.wav')
    save_audio(data=residual, file_path=residual_path, sample_rate=sample_rate, device=device)

    # Plot IMFs and residual
    signal_name = os.path.basename(audio_path).split('.')[0]  # Extract name without extension
    plot_imfs(imfs=imfs, residual=residual, save_path=save_path, signal_name=signal_name, sample_rate=sample_rate)

inp_path = r"./neurovoz_v3/data/audios"
out_path = r"./Outputs/EEMD/IMFs"
files = list( os.listdir(inp_path))

In [28]:
with tqdm(enumerate(files), total=len(files)) as t:
    for i, file in t:
        audio_path = os.path.join(inp_path, file)
        save_path = os.path.join(out_path, file[:-4])
        process_audio_eemd(audio_path=audio_path, save_path=save_path)

  3%|▎         | 96/2976 [3:42:12<111:06:07, 138.88s/it]


KeyboardInterrupt: 

In [21]:
# def process_audio_eemd(audio_path: str, save_path: str, min_val: int = None, max_val: int = None, 
#                        ensemble_size: int = 5, noise_std: float = 0.1, max_imfs: int = 16):
#     """
#     Process a single audio file, extract EEMD IMFs, save them as audio, and plot them.

#     Parameters:
#         audio_path (str): Path to the input audio file.
#         save_path (str): Path to save the outputs.
#         min_val (int, optional): Minimum value for normalization.
#         max_val (int, optional): Maximum value for normalization.
#         ensemble_size (int): Number of noisy signals for EEMD.
#         noise_std (float): Standard deviation of the added noise.
#         max_imfs (int): Maximum number of IMFs.
#     """
#     sample_rate, data = wavfile.read(audio_path)
#     data = torch.tensor(data).type(torch.float32).to(device)
    
#     if min_val and max_val:
#         data = (data - min_val) / (max_val - min_val)

#     # Extract IMFs and residual using EEMD
#     imfs, residual = EnsembleEmpiricalModeDecomposition(signal=data, 
#                                                         ensemble_size=ensemble_size, 
#                                                         noise_std=noise_std, 
#                                                         max_imfs=max_imfs)

#     os.makedirs(save_path, exist_ok=True)

#     # Save IMFs as .wav files
#     for i, imf in enumerate(imfs):
#         imf_path = os.path.join(save_path, f'eemd_imf_{i}.wav')
#         save_audio(data=imf, file_path=imf_path, sample_rate=sample_rate, device=device)

#     # Save residual as a .wav file
#     residual_path = os.path.join(save_path, 'residual.wav')
#     save_audio(data=residual, file_path=residual_path, sample_rate=sample_rate, device=device)

#     # Plot IMFs and residual
#     signal_name = os.path.basename(audio_path).split('.')[0]  # Extract name without extension
#     plot_imfs(imfs=imfs, residual=residual, save_path=save_path, signal_name=signal_name, sample_rate=sample_rate)

# # %% Perform EEMD for all data
# inp_path = r"./neurovoz_v3/data/audios"
# out_path = r"./Outputs/EEMD/IMFs"
# files = list(os.listdir(inp_path))


In [22]:
# with tqdm(enumerate(files), total=len(files)) as t:
#     for i, file in t:
#         audio_path = os.path.join(inp_path, file)
#         save_path = os.path.join(out_path, file[:-4])
#         process_audio_eemd(audio_path=audio_path, save_path=save_path)

  0%|          | 3/2976 [00:31<8:32:51, 10.35s/it]


KeyboardInterrupt: 