In [1]:
from crosspy.preprocessing.signal import filter_data
from crosspy.core.synchrony import wpli, cplv
from crosspy.core.criticality import dfa, compute_BiS, efi
from numpy import log10
import matplotlib.pyplot as plt
from scipy.io import loadmat
from scipy.io import savemat 
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
import time
from pathlib import Path
import cupy as cp
import gc 

In [2]:
def free_gpu_mem():
    mempool = cp.get_default_memory_pool()
    mempool.free_all_blocks()

    for i in range(3):
        gc.collect(i)

def process_data(data, sfreq, n_freqs):
    #data = np.load('test.npy') # channels x samples
    NORM_POWER       = True
    n_chans          = data.shape[0]    
    omega            = 7.5
    MIN_FREQ         = 1     # Hz
    MAX_FREQ         = 225   # Hz    
    DFA_MIN_WIN      = 20 # num x [f narrow-band cycle length]
    DFA_MAX_WIN      = 10 # e.g., 5 means 1/5 of the recording duration
    FEI_WIN_MULTIP   = 50
    
    frequencies        = np.round(np.geomspace(MIN_FREQ, MAX_FREQ, n_freqs), 2)
    wpli_freqwise      = np.zeros((n_freqs, n_chans, n_chans))
    cplv_freqwise      = np.zeros((n_freqs, n_chans, n_chans), dtype=complex)
    dfa_freqwise       = np.zeros((n_freqs, n_chans))
    fei_freqwise       = np.zeros((n_freqs, n_chans))
    bis_freqwise       = np.zeros((n_freqs, n_chans))
    
    for freq_idx, freq in enumerate(frequencies):
        #print(str(freq_idx) + ': ' + str(freq) + ' Hz ...')
        dfa_window_sizes = np.geomspace(DFA_MIN_WIN*sfreq/freq, data.shape[-1]//DFA_MAX_WIN, 30).astype(int)
        fei_window_size = int(FEI_WIN_MULTIP*sfreq/freq)        
        data_filt = filter_data(data, sfreq=sfreq, frequency=freq, omega=omega, n_jobs='cuda')
        
        
        # dont need data_power because we are not computing BiS
        # tip: move power calc to BiS function?
        # data_power = data_envelope**2
        # if NORM_POWER: 
        #     CNETERED = data_power - np.median(data_power)
        #     data_power = CNETERED / np.max(np.abs(CNETERED))
            

        # have to use get() function to move data from GPU to CPU
        wpli_freqwise[freq_idx] = wpli(data_filt).get()
        cplv_freqwise[freq_idx] = cplv(data_filt).get()

        data_envelope = np.abs(data_filt)
        del data_filt
        free_gpu_mem()

        dfa_freqwise[freq_idx] = dfa(data_envelope, window_lengths = dfa_window_sizes)[2]
        fei_freqwise[freq_idx] = efi(data_envelope, window_size=fei_window_size, overlap=0.5).get()          
        #bis_freqwise[freq_idx] = compute_BiS(data_power.get(), method='mle')   #need to move GPU to CPU
         
        # lets free some memory
        del data_envelope
        free_gpu_mem()
            
    return frequencies, wpli_freqwise, cplv_freqwise, dfa_freqwise, fei_freqwise, bis_freqwise




'''

def process_data(data, sfreq, n_freqs):
    #data = np.load('test.npy') # channels x samples
    NORM_POWER       = True
    n_chans          = data.shape[0]    
    omega            = 7.5
    MIN_FREQ         = 1     # Hz
    MAX_FREQ         = 225   # Hz    
    DFA_MIN_WIN      = 20 # num x [f narrow-band cycle length]
    DFA_MAX_WIN      = 10 # e.g., 5 means 1/5 of the recording duration
    FEI_WIN_MULTIP   = 50
    
    frequencies        = np.round(np.geomspace(MIN_FREQ, MAX_FREQ, n_freqs), 2)
    wpli_freqwise      = np.zeros((n_freqs, n_chans, n_chans))
    cplv_freqwise      = np.zeros((n_freqs, n_chans, n_chans), dtype=complex)
    dfa_freqwise       = np.zeros((n_freqs, n_chans))
    fei_freqwise       = np.zeros((n_freqs, n_chans))
    bis_freqwise       = np.zeros((n_freqs, n_chans))
    
    for freq_idx, freq in enumerate(frequencies):
        #print(str(freq_idx) + ': ' + str(freq) + ' Hz ...')
        dfa_window_sizes = np.geomspace(DFA_MIN_WIN*sfreq/freq, data.shape[-1]//DFA_MAX_WIN, 30).astype(int)
        fei_window_size = int(FEI_WIN_MULTIP*sfreq/freq)        
        data_filt = filter_data(data, sfreq=sfreq, frequency=freq, omega=omega, n_jobs='cuda')
        data_envelope = np.abs(data_filt)
        
        data_power = data_envelope**2
        if NORM_POWER: 
            CNETERED = data_power - np.median(data_power)
            data_power = CNETERED / np.max(np.abs(CNETERED))
            

        # have to use get() function to move data from GPU to CPU
        wpli_freqwise[freq_idx] = wpli(data_filt).get()
        cplv_freqwise[freq_idx] = cplv(data_filt).get()
        dfa_freqwise[freq_idx] = dfa(data_envelope, window_lengths = dfa_window_sizes)[2]
        fei_freqwise[freq_idx] = efi(data_envelope, window_size=fei_window_size, overlap=0.5).get()          
        #bis_freqwise[freq_idx] = compute_BiS(data_power.get(), method='mle')   #need to move GPU to CPU
         
        # lets free some memory
        del data_filt
        del data_envelope
        mempool = cp.get_default_memory_pool()
        mempool.free_all_blocks()

        for i in range(3):
            gc.collect(i)
    
    return frequencies, wpli_freqwise, cplv_freqwise, dfa_freqwise, fei_freqwise, bis_freqwise
'''

"\n\ndef process_data(data, sfreq, n_freqs):\n    #data = np.load('test.npy') # channels x samples\n    NORM_POWER       = True\n    n_chans          = data.shape[0]    \n    omega            = 7.5\n    MIN_FREQ         = 1     # Hz\n    MAX_FREQ         = 225   # Hz    \n    DFA_MIN_WIN      = 20 # num x [f narrow-band cycle length]\n    DFA_MAX_WIN      = 10 # e.g., 5 means 1/5 of the recording duration\n    FEI_WIN_MULTIP   = 50\n    \n    frequencies        = np.round(np.geomspace(MIN_FREQ, MAX_FREQ, n_freqs), 2)\n    wpli_freqwise      = np.zeros((n_freqs, n_chans, n_chans))\n    cplv_freqwise      = np.zeros((n_freqs, n_chans, n_chans), dtype=complex)\n    dfa_freqwise       = np.zeros((n_freqs, n_chans))\n    fei_freqwise       = np.zeros((n_freqs, n_chans))\n    bis_freqwise       = np.zeros((n_freqs, n_chans))\n    \n    for freq_idx, freq in enumerate(frequencies):\n        #print(str(freq_idx) + ': ' + str(freq) + ' Hz ...')\n        dfa_window_sizes = np.geomspace(DFA_MIN_W

In [4]:
Fs = 1000 # MEG sampling rate in Hz
smple_CAP = 12*60*Fs 

SUBJ_DIR  = Path('/data/sheng/MEG_source_data/')
SUBJS     = [d.name for d in SUBJ_DIR.iterdir() if d.is_dir() and d.name.startswith('case_')]

#MAT_FILE_WRITEOUT  = Path('/mnt/megshare/Development/Sheng/__dump/20241216_aes_poster_figures/')
MAT_FILE_WRITEOUT  = Path('/mnt/megshare/Development/Sheng/__dump/20241216_Cohort_n126')
MAT_FILE_WRITEOUT  = MAT_FILE_WRITEOUT / 'WHOLE_COHORT_CROSSPY_FEATURES'

if not MAT_FILE_WRITEOUT.exists():
    MAT_FILE_WRITEOUT.mkdir()

for i, iSubj in enumerate(SUBJS):
    print(i, iSubj)

0 case_0307
1 case_0332
2 case_0303
3 case_0340
4 case_0440
5 case_0442
6 case_0441
7 case_0321
8 case_0345
9 case_0297
10 case_0335
11 case_0443
12 case_0237
13 case_0336
14 case_0356
15 case_0400
16 case_0449
17 case_0299
18 case_0393
19 case_0291
20 case_0220
21 case_0392
22 case_0391
23 case_0199
24 case_0324
25 case_0347_1
26 case_0228
27 case_0195
28 case_0212
29 case_0234
30 case_0284
31 case_0333_2
32 case_0095
33 case_0108
34 case_0142
35 case_0152
36 case_0183_2
37 case_0184
38 case_0217
39 case_0266
40 case_0255
41 case_0268
42 case_0287
43 case_0257
44 case_0438_1
45 case_0203_1
46 case_0260
47 case_0293
48 case_0295_1
49 case_0147
50 case_0176
51 case_0187
52 case_0197
53 case_0213
54 case_0214
55 case_0420
56 case_0444_0
57 case_0434_0
58 case_0310
59 case_0138_2
60 case_0017
61 case_0106
62 case_0116
63 case_0118
64 case_0240_1
65 case_0253_1
66 case_0080
67 case_0037
68 case_0059
69 case_0065
70 case_0074
71 case_0068_2
72 case_0448
73 case_0353
74 case_0044
75 case_008

In [8]:
#SUBJS = ['case_0017', 'case_0037', 'case_0059', 'case_0068_2', 'case_0074', 'case_0080', 'case_0147', 'case_0213', 'case_0253_1', 'case_0434_0']
#SUBJS = ['case_0434_0']

In [5]:
t1 = time.time()
#SUBJS = SUBJS[49:]
# for all subjects ##############################################################################
for i, subj in enumerate(SUBJS):
    SETS  = os.listdir(SUBJ_DIR / subj)
    for j, iSet in enumerate(SETS): 
        
        subjPath = Path(SUBJ_DIR / subj / iSet / 'parcel_ts')
        filePath = list(subjPath.glob("*.npy"))[0]
        print('Computing crosspy features for ', i, subj, 
              j, iSet, filePath.name, '=========================>')

        ## compute one session for one subject's then save 
        t11=time.time()
        
        DATA = np.load(filePath) # data: channels x samples
        print('\tOrig  data: ', DATA.shape, '. If original data is longer than 12, it will be truncated to', smple_CAP, 'samples.')
        
        if DATA.shape[1] > smple_CAP + 10000:
            
            data = DATA[:, 10000:smple_CAP + 10000]
            
        else: 
            data = DATA

        del DATA        
        
        
        freq_bank, wPLI, cPLV, DFA, fEI, BiS = process_data(data, sfreq=Fs, n_freqs=50)

        # Output name set to 'case_xxxx_y_setzz_...'  ======================================
        SUBJ = filePath.parent.parent.parent.stem
        iSet  = filePath.parent.parent.stem
        if len(SUBJ) == 9: 
            SUBJ = SUBJ + '_1'
        SET   = 'set' + (f"{int(iSet[3:]):02d}")    
        
        #parcTS = {'parcTS': data}
        matfile = MAT_FILE_WRITEOUT / (SUBJ + '_' + SET + '.mat')        
        savemat(matfile, {'wPLI': wPLI, 'cPLV': cPLV, 'DFA': DFA, 'fEI':fEI, 'BiS': BiS, 'freq_bank':freq_bank})
        t12=time.time()
        print('\tTime spent:', round((t12-t11)/60,1), ' mins.' )
        del data, freq_bank, wPLI, cPLV, DFA, fEI, BiS


t2 = time.time()    
print(f"Finished at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(t2))}")
print('Total time spent:', round((t2-t1)/60,1), ' mins.' )

	Orig  data:  (200, 610000) . If original data is longer than 12, it will be truncated to 720000 samples.




	Time spent: 1.6  mins.
	Orig  data:  (200, 598000) . If original data is longer than 12, it will be truncated to 720000 samples.
	Time spent: 1.7  mins.
	Orig  data:  (200, 604000) . If original data is longer than 12, it will be truncated to 720000 samples.
	Time spent: 1.6  mins.
	Orig  data:  (200, 604000) . If original data is longer than 12, it will be truncated to 720000 samples.
	Time spent: 1.7  mins.
	Orig  data:  (200, 628000) . If original data is longer than 12, it will be truncated to 720000 samples.


OutOfMemoryError: Out of memory allocating 960,000,000 bytes (allocated so far: 12,915,623,936 bytes).

In [18]:
(smple_CAP + 10000)/1000

730.0