# Setup

Initial module setup.

In [36]:
import bs4
import dataclasses
import pathlib
import urllib.request
import urllib.parse
import re
import mne.io
import requests
import pandas as pd
import typing
import mne

# Data Structures

In [37]:
@dataclasses.dataclass
class FrequencyBand:
    lower: typing.Optional[float]
    upper: typing.Optional[float]

# Constants

In [38]:
DATASET_URL = 'https://physionet.org/files/auditory-eeg/1.0.0/Segmented_Data/'
DATASET_SAMPLE_FREQ_HZ = 200
DATA_CHANNEL_NAMES = ['T7','F8','Cz','P4']
FREQUENCIES = [
    FrequencyBand(lower=8.0, upper=12.0),   # Alpha
    FrequencyBand(lower=12.0, upper=35.0),  # Beta
    FrequencyBand(lower=4.0, upper=8.0),    # Theta
    FrequencyBand(lower=35.0, upper=None)   # Gamma
]

# Utilities

In [39]:
def retrieve_dataset() -> pathlib.Path:
    """
    Downloads (if necessary) the dataset and retrieves the path to the root of the dataset files directory.

    :return: The path object pointing to the dataset directory.
    """
    data_directory = _get_data_directory()
    if data_directory.exists():
        return data_directory
    data_directory.mkdir(exist_ok=True)
    _download_dataset(data_directory)
    return data_directory


def _get_data_directory() -> pathlib.Path:
    """
    Helper function which generates the path to the data directory.

    :return: The data directory path object.
    """
    return pathlib.Path().resolve().parent / 'data'


def _download_dataset(target_path: pathlib.Path):
    """
    Initiates download of the dataset and saves all files into the given target path directory.

    :param target_path: The target path directory.
    """
    with requests.get(DATASET_URL) as listing_page:
        listing_soup = bs4.BeautifulSoup(
            listing_page.content,
            features='html.parser'
        )
        _download_files_in_listing(target_path, listing_soup)


def _download_files_in_listing(target_path: pathlib.Path, listing_soup: bs4.BeautifulSoup):
    """
    Helper function which iterates over all file links in the given BeautifulSoup object and downloads each file into
    the target path's directory.

    :param target_path: The target path directory.
    :param listing_soup: The BeautifulSoup object to use to find download links.
    """
    experiment_1_data_pattern = re.compile(r's\d{2}_ex05\.csv')
    
    for file_link in listing_soup.find_all('a'):
        file_href = file_link.get('href')
        if file_href and experiment_1_data_pattern.match(file_href):
            file_path = target_path / file_href
            file_url = urllib.parse.urljoin(DATASET_URL, file_href)
            _download_url_to_file(file_path, file_url)


def _download_url_to_file(file_path: pathlib.Path, url: str):
    """
    Downloads the given URL's remote content to the given file path.

    :param file_path: The file path to download to.
    :param url: The URL to download from.
    """
    with requests.get(url) as response:
        with open(file_path, 'wb') as out_file:
            for chunk in response.iter_content(chunk_size=1024):
                out_file.write(chunk)
                

def read_data(target_path: pathlib.Path) -> typing.Dict[str, pd.DataFrame]:
    """
    Reads all data files in the given directory path and generates a structure of dataframes.
    
    :param target_path: The target directory path.
    :return: A map of dataframes, where the key is an identifier for the file and the value is the dataframe.
    """
    loaded_data_map = {}
    
    for data_file in target_path.iterdir():
        if data_file.suffix == '.csv':
            subject_identifier = _get_subject_identifier(data_file.name)
            dataframe = pd.read_csv(data_file, index_col=0, header=0)
            loaded_data_map[subject_identifier] = dataframe
    
    return loaded_data_map


def _get_subject_identifier(data_file_name: str) -> str:
    """
    Helper function which parses a subject identifier from a data file name.
    
    :param data_file_name: The file name to parse. 
    :return: A subject identifier.
    :raises ValueError: If the identifier could not be parsed.
    """
    identifier_pattern = re.compile(r'(?P<identifier>s\d{2})')
    search_result = re.search(identifier_pattern, data_file_name)
    identifier = search_result.group('identifier')
    if not identifier:
        raise ValueError(f'Unable to parse subject identifier from file: "{data_file_name}"')
    return identifier.upper()
    
    
def window_dataset(dataframe_map: typing.Dict[str, pd.DataFrame]) -> typing.Dict[str, typing.List[pd.DataFrame]]:
    """
    Windows the given data map, using a window size of 1,200 and an overlap of 50%.
    
    :param dataframe_map: The data map.
    :return: The windowed data map.
    """
    window_size = 1200
    overlap = 0.5
    windowed_data = {}
    
    for identifier, dataframe in dataframe_map.items():
        windowed_data[identifier] = _window_dataframe(dataframe, window_size, overlap)
    
    return windowed_data


def _window_dataframe(dataframe: pd.DataFrame, size: int, overlap: float) -> typing.List[pd.DataFrame]:
    """
    Create a series of windows from the given dataframe.
    
    :param dataframe: The dataframe to window.
    :param size: The size of the windows to create.
    :param overlap: The percentage overlap (e.g., 0.5) of the windows.
    :return: The windowed data.
    """
    windowed_data = []
    start = 0
    end = size
    
    while end <= len(dataframe):
        window = dataframe[start:end]
        windowed_data.append(window)
        
        start += int(size * (1 - overlap))
        end += int(size * (1 - overlap))
    
    return windowed_data
    
    
def filter_windows(data_windows: typing.Dict[str, typing.List[pd.DataFrame]]) -> typing.Dict[str, typing.List[typing.List[pd.DataFrame]]]:
    """
    Applies filtration to all the data windows contained in the given data map.
    
    :param data_windows: the data windows to filter.
    :return: a new data map, wherein the keys are the subject identifiers and the values are
             a 2D list (dimension 1: the data windows/frames, dimension 2: the filtered data for the window/frame).
    """
    data_windows_filtered = {}
    
    for identifier, windows in data_windows.items():
        filtered_windows = []
        
        for window_data in windows:
            mne_window_data = convert_dataframe_to_mne(window_data)
            filtered_windows.append(
                _retrieve_target_bands(mne_window_data)
            )
        
        data_windows_filtered[identifier] = filtered_windows
        
    return data_windows_filtered
            
            
def _retrieve_target_bands(mne_data: mne.io.RawArray) -> typing.List[pd.DataFrame]:
    """
    Retrieves a set of frequency bands from the given MNE data. A raw band is included in the result.
    
    :param mne_data: The MNE data to retrieve target bands from.
    :return: The set of target bands, as a list. 
    """
    bands = []
    
    for frequency in FREQUENCIES:
        filtered_data: mne.io.Raw = mne_data.filter(
            l_freq=frequency.lower,
            h_freq=frequency.upper,
            verbose=False
        )
        bands.append(
            filtered_data
        )
    bands.append(mne_data.filter(l_freq=None, h_freq=None, verbose=False)) # Raw data
    
    return [b.to_data_frame() for b in bands]
    
    
def convert_dataframe_to_mne(dataframe: pd.DataFrame) -> mne.io.RawArray:
    """
    Converts the given dataframe over to Python-MNE format.
    
    :param dataframe: The dataframe to convert.
    :return: A Python-MNE data array.
    """
    transposed_dataframe = dataframe.transpose(copy=True)
    data_info = mne.create_info(DATA_CHANNEL_NAMES, DATASET_SAMPLE_FREQ_HZ, ch_types='eeg')
    return mne.io.RawArray(transposed_dataframe.to_numpy(), data_info)


def print_data_map_summary(map_to_print: typing.Dict[str, pd.DataFrame]):
    """
    Helper function which prints some basic information on a data map.
    
    :param map_to_print: The data map to print data on. 
    """
    print(f'Number of subjects: {len(map_to_print.keys())}')
    print('Subject identifiers:')
    for key in map_to_print:
        print(key)
    print('Example dataframe:')
    sample_dataframe = next(iter(map_to_print.values()))
    print(sample_dataframe.head())


def print_windowed_data_summary(windowed_data: typing.Dict[str, typing.List[pd.DataFrame]]):
    """
    Helper function which prints some basic information on a windowed data map.
    
    :param windowed_data: The windowed data to summarize.
    """
    for key in windowed_data:
        print(f'Subject: {key}, Windows: {len(windowed_data[key])}')
    print('Example window:')
    sample_windows = next(iter(windowed_data.values()))
    print(sample_windows[0].head())
    
    
def print_filtered_data_summary(filtered_data: typing.Dict[str, typing.List[typing.List[pd.DataFrame]]]):
    """
    Helper function which prints some basic information on a filtered and windowed data map.
    
    :param filtered_data: The filtered and windowed data to summarize.
    """
    for key in filtered_data:
        filtered_data_sizes = [len(filtered_window) for filtered_window in filtered_data[key]]
        print(f'Subject: {key}, Filtered data sizes: {filtered_data_sizes}')
    print('Example filtered data frame:')
    sample_windows = next(iter(filtered_data.values()))
    print(sample_windows[0][0].head())


# Setup Dataset

In [40]:
dataset_path = retrieve_dataset()
data_map = read_data(dataset_path)
print_data_map_summary(data_map)

Number of subjects: 20
Subject identifiers:
S01
S02
S03
S04
S05
S06
S07
S08
S09
S10
S11
S12
S13
S14
S15
S16
S17
S18
S19
S20
Example dataframe:
               T7           F8          Cz          P4
13200  431.251617 -1189.493896  454.405334  345.306824
13201  444.240265 -1194.415649  471.231140  363.666016
13202  439.064270 -1188.719727  457.135437  325.425537
13203  442.071136 -1193.476929  458.751099  340.463654
13204  435.933960 -1197.149414  442.688232  333.630859


# Pre-process Data

## Window and Filter Data

In [41]:
data_map = window_dataset(data_map)
print_windowed_data_summary(data_map)

Subject: S01, Windows: 39
Subject: S02, Windows: 39
Subject: S03, Windows: 39
Subject: S04, Windows: 39
Subject: S05, Windows: 65
Subject: S06, Windows: 39
Subject: S07, Windows: 39
Subject: S08, Windows: 39
Subject: S09, Windows: 39
Subject: S10, Windows: 39
Subject: S11, Windows: 39
Subject: S12, Windows: 39
Subject: S13, Windows: 39
Subject: S14, Windows: 39
Subject: S15, Windows: 39
Subject: S16, Windows: 39
Subject: S17, Windows: 39
Subject: S18, Windows: 39
Subject: S19, Windows: 39
Subject: S20, Windows: 39
Example window:
               T7           F8          Cz          P4
13200  431.251617 -1189.493896  454.405334  345.306824
13201  444.240265 -1194.415649  471.231140  363.666016
13202  439.064270 -1188.719727  457.135437  325.425537
13203  442.071136 -1193.476929  458.751099  340.463654
13204  435.933960 -1197.149414  442.688232  333.630859


In [42]:
data_map = filter_windows(data_map)
print_filtered_data_summary(data_map)

Creating RawArray with float64 data, n_channels=4, n_times=1200
    Range : 0 ... 1199 =      0.000 ...     5.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=1200
    Range : 0 ... 1199 =      0.000 ...     5.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=1200
    Range : 0 ... 1199 =      0.000 ...     5.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=1200
    Range : 0 ... 1199 =      0.000 ...     5.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=1200
    Range : 0 ... 1199 =      0.000 ...     5.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=1200
    Range : 0 ... 1199 =      0.000 ...     5.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=1200
    Range : 0 ... 1199 =      0.000 ...     5.995 secs
Ready.
Creating RawArray with float64 data, n_channels=4, n_times=1200
    Range : 0 ... 1199 =      0.000 ...     5.995 secs

In [None]:
# TODO: extract features (min, max, mean, Zero Crossing Rate)
# TODO: Train verification model for each subject using random forest