# Setup

Initial module setup.

In [1]:
import pandas as pd
import typing
import mne
import numpy as np

from sklearn.ensemble import RandomForestClassifier
from eeg_auth_models_framework import data, pre_process, features, training, model
from eeg_auth_models_framework.training.base import StratifiedSubjectData

# Constants

In [2]:
DATASET_SAMPLE_FREQ_HZ = 200
DATA_CHANNEL_NAMES = ['T7','F8','Cz','P4']
FREQUENCIES = [
    pre_process.FrequencyBand(lower=8.0, upper=12.0, label='Alpha'),
    pre_process.FrequencyBand(lower=12.0, upper=35.0, label='Beta'),
    pre_process.FrequencyBand(lower=4.0, upper=8.0, label='Theta'),
    pre_process.FrequencyBand(lower=35.0, upper=None, label='Gamma'),
    pre_process.FrequencyBand(lower=None, upper=None, label='Raw'),
]
WINDOW_SIZE = 1200
WINDOW_OVERLAP = 0.5
K_FOLDS = 10

# Utilities

## Types

In [3]:
SubjectDataMap = typing.Dict[str, pd.DataFrame]
SubjectFramesMap = typing.Dict[str, typing.List[pd.DataFrame]]
RawFrequencyDataMap = typing.Dict[str, typing.Union[mne.io.Raw, mne.io.RawArray]]
SubjectFrameFeaturesMap = typing.Dict[str, typing.List[np.ndarray]]
LabelledDataset = typing.Tuple[typing.List[np.ndarray], typing.List[int]]
LabelledDatasetMap = typing.Dict[str, LabelledDataset]
StratifiedData = typing.List[typing.Tuple[LabelledDataset, LabelledDataset]]
StratifiedDatasetMap = typing.Dict[str, StratifiedData]
MusicIDClassifiers = typing.Dict[str, RandomForestClassifier]
T = typing.TypeVar('T')

# Setup Dataset

# Model Builder Configuration

Configure data source, data reading method, data labelling method, and training process. 

In [4]:
class MusicIDModelBuilder(model.ModelBuilder[MusicIDClassifiers]):
    @property
    def data_downloader(self):
        return data.AuditoryDataDownloader()
    
    @property
    def data_reader(self):
        return data.AuditoryDataReader()
    
    @property
    def labeller(self):
        return training.SubjectDataLabeller()
    
    def run_training(self, labelled_data: typing.Dict[str, typing.List[StratifiedSubjectData]]):
        subject_models: typing.Dict[str, RandomForestClassifier] = {
            subject: RandomForestClassifier(
                n_estimators=100,
                criterion='gini',
                max_depth=10,
                random_state=32
            )
            for subject in labelled_data
        }
        training_scores: typing.Dict[str, typing.List[float]] = {
            subject: [] for subject in labelled_data
        }
        for subject in labelled_data:
            print(f'TRAINING MODEL FOR SUBJECT: {subject}')
            stratified_data = labelled_data[subject]
            classifier = subject_models[subject]
            iteration_count = 1
            for segment in stratified_data:
                print(f'FOLD: {iteration_count}')
                classifier.fit(segment.train.x, segment.train.y)
                training_scores[subject].append(
                    classifier.score(segment.test.x, segment.test.y)
                )
                iteration_count += 1
        print('TRAINING COMPLETE')
        
        return subject_models

# Pre-Processing Steps

Define pre-processing steps to be used in model.

In [5]:
pre_process_steps = [
    pre_process.EEGBandpassFilterStep(
        FREQUENCIES, 
        DATA_CHANNEL_NAMES, 
        DATASET_SAMPLE_FREQ_HZ
    ),
    pre_process.DataWindowStep(WINDOW_SIZE, WINDOW_OVERLAP)
]

# Feature Extraction Steps

Define feature extraction steps to be applied to the pre-processed data.

In [6]:
feature_extraction_steps = [
    features.StatisticalFeatureExtractor([
        features.StatisticalFeature.MIN,
        features.StatisticalFeature.MAX,
        features.StatisticalFeature.MEAN,
        features.StatisticalFeature.ZERO_CROSSING_RATE
    ])
]

# Training

Execute training of authentication models.

In [7]:
music_id_builder = MusicIDModelBuilder(
    pre_process_steps,
    feature_extraction_steps
)
music_id_builder.train(K_FOLDS) # FIXME: breaks with error "classification metrics can't handle a mix of continuous-multioutput and binary targets"

Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=40114
    Range : 0 ... 40113 =      0.000 ...   200.565 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ..

ValueError: Classification metrics can't handle a mix of continuous-multioutput and binary targets