In [8]:
from scipy.io import loadmat
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from scipy.signal import butter, filtfilt


In [9]:
def load_data(filename):
    """
    Load data from a .mat file.
    """
    data = loadmat(filename)
    return data['lsl_data'][:,:-1], data['marker_data'].reshape(-1,4)

In [10]:
def bandpassFilter(data, sr, lowcut, highcut):
    ret = []
    nyq = 0.5 * sr
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(5, [low, high], btype='band')
    for d in data.T:
        ret.append(filtfilt(b, a, d))
    return ret

In [11]:
def filter_data(raw_data, marker_data):
    cleaned_data = []
    labels = []
    for start, label, end, isbad in marker_data:
        if abs(isbad-99) < 1e-3:
            continue
        cleaned_data.append([])
        mask = (raw_data[:, 0] >= start) & (raw_data[:, 0] <= end)
        filtered_data = raw_data[mask][:1400]
        filtered_data = bandpassFilter(filtered_data, 1000, 50, 450)
        cleaned_data[-1].extend(filtered_data)
        labels.append(label)
    return np.array(cleaned_data), labels

In [12]:
all_data = []
all_labels = []
for filename in os.listdir('raw_data'):
    if filename.endswith('.mat'):
        raw_data, marker_data = load_data('raw_data/' + filename)
        cleaned_data, labels = filter_data(raw_data, marker_data)
        all_data.append(cleaned_data)
        all_labels += labels
all_labels = np.array(all_labels)
all_data = np.concatenate(all_data)

In [13]:
np.save( 'processed_data/all_data.npy', all_data)
np.save('processed_data/all_labels.npy', all_labels)

In [14]:
all_data.shape

(240, 5, 1400)