# Setup

Initial module setup.

In [10]:
import bs4
import copy
import dataclasses
import pathlib
import urllib.request
import urllib.parse
import re
import mne.io
import requests
import pandas as pd
import typing
import mne
import numpy as np

# Data Structures

In [11]:
@dataclasses.dataclass
class FrequencyBand:
    lower: typing.Optional[float]
    upper: typing.Optional[float]
    label: str

# Constants

In [12]:
DATASET_URL = 'https://physionet.org/files/auditory-eeg/1.0.0/Segmented_Data/'
DATASET_SAMPLE_FREQ_HZ = 200
DATA_CHANNEL_NAMES = ['T7','F8','Cz','P4']
RAW_FREQUENCY = 'Raw'
FREQUENCIES = [
    FrequencyBand(lower=8.0, upper=12.0, label='Alpha'),
    FrequencyBand(lower=12.0, upper=35.0, label='Beta'),
    FrequencyBand(lower=4.0, upper=8.0, label='Theta'),
    FrequencyBand(lower=35.0, upper=None, label='Gamma')
]
TRAINING_SPLIT_RATIO = 0.8
FEATURE_MIN = 'Min'
FEATURE_MAX = 'Max'
FEATURE_MEAN = 'Mean'
FEATURE_ZCR = 'ZCR'

# Utilities

## Types

In [13]:
SubjectDataMap = typing.Dict[str, pd.DataFrame]
SubjectFramesMap = typing.Dict[str, typing.List[pd.DataFrame]]
RawFrequencyDataMap = typing.Dict[str, typing.Union[mne.io.Raw, mne.io.RawArray]]
SubjectFrameFeaturesMap = typing.Dict[str, typing.List[np.ndarray]]
LabelledDataset = typing.Tuple[typing.List[np.ndarray], typing.List[int]]
LabelledDatasetMap = typing.Dict[str, LabelledDataset]
T = typing.TypeVar('T')

In [14]:
def retrieve_dataset() -> pathlib.Path:
    """
    Downloads (if necessary) the dataset and retrieves the path to the root of the dataset files directory.

    :return: The path object pointing to the dataset directory.
    """
    data_directory = _get_data_directory()
    if data_directory.exists():
        return data_directory
    data_directory.mkdir(exist_ok=True)
    _download_dataset(data_directory)
    return data_directory


def _get_data_directory() -> pathlib.Path:
    """
    Helper function which generates the path to the data directory.

    :return: The data directory path object.
    """
    return pathlib.Path().resolve().parent / 'data'


def _download_dataset(target_path: pathlib.Path):
    """
    Initiates download of the dataset and saves all files into the given target path directory.

    :param target_path: The target path directory.
    """
    with requests.get(DATASET_URL) as listing_page:
        listing_soup = bs4.BeautifulSoup(
            listing_page.content,
            features='html.parser'
        )
        _download_files_in_listing(target_path, listing_soup)


def _download_files_in_listing(target_path: pathlib.Path, listing_soup: bs4.BeautifulSoup):
    """
    Helper function which iterates over all file links in the given BeautifulSoup object and downloads each file into
    the target path's directory.

    :param target_path: The target path directory.
    :param listing_soup: The BeautifulSoup object to use to find download links.
    """
    experiment_1_data_pattern = re.compile(r's\d{2}_ex05\.csv')
    
    for file_link in listing_soup.find_all('a'):
        file_href = file_link.get('href')
        if file_href and experiment_1_data_pattern.match(file_href):
            file_path = target_path / file_href
            file_url = urllib.parse.urljoin(DATASET_URL, file_href)
            _download_url_to_file(file_path, file_url)


def _download_url_to_file(file_path: pathlib.Path, url: str):
    """
    Downloads the given URLs remote content to the given file path.

    :param file_path: The file path to download to.
    :param url: The URL to download from.
    """
    with requests.get(url) as response:
        with open(file_path, 'wb') as out_file:
            for chunk in response.iter_content(chunk_size=1024):
                out_file.write(chunk)
                

def read_data(target_path: pathlib.Path) -> SubjectDataMap:
    """
    Reads all data files in the given directory path and generates a structure of dataframes.
    
    :param target_path: The target directory path.
    :return: A map of dataframes, where the key is an identifier for the file and the value is the dataframe.
    """
    loaded_data_map = {}
    
    for data_file in target_path.iterdir():
        if data_file.suffix == '.csv':
            subject_identifier = _get_subject_identifier(data_file.name)
            dataframe = pd.read_csv(data_file, index_col=0, header=0)
            loaded_data_map[subject_identifier] = dataframe
    
    return loaded_data_map


def _get_subject_identifier(data_file_name: str) -> str:
    """
    Helper function which parses a subject identifier from a data file name.
    
    :param data_file_name: The file name to parse. 
    :return: A subject identifier.
    :raises ValueError: If the identifier could not be parsed.
    """
    identifier_pattern = re.compile(r'(?P<identifier>s\d{2})')
    search_result = re.search(identifier_pattern, data_file_name)
    identifier = search_result.group('identifier')
    if not identifier:
        raise ValueError(f'Unable to parse subject identifier from file: "{data_file_name}"')
    return identifier.upper()
    
    
def window_dataset(dataframe_map: SubjectDataMap) -> SubjectFramesMap:
    """
    Windows the given data map, using a window size of 1,200 and an overlap of 50%.
    
    :param dataframe_map: The data map.
    :return: The windowed data map.
    """
    window_size = 1200
    overlap = 0.5
    windowed_data = {}
    
    for identifier, dataframe in dataframe_map.items():
        windowed_data[identifier] = _window_dataframe(dataframe, window_size, overlap)
    
    return windowed_data


def _window_dataframe(dataframe: pd.DataFrame, size: int, overlap: float) -> typing.List[pd.DataFrame]:
    """
    Create a series of windows from the given dataframe.
    
    :param dataframe: The dataframe to window.
    :param size: The size of the windows to create.
    :param overlap: The percentage overlap (e.g., 0.5) of the windows.
    :return: The windowed data.
    """
    windowed_data = []
    start = 0
    end = size
    
    while end <= len(dataframe):
        window = dataframe[start:end]
        windowed_data.append(window)
        
        start += int(size * (1 - overlap))
        end += int(size * (1 - overlap))
    
    return windowed_data
    
    
def filter_subject_data(subject_data: SubjectDataMap) -> SubjectDataMap:
    """
    Applies filtration to all the dataframes for each subject in the given data map.
    
    :param subject_data: the subject data to filter.
    :return: a new data map, wherein the keys are the subject identifiers and the values are the
             filtered data.
    """
    data_windows_filtered = {}
    
    for identifier, data_to_filter in subject_data.items():
        mne_data = convert_dataframe_to_mne(data_to_filter)
        data_windows_filtered[identifier] = _retrieve_target_bands(mne_data)
        
    return data_windows_filtered
            
            
def _retrieve_target_bands(mne_data: mne.io.RawArray) -> pd.DataFrame:
    """
    Retrieves a set of frequency bands from the given MNE data. A raw band is included in the result.
    The result is formatted as a single DataFrame, wherein each column is a channel / frequency band
    combination.
    
    :param mne_data: The MNE data to retrieve target bands from.
    :return: Target band data, per channel and frequency band, in a single DataFrame.
    """
    bands_map: RawFrequencyDataMap = {}
    
    # Retrieve the different frequency bands from the original data
    for frequency in FREQUENCIES:
        filtered_data = copy.deepcopy(mne_data)
        filtered_data: mne.io.Raw = filtered_data.filter(
            l_freq=frequency.lower,
            h_freq=frequency.upper,
            verbose=False,
            l_trans_bandwidth=1,
            h_trans_bandwidth=1
        )
        bands_map[frequency.label] = filtered_data
    # Retain a "raw" frequency band
    raw_data = copy.deepcopy(mne_data)
    bands_map[RAW_FREQUENCY] = raw_data
    
    return _map_channel_frequencies(bands_map)
    
    
def _map_channel_frequencies(bands_map: RawFrequencyDataMap) -> pd.DataFrame:
    """
    Generates a single DataFrame wherein each column is a channel / frequency band combination.
    For example, if there was only one channel (e.g., T7) and two frequency bands (e.g., Alpha, Beta) then
    the columns would be: 
        "T7.Alpha", "T7.Beta"
    
    :param bands_map: The raw frequency data in a map, 
                      where each key is a frequency type and each value is a Dataframe of frequency data.  
    :return: The DataFrame of channel / frequency band data.
    """
    frequency_data: typing.Dict[str, np.ndarray] = {}
    
    for channel in DATA_CHANNEL_NAMES:
        for frequency_type in bands_map:
            channel_data: pd.DataFrame = bands_map[frequency_type].to_data_frame(picks=channel)
            frequency_data[f'{channel}.{frequency_type}'] = channel_data[channel].to_numpy()
    
    return pd.DataFrame(frequency_data)
    
    
def convert_dataframe_to_mne(dataframe: pd.DataFrame) -> mne.io.RawArray:
    """
    Converts the given dataframe over to Python-MNE format.
    
    :param dataframe: The dataframe to convert.
    :return: A Python-MNE data array.
    """
    transposed_dataframe = dataframe.transpose(copy=True)
    data_info = mne.create_info(DATA_CHANNEL_NAMES, DATASET_SAMPLE_FREQ_HZ, ch_types='eeg')
    return mne.io.RawArray(transposed_dataframe.to_numpy(), data_info)


def extract_features(filtered_data: SubjectFramesMap) -> SubjectFrameFeaturesMap:
    """
    Extracts features from the given filtered frame data (assumed to be mapped by subject). 
    The resulting map has each subject as a key, and a list of feature vectors as the value (where
    each feature vector was generated from one frame).
    
    :param filtered_data: The filtered frame data.
    :return: A map of subject windowed feature vectors.
    """
    features_map = {}
    
    for key in filtered_data:
        frame_features = []
        for frame in filtered_data[key]:
            frame_features.append(_extract_features_from_frame(frame))
        features_map[key] = frame_features
    
    return features_map


def _extract_features_from_frame(frame_channel_data: pd.DataFrame) -> np.ndarray:
    """
    Extracts a feature vector from the given frame data.
    
    :param frame_channel_data: The frame data, with the channels / frequency bands by column.
    :return: A single array of features.
    """
    extracted_feature_chunks = []
    
    for frame_column in frame_channel_data:
        frame_data_column = frame_channel_data[frame_column]
        minimum_data = frame_data_column.min()
        maximum_data = frame_data_column.max()
        mean_data = frame_data_column.mean()
        zero_crossing_rate_data = frame_data_column.agg(_get_zero_crossing_rate)
        feature_vector_chunk = np.array(
            [minimum_data, maximum_data, mean_data, zero_crossing_rate_data]
        )
        extracted_feature_chunks.append(feature_vector_chunk)
    
    return np.array(extracted_feature_chunks).flatten()


def _get_zero_crossing_rate(data_to_process: pd.Series) -> float:
    """
    Helper function which calculates a Zero Crossing Rate (ZCR) for the given
    Pandas Series.
    
    :param data_to_process: The series to retrieve the ZCR for.
    :return: The ZCR.
    """
    row_array = data_to_process.to_numpy()
    zero_crossings = _count_zero_crossings(row_array)
    return zero_crossings / len(row_array)
    
    
def _count_zero_crossings(target_array: np.ndarray) -> int:
    """
    Helper function which counts the number of zero crossings in a given array.
    
    see: https://stackoverflow.com/a/30281079/13261549
    
    :param target_array: The array to count zero crossings from.
    :return: The number of zero crossings in the array.
    """
    return ((target_array[:-1] * target_array[1:]) < 0).sum()


def get_labelled_dataset_map(map_to_convert: SubjectFrameFeaturesMap) -> LabelledDatasetMap:
    """
    Helper function which converts the given subject features map to a map of labelled datasets.
    
    :param map_to_convert: the original subject features map to convert. 
    :return: a new map wherein the keys are subject identifiers and the values are labelled datasets.
    """
    converted_map = {}
    
    for key in map_to_convert:
        converted_map[key] = _get_x_y_labelled_dataset(map_to_convert, key)
        
    return converted_map


def _get_x_y_labelled_dataset(map_to_label: SubjectFrameFeaturesMap, target_subject_key: str) -> LabelledDataset:
    """
    Utility function which generates a list of samples and a list of associated labels, based on the given target subject
    (i.e., '1' indicates the sample is for the target, '0' otherwise).
    
    
    :param map_to_label: a map wherein the keys are subject identifiers and the values are lists of data samples.
    :param target_subject_key: the key to use to tailor the dataset to.
    :return: a Tuple containing samples, and the corresponding labels.
    """
    if target_subject_key not in map_to_label:
        raise KeyError(f'Key "{target_subject_key}" not found in data map!')
    label_translation_map = {}
    samples_list = []
    labels_list = []
    
    for key in map_to_label:
        label_id = 1 if key == target_subject_key else 0
        label_translation_map[key] = label_id
        for subject_frame_sample in map_to_label[key]:
            samples_list.append(subject_frame_sample)
            labels_list.append(label_id)
    
    return samples_list, labels_list
        

def get_sample_value_from_map(map_to_sample: typing.Dict[str, T]) -> T:
    """
    Helper function which retrieves a sample dataframe from the given map of data.
    
    :param map_to_sample: The data map to get a sample from. 
    """
    return next(iter(map_to_sample.values()))


def print_info_about_subjects(map_to_summarize: SubjectDataMap):
    """
    Helper function which prints some basic information about the
    subjects in a data map.
    
    :param map_to_summarize: the map to print info from.
    """
    print('SUBJECT DATA')
    print(f'Number of subjects: {len(map_to_summarize.keys())}')
    print('Subject identifiers:')
    for key in map_to_summarize:
        print(key)


def print_windowed_data_summary(windowed_data: SubjectFramesMap):
    """
    Helper function which prints some basic information on a windowed data map.
    
    :param windowed_data: The windowed data to summarize.
    """
    print('WINDOWED DATA')
    for key in windowed_data:
        print(f'Subject: {key}, Windows: {len(windowed_data[key])}')
        
        
def print_labelled_data_summary(labelled_data: LabelledDatasetMap):
    """
    Helper function which prints basic information about the given labelled data map.
    
    :param labelled_data: the labelled data map to summarize.
    """
    print('LABELLED DATA')
    for key in labelled_data:
        subject_labels = labelled_data[key][1]
        positive_count = len(
            list(
                filter(lambda label: label == 1, subject_labels)
            )
        )
        negative_count = len(
            list(
                filter(lambda label: label == 0, subject_labels)
            )
        )
        print(f'Subject: {key}')
        print(f'\tPositive data samples: {positive_count}')
        print(f'\tNegative data samples: {negative_count}')

# Setup Dataset

In [15]:
dataset_path = retrieve_dataset()
data_map = read_data(dataset_path)
print_info_about_subjects(data_map)
print('SAMPLE DATAFRAME')
sample_dataframe = get_sample_value_from_map(data_map)
sample_dataframe.head()

SUBJECT DATA
Number of subjects: 20
Subject identifiers:
S01
S02
S03
S04
S05
S06
S07
S08
S09
S10
S11
S12
S13
S14
S15
S16
S17
S18
S19
S20
SAMPLE DATAFRAME


Unnamed: 0,T7,F8,Cz,P4
13200,431.251617,-1189.493896,454.405334,345.306824
13201,444.240265,-1194.415649,471.23114,363.666016
13202,439.06427,-1188.719727,457.135437,325.425537
13203,442.071136,-1193.476929,458.751099,340.463654
13204,435.93396,-1197.149414,442.688232,333.630859


# Pre-process Data

## Filter Data

In [16]:
data_map = filter_subject_data(data_map)
print('SAMPLE FILTERED DATAFRAME')
sample_dataframe = get_sample_value_from_map(data_map)
sample_dataframe.head()

Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=40114
    Range : 0 ... 40113 =      0.000 ...   200.565 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ...   119.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=24000
    Range : 0 ... 23999 =      0.000 ..

Unnamed: 0,T7.Alpha,T7.Beta,T7.Theta,T7.Gamma,T7.Raw,F8.Alpha,F8.Beta,F8.Theta,F8.Gamma,F8.Raw,Cz.Alpha,Cz.Beta,Cz.Theta,Cz.Gamma,Cz.Raw,P4.Alpha,P4.Beta,P4.Theta,P4.Gamma,P4.Raw
0,-7.704948e-08,4.041212e-08,-5.87308e-08,-8.437695e-08,431251600.0,1.94178e-07,-1.985079e-07,1.718625e-07,9.05942e-08,-1189494000.0,-6.883383e-08,6.328271e-08,-7.177592e-08,-8.881784e-10,454405300.0,-6.150636e-08,5.373479e-08,-4.551914e-08,-3.108624e-08,345306800.0
1,1597428.0,5617263.0,260440.5,6113604.0,444240300.0,-700075.1,-1784533.0,-625694.4,-3457973.0,-1194416000.0,1179360.0,2073403.0,106048.1,14290680.0,471231100.0,642413.8,-5416940.0,-223332.7,24436510.0,363666000.0
2,3003843.0,7803250.0,516193.8,-2495471.0,439064300.0,-1323957.0,-2837481.0,-1221153.0,4026421.0,-1188720000.0,2203946.0,2787055.0,224076.0,-1378198.0,457135400.0,1211374.0,-7326968.0,-434206.4,-11518090.0,325425500.0
3,4046442.0,5954752.0,767039.3,1200391.0,442071100.0,-1809801.0,-3154577.0,-1753179.0,1247092.0,-1193477000.0,2936890.0,1816007.0,366279.9,45875.54,458751100.0,1644117.0,-4896868.0,-623130.0,1218548.0,340463700.0
4,4598322.0,2588707.0,1012805.0,-2499456.0,435934000.0,-2103975.0,-3215320.0,-2199496.0,-962059.5,-1197149000.0,3276660.0,-3953.14,544973.3,-15118960.0,442688200.0,1887982.0,-453984.1,-776919.0,-9855401.0,333630900.0


## Window Data

In [17]:
data_map = window_dataset(data_map)
print_windowed_data_summary(data_map)
print('SAMPLE WINDOW')
sample_windows = get_sample_value_from_map(data_map)
sample_windows[0].head()

WINDOWED DATA
Subject: S01, Windows: 39
Subject: S02, Windows: 39
Subject: S03, Windows: 39
Subject: S04, Windows: 39
Subject: S05, Windows: 65
Subject: S06, Windows: 39
Subject: S07, Windows: 39
Subject: S08, Windows: 39
Subject: S09, Windows: 39
Subject: S10, Windows: 39
Subject: S11, Windows: 39
Subject: S12, Windows: 39
Subject: S13, Windows: 39
Subject: S14, Windows: 39
Subject: S15, Windows: 39
Subject: S16, Windows: 39
Subject: S17, Windows: 39
Subject: S18, Windows: 39
Subject: S19, Windows: 39
Subject: S20, Windows: 39
SAMPLE WINDOW


Unnamed: 0,T7.Alpha,T7.Beta,T7.Theta,T7.Gamma,T7.Raw,F8.Alpha,F8.Beta,F8.Theta,F8.Gamma,F8.Raw,Cz.Alpha,Cz.Beta,Cz.Theta,Cz.Gamma,Cz.Raw,P4.Alpha,P4.Beta,P4.Theta,P4.Gamma,P4.Raw
0,-7.704948e-08,4.041212e-08,-5.87308e-08,-8.437695e-08,431251600.0,1.94178e-07,-1.985079e-07,1.718625e-07,9.05942e-08,-1189494000.0,-6.883383e-08,6.328271e-08,-7.177592e-08,-8.881784e-10,454405300.0,-6.150636e-08,5.373479e-08,-4.551914e-08,-3.108624e-08,345306800.0
1,1597428.0,5617263.0,260440.5,6113604.0,444240300.0,-700075.1,-1784533.0,-625694.4,-3457973.0,-1194416000.0,1179360.0,2073403.0,106048.1,14290680.0,471231100.0,642413.8,-5416940.0,-223332.7,24436510.0,363666000.0
2,3003843.0,7803250.0,516193.8,-2495471.0,439064300.0,-1323957.0,-2837481.0,-1221153.0,4026421.0,-1188720000.0,2203946.0,2787055.0,224076.0,-1378198.0,457135400.0,1211374.0,-7326968.0,-434206.4,-11518090.0,325425500.0
3,4046442.0,5954752.0,767039.3,1200391.0,442071100.0,-1809801.0,-3154577.0,-1753179.0,1247092.0,-1193477000.0,2936890.0,1816007.0,366279.9,45875.54,458751100.0,1644117.0,-4896868.0,-623130.0,1218548.0,340463700.0
4,4598322.0,2588707.0,1012805.0,-2499456.0,435934000.0,-2103975.0,-3215320.0,-2199496.0,-962059.5,-1197149000.0,3276660.0,-3953.14,544973.3,-15118960.0,442688200.0,1887982.0,-453984.1,-776919.0,-9855401.0,333630900.0


## Feature Extraction

In [18]:
data_map = extract_features(data_map)
sample_windows = get_sample_value_from_map(data_map)
sample_frame = sample_windows[0]
print('SAMPLE FEATURES')
print(f'Size: {len(sample_frame)}')
print(f'Elements: {sample_frame}')

SAMPLE FEATURES
Size: 80
Elements: [-1.38215250e+07  1.39524481e+07  2.30669831e+04  1.04166667e-01
 -1.33188938e+07  1.23118873e+07  1.07595349e+04  2.13333333e-01
 -8.50688623e+06  7.58143167e+06  5.38479622e+03  6.33333333e-02
 -3.11927915e+07  3.15010114e+07  3.08707558e+03  5.43333333e-01
  3.86584106e+08  4.81055847e+08  4.40934159e+08  0.00000000e+00
 -1.63194953e+07  1.54157468e+07 -1.19110791e+04  1.05000000e-01
 -1.22098452e+07  1.36056335e+07  9.85539811e+02  2.15833333e-01
 -1.09752728e+07  9.74591110e+06 -1.56998768e+04  6.33333333e-02
 -3.36370308e+07  2.89119063e+07 -3.10105768e+03  5.55833333e-01
 -1.23803223e+09 -1.14405408e+09 -1.19490520e+09  0.00000000e+00
 -1.44408963e+07  1.33289983e+07  2.09954706e+04  1.02500000e-01
 -1.29073672e+07  9.43481925e+06 -5.55832882e+03  2.05000000e-01
 -6.66783987e+06  6.50344327e+06 -5.37477503e+03  6.16666667e-02
 -1.91941124e+07  2.21471060e+07  8.34716262e+03  5.11666667e-01
  4.01601685e+08  4.87976532e+08  4.54110164e+08  0.000

## Prepare Training and Testing Data

## Label Datasets

In [19]:
data_map = get_labelled_dataset_map(data_map)
print_labelled_data_summary(data_map)

LABELLED DATA
Subject: S01
	Positive data samples: 39
	Negative data samples: 767
Subject: S02
	Positive data samples: 39
	Negative data samples: 767
Subject: S03
	Positive data samples: 39
	Negative data samples: 767
Subject: S04
	Positive data samples: 39
	Negative data samples: 767
Subject: S05
	Positive data samples: 65
	Negative data samples: 741
Subject: S06
	Positive data samples: 39
	Negative data samples: 767
Subject: S07
	Positive data samples: 39
	Negative data samples: 767
Subject: S08
	Positive data samples: 39
	Negative data samples: 767
Subject: S09
	Positive data samples: 39
	Negative data samples: 767
Subject: S10
	Positive data samples: 39
	Negative data samples: 767
Subject: S11
	Positive data samples: 39
	Negative data samples: 767
Subject: S12
	Positive data samples: 39
	Negative data samples: 767
Subject: S13
	Positive data samples: 39
	Negative data samples: 767
Subject: S14
	Positive data samples: 39
	Negative data samples: 767
Subject: S15
	Positive data sample