In [None]:
import mne
import numpy as np
import xarray as xr
import yaml
from megspikes.casemanager.casemanager import CaseManager
from megspikes.pipeline import (iz_prediction_pipeline,
                                read_detection_iz_prediction_pipeline)
from sklearn import set_config

set_config(display='diagram')
set_config(print_changed_only=False)

%load_ext autoreload
%autoreload 2

In [None]:
with open('case_info.yml', 'rt') as f:
    cases = yaml.safe_load(f.read())


### Setup params

In [None]:
params = {
    'PrepareClustersDataset': {'detection_sfreq': 200.}
}

### Run clusters localization for all cases

NOTE: if the previous results will not be overwritten. Algorithm will through an
error and stop.

In [None]:
for subj in range(1, 7):
    case = CaseManager(root=cases['cases_path'], case= cases['case_name'][subj],
                       free_surfer= cases['free_surfer_path'])

    case.set_basic_folders()
    case.select_fif_file(case.run)
    case.prepare_forward_model()

    pipe = read_detection_iz_prediction_pipeline(case, params)
    detection_results = xr.open_dataset(case.dataset)

    raw = mne.io.read_raw_fif(case.fif_file)
    clusters, _ = pipe.fit_transform((detection_results, raw.copy()))


### Veiw clusters

In [None]:
case = CaseManager(root=cases['cases_path'], case= cases['case_name'][6],
                   free_surfer= cases['free_surfer_path'])

case.set_basic_folders()
case.select_fif_file(case.run)
case.prepare_forward_model()

In [None]:
clusters = xr.open_dataset(case.cluster_dataset)

In [None]:
from megspikes.visualization.visualization import ClusterSlopeViewer
pc = ClusterSlopeViewer(clusters, case)

In [None]:
%matplotlib qt5

app = pc.view()
# app.show()
app

### Localize manual spikes

In [None]:
params = {
    'PrepareClustersDataset': {'detection_sfreq': 1000.}
}

for subj in range(1, 7):
    case_name = cases['case_name'][subj]
    case = CaseManager(root=cases['cases_path'], case=case_name,
                       free_surfer=cases['free_surfer_path'])

    case.set_basic_folders()
    case.select_fif_file(case.run)
    case.prepare_forward_model()

    case.cluster_dataset = case.cluster_dataset.with_name(f'{case_name}_manual.nc')

    manual = case.basic_folders['MANUAL'] / f"{case_name}_manual_detections.npy"

    manual_detections = {
        'spikes': np.load(str(manual), allow_pickle=True), # spikes in ms
        'clusters': np.int32([0])
    }
    pipe = iz_prediction_pipeline(case, params)

    raw = mne.io.read_raw_fif(case.fif_file)
    clusters, _ = pipe.fit_transform((manual_detections, raw.copy()))
