In [1]:
"""artifact_pipeline.py
A pipeline for removing artifact.
"""
# Package Header #
from src.spikedetection.header import *

# Header #
__author__ = __author__
__credits__ = __credits__
__maintainer__ = __maintainer__
__email__ = __email__

# Imports #
# Standard Libraries #
import importlib
import itertools
import pathlib
from typing import NamedTuple

# Third-Party Packages #
from dspobjects.plot import Figure, TimeSeriesPlot, SpectraPlot, BarPlot
from fooof.sim.gen import gen_aperiodic
from fooof import FOOOF, FOOOFGroup
import hdf5objects
import matplotlib.pyplot as plt
import matplotlib.colors
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from scipy.io import loadmat
from scipy.stats import entropy
from scipy.signal import savgol_filter, welch
import sklearn
from sklearn import metrics
from sklearn.decomposition import PCA
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
from sklearn.preprocessing import scale
from sklearn.model_selection import train_test_split
import toml
import torch
from torch import nn
from xltektools.hdf5framestructure import XLTEKStudyFrame

# Local Packages #
from src.spikedetection.artifactrejection.fooof.goodnessauditor import GoodnessAuditor, RSquaredBoundsAudit, SVMAudit
from src.spikedetection.artifactrejection.fooof.ooffitter import OOFFitter, iterdim


if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Running on the GPU")
else:
    device = torch.device("cpu")
    print("Running on the CPU")


# Definitions #
# Data Classes
class ElectrodeLead(NamedTuple):
    name: str
    type: str
    contacts: dict


# Classes #


# Functions #
def powerset(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s)+1))


def closest_square(n):
    n = int(n)
    i = int(np.ceil(np.sqrt(n)))
    while True:
        if (n % i) == 0:
            break
        i += 1
    assert n == (i * (n // i))
    return i, n // i


def get_lead_groups(el_label, el_type):
    assert len(el_label) == len(el_type)

    LEAD_NAME_NOID = np.array([''.join(map(lambda c: '' if c in '0123456789' else c, ll))
        for ll in el_label])
    CONTACT_IX = np.arange(len(el_label))
    LEAD_NAME = np.unique(LEAD_NAME_NOID)

    lead_group = {}
    for l_name in LEAD_NAME:
        lead_group[l_name] = \
            {'Contacts': el_label[np.flatnonzero(LEAD_NAME_NOID == l_name)],
             'IDs': CONTACT_IX[np.flatnonzero(LEAD_NAME_NOID == l_name)],
             'Type': np.unique(el_type[np.flatnonzero(LEAD_NAME_NOID == l_name)])}
        assert len(lead_group[l_name]['Type']) == 1

        lead_group[l_name]['Type'] = lead_group[l_name]['Type'][0]

    return lead_group


def make_bipolar(lead_group):
    for l_name in lead_group:
        sel_lead = lead_group[l_name]
        n_contact = len(sel_lead['IDs'])
        if 'grid' in sel_lead['Type']:
            n_row, n_col = closest_square(n_contact)
        else:
            n_row, n_col = [n_contact, 1]

        CA = np.arange(len(sel_lead['Contacts'])).reshape((n_row, n_col), order='F')

        lead_group[l_name]['Contact_Pairs_ix'] = []

        if n_row > 1:
            for bp1, bp2 in zip(CA[:-1, :].flatten(), CA[1:, :].flatten()):
                lead_group[l_name]['Contact_Pairs_ix'].append(
                        (sel_lead['IDs'][bp1],
                         sel_lead['IDs'][bp2]))

        if n_col > 1:
            for bp1, bp2 in zip(CA[:, :-1].flatten(), CA[:, 1:].flatten()):
                lead_group[l_name]['Contact_Pairs_ix'].append(
                        (sel_lead['IDs'][bp1],
                         sel_lead['IDs'][bp2]))

        """
        if (n_row > 1) & (n_col > 1):
            for bp1, bp2 in zip(CA[:-1, :-1].flatten(), CA[1:, 1:].flatten()):
                lead_group[l_name]['Contact_Pairs_ix'].append(
                        (sel_lead['IDs'][bp1],
                         sel_lead['IDs'][bp2]))
        lead_group[l_name]['Contact_Pairs_ix'] = np.array(
            lead_group[l_name]['Contact_Pairs_ix'])

        lead_group[l_name]['Contact_Pairs_ix'] = \
            lead_group[l_name]['Contact_Pairs_ix'][
                np.argsort(lead_group[l_name]['Contact_Pairs_ix'][:, 0])]
        """

    return lead_group


def make_bipolar_elecs_all(eleclabels, eleccoords):

    lead_group = get_lead_groups(eleclabels[:, 1], eleclabels[:, 2])
    lead_group = make_bipolar(lead_group)

    bp_elecs_all = {
            'IDX': [],
            'Anode': [],
            'Cathode': [],
            'Lead': [],
            'Contact': [],
            'Contact_Abbr': [],
            'Type': [],
            'x': [],
            'y': [],
            'z': []}

    for l_name in lead_group:
        for el_ix, el_iy in lead_group[l_name]['Contact_Pairs_ix']:
            bp_elecs_all['IDX'].append((el_ix, el_iy))
            bp_elecs_all['Anode'].append(el_ix)
            bp_elecs_all['Cathode'].append(el_iy)

            bp_elecs_all['Lead'].append(l_name)
            bp_elecs_all['Contact'].append('{}-{}'.format(eleclabels[el_ix, 1], eleclabels[el_iy, 1]))
            bp_elecs_all['Contact_Abbr'].append('{}-{}'.format(eleclabels[el_ix, 0], eleclabels[el_iy, 0]))
            bp_elecs_all['Type'].append(lead_group[l_name]['Type'])

            try:
                coord = (eleccoords[el_ix] + eleccoords[el_iy]) / 2
            except:
                coord = [np.nan, np.nan, np.nan]
            bp_elecs_all['x'].append(coord[0])
            bp_elecs_all['y'].append(coord[1])
            bp_elecs_all['z'].append(coord[2])

    bp_elecs_all = pd.DataFrame(bp_elecs_all)
    if np.core.numeric.dtype is None:
        importlib.reload(np.core.numeric)
    return bp_elecs_all.sort_values(by=['Anode', 'Cathode']).reset_index(drop=True)


def get_ECoG_sample(study_frame, time_start, time_end):
    natus_data = {}

    # Get the Sample Rate
    if study_frame.validate_sample_rate():
        natus_data['fs'] = 1024  #
    else:
        natus_data['fs'] = 1024

    # Get the minimum number of channels present in all recordings
    natus_data['min_valid_chan'] = min([shape[1] for shape in study_frame.get_shapes()])

    natus_data['data'] = study_frame.find_data_range(time_start, time_end, approx=True)

    return natus_data


def convert_ECoG_BP(natus_data, BP_ELECS):
    natus_data['data'] = (natus_data['data'].data[:, BP_ELECS['Anode'].values] -
                          natus_data['data'].data[:, BP_ELECS['Cathode'].values])

    return natus_data


def half_life(duration, fs_state):
    samples = duration / fs_state
    return np.exp(-(1/samples)*np.log(2))


def do_fitting(foo, freqs, spectrum, freq_range):
    foo.add_data(freqs, spectrum, freq_range)
    aperiodic_params_ = foo._robust_ap_fit(freqs, spectrum)
    ap_fit = gen_aperiodic(freqs, aperiodic_params_)
    r_val = np.corrcoef(spectrum, ap_fit)
    return r_val[0][1] ** 2


def do_fittings(foo, freqs, spectra, freq_range):
    r_sq = []
    for spectrum in spectra:
        r_sq.append(do_fitting(foo, freqs, spectrum, freq_range))

    return r_sq


def load_data(files, info):
    artifact_info = toml.load(info.as_posix())["raters"]
    artifact_data = {}

    for file in files:
        name_parts = file.name.split('.')
        subject_id = name_parts[0]
        file_number = int(name_parts[2])
        artifact_file = loadmat(file.as_posix(), squeeze_me=True)

        clip_data = {
            "sample_rate": artifact_file["fs"],
            "channel_labels": artifact_file["channels"],
            "time_axis": artifact_file["timestamp vector"],
            "data": artifact_file["data"],
        }

        if subject_id not in artifact_data:
            artifact_data[subject_id] = [None] * 10

        artifact_data[subject_id][file_number] = clip_data

    return artifact_data, artifact_info


Running on the GPU


In [2]:
# Parameters #
SVM_PATH = pathlib.Path.cwd().joinpath("all_metric_svm.obj")
ARTIFACT_DIR = pathlib.Path("/home/anthonyfong/ProjectData/EpilepsySpikeDetection/Artifact_Review/")
ARTIFACT_INFO = ARTIFACT_DIR.joinpath("Artifact_Info.toml")
ARTIFACT_FILES = ARTIFACT_DIR.glob("*.mat")
OUT_DIR = pathlib.Path("/home/anthonyfong/ProjectData/EpilepsySpikeDetection/Artifact_Review/Images")
TIME_AXIS = 0
CHANNEL_AXIS = 1
LOWER_FREQUENCY = 1
UPPER_FREQUENCY = 250
METRICS = {"r_squared", "normal_entropy", "mae", "mse", "rmse", "curve_offset", "curve_exp"}
BEST_METRICS = {"r_squared", "normal_entropy", "mae", "rmse"}


In [2]:
# FOOOF
fg = FOOOFGroup(peak_width_limits=[4, 8], min_peak_height=0.05, max_n_peaks=1, verbose=True)


In [5]:
# Aggregate Data #
# Load Data
artifact_data, artifact_info = load_data(ARTIFACT_FILES, ARTIFACT_INFO)

# Create Data Structures
ag_reviews = {reviewer["name"]: [] for reviewer in artifact_info}
ag_reviews |= {"Reviewer Intersection": [], "Reviewer Union": []}
ag_metrics = {m: [] for m in METRICS}

artifact_metrics = {}
aggregate_data = {"reviews": ag_reviews, "metrics": ag_metrics}

# Format Data From Files and Create Metrics
for subject_id, data in artifact_data.items():
    artifact_metrics[subject_id] = [None] * len(data)
    for i, artifact_clip in enumerate(data):
        # Format Reviewer Data
        review_channels = {}
        for reviewer in artifact_info:
            zero_index = tuple(np.array(reviewer["review_channels"][subject_id][i]) - 1)
            review_channels[reviewer["name"]] = zero_index
        review_union = set()
        review_intersect = set(np.array(artifact_info[0]["review_channels"][subject_id][i]) - 1)
        for rv in review_channels.values():
            review_union |= (set(rv))
            review_intersect.intersection_update(set(rv))
        review_union = tuple(review_union)
        review_intersect = tuple(review_intersect)
        reviews = review_channels.copy()
        reviews.update({"Reviewer Intersection": review_intersect, "Reviewer Union": review_union})

        # Create Metrics
        sample_rate = artifact_clip["sample_rate"]

        freqs, spectra = welch(artifact_clip["data"], fs=sample_rate, nperseg=2048, axis=TIME_AXIS)

        # Limit Frequency Range
        lower_limit = int(np.searchsorted(freqs, LOWER_FREQUENCY, side="right") - 1)
        upper_limit = int(np.searchsorted(freqs, UPPER_FREQUENCY, side="right"))


        spectra = spectra[(slice(None),) * TIME_AXIS + (slice(lower_limit, upper_limit),)]
        freqs = freqs[lower_limit:upper_limit]

        # Fitting
        fg.fit(freqs, spectra.T)

        fg.get_fooof()




        curve_removed = fit_curves.spectra - fit_curves.curves
        curve_2 = curve_removed ** 2

        prob = curve_2 / np.sum(curve_2, axis=0)
        entro = entropy(prob)
        normal_entropy = entro / np.log(prob.shape[0])

        artifact_metric = {
            "raw_data": artifact_clip["data"],
            "fit_curves": fit_curves,
            "r_squared": fit_curves.r_squared,
            "mae": fit_curves.mae,
            "mse": fit_curves.mse,
            "rmse": fit_curves.rmse,
            "normal_entropy": normal_entropy,
            "curve_offset": fit_curves.parameters[0, :],
            "curve_exp": fit_curves.parameters[1, :],
            "reviews": reviews,
        }

        # Load Data into Data Structures
        artifact_metrics[subject_id][i] = artifact_metric

        ag_metrics["r_squared"] += list(fit_curves.r_squared)
        ag_metrics["normal_entropy"] += list(normal_entropy)
        ag_metrics["mae"] += list(fit_curves.mae)
        ag_metrics["mse"] += list(fit_curves.mse)
        ag_metrics["rmse"] += list(fit_curves.rmse)
        ag_metrics["curve_offset"] += list(fit_curves.parameters[0, :])
        ag_metrics["curve_exp"] += list(fit_curves.parameters[1, :])
        for reviewer, channels in reviews.items():
            good_channels = np.zeros((fit_curves.spectra.shape[CHANNEL_AXIS],))
            good_channels[channels,] = 1
            aggregate_data["reviews"][reviewer] += list(good_channels)

# Load Data into Pandas Data Frame
review_dataframe = pd.DataFrame.from_dict(ag_reviews)
metrics_dataframe = pd.DataFrame.from_dict(ag_metrics)
for name, metric_ in ag_metrics.items():
    ag_metrics[name] = np.array(metric_)

importlib.reload(np.core.numeric)  # Pandas causes numpy to break which is dumb....
metrics_all_scaled = scale(metrics_dataframe.to_numpy())
metrics_all_scaled = pd.DataFrame(metrics_all_scaled, columns=metrics_dataframe.columns)

ValueError: Found array with 0 sample(s) (shape=(0, 7)) while a minimum of 1 is required by the scale function.