In [1]:
import os 
import torch
import numpy as np 
from scipy.io import wavfile
from scipy.interpolate import CubicSpline
from tqdm import tqdm

In [2]:
# import os 
# import torch
# import numpy as np 
# from scipy.io import wavfile
# from scipy.interpolate import CubicSpline
# from scipy.signal import stft
# import matplotlib.pyplot as plt
# from tqdm import tqdm

In [4]:
# Count number of zero crossings

def count_zero_crossings(signal:torch.tensor):
    return torch.sum(torch.diff(torch.sign(signal)) != 0)

In [5]:
# Find the IMF for a residual

def sift(residual:torch.tensor, max_iter:int=128, tol:float=1e-4, device:torch.device='cpu'):

    h = residual

    for i in range(max_iter):
        
        maxima = torch.where((h[:-2] < h[1:-1]) & (h[1:-1] > h[2:]))[0] + 1
        minima = torch.where((h[:-2] > h[1:-1]) & (h[1:-1] < h[2:]))[0] + 1

        if len(maxima) < 2 or len(minima) < 2:
            break
        
        h = h.cpu().numpy()
        maxima = maxima.cpu().numpy()
        minima = minima.cpu().numpy()

        upper_env = CubicSpline(maxima, h[maxima]) (torch.arange(len(h)))
        lower_env = CubicSpline(minima, h[minima]) (torch.arange(len(h)))

        h = torch.from_numpy(h).to(device)
        upper_env = torch.from_numpy(upper_env).to(device)
        lower_env = torch.from_numpy(lower_env).to(device)

        mean_env = (upper_env + lower_env) / 2
        new_h = h - mean_env

        zero_crossings = count_zero_crossings(new_h)
        extrema_count = len(maxima) + len(minima)
        
        if abs(zero_crossings - extrema_count) <= 1:
            #print(f"IMF found with {i + 1} iterations")
            return new_h
        
        h = new_h

    return h

In [6]:
# Perform EMD

def EmpricalModeDecomposition(signal:torch.tensor, max_imfs:int=10, device:torch.device='cpu'):

    imfs = []
    residual = signal

    for _ in range(max_imfs):

        imf = sift(residual, device=device)
        imfs.append(imf)
        residual -= imf

        if torch.all(torch.abs(residual) < 1e-6):
            break

    return imfs, residual

In [7]:
# Save the audio file

def save_audio(data:torch.tensor, file_path:str, sample_rate:int, device:torch.device='cpu'):

    data = (data / torch.max( torch.abs(data) ) ) * 32767

    if device == 'cuda':
        data = data.cpu()
    data = data.numpy()

    if np.isnan(data).any() or np.isinf(data).any():
        data = np.nan_to_num(data)
        return  
    data = data.astype(np.int16)

    wavfile.write(file_path, sample_rate, data)

In [8]:
# # Conversion to mel  spectrogram

# def hz_to_mel(hz):
#     return 2595 * np.log10(1 + hz / 700.0)

# def mel_to_hz(mel):
#     return 700 * (10**(mel / 2595) - 1)

# def save_mel_spectrogram(signal:torch.tensor, sample_rate:int, save_path:str, device:torch.device='cpu'):

#     n_mels = 512
#     n_fft = 1024
#     hop_length = 512

#     mel_bins = np.linspace(hz_to_mel(0), hz_to_mel(sample_rate / 2), n_mels + 2)
#     hz_bins = mel_to_hz(mel_bins)
#     bin_idx = np.floor(hz_bins / (sample_rate / n_fft)).astype(int)
    
#     # Create the filter bank
#     filter_bank = np.zeros((n_mels, n_fft // 2 + 1))
#     for m in range(1, n_mels + 1):
#         filter_bank[m - 1, bin_idx[m - 1]:bin_idx[m]] = np.linspace(0, 1, bin_idx[m] - bin_idx[m - 1])
#         filter_bank[m - 1, bin_idx[m]:bin_idx[m + 1]] = np.linspace(1, 0, bin_idx[m + 1] - bin_idx[m])
#     return filter_bank

# # Compute STFT
# f, t, Zxx = stft(signal, fs, nperseg=1024)
# power_spectrogram = np.abs(Zxx)**2

# # Convert to Mel scale
# n_mels = 128
# mel_filter = mel_filter_bank(n_mels, Zxx.shape[0] * 2, fs)
# mel_spectrogram = np.dot(mel_filter, power_spectrogram)

# # Plot the Mel spectrogram
# plt.figure(figsize=(10, 6))
# plt.imshow(10 * np.log10(mel_spectrogram + 1e-10), origin='lower', aspect='auto', cmap='magma', extent=[t[0], t[-1], 0, n_mels])
# plt.colorbar(label='Power (dB)')
# plt.title('Mel Spectrogram')
# plt.xlabel('Time (s)')
# plt.ylabel('Mel bands')
# plt.show()

In [9]:
# process a audio file, given path

def process_audio(audio_path:str, save_path:str, min_val:int=None, max_val:int=None):
    
    sample_rate, data = wavfile.read(audio_path)
    data = torch.tensor(data).type(torch.float32).to(device)
    
    if min_val and max_val:
        data = (data - min_val) / (max_val - min_val)

    imfs, residual = EmpricalModeDecomposition(signal=data, device=device)
    reconstructed = sum(imfs) + residual

    os.makedirs(save_path, exist_ok=True)

    for i, imf in enumerate(imfs):
        imf_path = save_path + '/imf_' + str(i) + '.wav'
        save_audio(data=imf, file_path=imf_path, sample_rate=sample_rate, device=device)
    
    residual_path = save_path + '/residual.wav'
    save_audio(data=residual, file_path=residual_path, sample_rate=sample_rate, device=device)

    reconstructed_path = save_path + '/reconstructed.wav'
    save_audio(data=reconstructed, file_path=reconstructed_path, sample_rate=sample_rate, device=device)

In [10]:
# For a single data point (trial)

audio_path = '../../Dataset/neurovoz_v3/data/audios/HC_A1_0034.wav'
save_path = '../../Dataset/EMD Audios/HC_A1_0034'

process_audio(audio_path, save_path)

In [11]:
# # Perform EMD for all the data

# inp_path = '../../Dataset/neurovoz_v3/data/audios'
# out_path = '../../Dataset/EMD Audios'
# files = list( os.listdir(inp_path))

# with tqdm(enumerate(files), total=len(files)) as t:
    
#     for i, file in t:
        
#         audio_path = os.path.join(inp_path, file)
#         save_path = os.path.join(out_path, file[:-4])
        
#         process_audio(audio_path=audio_path, save_path=save_path)