In [None]:
import glob, os, mne, mne_features, mne_connectivity
import numpy as np
import pandas as pd

CHANNEL_NAMES = ['AF3', 'F7', 'F3', 'FC5', 'T7',
                 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4']
sfreq = 128
n_channels = len(CHANNEL_NAMES)
freq_bands = {
    # 'delta': (0.5, 4), # delta has been filtered out
    'theta': [4, 8],
    'alpha-1': [8, 10],
    'alpha-2': [10, 13],
    'beta': [13, 30],
    'gamma': [30, 45],  # above 45 has been filtered out
}
freq_band_names = list(freq_bands.keys())
min_freqs = [freq_bands[band][0] for band in freq_bands]
max_freqs = [freq_bands[band][1] for band in freq_bands]
n_bands = len(freq_bands)

# Left channels end with odd numbers, right channels end with even numbers
left_channels = sorted([i for i in CHANNEL_NAMES if int(i[-1]) % 2 == 1])
right_channels = sorted([i for i in CHANNEL_NAMES if int(i[-1]) % 2 == 0])

out_dir = "../data/features"

os.makedirs(out_dir, exist_ok=True)

paths_to_subject_epochs = glob.glob("../data/segmented/*/clean/*-epo.fif")

# Split paths into a dict by segment lengths
paths_by_segment_length = {}

for path in paths_to_subject_epochs:
    segment_length = int(os.path.basename(os.path.dirname(os.path.dirname(path))).strip('s'))

    if segment_length not in paths_by_segment_length:
        paths_by_segment_length[segment_length] = []

    paths_by_segment_length[segment_length].append(path)


print(paths_by_segment_length)


def get_epoch_features(epoch):
    d = {}

    # Time complexity features
    funcs = ['hjorth_mobility', 'hjorth_complexity', 'variance', 'app_entropy',
                'line_length', 'skewness', 'kurtosis', 'rms', 'decorr_time', 'higuchi_fd', 'katz_fd', 'samp_entropy', 'hurst_exp']
    
    epoch = np.expand_dims(epoch, axis=0)

    x = mne_features.feature_extraction.extract_features(epoch, sfreq, funcs, n_jobs=1)
    x = x.reshape((len(funcs), n_channels))

    for func_name, channels in zip(funcs, x):
        for ch_val, ch_name in zip(channels, CHANNEL_NAMES):
            feat_name = f'time_{func_name}_{ch_name}'
            d[feat_name] = ch_val

    return d

    # Band power features
    for normalize in [False, True]:
        freq_bands_ = np.asanyarray([freq_bands[band][0]
                                        for band in freq_bands] + [freq_bands['gamma'][1]])

        powers_and_ratios = mne_features.univariate.compute_pow_freq_bands(
            sfreq, epoch[0],
            freq_bands_, normalize=normalize, ratios='all', psd_method='welch',
            psd_params={'welch_n_overlap': sfreq // 2})
        num_powers = n_channels * n_bands

        # Powers and ratios are returned in a single array, need to split them

        # Shape is (n_channels, n_bands)
        powers = powers_and_ratios[:num_powers].reshape((n_channels, -1))
        print(powers.shape)

        # Shape is (n_channels, n_bands, n_bands-1)
        ratios = powers_and_ratios[num_powers:]
        pow_ratios = ratios.reshape(n_channels, n_bands, -1)

        # TODO also save power ratios

        for el_idx, el in enumerate(powers):
            for band_idx, band_pow in enumerate(el):
                feat_name = f'{'rel' if normalize else 'abs'}_pow_{freq_band_names[band_idx]}_{
                    CHANNEL_NAMES[el_idx]}'
                d[feat_name] = band_pow

    # Connectivity features
    min_freq = min(min_freqs)
    max_freq = max(max_freqs)

    freqs = np.linspace(min_freq, max_freq, int(
        (max_freq - min_freq) * 4 + 1))

    res = mne_connectivity.spectral_connectivity_time(
        epoch, freqs=freqs, method="pli", sfreq=sfreq, mode="cwt_morlet", fmin=min_freqs, fmax=max_freqs,
        faverage=True).get_data()

    conn_of_one_epoch = res[0]
    matrix = conn_of_one_epoch.reshape(
        (n_channels, n_channels, n_bands))

    matrix = np.moveaxis(matrix, 2, 0)

    for band_idx, band in enumerate(matrix):
        for el_idx, el in enumerate(band):
            for el2_idx, el2 in enumerate(el[:el_idx]):
                feat_name = f'conn_{freq_band_names[band_idx]}_{
                    CHANNEL_NAMES[el_idx]}_{CHANNEL_NAMES[el2_idx]}'
                d[feat_name] = el2

    # Asymmetry index features
    for band_name in freq_band_names:
        for left_ch_name, right_ch_name in zip(left_channels, right_channels):
            left_abs_pow = d['abs_pow_' + band_name + '_' + left_ch_name]
            right_abs_pow = d['abs_pow_' + band_name + '_' + right_ch_name]

            asym_idx = np.log(right_abs_pow) - np.log(left_abs_pow)

            d[f'ai_{band_name}_{right_ch_name}-{left_ch_name}'] = asym_idx

    return d

def get_features_of_all_epochs(path):
    epochs = mne.read_epochs(path)

    epoch_dicts = []

    for index, epoch in enumerate(epochs):
        # print(epochs.metadata.iloc[index])
        feats = get_epoch_features(epoch)
        epoch_dicts.append(feats)

    cols = epoch_dicts[0].keys()

    subject_id = epochs.metadata['subject'][0]
    dataset = epochs.metadata['dataset'][0]
    # sam_label = epochs.metadata[SAM][0]

    for d in epoch_dicts:
        d['subject'] = subject_id
        d['dataset'] = dataset
        d['uniq_subject_id'] = f'{dataset}_{subject_id}'

    df = pd.DataFrame(epoch_dicts, columns=cols)

    return df

for seglen, paths in paths_by_segment_length.items():
    all_dfs = [get_features_of_all_epochs(path) for path in paths[:1]]

    df = pd.concat(all_dfs)

    # df.to_csv(os.path.join(out_dir, os.path.basename(path).replace("-epo.fif", "-features.csv")))

df

{10: ['../data/segmented/10s/clean/S10-epo.fif', '../data/segmented/10s/clean/S19-epo.fif', '../data/segmented/10s/clean/S09-epo.fif', '../data/segmented/10s/clean/S08-epo.fif', '../data/segmented/10s/clean/S18-epo.fif', '../data/segmented/10s/clean/S11-epo.fif', '../data/segmented/10s/clean/S01-epo.fif', '../data/segmented/10s/clean/S03-epo.fif', '../data/segmented/10s/clean/S13-epo.fif', '../data/segmented/10s/clean/S12-epo.fif', '../data/segmented/10s/clean/S02-epo.fif', '../data/segmented/10s/clean/S21-epo.fif', '../data/segmented/10s/clean/S07-epo.fif', '../data/segmented/10s/clean/S17-epo.fif', '../data/segmented/10s/clean/S16-epo.fif', '../data/segmented/10s/clean/S06-epo.fif', '../data/segmented/10s/clean/S20-epo.fif', '../data/segmented/10s/clean/S04-epo.fif', '../data/segmented/10s/clean/S14-epo.fif', '../data/segmented/10s/clean/S22-epo.fif', '../data/segmented/10s/clean/S23-epo.fif', '../data/segmented/10s/clean/S15-epo.fif', '../data/segmented/10s/clean/S05-epo.fif']}
Read

Unnamed: 0,time_hjorth_mobility_AF3,time_hjorth_mobility_F7,time_hjorth_mobility_F3,time_hjorth_mobility_FC5,time_hjorth_mobility_T7,time_hjorth_mobility_P7,time_hjorth_mobility_O1,time_hjorth_mobility_O2,time_hjorth_mobility_P8,time_hjorth_mobility_T8,...,time_hurst_exp_O2,time_hurst_exp_P8,time_hurst_exp_T8,time_hurst_exp_FC6,time_hurst_exp_F4,time_hurst_exp_F8,time_hurst_exp_AF4,subject,dataset,uniq_subject_id
0,0.61511,0.634856,0.62586,0.630724,0.615947,0.620682,0.62533,0.608927,0.613778,0.670879,...,0.135478,0.13701,-0.003604,0.163226,0.171551,0.174824,0.177105,10,dasps,dasps_10
1,0.581446,0.544871,0.55845,0.547877,0.539853,0.551864,0.592275,0.58977,0.594629,0.600414,...,0.218744,0.218784,0.208438,0.231207,0.154762,0.223394,0.223604,10,dasps,dasps_10
2,0.701578,0.636604,0.540489,0.686146,0.699415,0.702145,0.692845,0.657938,0.65772,0.677678,...,0.179952,0.184519,0.184963,0.208377,0.154433,0.243906,0.230809,10,dasps,dasps_10
3,0.735243,0.74243,0.705117,0.74621,0.76778,0.797493,0.734185,0.69454,0.690328,0.771464,...,0.193307,0.198997,0.216275,0.251171,0.205986,0.234003,0.20659,10,dasps,dasps_10
4,0.698854,0.707288,0.640635,0.73481,0.719751,0.711288,0.705523,0.690098,0.691115,0.743435,...,0.207603,0.198652,0.232827,0.211135,0.210614,0.254149,0.227354,10,dasps,dasps_10
5,0.64898,0.676992,0.640052,0.638089,0.651303,0.656697,0.633258,0.643304,0.639485,0.620999,...,0.179287,0.182193,0.195922,0.180362,0.173291,0.162677,0.190254,10,dasps,dasps_10
6,0.683304,0.759598,0.654954,0.709449,0.700255,0.599925,0.701177,0.67139,0.754927,0.642605,...,0.193458,0.232715,0.218477,0.188922,0.213262,0.241997,0.200184,10,dasps,dasps_10
7,0.653964,0.673287,0.625053,0.654068,0.727238,0.716943,0.594269,0.592052,0.572739,0.653758,...,0.132205,0.163151,0.13285,0.16403,0.148031,0.144888,0.131327,10,dasps,dasps_10
8,0.716858,0.764018,0.715415,0.680804,0.795422,0.790644,0.7156,0.6654,0.77888,0.783217,...,0.218633,0.188085,0.201056,0.17566,0.197989,0.194696,0.209038,10,dasps,dasps_10
9,0.625547,0.658742,0.684721,0.675292,0.702134,0.683148,0.704169,0.662005,0.657603,0.675027,...,0.15357,0.174612,0.169274,0.19926,0.169399,0.173951,0.2037,10,dasps,dasps_10
