In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
import mne
from vmdpy import VMD
import os

In [None]:
sample = mne.read_epochs("F:\Research\QUML\EEG\Data\derivatives\sub-01\ses-01\sub-01_ses-01_eeg-epo.fif")[40]
sample_df = sample.to_data_frame()

full_interval = sample_df.time.unique()
conc_interval = full_interval[np.where((full_interval >= -0.5) & (full_interval <= 0))[0]]
cue_interval = full_interval[np.where((full_interval >= 0) & (full_interval <= 0.5))[0]]
action_interval = full_interval[np.where((full_interval >= 0.6) & (full_interval <= 2.9))[0]]
relax_interval = full_interval[np.where((full_interval >= 3) & (full_interval <= 4))[0]]

action_data = sample_df[sample_df.time.isin(conc_interval)]

In [None]:
channels = ["A"+str(i) for i in range(10, 16)] + ["C"+str(i) for i in range(5, 8)] + ["D"+str(i) for i in range(5, 8)] + ["D"+str(i) for i in range(20, 24)]

In [None]:
action_selected = action_data[channels]

# Checking provided pipeline

In [None]:
def signalBoundary(x = None,nfft = None,SignalLength = None,isInverse = None):
    '''
    signalBoundary applies mirroring to signal if ifInverse is 0 and removes
    mirrored signal otherwise. Mirror extension of the signal by half its
    length on each side. Removing mirrored signal is a inverse process of the
    mirror extension.
    '''
    HalfSignalLength = int(np.floor(SignalLength / 2))
    MirroredSignalLength = int(SignalLength * 2 + (np.floor(SignalLength / 2) - np.ceil(SignalLength / 2)))
    if isInverse:
        xr = np.real(np.fft.ifft(x,nfft,axis=0)) #### need to check
        y = xr[HalfSignalLength:(MirroredSignalLength - HalfSignalLength)]

    else:
        xr = np.vstack([x[HalfSignalLength:0:- 1],x,x[SignalLength-1:int(np.ceil(SignalLength / 2))-1:- 1]]) #fixed
        y = np.fft.fft(xr,nfft,axis=0)
    return y
def initialCentralFreqByFindPeaks(x = None,f = None,nfft = None,numIMFs = None):
    '''
    Initialize central frequencies by finding the locations of signal peaks
    in frequency domain by using findpeaks function. The number of peaks is
    determined by NumIMFs.
    '''
    BW = 2 / nfft
    minBWGapIndex = 2 * BW / f[1]
    x = np.where(x < np.mean(x), np.mean(x), x)
    TF, _ = find_peaks(np.ravel(x), distance=int(minBWGapIndex[0]))
    pkst = x[TF]
    locst = f[TF]
    numpPeaks = len(pkst)
    # Check for DC component
    if x[0] >= x[1]:
      pks = np.zeros((numpPeaks + 1,1))
      locs = np.zeros((numpPeaks + 1,1))
      pks[1:len(pkst)+1] = pkst
      locs[1:len(pkst)+1] = locst
      pks[0] = x[0]
      locs[0] = f[1]
    else:
      pks = np.zeros((numpPeaks,1))
      locs = np.zeros((numpPeaks,1))
      pks[0:len(pkst)] = pkst
      locs[0:len(pkst)] = locst

    index = (-pks).argsort(axis=0)[:len(pks)]  # descending sorting
    np.random.seed(1337)
    centralFreq = 0.5 * np.random.rand(numIMFs,1)
    # Check if the number of peaks is less than number of IMFs
    if len(locs) < numIMFs:
        centralFreq[0:len(locs[index])] = locs
    else:
        centralFreq[0:numIMFs] = locs[index[0:numIMFs]].reshape(numIMFs,1) # was (3,1) now (numIMFs,1)

    return centralFreq

def computeVMD(x = None,numIMFs = None,MaxIterations = None,AbsoluteTolerance = None,TolFactor = None,InitializeMethod = None,penaltyFactor = None,LMUpdateRate = None,Display = None):
    '''
    VMD process
    '''
    SignalLength = len(x)
    MirroredSignalLength = int(SignalLength * 2 + (np.floor(SignalLength / 2) - np.ceil(SignalLength / 2)))
    FFTLength = MirroredSignalLength
    nfft = FFTLength

    relativeDiff = float('inf')
    absoluteDiff = relativeDiff
    RelativeTolerance = AbsoluteTolerance * TolFactor
    tau = LMUpdateRate
    eps = 2.2204e-16

    InitialIMFs = np.zeros((len(x),numIMFs))
    # Reduce edge effect by mirroring signal; 0 stands for apply mirror
    # signal frequency domain with full bandwidth
    sigFDFull = signalBoundary(x,nfft,SignalLength,0)
    if not np.mod(nfft,2) :
        NumHalfFreqSamples = int(nfft / 2 + 1)
    else:
        NumHalfFreqSamples = int((nfft + 1) / 2)
        # Get half of the bandwidth
    sigFD = np.array(sigFDFull[0:NumHalfFreqSamples],dtype = 'complex_')
    # fft for initial IMFs and get half of bandwidth
    initIMFfdFull = np.real(np.fft.fft(InitialIMFs,nfft,axis=0))
    initIMFfd = initIMFfdFull[0:NumHalfFreqSamples] + eps
    IMFfd = np.array(initIMFfd,dtype = 'complex_')

    sumIMF = np.array(np.sum(IMFfd, axis=1).reshape(len(IMFfd),1),dtype = 'complex_')
    InitialLM = np.zeros((NumHalfFreqSamples,1), dtype=np.complex_)
    LM = InitialLM

    # Frequency vector from [0,0.5) for odd nfft and [0,0.5] for even nfft
    #f = np.linspace(0,((nfft / 2)-1)/ nfft,SignalLength).reshape(len(x), 1)
    f = np.linspace(0,(np.floor(nfft / 2))/ nfft,NumHalfFreqSamples).reshape(NumHalfFreqSamples, 1)
    # Get the initial central frequencies
    if 'peaks' == (InitializeMethod):
        centralFreq = initialCentralFreqByFindPeaks(np.abs(sigFD),f,nfft,numIMFs)
    else:
        if 'grid' == (InitializeMethod):
            # grid within [0,0.5]
            centralFreq = np.transpose((np.multiply((0.5 / numIMFs),((np.arange(1,numIMFs+1)) - 1))))
        else:
            if 'random' == (InitializeMethod):
                # random selected from U[0,0.5]
                centralFreq = 0.5 * np.random.rand(numIMFs,1)
    # Progress display set-up
    if (Display):
        print('#Iteration  |  Absolute Improvement  |  Relative Improvement  \n')

    iter = 0
    initIMFNorm = np.abs(initIMFfd) ** 2
    normIMF = np.zeros((initIMFfd.shape[0],initIMFfd.shape[1]))

    # Optimization iterations
    while (iter < MaxIterations and (relativeDiff > RelativeTolerance or absoluteDiff > AbsoluteTolerance)):
        for kk in range (0,numIMFs):
            sumIMF[:,[0]] = np.array((sumIMF[:,0] - IMFfd[:,kk]).reshape(len(IMFfd),1),dtype = 'complex_')
            IMFfd[:,[kk]] = np.array((sigFD - sumIMF[:,[0]] + LM / 2) / (1 + penaltyFactor * (f - centralFreq[kk]) ** 2),dtype = 'complex_')
            normIMF[:,[kk]] = (np.abs(IMFfd[:,kk]) ** 2).reshape(len(IMFfd),1)
            centralFreq[kk] = np.dot(f.T,normIMF[:,kk]) / np.sum(normIMF[:,kk])
            sumIMF[:,[0]] = np.array((sumIMF[:,0] + IMFfd[:,kk]).reshape(len(IMFfd),1),dtype = 'complex_')
        LM = np.array(LM + tau * (sigFD - sumIMF),dtype = 'complex_')

        absDiff = np.mean((np.abs(IMFfd - initIMFfd)) ** 2,axis=0)
        absoluteDiff = np.sum(absDiff,axis=0)
        relativeDiff = np.sum(absDiff / np.mean(initIMFNorm,axis=0),axis=0)
        # Sort IMF and central frequecies in descend order
        # In ADMM, the IMF with greater power will be substracted first
        newIMFd = np.sum(np.abs(IMFfd) ** 2,axis=0)
        sortedIndex = (-newIMFd).argsort(axis=0)[:len(newIMFd)]  # descending sorting
        IMFfd = np.array(IMFfd[:,sortedIndex],dtype = 'complex_')
        centralFreq = centralFreq[sortedIndex[0:len(centralFreq)]]
        initIMFfd = np.array(IMFfd,dtype = 'complex_')
        initIMFNorm = normIMF
        iter = iter + 1

        # Progress display
        if (Display) and (iter == 1 or not np.mod(iter,20)  or iter >= MaxIterations or (relativeDiff <= RelativeTolerance and absoluteDiff <= AbsoluteTolerance)):
            print ("   ", iter,"    |", absoluteDiff," |  ", relativeDiff,"\n")


    ## Convert to time domain signal
    # Transform to time domain

    IMFfdFull = np.zeros((nfft,numIMFs), dtype=np.complex_)
    IMFfdFull[0:IMFfd.shape[0]] = IMFfd
    if not np.mod(nfft,2) :
      IMFfdFull[IMFfd.shape[0]: len(IMFfdFull)] = np.conj(IMFfd[len(IMFfd) - 2:0:-1]) # was np.conj(IMFfd[len(IMFfd) - 1:1:-1])
    else:
      IMFfdFull[IMFfd.shape[0]: len(IMFfdFull)] = np.conj(IMFfd[len(IMFfd)-1:0:-1]) # was np.conj(IMFfd[len(IMFfd):0:-1])
    index = (-centralFreq).argsort(axis=0)[:len(centralFreq)]  # descending sorting

    IMFs = signalBoundary(IMFfdFull[:,index].reshape(len(IMFfdFull),numIMFs),nfft,SignalLength,1)

    # Output information
    class myinfo:
          pass

    info = myinfo() # Create an empty  record

    # Fill the fields of the record
    info.ExitFlag = 0
    info.CentralFrequencies = centralFreq
    info.NumIterations = iter
    info.AbsoluteImprovement = absoluteDiff
    info.RelativeImprovement = relativeDiff
    info.LagrangeMultiplier = LM

    # Specify stopping flag
    if iter < MaxIterations:
        info.ExitFlag = 1

    # Calculate residual
    residual = x - np.sum(IMFs,axis=1).reshape(len(x),1)

    if (Display):
      if  (info.ExitFlag):
        print (iter, 'iterations are done. The value of absolut and relative tolerances are ',AbsoluteTolerance, 'and', RelativeTolerance)
      else:
        print('All', iter, 'iterations are done!')

    return IMFs,residual,info

In [None]:
# parameters configuration (default)
numIMFs = 3
MaxIterations = 500
AbsoluteTolerance = 5e-06
TolFactor = 1000.0
InitializeMethod = 'peaks'
penaltyFactor = 1000
LMUpdateRate = 0.01
Display = True

# computing IMFs
IMFs,residual,info = computeVMD(x,numIMFs,MaxIterations,AbsoluteTolerance,TolFactor,InitializeMethod,penaltyFactor,LMUpdateRate,Display)

In [None]:
#Plotting signal, IMFs and residual
n = numIMFs+2
plt.figure(figsize= (16,12))

plt.subplot(n,1,1)
plt.plot(x)
plt.title('Main Data',fontweight="bold")


for i in range(2, n):
  plt.subplot(n,1,i)
  plt.plot(IMFs[:,i-2])
  plt.title(f'IMF {i-1}',fontweight="bold")

plt.subplot(n,1,n)
plt.plot(residual)
plt.title('Residual',fontweight="bold")
plt.tight_layout()
plt.show()

# Checking vmdpy

In [None]:
#. some sample parameters for VMD
alpha = 5000      # moderate bandwidth constraint
tau = 0           # noise-tolerance (no strict fidelity enforcement)
K = 5              # 3 modes
DC = 0             # no DC part imposed
init = 1           # initialize omegas uniformly
tol = 1e-7

In [None]:
figures_dir = "F:\Research\QUML\EEG\Figures\VMD plots for selected channels"

In [None]:
for channel in action_selected.columns:
    x = action_selected[channel]
    print("Channel: ", channel)
    u, u_hat, omega = VMD(x, alpha, tau, K, DC, init, tol)
    plt.figure(figsize=(30,20))
    plt.plot(u[0])
    plt.title(channel + " IMF 1", fontdict={'fontsize': 45})
    plt.savefig(os.path.join(figures_dir, channel + "_imf_1.png"))
    plt.show()
    plt.figure(figsize=(30,20))
    plt.plot(u[1])
    plt.title(channel + " IMF 2", fontdict={'fontsize': 45})
    plt.savefig(os.path.join(figures_dir, channel + "_imf_2.png"))
    plt.show()
    plt.figure(figsize=(30,20))
    plt.plot(u[2])
    plt.title(channel + " IMF 3", fontdict={'fontsize': 45})
    plt.savefig(os.path.join(figures_dir, channel + "_imf_3.png"))
    plt.show()
    plt.figure(figsize=(30,20))
    plt.plot(u[3])
    plt.title(channel + " IMF 4", fontdict={'fontsize': 45})
    plt.savefig(os.path.join(figures_dir, channel + "_imf_4.png"))
    plt.show()
    plt.figure(figsize=(30,20))
    plt.plot(u[4])
    plt.title(channel + " IMF 5", fontdict={'fontsize': 45})
    plt.savefig(os.path.join(figures_dir, channel + "_imf_5.png"))
    plt.show()

In [None]:
Fs = 256
n = len(x) # length of the signal
k = np.arange(n)
T = n/Fs
frq = k/T # two sides frequency range
frq = frq[:len(frq)//2] # one side frequency range

Y = np.fft.fft(u[0])/n # dft and normalization
Y = Y[:n//2]

plt.plot(frq,abs(Y)) # plotting the spectrum
plt.xlabel('Freq (Hz)')
plt.ylabel('|Y(freq)|')
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt


# Generate a noisy ECG signal
np.random.seed(0)
t = np.linspace(0, 10, 10000)
ecg_signal = np.sin(2 * np.pi * 1.2 * t) + np.sin(2 * np.pi * 2.3 * t) + \
             np.sin(2 * np.pi * 6 * t) + 0.5 * np.random.randn(len(t))

# Apply a bandpass filter to isolate the QRS complex
b, a = butter(4, (5, 15), btype='bandpass', fs=1000)
qrs_signal = filtfilt(b, a, ecg_signal)

# Apply VMD to the QRS complex signal
vmd = VMD(qrs_signal, num_modes=10)
vmd.run()
vmd_reconstructed = vmd.reconstruct()

# Plot the original and reconstructed signals
plt.figure(figsize=(12, 6))
plt.plot(t, qrs_signal, alpha=0.5, label='Original QRS signal')
plt.plot(t, vmd_reconstructed, linewidth=2, label='Reconstructed signal')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.legend()
plt.show()