Skip to content

Commit

Permalink
upgrade to release candidate and use the RdaKdeModel
Browse files Browse the repository at this point in the history
  • Loading branch information
tab-cmd committed Apr 19, 2023
1 parent bde4b57 commit 1e49f2a
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 59 deletions.
32 changes: 21 additions & 11 deletions alpha/alpha_experiment.py
Expand Up @@ -13,7 +13,7 @@
from bcipy.signal.process import get_default_transform, filter_inquiries
from loguru import logger
from preprocessing import AlphaTransformer
from base_model import BasePcaRdaKdeModel
from base_model import BaseRdaKdeModel
from pyriemann.classification import TSclassifier
from pyriemann.estimation import Covariances
from rich.console import Console
Expand All @@ -29,7 +29,11 @@
from sklearn.utils._testing import ignore_warnings

from bcipy.helpers.acquisition import analysis_channels
from bcipy.config import DEFAULT_PARAMETERS_PATH, TRIGGER_FILENAME, RAW_DATA_FILENAME, STATIC_AUDIO_PATH
from bcipy.helpers.stimuli import update_inquiry_timing
from bcipy.config import (TRIGGER_FILENAME,
RAW_DATA_FILENAME,
DEFAULT_DEVICE_SPEC_FILENAME)
import bcipy.acquisition.devices as devices


def cwt(data: np.ndarray, freq: int, fs: int) -> np.ndarray:
Expand Down Expand Up @@ -109,6 +113,9 @@ def load_data(data_folder: Path, trial_length=None, pre_stim=0.0, alpha=False):
type_amp = raw_data.daq_type
sample_rate = raw_data.sample_rate

devices.load(Path(data_folder, DEFAULT_DEVICE_SPEC_FILENAME))
device_spec = devices.preconfigured_device(raw_data.daq_type)

# setup filtering
default_transform = get_default_transform(
sample_rate_hz=sample_rate,
Expand All @@ -123,7 +130,7 @@ def load_data(data_folder: Path, trial_length=None, pre_stim=0.0, alpha=False):
logger.info(f"Device type: {type_amp}, fs={sample_rate}")

k_folds = parameters.get("k_folds")
model = BasePcaRdaKdeModel(k_folds=k_folds)
model = BaseRdaKdeModel(k_folds=k_folds)

# Process triggers.txt files
trigger_targetness, trigger_timing, trigger_symbols = trigger_decoder(
Expand All @@ -133,7 +140,7 @@ def load_data(data_folder: Path, trial_length=None, pre_stim=0.0, alpha=False):
)
# Channel map can be checked from raw_data.csv file or the devices.json located in the acquisition module
# The timestamp column [0] is already excluded.
channel_map = analysis_channels(channels, type_amp)
channel_map = analysis_channels(channels, device_spec)
data, fs = raw_data.by_channel()

inquiries, inquiry_labels, inquiry_timing = model.reshaper(
Expand All @@ -149,13 +156,14 @@ def load_data(data_folder: Path, trial_length=None, pre_stim=0.0, alpha=False):
)

inquiries, fs = filter_inquiries(inquiries, default_transform, sample_rate)
inquiry_timing = update_inquiry_timing(inquiry_timing, downsample_rate)
trial_duration_samples = int(poststim_length * fs)
if alpha:
pre_stim_duration_samples = int(pre_stim * fs)
else:
pre_stim_duration_samples = 0
data = model.reshaper.extract_trials(
inquiries, trial_duration_samples, inquiry_timing, downsample_rate, prestimulus_samples=pre_stim_duration_samples)
inquiries, trial_duration_samples, inquiry_timing, prestimulus_samples=pre_stim_duration_samples)

# define the training classes using integers, where 0=nontargets/1=targets
labels = inquiry_labels.flatten()
Expand Down Expand Up @@ -251,7 +259,8 @@ def flatten(data):


@ignore_warnings(category=ConvergenceWarning)
def main(input_path: Path, freq: float, hparam_tuning: bool, z_score_per_trial: bool, output_path: Optional[Path] = None):
def main(input_path: Path, freq: float, hparam_tuning: bool,
z_score_per_trial: bool, output_path: Optional[Path] = None):
data, labels, fs = load_data(input_path, trial_length=1.25, pre_stim=1.25, alpha=True)

# set output path to input path if not specified
Expand Down Expand Up @@ -322,9 +331,9 @@ def main(input_path: Path, freq: float, hparam_tuning: bool, z_score_per_trial:
z_transformed_entire_data = z_scorer.transform(data, do_slice=False)
# Copy of entire window for plotting
logger.info(
f"{z_transformed_target_window.min()=}, "
+ f"{z_transformed_target_window.mean()=}, "
+ f"{z_transformed_target_window.max()=}"
f"{z_transformed_target_window.min()=}, " +
f"{z_transformed_target_window.mean()=}, " +
f"{z_transformed_target_window.max()=}"
)

make_plots(z_transformed_target_window, labels, output_path / "2.z_target_window.png")
Expand Down Expand Up @@ -391,7 +400,9 @@ def main(input_path: Path, freq: float, hparam_tuning: bool, z_score_per_trial:
import argparse

p = argparse.ArgumentParser()
# trial length in seconds for alpha band: 1.25s before and 1.25s after response; z-scored per trial is False and hparam tuning is True/False (make sure both work)
# trial length in seconds for alpha band: 1.25s before and 1.25s after
# response; z-scored per trial is False and hparam tuning is True/False
# (make sure both work)
p.add_argument("--input", type=Path, help="Path to data folder", required=False, default=None)
p.add_argument("--output", type=Path, help="Path to save outputs", required=False, default=None)
p.add_argument("--freq", type=float, help="Frequency to keep after CWT", required=True)
Expand All @@ -407,7 +418,6 @@ def main(input_path: Path, freq: float, hparam_tuning: bool, z_score_per_trial:
else:
folder_path = args.input


if not folder_path.exists():
raise ValueError("data path does not exist")

Expand Down
146 changes: 102 additions & 44 deletions alpha/base_model.py
@@ -1,31 +1,25 @@
# Update the BciPy models with methods for calculating performance across models using sklearn's metrics module.
# Update the BciPy models with methods for calculating performance across models using sklearn's metrics module.
# This will allow us to compare models using the same metrics.

import pickle
from pathlib import Path
from typing import List

import numpy as np
# from bcipy.helpers.task import TrialReshaper
from bcipy.signal.exceptions import SignalException
from bcipy.helpers.exceptions import SignalException
from bcipy.signal.model import ModelEvaluationReport, SignalModel
from bcipy.signal.model.pca_rda_kde.classifier import RegularizedDiscriminantAnalysis
from bcipy.signal.model.pca_rda_kde.pca_rda_kde import PcaRdaKdeModel
from bcipy.signal.model.pca_rda_kde.cross_validation import (
from bcipy.signal.model.classifier import RegularizedDiscriminantAnalysis
from bcipy.signal.model import RdaKdeModel, PcaRdaKdeModel
from bcipy.signal.model.cross_validation import (
cost_cross_validation_auc,
cross_validation,
)
from bcipy.signal.model.pca_rda_kde.density_estimation import KernelDensityEstimate
from bcipy.signal.model.pca_rda_kde.dimensionality_reduction import (
from bcipy.signal.model.density_estimation import KernelDensityEstimate
from bcipy.signal.model.dimensionality_reduction import (
ChannelWisePrincipalComponentAnalysis,
# MockPCA,
MockPCA,
)
from bcipy.signal.model.pca_rda_kde.pipeline import Pipeline
from bcipy.signal.model.pipeline import Pipeline
from sklearn.utils.multiclass import unique_labels


class BasePcaRdaKdeModel(PcaRdaKdeModel):
# reshaper = TrialReshaper()
class BaseRdaKdeModel(RdaKdeModel):

def fit(self, train_data: np.array, train_labels: np.array) -> SignalModel:
"""
Expand All @@ -38,7 +32,7 @@ def fit(self, train_data: np.array, train_labels: np.array) -> SignalModel:
"""
model = Pipeline(
[
ChannelWisePrincipalComponentAnalysis(n_components=self.pca_n_components, num_ch=train_data.shape[0]),
MockPCA(),
RegularizedDiscriminantAnalysis(),
]
)
Expand Down Expand Up @@ -99,7 +93,8 @@ def evaluate(self, test_data: np.array, test_labels: np.array) -> ModelEvaluatio

lam_gam = (self.model.pipeline[1].lam, self.model.pipeline[1].gam)
tmp, _, _ = cost_cross_validation_auc(
tmp_model, 1, test_data, test_labels, lam_gam, k_folds=self.k_folds, split="uniform"
tmp_model, self.optimization_elements, test_data, test_labels,
lam_gam, k_folds=self.k_folds, split="uniform"
)
auc = -tmp
return ModelEvaluationReport(auc)
Expand All @@ -115,32 +110,95 @@ def predict(self, data: np.array) -> np.array:
probs = self.predict_proba(data)
return probs.argmax(-1)

def predict_proba(self, data: np.array) -> np.array:

class BasePcaRdaKdeModel(PcaRdaKdeModel):
# reshaper = TrialReshaper()

def fit(self, train_data: np.array, train_labels: np.array) -> SignalModel:
"""
sklearn-compatible method for predicting probabilities
Train on provided data using K-fold cross validation and return self.
Parameters:
train_data: shape (Channels, Trials, Trial_length) preprocessed data
train_labels: shape (Trials,) binary labels
Returns:
trained likelihood model
"""
if not self._ready_to_predict:
raise SignalException("must use model.fit() before model.predict_proba()")

# p(l=1 | e) = p(e | l=1) p(l=1) / p(e)
# log(p(l=1 | e)) = log(p(e | l=1)) + log(p(l=1)) - log(p(e))
log_scores_class_0 = self.model.transform(data)[:, 0]
log_scores_class_1 = self.model.transform(data)[:, 1]
log_post_0 = log_scores_class_0 + self.log_prior_class_0
log_post_1 = log_scores_class_1 + self.log_prior_class_1
denom = np.logaddexp(log_post_0, log_post_1)
log_post_0 -= denom
log_post_1 -= denom
posterior = np.exp(np.stack([log_post_0, log_post_1], axis=-1))
return posterior

def save(self, path: Path):
"""Save model weights (e.g. after training) to `path`"""
with open(path, "wb") as f:
pickle.dump(self.model, f)

def load(self, path: Path):
"""Load pretrained model weights from `path`"""
with open(path, "rb") as f:
self.model = pickle.load(f)
model = Pipeline(
[
ChannelWisePrincipalComponentAnalysis(n_components=self.pca_n_components, num_ch=train_data.shape[0]),
RegularizedDiscriminantAnalysis(),
]
)

# Find the optimal gamma + lambda values
arg_cv = cross_validation(train_data, train_labels, model=model, k_folds=self.k_folds)

# Get the AUC using those optimized gamma + lambda
rda_index = 1 # the index in the pipeline
model.pipeline[rda_index].lam = arg_cv[0]
model.pipeline[rda_index].gam = arg_cv[1]
tmp, sc_cv, y_cv = cost_cross_validation_auc(
model, rda_index, train_data, train_labels, arg_cv, k_folds=self.k_folds, split="uniform"
)
self.auc = -tmp
# After finding cross validation scores do one more round to learn the
# final RDA model
model.fit(train_data, train_labels)

# Insert the density estimates to the model and train using the cross validated
# scores to avoid over fitting. Observe that these scores are not obtained using
# the final model
model.add(KernelDensityEstimate(scores=sc_cv))
model.pipeline[-1].fit(sc_cv, y_cv)

self.model = model

if self.prior_type == "uniform":
self.log_prior_class_1 = self.log_prior_class_0 = np.log(0.5)
elif self.prior_type == "empirical":
prior_class_1 = np.sum(train_labels == 1) / len(train_labels)
self.log_prior_class_1 = np.log(prior_class_1)
self.log_prior_class_0 = np.log(1 - prior_class_1)
else:
raise ValueError("prior_type must be 'empirical' or 'uniform'")

self.classes_ = unique_labels(train_labels)
self._ready_to_predict = True
return self

def evaluate(self, test_data: np.array, test_labels: np.array) -> ModelEvaluationReport:
"""Computes AUROC of the intermediate RDA step of the pipeline using k-fold cross-validation
Args:
test_data (np.array): shape (Channels, Trials, Trial_length) preprocessed data.
test_labels (np.array): shape (Trials,) binary labels.
Raises:
SignalException: error if called before model is fit.
Returns:
ModelEvaluationReport: stores AUC
"""
if not self._ready_to_predict:
raise SignalException("must use model.fit() before model.evaluate()")

tmp_model = Pipeline([self.model.pipeline[0], self.model.pipeline[1]])

lam_gam = (self.model.pipeline[1].lam, self.model.pipeline[1].gam)
tmp, _, _ = cost_cross_validation_auc(
tmp_model, self.optimization_elements, test_data, test_labels,
lam_gam, k_folds=self.k_folds, split="uniform"
)
auc = -tmp
return ModelEvaluationReport(auc)

def predict(self, data: np.array) -> np.array:
"""
sklearn-compatible method for predicting
"""
if not self._ready_to_predict:
raise SignalException("must use model.fit() before model.predict()")

# p(l=1 | e) = p(e | l=1) p(l=1)
probs = self.predict_proba(data)
return probs.argmax(-1)
4 changes: 2 additions & 2 deletions alpha/baseline.py
Expand Up @@ -5,7 +5,7 @@

import numpy as np
from alpha_experiment import load_data, load_experimental_data
from base_model import BasePcaRdaKdeModel
from base_model import BasePcaRdaKdeModel, BaseRdaKdeModel
from loguru import logger
from rich.console import Console
from rich.table import Table
Expand All @@ -23,7 +23,7 @@ def main(input_path, output_path):
# extract relevant session information from parameters file
data, labels, _ = load_data(input_path, alpha=False)

model = make_pipeline(FunctionTransformer(reorder), BasePcaRdaKdeModel(k_folds=10))
model = make_pipeline(FunctionTransformer(reorder), BaseRdaKdeModel(k_folds=10))

n_folds = 10
np.random.seed(1)
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
@@ -1,7 +1,7 @@
bcipy
bcipy==2.0.0rc3
loguru==0.6.0
numpy
pyriemann==0.2.7
PyWavelets==1.2.0
PyWavelets==1.4.1
rich==11.2.0
scikit-learn

0 comments on commit 1e49f2a

Please sign in to comment.