In [1]:
import logging
import warnings

import mne
import xarray as xr
import yaml
from megspikes.pipeline import aspire_alphacsc_pipeline
from megspikes.visualization.report import report_detection, report_atoms_library

from utils.utils import setup_case_manager

warnings.filterwarnings("ignore", category=DeprecationWarning)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.debug("test debug")
logging.info("test info")

from sklearn import set_config
set_config(display='diagram')
%load_ext autoreload
%autoreload 2


INFO:root:test info


### Setup parameters

In [None]:
params_for_detection = {
    'n_ica_components': 20,
    'n_runs': 1,
    'runs': [0, ],
    'n_atoms': 3,
    'PeakDetection': {'width': 2},
    'CleanDetections': {'n_cleaned_peaks': 2000},
    'SelectAlphacscEvents': {
        'z_hat_threshold': 7.,
        'z_hat_threshold_min': 1.5}
}

### Run detection pipeline for all cases

In [None]:
for subj in [1, 2, 3, 4, 6, 7]:
    case = setup_case_manager(subj)

    pipe = aspire_alphacsc_pipeline(case, params_for_detection)

    dataset, raw = pipe.fit_transform(None)


### Add alpha-notch filtering for case № 5 

In [None]:
# use different parameters for case 5
params_for_detection_case_5 = {
    'n_ica_components': 20,
    'n_runs': 1,
    'runs': [0, ],
    'n_atoms': 3,
    'PrepareData' : {'alpha_notch': 10},
    'PeakDetection': {'width': 2},
    'CleanDetections': {
        'n_cleaned_peaks': 300,
        'diff_threshold': 0.3
    },
    'SelectAlphacscEvents': {
        'z_hat_threshold': 4.,
        'z_hat_threshold_min': 1.5},
}

In [None]:
case = setup_case_manager(5)

pipe = aspire_alphacsc_pipeline(
    case, params_for_detection_case_5,
    rewrite_previous_results=True,
    manual_ica_components={'grad': ((1,4),), 'mag': ((0, 4),)}) # run pipeline using manually selected ICA components

dataset, raw = pipe.fit_transform(None)

### PFD detection report for each case

Report includes the following plots for each run and sensor type:
1. ICA components
2. AlphaCSC atoms
3. AlphaCSC events for each atom

In [None]:
for subj in range(1, 8):
    case = setup_case_manager(subj)

    detection_results = xr.open_dataset(case.dataset)

    raw = mne.io.read_raw_fif(case.fif_file, preload=True)

    pdf_path = case.basic_folders['REPORTS'] / 'detection_report.pdf'
    report_detection(pdf_path, detection_results, raw.copy())

    pdf_path = case.basic_folders['REPORTS'] / 'atoms_library_report.pdf'
    report_atoms_library(pdf_path, detection_results, raw.copy())



### Rerun merging step for one case
It would be necessary when there are not many atoms in the atoms library.

In [None]:
from sklearn.pipeline import Pipeline
from megspikes.detection.detection import AspireAlphacscRunsMerging
from megspikes.database.database import SaveDataset, LoadDataset
from megspikes.pipeline import update_default_params
from megspikes.utils import PrepareData

case = setup_case_manager(6)

with open('aspire_alphacsc_default_params.yml', 'rt') as f:
    default_params = yaml.safe_load(f.read())
params = update_default_params(default_params, params_for_detection)

pipe_merging_only = Pipeline([
    ('prepare_data', PrepareData(data_file=case.fif_file, sensors=True, **params['PrepareData'])),
    ('load_aspire_alphacsc_dataset', LoadDataset(dataset=case.dataset, sensors=None, run=None)),
    ('merge_atoms', AspireAlphacscRunsMerging(**params['AspireAlphacscRunsMerging'])),
    ('save_dataset', SaveDataset(dataset=case.dataset))])

In [None]:
pipe_merging_only

In [None]:
_,_ = pipe_merging_only.fit_transform(())
detection_results = xr.open_dataset(case.dataset)

In [None]:
spikes = detection_results.alphacsc_atoms_library_properties.loc[
    dict(atoms_library_property='library_detection')].values

sum(spikes != 0)


### PFD detection report for each case again

Report includes the following plots for each run and sensor type:
1. ICA components
2. AlphaCSC atoms
3. AlphaCSC events for each atom

In [None]:
for subj in range(1, 8):
    case = setup_case_manager(subj)

    detection_results = xr.open_dataset(case.dataset)

    raw = mne.io.read_raw_fif(case.fif_file, preload=True)

    pdf_path = case.basic_folders['REPORTS'] / 'detection_report.pdf'
    report_detection(pdf_path, detection_results, raw.copy())

    pdf_path = case.basic_folders['REPORTS'] / 'atoms_library_report.pdf'
    report_atoms_library(pdf_path, detection_results, raw.copy())



