# Setup

Initial module setup.

In [1]:
import mne.io
import pandas as pd
import typing
import mne
import numpy as np

from sklearn.svm import LinearSVC
from sklearn.model_selection import StratifiedKFold
from eeg_auth_defense_utilities import data_retrieval, filtration, features, formatting

# Constants

In [2]:
DATASET_SAMPLE_FREQ_HZ = 200
DATA_CHANNEL_NAMES = ['T7','F8','Cz','P4']
FREQUENCIES = [
    filtration.FrequencyBand(lower=8.0, upper=12.0, label='Alpha'),
    filtration.FrequencyBand(lower=12.0, upper=35.0, label='Beta'),
    filtration.FrequencyBand(lower=4.0, upper=8.0, label='Theta'),
    filtration.FrequencyBand(lower=35.0, upper=None, label='Gamma'),
    filtration.FrequencyBand(lower=None, upper=None, label='Raw'),
]
K_FOLDS = 10

# Utilities

## Types

In [3]:
SubjectFramesMap = typing.Dict[str, typing.List[pd.DataFrame]]
SubjectFrameFeaturesMap = typing.Dict[str, typing.List[np.ndarray]]
LabelledDataset = typing.Tuple[typing.List[np.ndarray], typing.List[int]]
LabelledDatasetMap = typing.Dict[str, LabelledDataset]
StratifiedData = typing.List[typing.Tuple[LabelledDataset, LabelledDataset]]
StratifiedDatasetMap = typing.Dict[str, StratifiedData]
T = typing.TypeVar('T')

## Functions

In [4]:
def filter_subject_data(subject_data: formatting.SubjectDataMap) -> formatting.SubjectDataMap:
    """
    Applies filtration to all the dataframes for each subject in the given data map.
    
    :param subject_data: the subject data to filter.
    :return: a new data map, wherein the keys are the subject identifiers and the values are the
             filtered data.
    """
    data_windows_filtered = {}
    bandpass_filter = filtration.EEGBandpassFilter(FREQUENCIES)
    
    for identifier, data_to_filter in subject_data.items():
        mne_data = convert_dataframe_to_mne(data_to_filter)
        data_windows_filtered[identifier] = bandpass_filter.apply_filter(mne_data, DATA_CHANNEL_NAMES)
        
    return data_windows_filtered


def convert_dataframe_to_mne(dataframe: pd.DataFrame) -> mne.io.RawArray:
    """
    Converts the given dataframe over to Python-MNE format.
    
    :param dataframe: The dataframe to convert.
    :return: A Python-MNE data array.
    """
    transposed_dataframe = dataframe.transpose(copy=True)
    data_info = mne.create_info(DATA_CHANNEL_NAMES, DATASET_SAMPLE_FREQ_HZ, ch_types='eeg')
    return mne.io.RawArray(transposed_dataframe.to_numpy(), data_info)


def extract_features(data_map: SubjectFramesMap) -> SubjectFrameFeaturesMap:
    """
    Extracts features from the given map of subjects to their windowed data samples.
    
    :param data_map: the data to extract features from.
    :return: the extracted features for each subject's windows, in a map.
    """
    extracted_features_map = {}
    extractor = features.ARFeatureExtractor(ar_model_config={'lags': 25})
    
    for subject in data_map:
        extracted_features_map[subject] = [extractor.extract(window) for window in data_map[subject]]
        
    return extracted_features_map


def get_labelled_dataset_map(map_to_convert: SubjectFrameFeaturesMap) -> LabelledDatasetMap:
    """
    Helper function which converts the given subject features map to a map of labelled datasets.
    
    :param map_to_convert: the original subject features map to convert. 
    :return: a new map wherein the keys are subject identifiers and the values are labelled datasets.
    """
    converted_map = {}
    
    for key in map_to_convert:
        converted_map[key] = _get_x_y_labelled_dataset(map_to_convert, key)
        
    return converted_map


def _get_x_y_labelled_dataset(map_to_label: SubjectFrameFeaturesMap, target_subject_key: str) -> LabelledDataset:
    """
    Utility function which generates a list of samples and a list of associated labels, based on the given target subject
    (i.e., '1' indicates the sample is for the target, '0' otherwise).
    
    
    :param map_to_label: a map wherein the keys are subject identifiers and the values are lists of data samples.
    :param target_subject_key: the key to use to tailor the dataset to.
    :return: a Tuple containing samples, and the corresponding labels.
    """
    if target_subject_key not in map_to_label:
        raise KeyError(f'Key "{target_subject_key}" not found in data map!')
    label_translation_map = {}
    samples_list = []
    labels_list = []
    
    for key in map_to_label:
        label_id = 1 if key == target_subject_key else 0
        label_translation_map[key] = label_id
        for subject_frame_sample in map_to_label[key]:
            samples_list.append(subject_frame_sample)
            labels_list.append(label_id)
    
    return samples_list, labels_list
        
        
def apply_stratified_k_fold(labelled_data_map: LabelledDatasetMap, folds: int) -> StratifiedDatasetMap:
    """
    Apply stratified k-fold to the datasets stored in the given labelled dataset map.
    
    :param labelled_data_map: the labelled subject dataset map to apply k-fold to.  
    :param folds: the number of folds.
    :return: an update data map, wherein each key is the subject identifier and each value is a list of k-fold datasets.
    """
    stratified_data_map = {}
    splitter = StratifiedKFold(folds)
    
    for key in labelled_data_map:
        subject_x, subject_y = labelled_data_map[key]
        subject_x, subject_y = np.array(subject_x), np.array(subject_y)
        stratified_subject_data = []
        for train, test in splitter.split(subject_x, subject_y):
            subject_x_train = subject_x[train]
            subject_y_train = subject_y[train]
            subject_x_test = subject_x[test]
            subject_y_test = subject_y[test]
            stratified_subject_data.append(
                (
                    (subject_x_train, subject_y_train),
                    (subject_x_test, subject_y_test)
                )
            )
        stratified_data_map[key] = stratified_subject_data
    
    return stratified_data_map
        
        

def get_sample_value_from_map(map_to_sample: typing.Dict[str, T]) -> T:
    """
    Helper function which retrieves a sample dataframe from the given map of data.
    
    :param map_to_sample: The data map to get a sample from. 
    """
    return next(iter(map_to_sample.values()))


def print_info_about_subjects(map_to_summarize: formatting.SubjectDataMap):
    """
    Helper function which prints some basic information about the
    subjects in a data map.
    
    :param map_to_summarize: the map to print info from.
    """
    print('SUBJECT DATA')
    print(f'Number of subjects: {len(map_to_summarize.keys())}')
    print('Subject identifiers:')
    for key in map_to_summarize:
        print(key)


def print_windowed_data_summary(windowed_data_map: SubjectFramesMap):
    """
    Helper function which prints some basic information on a windowed data map.
    
    :param windowed_data_map: The windowed data to summarize.
    """
    print('WINDOWED DATA')
    for key in windowed_data_map:
        print(f'Subject: {key}, Windows: {len(windowed_data_map[key])}')
        
        
def print_labelled_data_summary(labelled_data: LabelledDatasetMap):
    """
    Helper function which prints basic information about the given labelled data map.
    
    :param labelled_data: the labelled data map to summarize.
    """
    print('LABELLED DATA')
    for key in labelled_data:
        subject_labels = labelled_data[key][1]
        positive_count = len(
            list(
                filter(lambda label: label == 1, subject_labels)
            )
        )
        negative_count = len(
            list(
                filter(lambda label: label == 0, subject_labels)
            )
        )
        print(f'Subject: {key}')
        print(f'\tPositive data samples: {positive_count}')
        print(f'\tNegative data samples: {negative_count}')


def print_scores_summary(scores_map: typing.Dict[str, typing.List[float]]):
    """
    Helper function which prints basic information about the
    scores provided in a scores map.
    
    :param scores_map: a map of scores, where the key is the subject
                       and the value is a list of scores.
    """
    print('AVERAGE SCORES')
    for key in scores_map:
        avg_score = np.mean(scores_map[key])
        print(f'{key}: {avg_score}')

# Setup Dataset

In [5]:
dataset_downloader = data_retrieval.AuditoryDataDownloader()
dataset_path = dataset_downloader.retrieve()
dataset_formatter = formatting.AuditoryDataFormatter()
data = dataset_formatter.format_data(dataset_path)
print(f'{len(data.keys())} SUBJECTS LOADED FROM DATASET')
print('SAMPLE:')
sample_value = get_sample_value_from_map(data)
sample_value.head()

20 SUBJECTS LOADED FROM DATASET
SAMPLE:


Unnamed: 0,T7,F8,Cz,P4
13200,431.251617,-1189.493896,454.405334,345.306824
13201,444.240265,-1194.415649,471.23114,363.666016
13202,439.06427,-1188.719727,457.135437,325.425537
13203,442.071136,-1193.476929,458.751099,340.463654
13204,435.93396,-1197.149414,442.688232,333.630859


# Pre-process Data

## Filter

In [6]:
data = filter_subject_data(data)
print('FILTERED DATA')
print('SAMPLE:')
sample_value = get_sample_value_from_map(data)
sample_value.head()

Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=40114
    Range : 0 ... 40113 =      0.000 ...   200.565 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ..

Unnamed: 0,T7.Alpha,T7.Beta,T7.Theta,T7.Gamma,T7.Raw,F8.Alpha,F8.Beta,F8.Theta,F8.Gamma,F8.Raw,Cz.Alpha,Cz.Beta,Cz.Theta,Cz.Gamma,Cz.Raw,P4.Alpha,P4.Beta,P4.Theta,P4.Gamma,P4.Raw
0,-7.704948e-08,4.041212e-08,-5.87308e-08,-8.437695e-08,431251600.0,1.94178e-07,-1.985079e-07,1.718625e-07,9.05942e-08,-1189494000.0,-6.883383e-08,6.328271e-08,-7.177592e-08,-8.881784e-10,454405300.0,-6.150636e-08,5.373479e-08,-4.551914e-08,-3.108624e-08,345306800.0
1,1597428.0,5617263.0,260440.5,6113604.0,444240300.0,-700075.1,-1784533.0,-625694.4,-3457973.0,-1194416000.0,1179360.0,2073403.0,106048.1,14290680.0,471231100.0,642413.8,-5416940.0,-223332.7,24436510.0,363666000.0
2,3003843.0,7803250.0,516193.8,-2495471.0,439064300.0,-1323957.0,-2837481.0,-1221153.0,4026421.0,-1188720000.0,2203946.0,2787055.0,224076.0,-1378198.0,457135400.0,1211374.0,-7326968.0,-434206.4,-11518090.0,325425500.0
3,4046442.0,5954752.0,767039.3,1200391.0,442071100.0,-1809801.0,-3154577.0,-1753179.0,1247092.0,-1193477000.0,2936890.0,1816007.0,366279.9,45875.54,458751100.0,1644117.0,-4896868.0,-623130.0,1218548.0,340463700.0
4,4598322.0,2588707.0,1012805.0,-2499456.0,435934000.0,-2103975.0,-3215320.0,-2199496.0,-962059.5,-1197149000.0,3276660.0,-3953.14,544973.3,-15118960.0,442688200.0,1887982.0,-453984.1,-776919.0,-9855401.0,333630900.0


## Window

In [7]:
window_formatter = formatting.DataWindowFormatter(window_size=1200, overlap=0.5)
windowed_data = {subject: window_formatter.create_windows(data) for subject, data in data.items()}

# Feature Extraction

In [8]:
features_map = extract_features(windowed_data)
print('EXTRACTED FEATURES')
sample_windows = get_sample_value_from_map(features_map)
sample_value = sample_windows[0]
print('SAMPLE:')
print(f'SIZE: {len(sample_value)}')
print(f'ELEMENTS: {sample_value}')

EXTRACTED FEATURES
SAMPLE:
SIZE: 520
ELEMENTS: [ 5.16522137e+01  2.24630794e+00 -9.03048407e-01 -5.82481784e-01
 -1.97534788e-01 -8.72548760e-02  3.46408427e-01  5.38384508e-01
 -2.40385954e-02 -2.88540029e-01 -1.69869593e-01  8.50351196e-02
  1.13690514e-01 -4.26474550e-01  4.28296715e-02  2.24245348e-01
  1.60818455e-01 -6.91497089e-02  4.95879506e-02  4.41601278e-02
 -1.01926451e-01 -2.73979528e-01 -9.29162278e-02  2.62771285e-01
  2.74920747e-01 -2.60489809e-01 -5.32375179e+02  4.65319564e+00
 -1.03634806e+01  1.28730021e+01 -8.04141945e+00 -5.29821507e-01
  4.34629590e+00 -1.51491641e+00 -2.24124622e+00  1.96164584e+00
  3.67623140e-01 -1.28788791e+00  1.33769993e-01  5.27846890e-01
  5.85958561e-01 -2.28847621e+00  1.86782973e+00  1.90941937e-01
 -1.08242894e+00 -3.59226297e-01  1.53422271e+00 -7.54151041e-01
 -7.37856488e-01  1.09292261e+00 -5.60687953e-01  1.00726258e-01
  1.74189788e+00  2.17268855e+00 -6.83306976e-01 -5.12570144e-01
 -2.85191559e-01 -2.66779384e-01  1.9417064

# Prepare Training and Test Data

## Label Datasets

In [9]:
labelled_feature_data = get_labelled_dataset_map(features_map)
print_labelled_data_summary(labelled_feature_data)

LABELLED DATA
Subject: S01
	Positive data samples: 39
	Negative data samples: 767
Subject: S02
	Positive data samples: 39
	Negative data samples: 767
Subject: S03
	Positive data samples: 39
	Negative data samples: 767
Subject: S04
	Positive data samples: 39
	Negative data samples: 767
Subject: S05
	Positive data samples: 65
	Negative data samples: 741
Subject: S06
	Positive data samples: 39
	Negative data samples: 767
Subject: S07
	Positive data samples: 39
	Negative data samples: 767
Subject: S08
	Positive data samples: 39
	Negative data samples: 767
Subject: S09
	Positive data samples: 39
	Negative data samples: 767
Subject: S10
	Positive data samples: 39
	Negative data samples: 767
Subject: S11
	Positive data samples: 39
	Negative data samples: 767
Subject: S12
	Positive data samples: 39
	Negative data samples: 767
Subject: S13
	Positive data samples: 39
	Negative data samples: 767
Subject: S14
	Positive data samples: 39
	Negative data samples: 767
Subject: S15
	Positive data sample

## Split Data

In [10]:
stratified_data = apply_stratified_k_fold(labelled_feature_data, K_FOLDS)
print(f'DATA SPLIT FOR {K_FOLDS}-FOLD CROSS VALIDATION')

DATA SPLIT FOR 10-FOLD CROSS VALIDATION


# Model

## Configuration

In [11]:
# Every subject gets their own classifier, trained to verify their data
classifiers: typing.Dict[str, LinearSVC] = {
    subject: LinearSVC(
        random_state=32,
        dual='auto',
        max_iter=2000
    )
    for subject in stratified_data
}

## Training

In [12]:
scores: typing.Dict[str, typing.List[float]] = {subject: [] for subject in stratified_data}
x_idx = 0
y_idx = 1
for subject in stratified_data:
    print(f'TRAINING MODEL FOR SUBJECT: {subject}')
    k_folds = stratified_data[subject]
    clf = classifiers[subject]
    fold_counter = 1
    for train_data, test_data in k_folds:
        print(f'FOLD: {fold_counter}')
        clf.fit(train_data[x_idx], train_data[y_idx])
        scores[subject].append(
            clf.score(test_data[x_idx], test_data[y_idx])
        )
        fold_counter += 1
print('TRAINING COMPLETE')

TRAINING MODEL FOR SUBJECT: S01
FOLD: 1
FOLD: 2
FOLD: 3
FOLD: 4
FOLD: 5
FOLD: 6
FOLD: 7
FOLD: 8
FOLD: 9
FOLD: 10
TRAINING MODEL FOR SUBJECT: S02
FOLD: 1
FOLD: 2
FOLD: 3
FOLD: 4
FOLD: 5
FOLD: 6
FOLD: 7
FOLD: 8
FOLD: 9
FOLD: 10
TRAINING MODEL FOR SUBJECT: S03
FOLD: 1
FOLD: 2
FOLD: 3
FOLD: 4
FOLD: 5
FOLD: 6
FOLD: 7
FOLD: 8
FOLD: 9
FOLD: 10
TRAINING MODEL FOR SUBJECT: S04
FOLD: 1
FOLD: 2
FOLD: 3
FOLD: 4
FOLD: 5
FOLD: 6
FOLD: 7
FOLD: 8
FOLD: 9
FOLD: 10
TRAINING MODEL FOR SUBJECT: S05
FOLD: 1
FOLD: 2
FOLD: 3
FOLD: 4
FOLD: 5
FOLD: 6
FOLD: 7
FOLD: 8
FOLD: 9
FOLD: 10
TRAINING MODEL FOR SUBJECT: S06
FOLD: 1
FOLD: 2
FOLD: 3
FOLD: 4
FOLD: 5
FOLD: 6
FOLD: 7
FOLD: 8
FOLD: 9
FOLD: 10
TRAINING MODEL FOR SUBJECT: S07
FOLD: 1
FOLD: 2
FOLD: 3
FOLD: 4
FOLD: 5
FOLD: 6
FOLD: 7
FOLD: 8
FOLD: 9
FOLD: 10
TRAINING MODEL FOR SUBJECT: S08
FOLD: 1
FOLD: 2
FOLD: 3
FOLD: 4
FOLD: 5
FOLD: 6
FOLD: 7
FOLD: 8
FOLD: 9
FOLD: 10
TRAINING MODEL FOR SUBJECT: S09
FOLD: 1
FOLD: 2
FOLD: 3
FOLD: 4
FOLD: 5
FOLD: 6
FOLD: 7
FOLD: 8


## Results

In [13]:
print_scores_summary(scores)

AVERAGE SCORES
S01: 0.8140277777777778
S02: 0.8003549382716049
S03: 0.8076234567901235
S04: 0.8069907407407408
S05: 0.733179012345679
S06: 0.7400462962962963
S07: 0.699212962962963
S08: 0.8153549382716049
S09: 0.7278549382716049
S10: 0.7451851851851852
S11: 0.7787654320987654
S12: 0.7997222222222222
S13: 0.9065432098765432
S14: 0.7773148148148148
S15: 0.7734413580246914
S16: 0.8084259259259259
S17: 0.7996604938271605
S18: 0.8486265432098765
S19: 0.8507716049382715
S20: 0.7599228395061728
