In [None]:
import logging
import warnings

import xarray as xr
import yaml
from megspikes.pipeline import aspire_alphacsc_pipeline
from utils.utils import setup_case_manager

warnings.filterwarnings("ignore", category=DeprecationWarning)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.debug("test debug")
logging.info("test info")

from sklearn import set_config
set_config(display='diagram')
%load_ext autoreload
%autoreload 2


### Setup parameters

In [None]:
params_for_detection = {
    'n_ica_components': 20,
    'n_runs': 1,
    'runs': [0, ],
    'n_atoms': 3,
    'PeakDetection': {'width': 2},
    'CleanDetections': {'n_cleaned_peaks': 2000},
    'SelectAlphacscEvents': {
        'z_hat_threshold': 7.,
        'z_hat_threshold_min': 1.5}
}

### Run detection pipeline for all cases

In [None]:
for subj in range(1, 8):
    case = setup_case_manager(subj)

    pipe = aspire_alphacsc_pipeline(case, params_for_detection)

    dataset, raw = pipe.fit_transform(None)


### Rerun merging step for one case

In [None]:
from sklearn.pipeline import Pipeline
from megspikes.detection.detection import AspireAlphacscRunsMerging
from megspikes.database.database import SaveDataset, LoadDataset
from megspikes.pipeline import update_default_params
from megspikes.utils import PrepareData

case = setup_case_manager(6)

with open('aspire_alphacsc_default_params.yml', 'rt') as f:
    default_params = yaml.safe_load(f.read())
params = update_default_params(default_params, params_for_detection)

pipe_merging_only = Pipeline([
    ('prepare_data', PrepareData(data_file=case.fif_file, sensors=True, **params['PrepareData'])),
    ('load_aspire_alphacsc_dataset', LoadDataset(dataset=case.dataset, sensors=None, run=None)),
    ('merge_atoms', AspireAlphacscRunsMerging(**params['AspireAlphacscRunsMerging'])),
    ('save_dataset', SaveDataset(dataset=case.dataset))])

In [None]:
pipe_merging_only

In [None]:
_,_ = pipe_merging_only.fit_transform(())
detection_results = xr.open_dataset(case.dataset)

In [None]:
spikes = detection_results.alphacsc_atoms_library_properties.loc[
    dict(atoms_library_property='library_detection')].values

sum(spikes != 0)