# Setup

Initial module setup.

In [1]:
import typing

from sklearn.svm import LinearSVC
from eeg_auth_models_framework import data, pre_process, features, training, model
from eeg_auth_models_framework.training.base import StratifiedSubjectData

# Constants

In [2]:
DATASET_SAMPLE_FREQ_HZ = 200
DATA_CHANNEL_NAMES = ['T7','F8','Cz','P4']
FREQUENCIES = [
    pre_process.FrequencyBand(lower=8.0, upper=12.0, label='Alpha'),
    pre_process.FrequencyBand(lower=12.0, upper=35.0, label='Beta'),
    pre_process.FrequencyBand(lower=4.0, upper=8.0, label='Theta'),
    pre_process.FrequencyBand(lower=35.0, upper=None, label='Gamma'),
    pre_process.FrequencyBand(lower=None, upper=None, label='Raw'),
]
WINDOW_SIZE = 1200
WINDOW_OVERLAP = 0.5
K_FOLDS = 10

# Utilities

## Types

In [3]:
ARSVMClassifiers = typing.Dict[str, LinearSVC]

## Functions

# Model Builder Configuration

Configure data source, data reading method, data labelling method, and training process. 

In [4]:
class ARSVMBuilder(model.ModelBuilder[ARSVMClassifiers]):
    @property
    def data_downloader(self):
        return data.AuditoryDataDownloader()
    
    @property
    def data_reader(self):
        return data.AuditoryDataReader()
    
    @property
    def labeller(self):
        return training.SubjectDataLabeller()
    
    def run_training(self, labelled_data: typing.Dict[str, typing.List[StratifiedSubjectData]]):
        subject_models: typing.Dict[str, LinearSVC] = {
            subject: LinearSVC(
                random_state=32,
                dual='auto',
                max_iter=2000
            )
            for subject in labelled_data
        }
        training_scores: typing.Dict[str, typing.List[float]] = {
            subject: [] for subject in labelled_data
        }
        for subject in labelled_data:
            print(f'TRAINING MODEL FOR SUBJECT: {subject}')
            stratified_data = labelled_data[subject]
            classifier = subject_models[subject]
            iteration_count = 1
            for segment in stratified_data:
                print(f'FOLD: {iteration_count}')
                classifier.fit(segment.train.x, segment.train.y)
                training_scores[subject].append(
                    classifier.score(segment.test.x, segment.test.y)
                )
                iteration_count += 1
        print('TRAINING COMPLETE')
        
        return subject_models

# Pre-Processing Steps

Define pre-processing steps to be used in model.

In [5]:
pre_process_steps = [
    pre_process.EEGBandpassFilterStep(
        FREQUENCIES, 
        DATA_CHANNEL_NAMES, 
        DATASET_SAMPLE_FREQ_HZ
    ),
    pre_process.DataWindowStep(WINDOW_SIZE, WINDOW_OVERLAP)
]

# Feature Extraction Steps

Define feature extraction steps to be applied to the pre-processed data.

In [6]:
feature_extraction_steps = [
    features.ARFeatureExtractor({'lags': 25})
]

# Training

Execute training of authentication models.

In [7]:
ar_svm_builder = ARSVMBuilder(
    pre_process_steps,
    feature_extraction_steps
)
ar_svm_builder.train(K_FOLDS)

Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=40114
    Range : 0 ... 40113 =      0.000 ...   200.565 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ..

{'S01': LinearSVC(dual='auto', max_iter=2000, random_state=32),
 'S02': LinearSVC(dual='auto', max_iter=2000, random_state=32),
 'S03': LinearSVC(dual='auto', max_iter=2000, random_state=32),
 'S04': LinearSVC(dual='auto', max_iter=2000, random_state=32),
 'S05': LinearSVC(dual='auto', max_iter=2000, random_state=32),
 'S06': LinearSVC(dual='auto', max_iter=2000, random_state=32),
 'S07': LinearSVC(dual='auto', max_iter=2000, random_state=32),
 'S08': LinearSVC(dual='auto', max_iter=2000, random_state=32),
 'S09': LinearSVC(dual='auto', max_iter=2000, random_state=32),
 'S10': LinearSVC(dual='auto', max_iter=2000, random_state=32),
 'S11': LinearSVC(dual='auto', max_iter=2000, random_state=32),
 'S12': LinearSVC(dual='auto', max_iter=2000, random_state=32),
 'S13': LinearSVC(dual='auto', max_iter=2000, random_state=32),
 'S14': LinearSVC(dual='auto', max_iter=2000, random_state=32),
 'S15': LinearSVC(dual='auto', max_iter=2000, random_state=32),
 'S16': LinearSVC(dual='auto', max_iter=