In [None]:
import mne
import numpy as np
import xarray as xr

from megspikes.pipeline import (iz_prediction_pipeline,
                                read_detection_iz_prediction_pipeline)
from megspikes.visualization.visualization import ClusterSlopeViewer
from sklearn import set_config

from utils.utils import setup_case_manager

set_config(display='diagram')
set_config(print_changed_only=False)

%load_ext autoreload
%autoreload 2


### Setup params

In [None]:
params = {
    'PrepareClustersDataset': {'detection_sfreq': 200.}
}

### Run clusters localization for all cases

NOTE: if the previous results will not be overwritten. Algorithm will through an
error and stop.

In [None]:
for subj in range(1, 8):
    case = setup_case_manager(subj)

    pipe = read_detection_iz_prediction_pipeline(case, params)
    detection_results = xr.open_dataset(case.dataset)

    raw = mne.io.read_raw_fif(case.fif_file)
    clusters, _ = pipe.fit_transform((detection_results, raw.copy()))


### View clusters

To manually update clusters slope use the following procedure:

1. Change **selected_for_iz_prediction** values in the table to select clusters for prediction. 1 - selected; 0 - not selected.
2. Change **time_baseline**, **time_slope**, **time_peak** values in the table. The time is in milliseconds.
3. [Optional] Change **Save dataset path**. Extention should be **file_name.nc**.
4. Press **Save Dataset** to rewrite or create the file with updated results.
5. [Optional] Check the time of the last modification of the manually checked file to be sure the changes were saved.

NOTE that some operations, such as saving, are slow. Please, wait until Jupyter notebook stops showing the process is running.

In [None]:
import xarray as xr
from megspikes.visualization.visualization import ClusterSlopeViewer
from utils.utils import setup_case_manager

In [None]:
subject = 7
case = setup_case_manager(subject)
clusters = xr.open_dataset(case.cluster_dataset)
pc = ClusterSlopeViewer(clusters, case)

In [None]:
%matplotlib qt5

app = pc.view()
# app.show()
app

### Open manually checked dataset

NOTE `ds.load()` is necessary to save new changes in an existing file. See this issue: https://github.com/pydata/xarray/issues/2029

In [None]:
%matplotlib qt5

subject = 2
case = setup_case_manager(subject)

path = case.cluster_dataset.with_name(f'{case.case}_clusters_manually_checked.nc')
with xr.open_dataset(path) as ds:
    clusters_checked = ds.load()
    
pc = ClusterSlopeViewer(clusters_checked, case)
app = pc.view()
app

### Localize manual spikes as one cluster

In [None]:
params = {
    'PrepareClustersDataset': {'detection_sfreq': 1000.}
}

for subj in range(1, 8):

    case = setup_case_manager(subj)

    case.cluster_dataset = case.manual_cluster_dataset

    manual_detections = {
        'spikes': np.load(str(case.manual_detections), allow_pickle=True), # spikes in ms
        'clusters': np.int32([0])
    }
    pipe = iz_prediction_pipeline(case, params)

    raw = mne.io.read_raw_fif(case.fif_file)
    clusters, _ = pipe.fit_transform((manual_detections, raw.copy()))


### View manual cluster

In [None]:
%matplotlib qt5

case = setup_case_manager(5)

clusters = xr.open_dataset(case.manual_cluster_dataset)
pc = ClusterSlopeViewer(clusters, case)

app = pc.view()
# app.show()
app

### Localize manual spikes on the peak individually
In this case each spike is localized individually. The output is used for the final statistics estimation.

**NOTE**: Manual detections should start from **zero** (not the first sample)!

**NOTE**: The number of smoothing steps could significantly decrease or increase the number of sources in final prediction.


In [None]:
from megspikes.localization.localization import ManualEventsLocalization
from megspikes.utils import PrepareData


for subj in range(1, 8):

    case = setup_case_manager(subj)
    prep_data = PrepareData(data_file=case.fif_file, sensors='grad')
    mel = ManualEventsLocalization(case=case, smoothing_steps=10, smoothing_steps_final=10)

    meg_data = prep_data.fit_transform(())

    manual_stc = mel.fit_transform(
        (np.load(str(case.manual_detections), allow_pickle=True), meg_data))

    np.save(case.basic_folders['MANUAL'] / 'manual_stc.npy', manual_stc)
