In [1]:
import mne
import numpy as np
import pandas as pd

"""
Additional filtering is not required as the data is already preprocessed.
"""


def load_eeg_data(file_path):
    """
    Load EEG data from a csv file and separate data and label.
    :param file_path: File path of the EEG data
    :return: EEG data (DataFrame), label
    """
    data_src = pd.read_csv(file_path)
    data = data_src.iloc[:, :-1]  # Exclude the last column as it is a label
    label = data_src.iloc[:, -1]  # Use the last column as a label
    return data, label


def compute_band_power(raw, band):
    """
    Compute the power in a specific frequency band.
    :param raw: MNE Raw object
    :param band: Frequency band of interest (tuple)
    :return: Power in the frequency band
    """
    fmin, fmax = band  # Setting frequency band
    data = raw.get_data()
    sfreq = raw.info['sfreq']
    psds, freqs = mne.time_frequency.psd_array_welch(data, sfreq=sfreq, fmin=fmin, fmax=fmax, n_fft=128)  # Compute PSD
    # Compute power in the frequency band
    band_power = np.sum(psds, axis=-1)
    return band_power


def extract_features(data, selected_columns, sfreq=250):
    """
    Extract features from EEG data. Furthermore, the data is downsampled to the target sampling frequency.
    :param data: EEG data (DataFrame)
    :param selected_columns: List of tuples containing channel index and frequency bands
    :param sfreq: Sampling frequency of the data
    :param target_sfreq: Target sampling frequency
    :return: Extracted features (DataFrame)
    """
    feature_dict = {}  # 결과를 저장할 딕셔너리

    for item in selected_columns:
        channel_idx = item[0]  # 채널 인덱스
        bands = item[1]  # 해당 채널에서 추출할 주파수 대역 리스트

        # 주파수 대역이 하나만 주어졌을 때도 리스트로 처리
        if isinstance(bands, tuple):
            bands = [bands]

        # 채널의 데이터 추출
        eeg_data = data.iloc[:, channel_idx].values  # 특정 채널의 데이터를 가져옴
        ch_name = data.columns[channel_idx]  # 채널 이름

        # mne RawArray 객체 생성
        info = mne.create_info(ch_names=[ch_name], sfreq=sfreq, ch_types='eeg')
        raw = mne.io.RawArray(eeg_data[np.newaxis, :], info)  # 2D array 필요

        # 주파수 대역별로 PSD 계산
        for band in bands:
            band_power = compute_band_power(raw, band)
            # 열 이름 생성 (예: Channel_1_10-12Hz)
            column_name = f'{ch_name}_{band[0]}-{band[1]}Hz'
            feature_dict[column_name] = band_power

    # 최종 데이터프레임 생성
    features = pd.DataFrame([feature_dict])

    return features


In [2]:
base_path = 'your_path/'

train_csv_path = base_path + 'train.csv'
val_csv_path = base_path + 'val.csv'
test_csv_path = base_path + 'test.csv'

train_json_path = base_path + 'json/train.json'
train_jsonl_path = base_path + 'jsonl/train.jsonl'

val_json_path = base_path + 'json/val.json'
val_jsonl_path = base_path + 'jsonl/val.jsonl'

csp_train_path = base_path + 'csp1/class_1_vs_5_train_features.csv'
csp_val_path = base_path + 'csp1/class_1_vs_5_val_features.csv'
csp_test_path = base_path + 'csp1/class_1_vs_5_test_features.csv'

In [3]:
from feature_extraction import *
dftrain, labeltrain = load_eeg_data(train_csv_path)
dfval, labelval = load_eeg_data(val_csv_path)
dftest, labeltest = load_eeg_data(test_csv_path)

In [4]:
csptrain, csptrainlabel = load_eeg_data(csp_train_path)
csptrain = csptrain.to_numpy()
cspval, cspvallabel = load_eeg_data(csp_val_path)
cspval = cspval.to_numpy()
csptest, csptestlabel = load_eeg_data(csp_test_path)
csptest = csptest.to_numpy()

In [5]:
selected_columns = [
        [0, [(10, 12), (12, 14)]],  # FCz
        [2, [(20, 22), (22, 24)]],  # C3
        [3, [(8, 10)]],  # Cz
        [4, [(20, 22), (22, 24)]],  # C4
        [5, [(28, 30)]],  # CP3
]

In [6]:
def csv_to_df(df, csp, window_size, selected_columns, labels):
    """
    Convert a DataFrame of EEG data into a JSON format suitable for GPT-3 davinci.
    :param df: Data converted to pandas DataFrame from the original csv file
    :param window_size: Window size to divide EEG data
    :param selected_columns: EEG channel to use (provide a list with frequency bands)
    :param labels: Label for each window (provide a list, left, right, top, bottom)
    :return: List of data in JSON format
    """
    df_array = pd.DataFrame()

    # EEG 채널 이름을 selected_columns에 매핑합니다.
    channel_names = ['FCz', 'C3', 'Cz', 'C4', 'CP3']  # 각각 0, 1, 2, 3에 대응

    for start in range(0, len(df) - window_size + 1, window_size):
        window_data = df.iloc[start:start + window_size, :]  # 전체 데이터를 가져옴
        label = str(int(labels[start]))  # Assuming labels are provided for each window

        # Extract features using the updated extract_features function
        features = extract_features(window_data, selected_columns)  # feature extraction
        cspdata = pd.DataFrame(csp[int(start / 1000)]).T  # cspdata 가져옴
        # features와 cspdata 가로 방향으로 합침
        features = pd.concat([features, cspdata], axis=1)
        # 라벨까지 합치기, 라벨의 column명은 'Label'
        features['Label'] = label
        
        # 최종 데이터프레임 생성
        df_array = pd.concat([df_array, features], axis=0)
    return df_array

In [7]:
train4ml = csv_to_df(dftrain, csptrain, 1000, selected_columns, labeltrain)
train4ml

Creating RawArray with float64 data, n_channels=1, n_times=1000
    Range : 0 ... 999 =      0.000 ...     3.996 secs
Ready.
Effective window size : 0.512 (s)
Effective window size : 0.512 (s)
Creating RawArray with float64 data, n_channels=1, n_times=1000
    Range : 0 ... 999 =      0.000 ...     3.996 secs
Ready.
Effective window size : 0.512 (s)
Effective window size : 0.512 (s)
Creating RawArray with float64 data, n_channels=1, n_times=1000
    Range : 0 ... 999 =      0.000 ...     3.996 secs
Ready.
Effective window size : 0.512 (s)
Creating RawArray with float64 data, n_channels=1, n_times=1000
    Range : 0 ... 999 =      0.000 ...     3.996 secs
Ready.
Effective window size : 0.512 (s)
Effective window size : 0.512 (s)
Creating RawArray with float64 data, n_channels=1, n_times=1000
    Range : 0 ... 999 =      0.000 ...     3.996 secs
Ready.
Effective window size : 0.512 (s)
Creating RawArray with float64 data, n_channels=1, n_times=1000
    Range : 0 ... 999 =      0.000 ... 

Unnamed: 0,FCz_10-12Hz,FCz_12-14Hz,C3_20-22Hz,C3_22-24Hz,Cz_8-10Hz,C4_20-22Hz,C4_22-24Hz,CP3_28-30Hz,0,1,Label
0,[2.7786458594297107e-05],[6.534202001009796e-06],[1.0130472628096345e-06],[7.331753863608874e-07],[2.033718063081021e-05],[1.3523823903022104e-06],[2.3240075120569586e-06],[1.2259994892338753e-06],-0.499882,-0.347643,5
0,[1.687638705039316e-05],[4.3827373457254e-06],[9.55979451930757e-07],[1.0677627534386007e-06],[1.6969380046355913e-05],[1.7850318610953953e-06],[1.2976369343593518e-06],[6.354569600546507e-07],-1.252038,-1.280571,1
0,[3.086740134086977e-05],[4.555036777364775e-06],[3.408454339035358e-07],[3.7697757510279986e-07],[3.353152539149078e-05],[1.8730413037123893e-06],[9.872644173648066e-07],[9.009861479545202e-07],-0.210339,-0.917821,1
0,[8.294395509890532e-06],[3.1994345508118116e-06],[2.8525236138546658e-06],[1.7614793996029338e-06],[2.7604201323840356e-05],[2.8692586591252587e-06],[1.5233991921838082e-06],[5.197436272665464e-07],-1.033145,-1.020001,1
0,[2.8796302857775622e-05],[2.79121587431383e-06],[1.0993873514132316e-06],[5.265761523324297e-07],[5.51623457564977e-05],[8.681873662900441e-07],[9.399759088636949e-07],[6.713900908757874e-07],-1.383115,-1.159320,1
...,...,...,...,...,...,...,...,...,...,...,...
0,[4.106836324006936e-05],[8.49221565685968e-06],[1.6337379665270453e-06],[1.4220875052484921e-06],[4.1661031622684804e-05],[3.12121815541909e-06],[4.105707091704173e-06],[7.999951154682973e-07],-0.549031,-0.561252,5
0,[2.1344017035668287e-05],[4.692333794260802e-06],[3.206780156247017e-06],[2.220878334012285e-06],[2.536699689085539e-05],[1.3081547942775915e-06],[3.259443344494703e-06],[7.909686222949669e-07],-0.251428,-0.537172,5
0,[3.532781349045091e-05],[9.784195706176183e-06],[4.328542825192819e-06],[2.8208861463736062e-06],[1.6170708455558354e-05],[4.4633881144328095e-06],[4.859084725792112e-06],[9.90129051019647e-07],-0.897200,-0.646019,5
0,[4.137183638219365e-05],[5.8195706999590605e-06],[2.5163255718169557e-06],[2.2860233461796853e-06],[2.5903911954690413e-05],[2.5894529803656887e-06],[4.259077783324512e-06],[1.046113530762091e-06],-0.562427,-0.571097,5


In [8]:
val4ml = csv_to_df(dfval, cspval, 1000, selected_columns, labelval)
val4ml

Creating RawArray with float64 data, n_channels=1, n_times=1000
    Range : 0 ... 999 =      0.000 ...     3.996 secs
Ready.
Effective window size : 0.512 (s)
Effective window size : 0.512 (s)
Creating RawArray with float64 data, n_channels=1, n_times=1000
    Range : 0 ... 999 =      0.000 ...     3.996 secs
Ready.
Effective window size : 0.512 (s)
Effective window size : 0.512 (s)
Creating RawArray with float64 data, n_channels=1, n_times=1000
    Range : 0 ... 999 =      0.000 ...     3.996 secs
Ready.
Effective window size : 0.512 (s)
Creating RawArray with float64 data, n_channels=1, n_times=1000
    Range : 0 ... 999 =      0.000 ...     3.996 secs
Ready.
Effective window size : 0.512 (s)
Effective window size : 0.512 (s)
Creating RawArray with float64 data, n_channels=1, n_times=1000
    Range : 0 ... 999 =      0.000 ...     3.996 secs
Ready.
Effective window size : 0.512 (s)
Creating RawArray with float64 data, n_channels=1, n_times=1000
    Range : 0 ... 999 =      0.000 ... 

Unnamed: 0,FCz_10-12Hz,FCz_12-14Hz,C3_20-22Hz,C3_22-24Hz,Cz_8-10Hz,C4_20-22Hz,C4_22-24Hz,CP3_28-30Hz,0,1,Label
0,[4.7137633160687955e-05],[5.0230693449175205e-06],[1.9081715545390816e-06],[1.0397121166532323e-06],[0.00011021609923310416],[2.3248146630928578e-06],[1.6004716150989847e-06],[2.478395657438407e-07],-0.536163,-0.296584,1
0,[1.5502622327514992e-05],[3.91455267154262e-06],[2.1267760303007706e-06],[8.895512590700392e-07],[2.3409828040922734e-05],[2.657699744174111e-06],[2.6624222355273737e-06],[7.387570955010971e-07],0.16409,-0.475559,5
0,[3.840787292777667e-05],[6.1432730029109604e-06],[1.6150044448154704e-06],[1.164924946620321e-06],[4.3027857731904437e-05],[2.078654274703208e-06],[1.4674055557814024e-06],[1.4221867935869383e-06],-0.386492,-0.763659,1
0,[4.136167296019675e-05],[3.609391069040703e-06],[1.616863698811523e-06],[1.2772183098950907e-06],[2.9417314737017237e-05],[7.386214084641678e-07],[6.074632952405682e-07],[3.4664719842652217e-07],-0.504133,-0.592341,1
0,[8.369013386129521e-05],[7.940288107743614e-06],[3.2052394052175197e-06],[1.2983923160776532e-06],[4.023960430322492e-05],[6.469270785462478e-06],[3.820031422055077e-06],[5.590958468524213e-07],0.160788,-0.060112,5
0,[4.7686648101341695e-05],[6.078672637437661e-06],[4.01136596855942e-06],[2.904402428145362e-06],[4.105632428866088e-05],[1.473335052043084e-06],[2.7358697839302735e-06],[1.7822009619906383e-06],-0.535909,-0.132033,5
0,[1.1498590897491503e-05],[3.823305786893487e-06],[2.313009338752889e-06],[2.1210286450541544e-06],[1.113397192737158e-05],[2.345549268118903e-06],[1.0176827775220936e-06],[7.740668430417985e-07],-0.901603,-1.131594,1
0,[1.4164402282255635e-05],[3.6064279480704863e-06],[8.00402219918027e-07],[1.6040285083122634e-06],[3.681877767553094e-05],[3.492225138295367e-06],[3.7421743004162357e-06],[2.5705522546375413e-07],-1.252852,-0.585375,1
0,[1.6813248495987134e-05],[2.074562225794961e-06],[3.454504237354228e-06],[2.0771015291391103e-06],[3.762706961707506e-05],[4.1524240092625645e-06],[3.4551949934635527e-06],[1.034301038366453e-06],-0.536652,-0.827075,5
0,[3.386377546053987e-05],[5.244503936001568e-06],[1.035065717643072e-06],[4.606146106258605e-07],[3.3056469638877426e-05],[1.939357374038682e-06],[3.7463566942717892e-06],[5.706205583485806e-07],-0.552168,-0.375012,1


In [9]:
test4ml = csv_to_df(dftest, csptest, 1000, selected_columns, labeltest)
test4ml

Creating RawArray with float64 data, n_channels=1, n_times=1000
    Range : 0 ... 999 =      0.000 ...     3.996 secs
Ready.
Effective window size : 0.512 (s)
Effective window size : 0.512 (s)
Creating RawArray with float64 data, n_channels=1, n_times=1000
    Range : 0 ... 999 =      0.000 ...     3.996 secs
Ready.
Effective window size : 0.512 (s)
Effective window size : 0.512 (s)
Creating RawArray with float64 data, n_channels=1, n_times=1000
    Range : 0 ... 999 =      0.000 ...     3.996 secs
Ready.
Effective window size : 0.512 (s)
Creating RawArray with float64 data, n_channels=1, n_times=1000
    Range : 0 ... 999 =      0.000 ...     3.996 secs
Ready.
Effective window size : 0.512 (s)
Effective window size : 0.512 (s)
Creating RawArray with float64 data, n_channels=1, n_times=1000
    Range : 0 ... 999 =      0.000 ...     3.996 secs
Ready.
Effective window size : 0.512 (s)
Creating RawArray with float64 data, n_channels=1, n_times=1000
    Range : 0 ... 999 =      0.000 ... 

Unnamed: 0,FCz_10-12Hz,FCz_12-14Hz,C3_20-22Hz,C3_22-24Hz,Cz_8-10Hz,C4_20-22Hz,C4_22-24Hz,CP3_28-30Hz,0,1,Label
0,[3.636409158184905e-05],[6.712441756687734e-06],[2.2421640242953407e-06],[1.4777462296701713e-06],[1.7959646018896762e-05],[3.225589557007945e-06],[8.940261958520566e-07],[8.911567419124663e-07],-0.571638,-0.701341,5
0,[1.4869351233457928e-05],[4.792233229267913e-06],[2.855824355178831e-06],[1.1407103489635961e-06],[1.4391225995699224e-05],[1.6491401485643712e-06],[6.514236907320359e-07],[7.003869462442066e-07],-0.172076,0.018313,5
0,[7.495101283819305e-05],[1.1260690061788619e-05],[9.390429251170542e-07],[7.747989532440392e-07],[3.5310201109638266e-05],[2.2573227563335323e-06],[2.139128997745699e-06],[1.4590515562394636e-06],-0.311205,0.31968,5
0,[3.361793973544247e-05],[5.996493425690316e-06],[2.2200009577917517e-06],[5.984246790047414e-07],[4.27576495648852e-05],[9.237821026589628e-07],[1.1500474795800083e-06],[5.417148717863387e-07],-0.886263,-0.258722,1
0,[6.194498758846551e-06],[2.937683853118183e-06],[1.7425670614556816e-06],[9.35379137517342e-07],[1.305661687849384e-05],[2.921954062546908e-06],[1.1751019328962231e-06],[7.031476218686252e-07],-1.013004,-1.238462,5
0,[8.553697669507031e-06],[2.5391250317061974e-06],[1.5704637115749766e-06],[1.3558432312538372e-06],[1.8690894057525078e-05],[5.513286457938423e-06],[3.388537274714886e-06],[6.754282909396633e-07],-0.90019,-0.87687,1
0,[4.740785983432448e-05],[5.014922855093031e-06],[1.6652863068104823e-06],[1.1674313889699438e-06],[3.9469824003708797e-05],[2.7834323681467e-06],[1.4938462367406564e-06],[1.4470214349117057e-06],-0.535235,-0.257661,5
0,[1.8981331282496227e-05],[5.544842452534518e-06],[9.787021826540847e-07],[4.451942946943974e-07],[1.2529324171985875e-05],[7.466710548607147e-07],[9.66606992777451e-07],[6.368398408209463e-07],-1.27082,-1.154261,1
0,[6.965610733380811e-05],[3.7694532934019772e-06],[3.4712912390102963e-06],[1.3058266653307961e-06],[6.59102433362696e-05],[5.782596371447733e-07],[9.386874428804169e-07],[5.740386979681202e-07],0.11711,-0.611736,1
0,[2.1907137039741807e-05],[5.004173703131427e-06],[4.979444825910718e-06],[6.106629662626543e-06],[1.4432778027214772e-05],[5.141575330759545e-06],[3.453740625587922e-06],[1.6874889645961792e-06],-0.696927,-0.562915,5


In [17]:
train4ml.to_csv(base_path+'train4ml.csv')
val4ml.to_csv(base_path+'val4ml.csv')
test4ml.to_csv(base_path+'test4ml.csv')