In [None]:
from crosspy.preprocessing.signal import filter_data
from crosspy.core.synchrony import wpli, cplv
from crosspy.core.criticality import dfa, compute_BiS, efi
from numpy import log10
import matplotlib.pyplot as plt
from scipy.io import loadmat
from scipy.io import savemat 
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
import time
from pathlib import Path
import cupy as cp
import gc 

In [None]:
def free_gpu_mem():
    mempool = cp.get_default_memory_pool()
    mempool.free_all_blocks()

    for i in range(3):
        gc.collect(i)

def process_data(data, sfreq, n_freqs):
    #data = np.load('test.npy') # channels x samples
    NORM_POWER       = True
    n_chans          = data.shape[0]    
    omega            = 7.5
    MIN_FREQ         = 1     # Hz
    MAX_FREQ         = 225   # Hz    
    DFA_MIN_WIN      = 20 # num x [f narrow-band cycle length]
    DFA_MAX_WIN      = 10 # e.g., 5 means 1/5 of the recording duration
    FEI_WIN_MULTIP   = 50
    
    frequencies        = np.round(np.geomspace(MIN_FREQ, MAX_FREQ, n_freqs), 2)
    wpli_freqwise      = np.zeros((n_freqs, n_chans, n_chans))
    cplv_freqwise      = np.zeros((n_freqs, n_chans, n_chans), dtype=complex)
    dfa_freqwise       = np.zeros((n_freqs, n_chans))
    fei_freqwise       = np.zeros((n_freqs, n_chans))
    bis_freqwise       = np.zeros((n_freqs, n_chans))
    
    for freq_idx, freq in enumerate(frequencies):
        #print(str(freq_idx) + ': ' + str(freq) + ' Hz ...')
        dfa_window_sizes = np.geomspace(DFA_MIN_WIN*sfreq/freq, data.shape[-1]//DFA_MAX_WIN, 30).astype(int)
        fei_window_size = int(FEI_WIN_MULTIP*sfreq/freq)        
        data_filt = filter_data(data, sfreq=sfreq, frequency=freq, omega=omega, n_jobs='cuda')
        
        
        # dont need data_power because we are not computing BiS
        # tip: move power calc to BiS function?
        # data_power = data_envelope**2
        # if NORM_POWER: 
        #     CNETERED = data_power - np.median(data_power)
        #     data_power = CNETERED / np.max(np.abs(CNETERED))
            

        # have to use get() function to move data from GPU to CPU
        wpli_freqwise[freq_idx] = wpli(data_filt).get()
        cplv_freqwise[freq_idx] = cplv(data_filt).get()

        data_envelope = np.abs(data_filt)
        del data_filt
        free_gpu_mem()

        dfa_freqwise[freq_idx] = dfa(data_envelope, window_lengths = dfa_window_sizes)[2]
        fei_freqwise[freq_idx] = efi(data_envelope, window_size=fei_window_size, overlap=0.5).get()          
        #bis_freqwise[freq_idx] = compute_BiS(data_power.get(), method='mle')   #need to move GPU to CPU
         
        # lets free some memory
        del data_envelope
        free_gpu_mem()
            
    return frequencies, wpli_freqwise, cplv_freqwise, dfa_freqwise, fei_freqwise, bis_freqwise




'''

def process_data(data, sfreq, n_freqs):
    #data = np.load('test.npy') # channels x samples
    NORM_POWER       = True
    n_chans          = data.shape[0]    
    omega            = 7.5
    MIN_FREQ         = 1     # Hz
    MAX_FREQ         = 225   # Hz    
    DFA_MIN_WIN      = 20 # num x [f narrow-band cycle length]
    DFA_MAX_WIN      = 10 # e.g., 5 means 1/5 of the recording duration
    FEI_WIN_MULTIP   = 50
    
    frequencies        = np.round(np.geomspace(MIN_FREQ, MAX_FREQ, n_freqs), 2)
    wpli_freqwise      = np.zeros((n_freqs, n_chans, n_chans))
    cplv_freqwise      = np.zeros((n_freqs, n_chans, n_chans), dtype=complex)
    dfa_freqwise       = np.zeros((n_freqs, n_chans))
    fei_freqwise       = np.zeros((n_freqs, n_chans))
    bis_freqwise       = np.zeros((n_freqs, n_chans))
    
    for freq_idx, freq in enumerate(frequencies):
        #print(str(freq_idx) + ': ' + str(freq) + ' Hz ...')
        dfa_window_sizes = np.geomspace(DFA_MIN_WIN*sfreq/freq, data.shape[-1]//DFA_MAX_WIN, 30).astype(int)
        fei_window_size = int(FEI_WIN_MULTIP*sfreq/freq)        
        data_filt = filter_data(data, sfreq=sfreq, frequency=freq, omega=omega, n_jobs='cuda')
        data_envelope = np.abs(data_filt)
        
        data_power = data_envelope**2
        if NORM_POWER: 
            CNETERED = data_power - np.median(data_power)
            data_power = CNETERED / np.max(np.abs(CNETERED))
            

        # have to use get() function to move data from GPU to CPU
        wpli_freqwise[freq_idx] = wpli(data_filt).get()
        cplv_freqwise[freq_idx] = cplv(data_filt).get()
        dfa_freqwise[freq_idx] = dfa(data_envelope, window_lengths = dfa_window_sizes)[2]
        fei_freqwise[freq_idx] = efi(data_envelope, window_size=fei_window_size, overlap=0.5).get()          
        #bis_freqwise[freq_idx] = compute_BiS(data_power.get(), method='mle')   #need to move GPU to CPU
         
        # lets free some memory
        del data_filt
        del data_envelope
        mempool = cp.get_default_memory_pool()
        mempool.free_all_blocks()

        for i in range(3):
            gc.collect(i)
    
    return frequencies, wpli_freqwise, cplv_freqwise, dfa_freqwise, fei_freqwise, bis_freqwise
'''

In [None]:
Fs = 1000 # MEG sampling rate in Hz
smple_CAP = 12*60*Fs 

# SUBJ_DIR  = Path('/mnt/megshare/Development/Sheng/__202412_WrapUp/__MEG_temp/ResearchMEG_SrcData/') # grr storage
SUBJ_DIR  = Path('/m/nbe/scratch/grr_epilepsy/MEG_spont')
SUBJS     = [file for file in os.listdir(SUBJ_DIR) if file.endswith('.npy')]
SUBJS = sorted(SUBJS)

#MAT_FILE_WRITEOUT  = Path('/mnt/megshare/Development/Sheng/__dump/20241216_aes_poster_figures/')
MAT_FILE_WRITEOUT  = Path('/m/nbe/scratch/grr_epilepsy/MEG_spont/results/')
MAT_FILE_WRITEOUT  = MAT_FILE_WRITEOUT / 'WHOLE_COHORT_CROSSPY_FEATURES'

if not MAT_FILE_WRITEOUT.exists():
    MAT_FILE_WRITEOUT.mkdir()

for i, iSubj in enumerate(SUBJS):
    print(i, iSubj)

In [8]:
#SUBJS = ['case_0017', 'case_0037', 'case_0059', 'case_0068_2', 'case_0074', 'case_0080', 'case_0147', 'case_0213', 'case_0253_1', 'case_0434_0']
#SUBJS = ['case_0434_0']

In [None]:
t1 = time.time()
#SUBJS = SUBJS[49:]
# for all subjects ##############################################################################
for i, subj in enumerate(SUBJS):
    t11=time.time()
    filePath = Path(SUBJ_DIR / subj)
    DATA = np.load(filePath) # data: channels x samples
    print('\tOrig  data: ', DATA.shape, '. If original data is longer than 12, it will be truncated to', smple_CAP, 'samples.')
    if DATA.shape[1] > smple_CAP + 10000:
        
        data = DATA[:, 10000:smple_CAP + 10000]
        
    else: 
        data = DATA

    del DATA            

    freq_bank, wPLI, cPLV, DFA, fEI, BiS = process_data(data, sfreq=Fs, n_freqs=50)

    # Output name set to 'case_xxxx_y_setzz_...'  ======================================
    writeoutPrefix =subj[0:13] + '_seg00'    
    matfile = MAT_FILE_WRITEOUT / (writeoutPrefix + '.mat')    
    t12=time.time()
    print('\tTime spent:', round((t12-t11)/60,1), ' mins.' )
    timeSpent_min = round((t12-t11)/60,1)
    savemat(matfile, {'wPLI': wPLI, 'cPLV': cPLV, 'DFA': DFA, 'fEI':fEI, 'BiS': BiS, 'freq_bank':freq_bank, 'timeSpent_min':timeSpent_min })
    del data, freq_bank, wPLI, cPLV, DFA, fEI, BiS



t2 = time.time()    
print(f"Finished at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(t2))}")
print('Total time spent:', round((t2-t1)/60,1), ' mins.' )

In [18]:
(smple_CAP + 10000)/1000

730.0