In [1]:
from scipy.io import loadmat
from scipy.io import savemat 
import os
import numpy as np
import matplotlib.pyplot as plt
from alphacsc import GreedyCDL
from alphacsc.utils import split_signal
from scipy import signal
from scipy.signal import butter, filtfilt
from matplotlib import gridspec
import time
from alphacsc import BatchCDL
from pathlib import Path

def butter_lowpass(data, cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = butter(order, normal_cutoff, btype='low')
    y = filtfilt(b, a, data)
    return y

def butter_highpass(data, cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = butter(order, normal_cutoff, btype='high')
    y = filtfilt(b, a, data)
    return y

In [2]:
!ls /mnt/megshare/Development/Sheng/__dump/20241126_aes_poster_figures/

#################################################################
## all parameters are set here
#################################################################
ori_fs             = 1000 # Original sampling rate (Hz)
decimate           = 10   # dec factor for analysis data
HI_PASS            = 1    #Hz 
LOW_PASS           = 30   #Hz
fs                 = ori_fs // decimate  # decimated signal fs
n_atoms            = 30
ATOM_DURATION      = 3    # seconds
SPLIT_BLOCKS       = 8    # Spliting a long resting into x blocks
N_JOBS             = 80   # cdl jobs
NORMALIZE          = True # normalize the parcel time series before CDL

MAT_FILE_WRITEOUT  = Path('/mnt/megshare/Development/Sheng/__dump/20241126_aes_poster_figures/')
MAT_FILE_WRITEOUT  = MAT_FILE_WRITEOUT / 'WHOLE_COHORT_ATOMS_30x3s_1-30Hz'
if not MAT_FILE_WRITEOUT.exists():
    MAT_FILE_WRITEOUT.mkdir()


SUBJ_DIR  = Path('/data/sheng/MEG_source_data/')
SUBJS     = [d.name for d in SUBJ_DIR.iterdir() if d.is_dir() and d.name.startswith('case_')]

if not MAT_FILE_WRITEOUT.exists():
    MAT_FILE_WRITEOUT.mkdir(parent=True)

__old				     tmp
png_WHOLE_COHORT_ATOMS_30x3s_1-30Hz  WHOLE_COHORT_ATOMS_30x3s_1-30Hz
subjects.txt			     WHOLE_COHORT_CROSSPY_FEATURES


In [3]:
t1 = time.time()
# for all subjects ##############################################################################
for i, subj in enumerate(SUBJS):
    SETS  = os.listdir(SUBJ_DIR / subj)
    
    t11=time.time()
    # for all spont sessions of the subj ####################
    for j, iSet in enumerate(SETS): 
                        
        subjPath = Path(SUBJ_DIR / subj / iSet / 'parcel_ts')
        filePath = list(subjPath.glob("*.npy"))[0]       
        
        #filePath = SUBJ_DIR / subj / iSet
        parcelTS     = np.load(filePath) # data: channels x samples
        PARC_N       = parcelTS.shape[0]
        SAMPLE_N     = signal.decimate(parcelTS[0], decimate).shape[0]
        bpTS         = np.zeros([PARC_N, SAMPLE_N])

        # band-pass the data, and decimate if need be #######################
        for iPARC in np.arange(0,PARC_N):
            aParcelTS            = parcelTS[iPARC]    
            lowpass_data         = butter_lowpass(aParcelTS, LOW_PASS, ori_fs)
            highpass_data        = butter_highpass(lowpass_data, HI_PASS, ori_fs)
            bpTS[iPARC, :]       = signal.decimate(highpass_data, decimate)
            if NORMALIZE:            
                iMEAN            = np.mean(bpTS[iPARC, :])
                iSTD             = np.std(bpTS[iPARC, ])            
                bpTS[iPARC,]     = (bpTS[iPARC, :]- iMEAN) / iSTD            

        del lowpass_data, highpass_data
        print('Band-pass done, now fitting: '+subj+ '-' + iSet + '...')


        # Split a long trial into x blocks
        split_parcTS = split_signal(bpTS[None], SPLIT_BLOCKS)

        # Define the shape of the dictionary
        n_times_atom = int(round(fs * ATOM_DURATION)) 
        print('Data shape:', split_parcTS.shape, '; N of Atoms: ',  n_atoms, '; Sample per atom: ', n_times_atom)


        cdl = BatchCDL(
            # Shape of the dictionary
            n_atoms=n_atoms,
            n_times_atom=n_times_atom,
            # Request a rank1 dictionary with unit norm temporal and spatial maps
            rank1=True, uv_constraint='separate',
            # Initialize the dictionary with random chunk from the data
            D_init='chunk',
            # rescale the regularization parameter to be 20% of lambda_max
            lmbd_max="scaled", reg=.2,
            # Number of iteration for the alternate minimization and cvg threshold
            n_iter=100, eps=1e-4,
            # solver for the z-step
            solver_z="lgcd", solver_z_kwargs={'tol': 1e-2, 'max_iter': 1000},
            # solver for the d-step
            solver_d='alternate_adaptive', solver_d_kwargs={'max_iter': 300},
            # Technical parameters
            verbose=1, random_state=0, n_jobs=N_JOBS)


        ######################################### FIT IT   
        cdl.fit(split_parcTS)
        del split_parcTS   
        ######################################### FIT IT   

        u_hat = cdl.u_hat_ # spatial
        v_hat = cdl.v_hat_ # temporal
        z_hat = cdl.z_hat_ # scores        
       
        # Output name set to 'case_xxxx_y_setzz_...'  ======================================
        
        SUBJ = filePath.parent.parent.parent.stem
        iSet  = filePath.parent.parent.stem
        
        if len(SUBJ) == 9: 
            SUBJ = SUBJ + '_1'
        SET   = 'set' + (f"{int(iSet[3:]):02d}")        
        
        FILENAME = (SUBJ + '_' + SET + '_BB_' + str(HI_PASS) + 
                    '-' + str(LOW_PASS) + 'Hz_' + str(n_atoms) + 
                    '_Atoms_' + str(ATOM_DURATION) + '(s)'+ '_NormParcTS(' + str(NORMALIZE)+ ').mat')
        MAT_FILE = MAT_FILE_WRITEOUT / FILENAME
        
        savemat(MAT_FILE, {"u_hat": u_hat, "v_hat": v_hat, 'z_hat': z_hat, 'fs': fs})
        print('\tWriting: ' + FILENAME)  
        t12=time.time()
        print('\tTime spent:', round((t12-t11)/60,1), ' mins.' )



t2 = time.time()    
print('\nN of atoms: ', n_atoms, ', Num of Jobs: ', N_JOBS)
print('Low: ', LOW_PASS, 'Hz, High:', HI_PASS, 'Hz')
print(f"Finished at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(t2))}")
print('Total time spent:', round((t2-t1)/60,1), ' mins.' )

Band-pass done, now fitting: case_0307-set1...
Data shape: (8, 200, 7625) ; N of Atoms:  30 ; Sample per atom:  300
.............
[BatchCDL] Converged after 13 iteration, (dz, du) = 7.812e-05, 8.515e-05
[BatchCDL] Fit in 772.1s
	Writing: case_0307_1_set01_BB_1-30Hz_30_Atoms_3(s)_NormParcTS(True).mat
	Time spent: 13.0  mins.
Band-pass done, now fitting: case_0307-set2...
Data shape: (8, 200, 7475) ; N of Atoms:  30 ; Sample per atom:  300
.....................
[BatchCDL] Converged after 21 iteration, (dz, du) = 6.718e-05, 9.471e-05
[BatchCDL] Fit in 1166.7s
	Writing: case_0307_1_set02_BB_1-30Hz_30_Atoms_3(s)_NormParcTS(True).mat
	Time spent: 32.6  mins.
Band-pass done, now fitting: case_0307-set3...
Data shape: (8, 200, 7550) ; N of Atoms:  30 ; Sample per atom:  300
..............



.......
[BatchCDL] Converged after 21 iteration, (dz, du) = 4.918e-05, 8.149e-05
[BatchCDL] Fit in 1105.0s
	Writing: case_0307_1_set03_BB_1-30Hz_30_Atoms_3(s)_NormParcTS(True).mat
	Time spent: 51.2  mins.
Band-pass done, now fitting: case_0307-set4...
Data shape: (8, 200, 7550) ; N of Atoms:  30 ; Sample per atom:  300
...................
[BatchCDL] Converged after 19 iteration, (dz, du) = 3.742e-05, 8.307e-05
[BatchCDL] Fit in 1036.9s
	Writing: case_0307_1_set04_BB_1-30Hz_30_Atoms_3(s)_NormParcTS(True).mat
	Time spent: 68.6  mins.
Band-pass done, now fitting: case_0307-set5...
Data shape: (8, 200, 7850) ; N of Atoms:  30 ; Sample per atom:  300
................
[BatchCDL] Converged after 16 iteration, (dz, du) = 4.909e-05, 8.768e-05
[BatchCDL] Fit in 998.6s
	Writing: case_0307_1_set05_BB_1-30Hz_30_Atoms_3(s)_NormParcTS(True).mat
	Time spent: 85.4  mins.
Band-pass done, now fitting: case_0307-set6...
Data shape: (8, 200, 7637) ; N of Atoms:  30 ; Sample per atom:  300
.............
[Ba

..............
[BatchCDL] Converged after 14 iteration, (dz, du) = 7.546e-05, 8.846e-05
[BatchCDL] Fit in 814.6s
	Writing: case_0420_1_set03_BB_1-30Hz_30_Atoms_3(s)_NormParcTS(True).mat
	Time spent: 13.7  mins.
Band-pass done, now fitting: case_0390_0-set3...
Data shape: (8, 200, 7537) ; N of Atoms:  30 ; Sample per atom:  300
.............
[BatchCDL] Converged after 13 iteration, (dz, du) = 9.514e-05, 9.923e-05
[BatchCDL] Fit in 799.7s
	Writing: case_0390_0_set03_BB_1-30Hz_30_Atoms_3(s)_NormParcTS(True).mat
	Time spent: 13.5  mins.
Band-pass done, now fitting: case_0390_0-set2...
Data shape: (8, 200, 7512) ; N of Atoms:  30 ; Sample per atom:  300
......................
[BatchCDL] Converged after 22 iteration, (dz, du) = 6.665e-05, 6.256e-05
[BatchCDL] Fit in 1339.3s
	Writing: case_0390_0_set02_BB_1-30Hz_30_Atoms_3(s)_NormParcTS(True).mat
	Time spent: 35.9  mins.
Band-pass done, now fitting: case_0390_0-set4...
Data shape: (8, 200, 7662) ; N of Atoms:  30 ; Sample per atom:  300
.....

..........
[BatchCDL] Converged after 10 iteration, (dz, du) = 9.678e-05, 9.980e-05
[BatchCDL] Fit in 799.6s
	Writing: case_0037_1_set04_BB_1-30Hz_30_Atoms_3(s)_NormParcTS(True).mat
	Time spent: 13.6  mins.
Band-pass done, now fitting: case_0059-set1...
Data shape: (8, 200, 15162) ; N of Atoms:  30 ; Sample per atom:  300
........
[BatchCDL] Converged after 8 iteration, (dz, du) = 9.626e-05, 9.395e-05
[BatchCDL] Fit in 606.6s
	Writing: case_0059_1_set01_BB_1-30Hz_30_Atoms_3(s)_NormParcTS(True).mat
	Time spent: 10.4  mins.
Band-pass done, now fitting: case_0065-set2...
Data shape: (8, 200, 6087) ; N of Atoms:  30 ; Sample per atom:  300
....................
[BatchCDL] Converged after 20 iteration, (dz, du) = 5.430e-05, 9.915e-05
[BatchCDL] Fit in 1108.3s
	Writing: case_0065_1_set02_BB_1-30Hz_30_Atoms_3(s)_NormParcTS(True).mat
	Time spent: 18.6  mins.
Band-pass done, now fitting: case_0074-set1...
Data shape: (8, 200, 11300) ; N of Atoms:  30 ; Sample per atom:  300
..............
[Batch