In [None]:
import mne
import re
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

plt.ion()

In [None]:
MFF_DIR = '/home/ajays/Desktop/WBI-data/'
STIM_CHANNEL_NAMES = ['201' + str(i) for i in range(10)]

In [None]:
def get_mapping(raw):    
    """
    algorithm to map ids to strings
    - create individual raw copies for each stimulus channel i
    - for each raw copy, find events and set id = events[0,:2]
    - id_list[i] = id
    """
    id_list = [0]*10
    for i,sc in enumerate(STIM_CHANNEL_NAMES):
        raw_temp = raw.copy()
        raw_temp.pick([sc])
        events = mne.find_events(raw_temp, verbose=False)
        id_list[i] = events[0,2]
        
    raw_temp = raw.copy()
    picks = raw_temp.pick(STIM_CHANNEL_NAMES)
    events = mne.find_events(raw_temp, verbose=False)
        
    return id_list, events, picks

In [None]:
fnames = sorted(filter(re.compile("[a-zA-Z0-9_-]*.mff").match, os.listdir(MFF_DIR)))
print(fnames)

In [None]:
"""
algorithm to create dataframe
DF = nothing dataframe
for file in files:
    read raw file
    id_list, events, picks = get_mapping(raw)
    new_events = convert_labels(id_list,events){
        replace each event with index val in id_list array
    }
    df = prepare_df(file,raw,new_events){
        dset = new_events[0 and 2]
        make dset into df
        add filename col = file
        return df
    }
    append df to DF
"""
def convert_labels(id_list,events):
    events[:,-1] = np.array([id_list.index(events[i,-1]) for i in range(len(events))])
    return events

def prepare_df(fname,raw,new_events):
    dset = new_events[:,[0,2]]
    df = pd.DataFrame(dset,columns=['s_time','label'])
    df['fname'] = [fname]*len(df)
    return df
    
data_df = []
for i,fname in enumerate(fnames):
    raw = mne.io.read_raw_egi(MFF_DIR + fname, preload=True, verbose=False)
    id_list, events, picks = get_mapping(raw)
    new_events = convert_labels(id_list,events)
    df = prepare_df(fname,raw,new_events)
    if len(data_df) == 0:
        data_df = df
    else:
        data_df = data_df.append(df)
    print(i,fname)

In [None]:
print(data_df.to_csv('./event_data.csv',index=False))

In [None]:
print(data_df.head())