In [None]:
import mne
import numpy as np
import xarray as xr

from megspikes.pipeline import (iz_prediction_pipeline,
                                read_detection_iz_prediction_pipeline)
from megspikes.visualization.visualization import ClusterSlopeViewer
from sklearn import set_config

from utils.utils import setup_case_manager

set_config(display='diagram')
set_config(print_changed_only=False)

%load_ext autoreload
%autoreload 2


### Setup params

In [None]:
params = {
    'PrepareClustersDataset': {'detection_sfreq': 200.}
}

### Run clusters localization for all cases

NOTE: if the previous results will not be overwritten. Algorithm will through an
error and stop.

In [None]:
for subj in range(1, 8):
    case = setup_case_manager(subj)

    pipe = read_detection_iz_prediction_pipeline(case, params)
    detection_results = xr.open_dataset(case.dataset)

    raw = mne.io.read_raw_fif(case.fif_file)
    clusters, _ = pipe.fit_transform((detection_results, raw.copy()))


### Veiw clusters

In [None]:
subject = 5
case = setup_case_manager(subject)
clusters = xr.open_dataset(case.cluster_dataset)

In [None]:
pc = ClusterSlopeViewer(clusters, case)

In [None]:
%matplotlib qt5

app = pc.view()
# app.show()
app

### Localize manual spikes

In [None]:
params = {
    'PrepareClustersDataset': {'detection_sfreq': 1000.}
}

for subj in range(1, 8):

    case = setup_case_manager(subj)

    case.cluster_dataset = case.manual_cluster_dataset

    manual_detections = {
        'spikes': np.load(str(case.manual_detections), allow_pickle=True), # spikes in ms
        'clusters': np.int32([0])
    }
    pipe = iz_prediction_pipeline(case, params)

    raw = mne.io.read_raw_fif(case.fif_file)
    clusters, _ = pipe.fit_transform((manual_detections, raw.copy()))




### View manual cluster

In [None]:
%matplotlib qt5

case = setup_case_manager(5)

clusters = xr.open_dataset(case.manual_cluster_dataset)
pc = ClusterSlopeViewer(clusters, case)

app = pc.view()
# app.show()
app