In [1]:
import numpy as np
import scipy.signal as sn
import matplotlib.pyplot as plt
from scipy.fft import fft, fftfreq, ifft
from librosa.display import specshow

In [2]:
def WaveletFunc(smpf, t_epoch, freq_coefs, WavType = 'morlet'):
    """
    Generates a family of wavelets based on the specified parameters.
    ---------------
    Parameters:
        smpf (float): 
            Sampling frequency for the wavelets.
        t_epoch (float): 
            Duration of the wavelets (in seconds).
        freq_coefs (array_like, float): 
            Array of center frequencies for the wavelets.
        WavType (str): 
            The type of wavelet to generate (default='morlet').

    Returns:
        wavelets (2D-array): 
            A 2D array with N wavelets, one for each frequency in freq_coefs.
        freq_coefs (array of float): 
            The input array of center frequencies for the wavelets.
    """
    
    N = freq_coefs.shape[0]
    
    t = np.linspace(-1*t_epoch, 1*t_epoch, (t_epoch*smpf))  # Time array
    
    
    wavelets = []
    cycles = 2
    for i in range(N):
        if i%3 == 0:
            cycles +=1
        
        f = freq_coefs[i]
        
        if (f <= 10):
            cycles = 3
        elif (f > 10) and (f <= 20):
            cycles = 4
        elif (f > 20) and (f <= 30):
            cycles = 5
        elif (f > 30) and (f <= 40):
            cycles = 6
        elif (f > 40) and (f <= 50):
            cycles = 7
        elif (f > 50) and (f <= 60):
            cycles = 8
        elif (f > 60) and (f <= 70):
            cycles = 9
        else:
            cycles = 10

        if WavType == 'morlet':
            s = cycles / (2 * np.pi * f)       
            A = 1/np.sqrt(s*np.sqrt(np.pi))
            wavelets.append(A*np.exp(-(t**2) / (2*s**2))*np.exp(1j*np.pi*f*t))
        elif WavType == "mexhat":
            pass
    
    return np.array(wavelets), freq_coefs

In [3]:
def CWT(signal, sf, BW, Nwvlts, plot = False, show_wvlt = False):
    """
    Performs a Complex Wavelet Transform (CWT).

    This implementation is based on ideas presented in Chapter 12 of the book "Analyzing Neural Time Series Data" by Mike X. Cohen, 2014.
    ---------------
    Parameters:
        signal (array_like, float): 
            1-D array representing the input signal.
        sf (int): 
            Sampling frequency of the signal in Hz.
        BW (float, float): 
            Frequency band range where the CWT is performed.
        Nwvlts (int): 
            Number of wavelets to use for the transform.
        plot (bool, optional): 
            If True, returns a spectrogram of the CWT (default=False).
        show_wvlt (bool, optional): 
            If True, plots the family of wavelets used (default=False).

    Returns:
        M_freq (2-D array, float): 
            Nwvlts x len(signal) array containing the CWT results.
        time (array_like, float): 
            Array representing the time values of the signal.
        freqs (array_like, float): 
            Array of center frequencies used in the CWT.
    """
    
    sf = int(sf)
    N = len(signal)
    duration = N // sf 
    signal = signal[0:int(sf*duration)]
    time = np.linspace(0, duration, int(sf*duration))
    k = 0
    M_freq = []
    M_phase = []

    T = 1/(sf)
    factor = sf//200
    
    yf_sg = fft(signal)
    xf_sg = fftfreq(N, T)[:N//factor]
            
    if show_wvlt == True:
        plt.figure()
        fig, (ax, bx, cx) = plt.subplots(3, 1, figsize = (15, 15))
        cx.plot(time, signal, 'black', 0.5)
    
    f_min = BW[0]
    f_max = BW[1]
    freq_coefs = np.logspace(np.log10(f_min), np.log10(f_max), Nwvlts)
    
    wvs, freqs = WaveletFunc(sf, duration, freq_coefs)
    
    for wv in wvs:
        yf_wv = fft(wv)
        xf_wv = fftfreq(N, T)[:N//factor]
        
        fft_conv = yf_sg*yf_wv
        s_ifft = ifft(fft_conv)
        
        if show_wvlt == True:
            ax.plot(wv, label = str(np.round(freqs[k], 2)) + ' Hz WVLT')
            ax.legend(fontsize = 8);
            bx.plot(xf_wv, 2.0/N * np.abs(yf_wv[:N//factor]))
            cx.plot(time, s_ifft.real)
        
        
        M_freq.append(np.abs(s_ifft))
        
        k += 1
    
    M_freq = np.array(M_freq)
    
    M_freq = np.concatenate([M_freq[:,(M_freq.shape[1]//2):],M_freq[:,0:M_freq.shape[1]//2]], axis = 1)
        
    if plot == True:
        plt.figure(1, figsize=(15, 8))
        ax = plt.axes([0.1, 0.65, 0.8, 0.2])
        specshow(M_freq, x_axis='time', y_axis='linear', x_coords=time, y_coords=freqs, shading='auto', cmap="jet")
        ax.set_xlabel("Time (s)")
        ax.set_ylabel("Frequency (Hz)")

        plt.colorbar()
        clim = np.percentile(M_freq, [5, 95])
        plt.clim(clim)
        #plt.yticks([2, 13, 35, 60, 100]);

    return M_freq, time, freqs

In [4]:
def CWT_long(signal, sf, window, BW, Nwvlts, plot=True, show_wvlt=False):
    """
    Performs a Complex Wavelet Transform (CWT).

    This implementation is based on ideas presented in Chapter 12 of the book "Analyzing Neural Time Series Data" by Mike X. Cohen, 2014.
    ---------------
    Parameters:
        signal (array_like, float): 
            1-D array representing the input signal.
        sf (int): 
            Sampling frequency of the signal in Hz.
        BW (float, float): 
            Frequency band range where the CWT is performed.
        Nwvlts (int): 
            Number of wavelets to use for the transform.
        plot (bool, optional): 
            If True, returns a spectrogram of the CWT (default=False).
        show_wvlt (bool, optional): 
            If True, plots the family of wavelets used (default=False).

    Returns:
        M_freq (2-D array, float): 
            Nwvlts x len(signal) array containing the CWT results.
        time (array_like, float): 
            Array representing the time values of the signal.
        freqs (array_like, float): 
            Array of center frequencies used in the CWT.
    """

    sample = window * sf
    signal_length = len(signal)
    N_samples = signal_length // int(sample)
    sample_points = np.arange(0, signal_length, int(sample))

    M_freqs = []
    freqs = None
    for sp in sample_points:
        signal_slice = signal[sp:int(sp + sample)]
        if len(signal_slice) >= sample:
            M_freq, time, freqs, _ = CWT(signal_slice, sf=sf, BW=BW, Nwvlts=Nwvlts, plot=plot, show_wvlt=show_wvlt)
            M_freqs.append(np.mean(M_freq, axis=1))
        else:
            pass

    return np.array(M_freqs), np.arange(0, N_samples, 1), freqs

In [5]:
def z_norm(M_freq):
    """
    Performs a Z-transform normalization on a CWT-matrix.

    This implementation is based on ideas presented in Chapter 18 of the book "Analyzing Neural Time Series Data" by Mike X. Cohen, 2014.
    ---------------
    Parameters:
        M_freq (2-D array of float): 
            Matrix containing the CWT results (N_wvlts x time).

    Returns:
        M_freq (2-D array of float): 
            Z-transform normalized CWT results.
    """
    
    M_mean = np.mean(M_freq, axis = 1)
    _, N = M_freq.shape
    M_pow2 = []
    for i in range(N): M_pow2.append((M_freq[:,i] - M_mean)**2)
        
    M_z_std = np.sqrt(np.mean(M_pow2, axis = 1))
    
    return ((M_freq.T-M_mean).T/M_z_std)