In [1]:
import os
import re
import pickle
import numpy as np
from scipy.io import loadmat
from pprint import pprint

In [None]:
# function to collect matching files and dirs
def collect_files(root, res, pattern="", collect_dirs=True, min_depth=None, max_depth=None):
    
    # check max depth
    if not max_depth is None and max_depth == 0:
        return
    
    # go through all item in the dir
    for item in os.listdir(root):
        
        # process item
        item_path = os.path.join(root, item)
        item_is_dir = os.path.isdir(item_path)
        
        # pull valid file in res if min depth has reached
        if min_depth is None or min_depth - 1 <= 0:
            if re.match(pattern, item_path):
                if not item_is_dir or collect_dirs:
                    res.append(item_path)
        
        # recursively collect all files
        if item_is_dir:
            next_min_depth = None if min_depth is None else min_depth - 1
            next_max_depth = None if max_depth is None else max_depth - 1
            collect_files(item_path, res, pattern, collect_dirs, next_min_depth, next_max_depth)

In [None]:
# collect the mat files
mat_files = []
collect_files("./", mat_files, pattern=".*\.mat$", collect_dirs=False)
mat_files.sort()
mat_files

In [None]:
# load all data into memory
# all_data[i] means data for ith subject
all_data = []
for i, mat_file in enumerate(mat_files):
    
    # re-index tmp into a dictionary
    tmp = loadmat(mat_file)["data"][0][0]
    tmp = {name: data for name, data in zip(tmp.dtype.names, tmp)}
    
    # rename column
    tmp["x"] = tmp["X"]
    del tmp["X"]
    
    # reshape columns
    tmp["y"] = tmp["y"].reshape(-1)
    tmp["y_stim"] = tmp["y_stim"].reshape(-1)
    tmp["trial"] = tmp["trial"].reshape(-1)
    
    # add subject info
    tmp["subject"] = i + 1
    
    all_data.append(tmp)

In [None]:
pprint(all_data[1])

In [None]:
# constants for data_extraction
sample_rate = 250 #hz
tick_len = 1000 // sample_rate # ms
pre_epoch = 0 #ms
post_epoch = 700 #ms

In [None]:
# give raw eeg data and tick times, return 2d signals
def extract_epochs(raw, ticks):
    pre_tick = int(pre_epoch // tick_len)
    post_tick = int(post_epoch // tick_len)
    raw_len = len(raw)
    signals = []
    for t in ticks:
        if t + post_tick <= raw_len: 
            signal = raw[t-pre_tick:t+post_tick, :]
            signals.append(signal)
    return np.array(signals)

In [None]:
# extract epochs for every subject
for i, data in enumerate(all_data):
    
    # extract from raw
    ticks, y_stim, y = data["flash"][:, [0, 2, 3]].T
    raw = data["x"]
    
    # get the epochs
    epochs = extract_epochs(raw, ticks)
    
    # label the epochs
    for j, x in enumerate(y):
        assert x == 1 or x == 2
        y[j] = 1 if x == 2 else 0
    
    # trim extra y and y_stim
    y = y[:len(epochs)]
    y_stim = y_stim[:len(epochs)]
    
    assert len(epochs) == len(y) and len(y) == len(y_stim)
    
    samples = np.array(list(zip(epochs, y, y_stim)))
    
    # save the data
    with open(f"s{i+1}.pkl", "wb") as outfile:
        pickle.dump(samples, outfile)

In [None]:
with open("s1.pkl", "rb") as infile:
    data = pickle.load(infile)

In [None]:
a = []
for i in data[:, 0]:
    a.append(i)
np.array(a).shape

In [2]:
for i in range(1, 9):
    with open(f"s{i}.pkl", "rb") as infile:
        data = pickle.load(infile)
        target = np.sum(data[:, 1])
        print(target, len(data) - target)

700 3500
700 3498
700 3499
700 3498
698 3500
700 3498
698 3500
700 3498
