In [None]:
import logging
import pathlib
import pickle

import numpy as np
import pandas as pd

FORMAT = "%(asctime)s %(message)s"
logging.basicConfig(level="INFO", format=FORMAT, datefmt="%d-%b-%y %H:%M:%S")


def load_data(folder_path):
    folder_path = pathlib.Path(folder_path)
    position_info = pd.read_pickle(folder_path / "Jaq_03_16_position_info.pkl")

    with open(folder_path / "Jaq_03_16_sorted_spike_times.pkl", "rb") as f:
        sorted_spike_times = pickle.load(f)

    with open(folder_path / "Jaq_03_16_clusterless_spike_times.pkl", "rb") as f:
        clusterless_spike_times = pickle.load(f)
    with open(
        folder_path / "Jaq_03_16_clusterless_spike_waveform_features.pkl", "rb"
    ) as f:
        clusterless_spike_waveform_features = pickle.load(f)

    position_time = np.asarray(position_info.index).astype(float)
    position1D = np.asarray(position_info.linear_position).astype(float)
    position2D = np.asarray(position_info[["nose_x", "nose_y"]]).astype(float)

    return (
        position_time,
        position1D,
        position2D,
        sorted_spike_times,
        clusterless_spike_times,
        clusterless_spike_waveform_features,
    )


(
    position_time,
    position1D,
    position2D,
    sorted_spike_times,
    clusterless_spike_times,
    clusterless_spike_waveform_features,
) = load_data("/Users/edeno/Downloads/Jaq_03_16_data")

In [None]:
import joblib

path = "/Users/edeno/Downloads/"

env = joblib.load(path + "Jaq_03_16_environment.pkl")

In [None]:
from non_local_detector import NonLocalClusterlessDetector

detector = NonLocalClusterlessDetector(environments=[env])
detector

In [None]:
detector.fit(
    position_time=position_time,
    position=position2D,
    spike_times=clusterless_spike_times,
    spike_waveform_features=clusterless_spike_waveform_features,
)

In [None]:
detector.plot_discrete_state_transition()

In [None]:
results = detector.predict(
    spike_times=clusterless_spike_times,
    spike_waveform_features=clusterless_spike_waveform_features,
    time=position_time[slice(90_000, 100_000)],
    position=position2D[slice(90_000, 100_000)],
    position_time=position_time[slice(90_000, 100_000)],
)
results