In [1]:
import logging

import numpy as np
import matplotlib.pyplot as plt

# Make analysis reproducible
np.random.seed(0)

# Enable logging
logging.basicConfig(level=logging.INFO)

In [2]:
import jax
import pprint

pprint.pprint(jax.devices())

2023-06-27 15:31:12.585020: W external/xla/xla/pjrt/gpu/gpu_helpers.cc:63] Unable to enable peer access between GPUs 0 and 9; status: INTERNAL: failed to enable peer access from 0x7f2644624ab0 to 0x7f264c63b230: CUDA_ERROR_TOO_MANY_PEERS: peer mapping resources exhausted
2023-06-27 15:31:12.592772: W external/xla/xla/pjrt/gpu/gpu_helpers.cc:63] Unable to enable peer access between GPUs 1 and 9; status: INTERNAL: failed to enable peer access from 0x7f2640624c00 to 0x7f264c63b230: CUDA_ERROR_TOO_MANY_PEERS: peer mapping resources exhausted
2023-06-27 15:31:12.599225: W external/xla/xla/pjrt/gpu/gpu_helpers.cc:63] Unable to enable peer access between GPUs 2 and 9; status: INTERNAL: failed to enable peer access from 0x7f2648624940 to 0x7f264c63b230: CUDA_ERROR_TOO_MANY_PEERS: peer mapping resources exhausted
2023-06-27 15:31:12.604666: W external/xla/xla/pjrt/gpu/gpu_helpers.cc:63] Unable to enable peer access between GPUs 3 and 9; status: INTERNAL: failed to enable peer access from 0x7f26

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=2, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=3, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=4, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=5, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=6, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=7, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=8, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=9, process_index=0, slice_index=0)]


In [3]:
device_id = 0
device = jax.devices()[device_id]
device

StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)

In [4]:
jax.config.update("jax_default_device", device)

In [5]:
from replay_trajectory_classification.sorted_spikes_simulation import (
    make_simulated_run_data,
)

MM_TO_INCHES = 1.0 / 25.4
TWO_COLUMN = 178.0 * MM_TO_INCHES
GOLDEN_RATIO = (np.sqrt(5) - 1.0) / 2.0

(
    time,
    linear_distance,
    sampling_frequency,
    spikes,
    place_fields,
) = make_simulated_run_data()

INFO:numexpr.utils:Note: detected 96 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
INFO:numexpr.utils:Note: NumExpr detected 96 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.
  from tqdm.autonotebook import tqdm


In [6]:
from replay_trajectory_classification.sorted_spikes_simulation import (
    make_fragmented_continuous_fragmented_replay,
)

replay_time, test_spikes = make_fragmented_continuous_fragmented_replay()

state_names = ["Continuous", "Fragmented"]

In [7]:
from replay_trajectory_classification import (
    Environment,
    RandomWalk,
    Uniform,
    estimate_movement_var,
)


movement_var = estimate_movement_var(linear_distance, sampling_frequency)

environment = Environment(place_bin_size=np.sqrt(movement_var))
continuous_transition_types = [
    [RandomWalk(movement_var=movement_var * 120), Uniform()],
    [Uniform(), Uniform()],
]

In [8]:
from non_local_detector import ContFragSortedSpikesClassifier
from non_local_detector.discrete_state_transitions import DiscreteNonStationaryDiagonal

discrete_transition_type = DiscreteNonStationaryDiagonal(
    diagonal_values=np.array([0.98, 0.98])
)


discrete_transition_covariate_data = {"speed": linear_distance}

In [9]:
classifier3 = ContFragSortedSpikesClassifier(
    environments=environment,
    discrete_transition_type=discrete_transition_type,
    continuous_transition_types=continuous_transition_types,
    sorted_spikes_algorithm="sorted_spikes_kde",
    sorted_spikes_algorithm_params={"position_std": 5.0},
).fit(linear_distance, spikes, discrete_transition_covariate_data=discrete_transition_covariate_data)
results3 = classifier3.predict(test_spikes, time=replay_time)

INFO:non_local_detector.models.base:Fitting initial conditions...
INFO:non_local_detector.models.base:Fitting discrete state transition
INFO:non_local_detector.models.base:Fitting continuous state transition...
INFO:non_local_detector.models.base:Fitting place fields...


Encoding models:   0%|          | 0/19 [00:00<?, ?cell/s]

INFO:non_local_detector.models.base:Computing log likelihood...


Non-Local Likelihood:   0%|          | 0/19 [00:00<?, ?cell/s]

INFO:non_local_detector.models.base:Computing posterior...
INFO:non_local_detector.models.base:Finished computing posterior...


In [10]:
classifier3.discrete_transition_coefficients_

array([[[ 3.8918203],
        [-3.8918203]],

       [[ 0.       ],
        [ 0.       ]],

       [[ 0.       ],
        [ 0.       ]],

       [[ 0.       ],
        [ 0.       ]],

       [[ 0.       ],
        [ 0.       ]],

       [[ 0.       ],
        [ 0.       ]],

       [[ 0.       ],
        [ 0.       ]],

       [[ 0.       ],
        [ 0.       ]],

       [[ 0.       ],
        [ 0.       ]]])

In [11]:
classifier3.estimate_parameters(
    spikes,
    linear_distance,
    time=time,
    discrete_transition_covariate_data=discrete_transition_covariate_data,
)

INFO:non_local_detector.models.base:Fitting initial conditions...
INFO:non_local_detector.models.base:Fitting discrete state transition
INFO:non_local_detector.models.base:Fitting continuous state transition...
INFO:non_local_detector.models.base:Fitting place fields...


Encoding models:   0%|          | 0/19 [00:00<?, ?cell/s]

INFO:non_local_detector.models.base:Computing log likelihood...


Non-Local Likelihood:   0%|          | 0/19 [00:00<?, ?cell/s]

INFO:non_local_detector.models.base:Expectation step...
INFO:non_local_detector.models.base:Computing posterior...
INFO:non_local_detector.models.base:Finished computing posterior...
INFO:non_local_detector.models.base:Maximization step..
INFO:non_local_detector.models.base:Computing stats..
INFO:non_local_detector.models.base:iteration 1, likelihood: -58030.05859375
INFO:non_local_detector.models.base:Expectation step...
INFO:non_local_detector.models.base:Computing posterior...
INFO:non_local_detector.models.base:Finished computing posterior...
INFO:non_local_detector.models.base:Maximization step..
INFO:non_local_detector.models.base:Computing stats..
INFO:non_local_detector.models.base:iteration 2, likelihood: -56397.63671875, change: 1632.421875
INFO:non_local_detector.models.base:Expectation step...
INFO:non_local_detector.models.base:Computing posterior...
INFO:non_local_detector.models.base:Finished computing posterior...
INFO:non_local_detector.models.base:Maximization step..


In [12]:
classifier3.discrete_transition_coefficients_

array([[[10.94868448],
        [ 2.62752951]],

       [[ 0.52799244],
        [-0.04175963]],

       [[ 0.83108227],
        [ 0.28845227]],

       [[ 1.42042038],
        [ 0.85278268]],

       [[ 0.8569469 ],
        [ 0.20866752]],

       [[-0.28800777],
        [-0.181746  ]],

       [[-3.5468464 ],
        [-3.35430653]],

       [[-0.59506997],
        [-0.4567826 ]],

       [[ 1.88704301],
        [ 1.02678134]]])