In [None]:
from scipy.io import loadmat
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from scipy.signal import butter, filtfilt

In [None]:
def load_data(filename):
    """
    Load data from a .mat file.
    Args:
        filename: the path of the .mat file
    Returns:
        data: a numpy array of shape (n_samples, n_channels)
        marker_data: a numpy array of shape (n_samples, 4)
    Note: 
        Commented lines are because python and matlab recorded data 
        have different marker data format.
    """
    data = loadmat(filename)
    marker_datas= []
    marker_data = data['marker_data']
    # for timestamp, label in marker_data:
    #     if abs(label-99)< 1e-3:
    #         continue
    #     marker_datas.extend([timestamp, label])
    # marker_datas = np.array(marker_datas)
    return data['lsl_data'][:,:-1], marker_data.reshape(-1,4)

In [None]:
def highpass_filter(data, cutoff, fs):
    """
    Highpass filter the data to remove low frequency noise.
    Args:
        data: a numpy array of shape (n_samples, n_channels)
        cutoff: the cutoff frequency
        fs: the sampling frequency
    """
    b, a = butter(2, cutoff / (0.5 * fs), btype='highpass')
    return filtfilt(b, a, data)

def bandstop_filter(data, lowcut,highcut, fs=1000):
    """
    Bandstop filter the data to remove high frequency noise.
    Args:
        data: a numpy array of shape (n_samples, n_channels)
        lowcut: the lower cutoff frequency
        highcut: the higher cutoff frequency
        fs: the sampling frequency
    """
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(2, [low, high], btype='bandstop')
    return filtfilt(b, a, data)

In [None]:
def filter_channels(data, fs):
    """
    Filter the data to remove low frequency noise and certain 
    frequency bands noise from electrical devices. (60Hz, 120Hz, 180Hz)
    Args:
        data: a numpy array of shape (n_samples, n_channels)
        fs: the sampling frequency
    Returns:
        ret: signal data after filtering
    """
    ret = []
    for d in data.T:
        d = highpass_filter(d, 20, fs)
        d = bandstop_filter(d, 58,62, fs)
        d = bandstop_filter(d, 118,122, fs)
        d = bandstop_filter(d, 178,182, fs)
        ret.append(d)
    
    return ret


In [None]:
def filter_data(raw_data, marker_data,fs):
    """
    Remove the data entries where the label is invalid (99), 
    and chunck the data into 1.4s windows according to the timestamps in marker data.
    Then we filter the chuncked data to remove low frequency noise and certain electrical devices noise.
    Args:
        raw_data: a numpy array of shape (n_samples, n_channels)
        marker_data: a numpy array of shape (n_samples, 4)
        fs: the sampling frequency
    Returns:
        cleaned_data(numpy array of shape n_samples, n_channels, 1400): signal data after filtering
        labels (list): labels of the data
    """
    cleaned_data = []
    labels = []
    for start, label, end, isbad in marker_data:
        if abs(isbad-99) < 1e-3:
            continue
        
        mask = (raw_data[:, 0] >= start) & (raw_data[:, 0] <= end)
        if len(raw_data[mask]) < 1400:
            continue
        cleaned_data.append([])

        filtered_data = raw_data[mask][:1400]
        filtered_data = filter_channels(filtered_data, fs)
        cleaned_data[-1].extend(filtered_data)
        labels.append(label)
    return np.array(cleaned_data), labels

In [None]:
"""
Load all data and labels from raw_data folder.
"""
all_data = []
all_labels = []
for filename in os.listdir('raw_data'):
    if filename.endswith('.mat'):
        print(filename)
        raw_data, marker_data = load_data('raw_data/' + filename)

        cleaned_data, labels = filter_data(raw_data, marker_data,1000)
        all_data.append(cleaned_data)
        all_labels += labels
all_labels = np.array(all_labels)
all_data = np.concatenate(all_data)

In [None]:
# all_data_1 = np.load('processed_data/tune_data1.npy')
# all_labels_1 = np.load('processed_data/tune_labels1.npy')


In [None]:
# all_data = np.concatenate([all_data_1, all_data])
# all_labels = np.concatenate([all_labels_1, all_labels])

In [None]:
# np.save( 'processed_data/tune_data.npy', all_data)
# np.save('processed_data/tune_labels.npy', all_labels)