# 03_compute_bandSNR

In [None]:
import numpy as np
from os.path import join as pjoin
from os.path import isdir
import os
import matplotlib.pyplot as plt
from matplotlib import cm, colors
import mne_bids
import mne
from mne_bids import write_raw_bids, BIDSPath
from scipy import stats
import re
from scipy import signal
import pandas as pd
from scipy import signal, fftpack

## define functions

In [None]:
def load_sub_raw_data(data_folder='/nfs/e5/studyforrest/forrest_movie_meg/', subject_idx='01', run_idx='01'):
    """
    load raw meg data. 
    
    input value: subject_idx and run_idx should be str
    """
    
    if not isinstance(subject_idx, str):
        raise ValueError('subject_dix must be str')
        
    if not isinstance(run_idx, str):
        raise ValueError('run_idx must be str')
    
    subject_data_folder = data_folder + 'sub-' + subject_idx + '/ses-movie/meg'
    fname = 'sub-' + subject_idx + '_ses-movie_task-movie_run-' + run_idx + '_meg.ds'
    raw_data_path = pjoin(subject_data_folder, fname)
    raw_data = mne.io.read_raw_ctf(raw_data_path, preload='True')
    
    print('total channels number is {}'.format(len(raw_data.info['chs'])))
    print('sample frequency is {} Hz'.format(raw_data.info['sfreq']))

    return raw_data

def fALFF(data, fs, f_range='band'):
    """

    Parameters
    ----------
        data: shape = [n_samples, n_features].
              for meg data: shape = [n_channels, n_samples]
    """
    
    # remove linear trend
    data_detrend = signal.detrend(data, axis=-1)
    
    if f_range == 'band':
        # convert to frequency domain        
        freqs, psd = signal.welch(data_detrend, fs=fs)
        
        band = [[1, 4], [4, 8], [8, 13], [13, 30], [30, 100]]
        # "delta": 1-4Hz
        # "theta": 4-8Hz
        # "alpha": 8-13Hz
        # "beta": 13-30Hz
        # "gamma": 30-100Hz
            
#            band = np.linspace(0,0.7,10)
#            band = list(zip(band[:-1],band[1:]))
            
        falff = [np.sum(psd[:, (freqs>i[0]) * (freqs<i[1])], axis=-1) / np.sum(psd[:, freqs<0.5*fs], axis=-1) for i in band]
        falff = np.asarray(falff).T
    
    return falff

def load_post_megdata(sub='02', run=1):
    bids_root = '/nfs/e5/studyforrest/forrest_movie_meg/preproc_data'
    
    sub_path = BIDSPath(subject=sub, run=int(run), task='movie', session='movie', processing='preproc', root=bids_root)
    raw_sub = mne_bids.read_raw_bids(sub_path)

    ch_name_picks = mne.pick_channels_regexp(raw_sub.ch_names, regexp='M[LRZ]...-4503')
    type_picks = mne.pick_types(raw_sub.info, meg=True)
    sub_picks= np.intersect1d(ch_name_picks, type_picks)
    sub_raw_data = raw_sub.get_data(picks=sub_picks)
    events_sub = mne.events_from_annotations(raw_sub)
    sample_sub = events_sub[0][:,0]
    
    sub_data = sub_raw_data.take(sample_sub, axis=1)
    return sub_data

def load_pre_megdata(sub='02', run=1):
    bids_root = '/nfs/e5/studyforrest/forrest_movie_meg/preproc_data'
    
    sub_path = BIDSPath(subject=sub, run=int(run), task='movie', session='movie', processing='preproc', root=bids_root)
    raw_sub_ev = mne_bids.read_raw_bids(sub_path)
    raw_sub = load_sub_raw_data(subject_idx=sub, run_idx='0'+str(run))
    ch_name_picks = mne.pick_channels_regexp(raw_sub.ch_names, regexp='M[LRZ]...-4503')
    type_picks = mne.pick_types(raw_sub.info, meg=True)
    sub_picks= np.intersect1d(ch_name_picks, type_picks)
    sub_raw_data = raw_sub.get_data(picks=sub_picks)
    events_sub = mne.events_from_annotations(raw_sub_ev)
    sample_sub = events_sub[0][:,0]
    
    sub_data = sub_raw_data.take(sample_sub, axis=1)
    return sub_data

## define variables

In [None]:
sub_list = ['{0:0>2d}'.format(sub) for sub in np.arange(1,12)]
run_list = ['{0:0>2d}'.format(run) for run in np.arange(1,9)]
band = ['delta', 'theta', 'alpha', 'beta', 'gamma']
# change path
data_pth = '/nfs/s2/userhome/daiyuxuan/workingdir/MEG-paper/output_data'
bids_root = '/nfs/e2/workingshop/daiyuxuan/MEG-paper/preproc_data' 

## compute band falff

In [None]:
# compute falff
post_falff_data = {}
pre_falff_data ={}
for sub in sub_list:
    post_falff_data[sub] = []
    pre_falff_data[sub] = []
    if sub == '01':
        run_ls = run_list + ['09']
    else:
        run_ls = run_list
    for run in run_ls:
        post_raw_data = load_post_megdata(sub=sub, run=int(run))
        post_falff = fALFF(post_raw_data, fs=600)
        pre_raw_data = load_pre_megdata(sub=sub, run=int(run))
        pre_falff = fALFF(pre_raw_data, fs=600)
        post_falff_data[sub].append(post_falff)
        pre_falff_data[sub].append(pre_falff)

# save falff
for sub in sub_list[1:]:
    pre_falff_data[sub].append(np.nan)
    post_falff_data[sub].append(np.nan)
pre_df = pd.DataFrame(pre_falff_data, columns=sub_list, index=run_list+['09'])
post_df = pd.DataFrame(post_falff_data, columns=sub_list, index=run_list+['09'])
pre_df.to_pickle(pjoin(data_pth, 'pre_falff_data.pickle'))
post_df.to_pickle(pjoin(data_pth, 'post_falff_data.pickle'))