In [6]:
import h5py
import numpy as np
from scipy.io import savemat
import matplotlib.pyplot as plt

# === MVMD parameters ===
alpha = 2000
tau = 0
K = 4
DC = True
init = 1
tol = 1e-6

# === Initialize storage for MVMD results ===
num_trials = 3
timepoints = 4000  # from index 2001 to 6001
channels = 306

u_all = np.zeros((num_trials, K, timepoints, channels))
u_hat_all = np.zeros((num_trials, K, timepoints, channels), dtype=np.complex64)
omega_all = np.zeros((num_trials, K, 2))

# === Load and process trials ===
file_path = r"D:\BTP\sub-1_ses-1_task-bcimici_meg.mat"

with h5py.File(file_path, 'r') as f:
    data = f['dataMAT']
    #trial_refs = data['trial'][0]
    trial_refs = data['trial']

    for i in range(num_trials):
        trial_ref = trial_refs[i,0]
        trial_data = f[trial_ref][2001:6001]  # (channels, timepoints)
        #first_trial = f[first_trial_ref][2001:6001]
        print("Shape of trial:", trial_data.shape)
        # Transpose to (channels, timepoints) if needed
        if trial_data.shape[0] != channels:
            trial_data = trial_data.T

        print(f"Processing trial {i+1}, shape: {trial_data.shape}")

        # Apply MVMD
        u, u_hat, omega = MVMD(trial_data, alpha, tau, K, DC, init, tol)
        # u shape: (K, C, T) → we want to store as (K, T, C)
        u_all[i] = u      # (K, T, C)
        u_hat_all[i] = u_hat
        omega_all[i] = np.transpose(omega,(1,0))                        # (K, C)

# === Save everything into one .mat file ===
save_dict = {
    'u': u_all,            # shape (3, 4, 4000, 306)
    'u_hat': u_hat_all,    # shape (3, 4, 4000, 306)
    'omega': omega_all     # shape (3, 4, 306)
}

savemat('mvmd_3_trials_N10_K04.mat', save_dict)
print("Saved to mvmd_3_trials_N10_K04.mat")


Shape of trial: (4000, 306)
Processing trial 1, shape: (306, 4000)


MemoryError: Unable to allocate 21.9 GiB for an array with shape (300, 4000, 4, 306) and data type complex128

In [1]:
import numpy as np
from scipy.fftpack import fft, ifft, fftshift, ifftshift

def MVMD(signal, alpha, tau, K, DC, init, tol):
    x, y = signal.shape
    if x > y:
        C = y
        signal = signal.T
    else:
        C = x

    T = signal.shape[1]
    t = np.arange(1, T + 1) / T
    freqs = t - 0.5 - 1 / T

    # FFT of the signal
    f_hat = fftshift(fft(signal, axis=1))
    f_hat_plus = f_hat.copy()
    f_hat_plus[:, :T//2] = 0  # Set the negative frequencies to zero

    # Initialization
    N = 500
    Alpha = alpha * np.ones(K)
    u_hat_plus = np.zeros((N, len(freqs), K, C), dtype=complex)
    omega_plus = np.zeros((N, K))

    if init == 1:
        omega_plus[0, :] = (0.5 / K) * np.arange(K)
    elif init == 2:
        omega_plus[0, :] = np.sort(np.exp(np.log(1 / T) + (np.log(0.5) - np.log(1 / T)) * np.random.rand(K)))
    else:
        omega_plus[0, :] = 0

    if DC:
        omega_plus[0, 0] = 0

    lambda_hat = np.zeros((N, len(freqs), C), dtype=complex)
    uDiff = tol + np.finfo(float).eps
    n = 0
    sum_uk = np.zeros((C, len(freqs)), dtype=complex)

    # Main algorithm
    while uDiff > tol and n < N:
        k = 0
        for c in range(C):
            sum_uk[c, :] = u_hat_plus[n, :, K-1, c] + sum_uk[c, :] - u_hat_plus[n, :, k, c]
            u_hat_plus[n+1, :, k, c] = (f_hat_plus[c, :] - sum_uk[c, :] - lambda_hat[n, :, c] / 2) / (1 + Alpha[k] * (freqs - omega_plus[n, k])**2)

        if not DC:
            temp1 = 0
            temp2 = 0
            for c in range(C):
                numerator = np.sum(freqs * np.abs(u_hat_plus[n+1, :, k, c])**2)
                denominator = np.sum(np.abs(u_hat_plus[n+1, :, k, c])**2)
                temp1 += numerator
                temp2 += denominator
            omega_plus[n+1, k] = temp1 / temp2

        for k in range(1, K):
            for c in range(C):
                sum_uk[c, :] = u_hat_plus[n+1, :, k-1, c] + sum_uk[c, :] - u_hat_plus[n, :, k, c]
                u_hat_plus[n+1, :, k, c] = (f_hat_plus[c, :] - sum_uk[c, :] - lambda_hat[n, :, c] / 2) / \
                                          (1 + Alpha[k] * (freqs - omega_plus[n, k])**2)

            temp1 = 0
            temp2 = 0
            for c in range(C):
                numerator = np.sum(freqs * np.abs(u_hat_plus[n+1, :, k, c])**2)
                denominator = np.sum(np.abs(u_hat_plus[n+1, :, k, c])**2)
                temp1 += numerator
                temp2 += denominator
            omega_plus[n+1, k] = temp1 / temp2

        for c in range(C):
            lambda_hat[n+1, :, c] = lambda_hat[n, :, c] + tau * (np.sum(u_hat_plus[n+1, :, :, c], axis=1) - f_hat_plus[c, :])

        n += 1

        uDiff = np.finfo(float).eps
        for i in range(K):
            for c in range(C):
                uDiff += np.sum((u_hat_plus[n, :, i, c] - u_hat_plus[n-1, :, i, c]) * \
                                np.conj(u_hat_plus[n, :, i, c] - u_hat_plus[n-1, :, i, c]))
        uDiff = np.abs(uDiff)

    N = min(N, n)
    omega = omega_plus[:N, :]

    u_hat = np.zeros((T, K, C), dtype=complex)
    for c in range(C):
        u_hat[:, :, c] = u_hat_plus[N, :, :, c]

    u = np.zeros((K, T, C))
    for k in range(K):
        for c in range(C):
            u[k, :, c] = np.real(ifft(ifftshift(u_hat[:, k, c])))

    u_hat = np.zeros((len(freqs), K, C), dtype=complex)
    for k in range(K):
        for c in range(C):
            u_hat[:, k, c] = fftshift(fft(u[k, :, c]))

    u_hat = np.transpose(u_hat, (1, 0, 2))

    return u, u_hat, omega
