In [7]:
import h5py
import numpy as np
from scipy.io import savemat
import matplotlib.pyplot as plt
import torch

# === MVMD parameters ===
alpha = 2000
tau = 0
K = 4
DC = True
init = 1
tol = 1e-6

# === Initialize storage for MVMD results ===
num_trials = 3
timepoints = 4000  # from index 2001 to 6001
channels = 306

u_all = np.zeros((num_trials, K, timepoints, channels))
u_hat_all = np.zeros((num_trials, K, timepoints, channels), dtype=np.complex64)
omega_all = np.zeros((num_trials, K, 2))

# === Load and process trials ===
file_path = r"D:\BTP\sub-1_ses-1_task-bcimici_meg.mat"

with h5py.File(file_path, 'r') as f:
    data = f['dataMAT']
    #trial_refs = data['trial'][0]
    trial_refs = data['trial']

    for i in range(num_trials):
        trial_ref = trial_refs[i,0]
        trial_data = f[trial_ref][2001:6001]  # (channels, timepoints)
        #first_trial = f[first_trial_ref][2001:6001]
        print("Shape of trial:", trial_data.shape)
        # Transpose to (channels, timepoints) if needed
        if trial_data.shape[0] != channels:
            trial_data = trial_data.T

        print(f"Processing trial {i+1}, shape: {trial_data.shape}")

        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Running on device: {device}")
        trial_tensor = torch.tensor(trial_data, dtype=torch.float32, device=device)


        # Apply MVMD
        u, u_hat, omega = MVMD(trial_tensor, alpha, tau, K, DC, init, tol)
        # u shape: (K, C, T) → we want to store as (K, T, C)
        u_all[i] = u      # (K, T, C)
        u_hat_all[i] = u_hat
        omega_all[i] = np.transpose(omega,(1,0))                        # (K, C)

# === Save everything into one .mat file ===
save_dict = {
    'u': u_all,            # shape (3, 4, 4000, 306)
    'u_hat': u_hat_all,    # shape (3, 4, 4000, 306)
    'omega': omega_all     # shape (3, 4, 306)
}

savemat('mvmd_3_trials_N10_K04_02.mat', save_dict)
print("Saved to mvmd_3_trials_N10_K04.mat")


Shape of trial: (4000, 306)
Processing trial 1, shape: (306, 4000)
Running on device: cuda


OutOfMemoryError: CUDA out of memory. Tried to allocate 10.98 GiB. GPU 0 has a total capacity of 4.00 GiB of which 3.17 GiB is free. Of the allocated memory 23.38 MiB is allocated by PyTorch, and 18.62 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [6]:
import torch
import numpy as np
from torch.fft import fft, ifft, fftshift, ifftshift

def MVMD(signal, alpha, tau, K, DC, init, tol):
    device = signal.device
    x, y = signal.shape
    if x > y:
        C = y
        signal = signal.T
    else:
        C = x

    T = signal.shape[1]
    t = torch.arange(1, T + 1, device=device) / T
    freqs = t - 0.5 - 1 / T

    # FFT of the signal
    f_hat = fftshift(fft(signal, dim=1), dim=1)
    f_hat_plus = f_hat.clone()
    f_hat_plus[:, :T // 2] = 0  # Set the negative frequencies to zero

    # Initialization
    N = 300
    Alpha = alpha * torch.ones(K, device=device)
    u_hat_plus = torch.zeros((N + 1, len(freqs), K, C), dtype=torch.complex64, device=device)
    omega_plus = torch.zeros((N + 1, K), device=device)

    if init == 1:
        omega_plus[0, :] = (0.5 / K) * torch.arange(K, device=device)
    elif init == 2:
        omega_plus[0, :] = torch.sort(torch.exp(torch.log(torch.tensor(1. / T, device=device)) + 
                                         (torch.log(torch.tensor(0.5, device=device)) - 
                                          torch.log(torch.tensor(1. / T, device=device))) * 
                                         torch.rand(K, device=device)))[0]
    else:
        omega_plus[0, :] = 0

    if DC:
        omega_plus[0, 0] = 0

    lambda_hat = torch.zeros((N + 1, len(freqs), C), dtype=torch.complex64, device=device)
    uDiff = tol + torch.finfo(torch.float32).eps
    n = 0
    sum_uk = torch.zeros((C, len(freqs)), dtype=torch.complex64, device=device)

    # Main loop
    while uDiff > tol and n < N:
        k = 0
        for c in range(C):
            sum_uk[c, :] = u_hat_plus[n, :, K - 1, c] + sum_uk[c, :] - u_hat_plus[n, :, k, c]
            u_hat_plus[n + 1, :, k, c] = (f_hat_plus[c, :] - sum_uk[c, :] - lambda_hat[n, :, c] / 2) / \
                                         (1 + Alpha[k] * (freqs - omega_plus[n, k]) ** 2)

        if not DC:
            temp1 = 0
            temp2 = 0
            for c in range(C):
                temp1 += torch.sum(freqs * torch.abs(u_hat_plus[n + 1, :, k, c]) ** 2)
                temp2 += torch.sum(torch.abs(u_hat_plus[n + 1, :, k, c]) ** 2)
            omega_plus[n + 1, k] = temp1 / temp2

        for k in range(1, K):
            for c in range(C):
                sum_uk[c, :] = u_hat_plus[n + 1, :, k - 1, c] + sum_uk[c, :] - u_hat_plus[n, :, k, c]
                u_hat_plus[n + 1, :, k, c] = (f_hat_plus[c, :] - sum_uk[c, :] - lambda_hat[n, :, c] / 2) / \
                                              (1 + Alpha[k] * (freqs - omega_plus[n, k]) ** 2)

            temp1 = 0
            temp2 = 0
            for c in range(C):
                temp1 += torch.sum(freqs * torch.abs(u_hat_plus[n + 1, :, k, c]) ** 2)
                temp2 += torch.sum(torch.abs(u_hat_plus[n + 1, :, k, c]) ** 2)
            omega_plus[n + 1, k] = temp1 / temp2

        for c in range(C):
            lambda_hat[n + 1, :, c] = lambda_hat[n, :, c] + tau * (
                torch.sum(u_hat_plus[n + 1, :, :, c], dim=1) - f_hat_plus[c, :])

        n += 1
        uDiff = torch.finfo(torch.float32).eps
        for i in range(K):
            for c in range(C):
                uDiff += torch.sum(torch.abs(u_hat_plus[n, :, i, c] - u_hat_plus[n - 1, :, i, c]) ** 2)

    N = min(N, n)
    omega = omega_plus[:N, :].cpu().numpy()

    u_hat = u_hat_plus[N, :, :, :].permute(1, 0, 2)
    u = torch.zeros((K, T, C), dtype=torch.float32, device=device)
    for k in range(K):
        for c in range(C):
            u[k, :, c] = torch.real(ifft(ifftshift(u_hat[k, :, c], dim=0)))

    return u.cpu().numpy(), u_hat.cpu().numpy(), omega
