In [1]:
import logging

import numpy as np
import matplotlib.pyplot as plt

# Make analysis reproducible
np.random.seed(0)

# Enable logging
logging.basicConfig(level=logging.INFO)

In [2]:
import jax
import pprint

pprint.pprint(jax.devices())

2023-06-26 17:34:42.014675: W external/xla/xla/pjrt/gpu/gpu_helpers.cc:63] Unable to enable peer access between GPUs 0 and 9; status: INTERNAL: failed to enable peer access from 0x7f58887d5bf0 to 0x7f5898624a80: CUDA_ERROR_TOO_MANY_PEERS: peer mapping resources exhausted
2023-06-26 17:34:42.029491: W external/xla/xla/pjrt/gpu/gpu_helpers.cc:63] Unable to enable peer access between GPUs 1 and 9; status: INTERNAL: failed to enable peer access from 0x7f5890625440 to 0x7f5898624a80: CUDA_ERROR_TOO_MANY_PEERS: peer mapping resources exhausted
2023-06-26 17:34:42.043813: W external/xla/xla/pjrt/gpu/gpu_helpers.cc:63] Unable to enable peer access between GPUs 2 and 9; status: INTERNAL: failed to enable peer access from 0x7f58a0624f60 to 0x7f5898624a80: CUDA_ERROR_TOO_MANY_PEERS: peer mapping resources exhausted
2023-06-26 17:34:42.055050: W external/xla/xla/pjrt/gpu/gpu_helpers.cc:63] Unable to enable peer access between GPUs 3 and 9; status: INTERNAL: failed to enable peer access from 0x7f58

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=2, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=3, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=4, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=5, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=6, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=7, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=8, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=9, process_index=0, slice_index=0)]


In [3]:
device_id = 0
device = jax.devices()[device_id]
device

StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)

In [4]:
jax.config.update("jax_default_device", device)

In [5]:
from replay_trajectory_classification.sorted_spikes_simulation import (
    make_simulated_run_data,
)

MM_TO_INCHES = 1.0 / 25.4
TWO_COLUMN = 178.0 * MM_TO_INCHES
GOLDEN_RATIO = (np.sqrt(5) - 1.0) / 2.0

(
    time,
    linear_distance,
    sampling_frequency,
    spikes,
    place_fields,
) = make_simulated_run_data()

INFO:numexpr.utils:Note: detected 96 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
INFO:numexpr.utils:Note: NumExpr detected 96 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.
  from tqdm.autonotebook import tqdm


In [6]:
from replay_trajectory_classification.sorted_spikes_simulation import (
    make_fragmented_continuous_fragmented_replay,
)

replay_time, test_spikes = make_fragmented_continuous_fragmented_replay()

state_names = ["Continuous", "Fragmented"]

In [7]:
from replay_trajectory_classification import (
    Environment,
    RandomWalk,
    Uniform,
    estimate_movement_var,
)


movement_var = estimate_movement_var(linear_distance, sampling_frequency)

environment = Environment(place_bin_size=np.sqrt(movement_var))
continuous_transition_types = [
    [RandomWalk(movement_var=movement_var * 120), Uniform()],
    [Uniform(), Uniform()],
]

In [8]:
from non_local_detector import ContFragSortedSpikesClassifier
from non_local_detector.discrete_state_transitions import DiscreteNonStationaryDiagonal

discrete_transition_type = DiscreteNonStationaryDiagonal(
    diagonal_values=np.array([0.98, 0.98])
)


discrete_transition_covariate_data = {"speed": linear_distance}

In [9]:
classifier3 = ContFragSortedSpikesClassifier(
    environments=environment,
    discrete_transition_type=discrete_transition_type,
    continuous_transition_types=continuous_transition_types,
    sorted_spikes_algorithm="sorted_spikes_kde",
    sorted_spikes_algorithm_params={"position_std": 5.0},
).fit(
    linear_distance,
    spikes,
    discrete_transition_covariate_data=discrete_transition_covariate_data,
)
results3 = classifier3.predict(test_spikes, time=replay_time)

INFO:non_local_detector.models.base:Fitting initial conditions...
INFO:non_local_detector.models.base:Fitting discrete state transition
INFO:non_local_detector.models.base:Fitting continuous state transition...
INFO:non_local_detector.models.base:Fitting place fields...


Encoding models:   0%|          | 0/19 [00:00<?, ?cell/s]

INFO:non_local_detector.models.base:Computing log likelihood...


Non-Local Likelihood:   0%|          | 0/19 [00:00<?, ?cell/s]

INFO:non_local_detector.models.base:Computing posterior...
INFO:non_local_detector.models.base:Finished computing posterior...


In [10]:
classifier3.discrete_transition_coefficients_

array([[[ 3.8918203],
        [-3.8918203]],

       [[ 0.       ],
        [ 0.       ]],

       [[ 0.       ],
        [ 0.       ]],

       [[ 0.       ],
        [ 0.       ]],

       [[ 0.       ],
        [ 0.       ]],

       [[ 0.       ],
        [ 0.       ]],

       [[ 0.       ],
        [ 0.       ]],

       [[ 0.       ],
        [ 0.       ]],

       [[ 0.       ],
        [ 0.       ]]])

In [11]:
classifier3.estimate_parameters(spikes, linear_distance, time=replay_time, discrete_transition_covariate_data=discrete_transition_covariate_data)

INFO:non_local_detector.models.base:Fitting initial conditions...
INFO:non_local_detector.models.base:Fitting discrete state transition
INFO:non_local_detector.models.base:Fitting continuous state transition...
INFO:non_local_detector.models.base:Fitting place fields...


Encoding models:   0%|          | 0/19 [00:00<?, ?cell/s]

INFO:non_local_detector.models.base:Computing log likelihood...


Non-Local Likelihood:   0%|          | 0/19 [00:00<?, ?cell/s]

INFO:non_local_detector.models.base:Expectation step...
INFO:non_local_detector.models.base:Computing posterior...
INFO:non_local_detector.models.base:Finished computing posterior...
INFO:non_local_detector.models.base:Maximization step..
INFO:non_local_detector.models.base:Computing stats..
INFO:non_local_detector.models.base:iteration 1, likelihood: -58030.05859375
INFO:non_local_detector.models.base:Expectation step...
INFO:non_local_detector.models.base:Computing posterior...
INFO:non_local_detector.models.base:Finished computing posterior...
INFO:non_local_detector.models.base:Maximization step..
INFO:non_local_detector.models.base:Computing stats..
INFO:non_local_detector.models.base:iteration 2, likelihood: -56397.63671875, change: 1632.421875
INFO:non_local_detector.models.base:Expectation step...
INFO:non_local_detector.models.base:Computing posterior...
INFO:non_local_detector.models.base:Finished computing posterior...
INFO:non_local_detector.models.base:Maximization step..


ValueError: static arguments should be comparable using __eq__.The following error was raised during a call to 'get_transition_matrix' when comparing two objects of types <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> and <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. The error was:
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type int32[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was _step at /cumulus/edeno/non_local_detector/src/non_local_detector/core.py:314 traced for scan.
------------------------------
The leaked intermediate value was created on line /cumulus/edeno/non_local_detector/src/non_local_detector/core.py:340 (hmm_smoother). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/tmp/ipykernel_653891/3007681538.py:12 (<module>)
/cumulus/edeno/non_local_detector/src/non_local_detector/models/base.py:1123 (predict)
/cumulus/edeno/non_local_detector/src/non_local_detector/models/base.py:470 (_predict)
/cumulus/edeno/non_local_detector/src/non_local_detector/core.py:340 (hmm_smoother)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

At:
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py(1579): _assert_live
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/core.py(476): full_raise
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/core.py(383): bind_with_trace
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/core.py(2647): bind
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/pjit.py(161): _python_pjit_helper
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/pjit.py(249): cache_miss
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/traceback_util.py(170): reraise_with_filtered_traceback
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py(258): deferring_binary_op
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py(791): op
  /cumulus/edeno/non_local_detector/src/non_local_detector/core.py(181): get_trans_mat
  /cumulus/edeno/non_local_detector/src/non_local_detector/core.py(258): _step
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/linear_util.py(188): call_wrapped
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py(2172): trace_to_subjaxpr_dynamic
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py(2150): trace_to_jaxpr_dynamic
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/profiler.py(314): wrapper
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/lax/control_flow/common.py(54): _initial_style_open_jaxpr
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/lax/control_flow/common.py(60): _initial_style_jaxpr
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(236): _create_jaxpr
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(250): scan
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback
  /cumulus/edeno/non_local_detector/src/non_local_detector/core.py(268): hmm_filter
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/linear_util.py(188): call_wrapped
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py(2172): trace_to_subjaxpr_dynamic
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py(2150): trace_to_jaxpr_dynamic
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/profiler.py(314): wrapper
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/pjit.py(913): _create_pjit_jaxpr
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/linear_util.py(345): memoized_fun
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/pjit.py(960): _pjit_jaxpr
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/pjit.py(498): common_infer_params
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/api.py(301): infer_params
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/pjit.py(155): _python_pjit_helper
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/pjit.py(249): cache_miss
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback
  /cumulus/edeno/non_local_detector/src/non_local_detector/core.py(309): hmm_smoother
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/linear_util.py(188): call_wrapped
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py(2172): trace_to_subjaxpr_dynamic
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py(2150): trace_to_jaxpr_dynamic
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/profiler.py(314): wrapper
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/pjit.py(913): _create_pjit_jaxpr
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/linear_util.py(345): memoized_fun
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/pjit.py(960): _pjit_jaxpr
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/pjit.py(498): common_infer_params
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/api.py(301): infer_params
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/pjit.py(155): _python_pjit_helper
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/pjit.py(249): cache_miss
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback
  /cumulus/edeno/non_local_detector/src/non_local_detector/models/base.py(470): _predict
  /cumulus/edeno/non_local_detector/src/non_local_detector/models/base.py(525): estimate_parameters
  /cumulus/edeno/non_local_detector/src/non_local_detector/models/base.py(1156): estimate_parameters
  /tmp/ipykernel_653891/2097719467.py(1): <module>
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3505): run_code
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3445): run_ast_nodes
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3266): run_cell_async
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3061): _run_cell
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3006): run_cell
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/ipykernel/zmqshell.py(531): run_cell
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/ipykernel/ipkernel.py(411): do_execute
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/ipykernel/kernelbase.py(729): execute_request
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/ipykernel/kernelbase.py(406): dispatch_shell
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/ipykernel/kernelbase.py(499): process_one
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/ipykernel/kernelbase.py(510): dispatch_queue
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/asyncio/events.py(80): _run
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/asyncio/base_events.py(1909): _run_once
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/asyncio/base_events.py(603): run_forever
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/tornado/platform/asyncio.py(215): start
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/ipykernel/kernelapp.py(711): start
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/traitlets/config/application.py(992): launch_instance
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/site-packages/ipykernel_launcher.py(17): <module>
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/runpy.py(86): _run_code
  /home/edeno/miniconda3/envs/non_local_detector_gpu/lib/python3.10/runpy.py(196): _run_module_as_main


In [None]:
classifier3.discrete_transition_coefficients_

In [None]:
%debug