In [1]:
import joblib
import pandas as pd
import numpy as np
import logging

FORMAT = "%(asctime)s %(message)s"
logging.basicConfig(level="INFO", format=FORMAT, datefmt="%d-%b-%y %H:%M:%S")


def load_data():
    try:
        path = "/Users/edeno/Downloads/"
        position_info = pd.read_pickle(path + "Jaq_03_16_position_info.pkl")
        spikes = pd.read_pickle(path + "Jaq_03_16_spikes.pkl")
        is_ripple = pd.read_pickle(path + "Jaq_03_16_is_ripple.pkl")
        env = joblib.load(path + "Jaq_03_16_environment.pkl")

        time = np.asarray(position_info.index / np.timedelta64(1, "s"))
        spikes = np.asarray(spikes).astype(float)
        position = np.asarray(position_info.linear_position).astype(float)
        position2D = np.asarray(position_info[["nose_x", "nose_y"]]).astype(float)
        is_ripple = np.asarray(is_ripple).squeeze()
        speed = np.asarray(position_info.nose_vel).astype(float)
    except FileNotFoundError:
        path = "/cumulus/edeno/non_local_paper/notebooks/"
        position_info = pd.read_pickle(path + "Jaq_03_16_position_info.pkl")
        spikes = pd.read_pickle(path + "Jaq_03_16_spikes.pkl")
        is_ripple = pd.read_pickle(path + "Jaq_03_16_is_ripple.pkl")
        env = joblib.load(path + "Jaq_03_16_environment.pkl")

        time = np.asarray(position_info.index / np.timedelta64(1, "s"))
        spikes = np.asarray(spikes).astype(float)
        position = np.asarray(position_info.linear_position).astype(float)
        position2D = np.asarray(position_info[["nose_x", "nose_y"]]).astype(float)
        is_ripple = np.asarray(is_ripple).squeeze()
        speed = np.asarray(position_info.nose_vel).astype(float)

    return is_ripple, spikes, position, speed, env, time, position2D

In [2]:
import jax
import pprint

pprint.pprint(jax.devices())

2023-06-26 15:31:14.199855: W external/xla/xla/pjrt/gpu/gpu_helpers.cc:63] Unable to enable peer access between GPUs 0 and 9; status: INTERNAL: failed to enable peer access from 0x7f140c626e80 to 0x7f14106248a0: CUDA_ERROR_TOO_MANY_PEERS: peer mapping resources exhausted
2023-06-26 15:31:14.201239: W external/xla/xla/pjrt/gpu/gpu_helpers.cc:63] Unable to enable peer access between GPUs 1 and 9; status: INTERNAL: failed to enable peer access from 0x7f142c624ef0 to 0x7f14106248a0: CUDA_ERROR_TOO_MANY_PEERS: peer mapping resources exhausted
2023-06-26 15:31:14.202461: W external/xla/xla/pjrt/gpu/gpu_helpers.cc:63] Unable to enable peer access between GPUs 2 and 9; status: INTERNAL: failed to enable peer access from 0x7f142059ea10 to 0x7f14106248a0: CUDA_ERROR_TOO_MANY_PEERS: peer mapping resources exhausted
2023-06-26 15:31:14.203529: W external/xla/xla/pjrt/gpu/gpu_helpers.cc:63] Unable to enable peer access between GPUs 3 and 9; status: INTERNAL: failed to enable peer access from 0x7f13

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=2, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=3, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=4, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=5, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=6, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=7, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=8, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=9, process_index=0, slice_index=0)]


In [3]:
device_id = 1
device = jax.devices()[device_id]
jax.config.update("jax_default_device", device)
device

StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0)

In [4]:
is_ripple, spikes, position, speed, env, time, position2D = load_data()

26-Jun-23 15:31:16 Cupy is not installed or GPU is not detected. Ignore this message if not using GPU
  from tqdm.autonotebook import tqdm


In [5]:
from non_local_detector import NonLocalSortedSpikesDetector

detector = NonLocalSortedSpikesDetector(
    sorted_spikes_algorithm="sorted_spikes_kde",
    sorted_spikes_algorithm_params={
        "position_std": 6.0,
        "use_diffusion": False,
        "block_size": 8000,
        "interpolate_local": False,
    },
).fit(position2D, spikes, is_training=~is_ripple)

26-Jun-23 15:31:18 Fitting initial conditions...
26-Jun-23 15:31:18 Fitting discrete state transition
26-Jun-23 15:31:18 Fitting continuous state transition...
26-Jun-23 15:31:20 Fitting place fields...


KeyError: 'sorted_spikes_kde_jax2'

In [None]:
detector.plot_discrete_state_transition()

In [None]:
results = detector.predict(spikes, position2D, time=time)
results

In [None]:
results = detector.estimate_parameters(spikes, position2D, time=time)

In [None]:
detector.plot_discrete_state_transition()