In [None]:
import logging

import numpy as np

# Make analysis reproducible
np.random.seed(0)

# Enable logging
logging.basicConfig(level=logging.INFO)

In [None]:
import pprint

import jax

pprint.pprint(jax.devices())

In [None]:
device_id = 0
device = jax.devices()[device_id]
device

In [4]:
jax.config.update("jax_default_device", device)

In [None]:
from replay_trajectory_classification.sorted_spikes_simulation import (
    make_simulated_run_data,
)

MM_TO_INCHES = 1.0 / 25.4
TWO_COLUMN = 178.0 * MM_TO_INCHES
GOLDEN_RATIO = (np.sqrt(5) - 1.0) / 2.0

(
    time,
    linear_distance,
    sampling_frequency,
    spikes,
    place_fields,
) = make_simulated_run_data()

In [6]:
from replay_trajectory_classification.sorted_spikes_simulation import (
    make_fragmented_continuous_fragmented_replay,
)

replay_time, test_spikes = make_fragmented_continuous_fragmented_replay()

state_names = ["Continuous", "Fragmented"]

In [7]:
from replay_trajectory_classification import (
    Environment,
    RandomWalk,
    Uniform,
    estimate_movement_var,
)

movement_var = estimate_movement_var(linear_distance, sampling_frequency)

environment = Environment(place_bin_size=np.sqrt(movement_var))
continuous_transition_types = [
    [RandomWalk(movement_var=movement_var * 120), Uniform()],
    [Uniform(), Uniform()],
]

In [8]:
from non_local_detector import ContFragSortedSpikesClassifier
from non_local_detector.discrete_state_transitions import DiscreteNonStationaryDiagonal

discrete_transition_type = DiscreteNonStationaryDiagonal(
    diagonal_values=np.array([0.98, 0.98])
)


discrete_transition_covariate_data = {"speed": linear_distance}

In [9]:
spike_times = [time[spike_train.astype(bool)] for spike_train in spikes.T]

In [10]:
test_spike_times = [replay_time[spike_train.astype(bool)] for spike_train in test_spikes.T]

In [None]:
classifier3 = ContFragSortedSpikesClassifier(
    environments=environment,
    discrete_transition_type=discrete_transition_type,
    continuous_transition_types=continuous_transition_types,
    sorted_spikes_algorithm="sorted_spikes_kde",
    sorted_spikes_algorithm_params={"position_std": 5.0},
).fit(time,
      linear_distance, spike_times, discrete_transition_covariate_data=discrete_transition_covariate_data)
results3 = classifier3.predict(test_spike_times, time=replay_time)

In [None]:
classifier3.discrete_transition_coefficients_

In [None]:
classifier3.estimate_parameters(
    time,
    linear_distance,
    spike_times,
    time=time,
    discrete_transition_covariate_data=discrete_transition_covariate_data,
)

In [None]:
classifier3.discrete_transition_coefficients_