In [1]:
%matplotlib inline
%reload_ext autoreload

In [2]:
import logging
import os

import matplotlib
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

STATE_COLORS = {
    'stationary': '#9f043a',
    'fragmented': '#ff6944',
    'continuous': '#521b65',
    'stationary-continuous-mix': '#61c5e6',
    'fragmented-continuous-mix': '#2a586a',
    '': '#c7c7c7',
}

# Set background and fontsize
rc_params = {
    'pdf.fonttype': 42, # Make fonts editable in Adobe Illustrator
    'ps.fonttype': 42, # Make fonts editable in Adobe Illustrator
    'axes.labelcolor': '#222222',
    'text.color': '#222222',
    }
sns.set(style="white", context='paper', rc=rc_params,
        font_scale=1.3)

# Make analysis reproducible
np.random.seed(0)

# Enable logging
logging.basicConfig(level=logging.INFO)

In [3]:
from replay_trajectory_classification.clusterless_simulation import make_simulated_run_data

(time, linear_distance, sampling_frequency,
 multiunits, multiunits_spikes) = make_simulated_run_data()



In [4]:
from replay_trajectory_classification import ClusterlessClassifier
from replay_trajectory_classification.environments import Environment
from replay_trajectory_classification.continuous_state_transitions import RandomWalk, Uniform, Identity, estimate_movement_var

movement_var = estimate_movement_var(linear_distance, sampling_frequency)


# If your marks are integers, use this algorithm because it is much faster
clusterless_algorithm = 'multiunit_likelihood'
clusterless_algorithm_params = {
    'mark_std': 1.0,
    'position_std': 12.5,
}

environment = Environment(place_bin_size=np.sqrt(movement_var))

continuous_transition_types = [[RandomWalk(movement_var=movement_var * 120),  Uniform(), Identity()],
                                [Uniform(),                                   Uniform(), Uniform()],
                                [RandomWalk(movement_var=movement_var * 120), Uniform(), Identity()],
                               ]


classifier = ClusterlessClassifier(
    environments=environment,
    continuous_transition_types=continuous_transition_types,
    clusterless_algorithm=clusterless_algorithm,
    clusterless_algorithm_params=clusterless_algorithm_params)
classifier.fit(linear_distance, multiunits)

INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting continuous state transition...
INFO:replay_trajectory_classification.classifier:Fitting discrete state transition
INFO:replay_trajectory_classification.classifier:Fitting multiunits...


In [5]:
from replay_trajectory_classification.clusterless_simulation import make_hover_continuous_hover_replay

state_names = ['continuous', 'fragmented', 'stationary']

replay_time, test_multiunits = make_hover_continuous_hover_replay()

results = classifier.predict(test_multiunits, time=replay_time, state_names=state_names)

INFO:replay_trajectory_classification.classifier:Estimating likelihood...


n_electrodes:   0%|          | 0/5 [00:00<?, ?it/s]

INFO:replay_trajectory_classification.classifier:Estimating causal posterior...
INFO:replay_trajectory_classification.classifier:Estimating acausal posterior...


In [6]:
results.likelihood.isel(time=0)

In [7]:
def viterbi(initial_conditions, continuous_state_transition,
            discrete_state_transition, likelihood):
    """Adaptive filter to iteratively calculate the posterior probability
    of a state variable using past information.

    Parameters
    ----------
    initial_conditions : np.ndarray, shape (n_states, n_bins, 1)
    continuous_state_transition : np.ndarray, shape (n_states, n_states,
                                                  n_bins, n_bins)
    discrete_state_transition : np.ndarray, shape (n_states, n_states)
    likelihood : np.ndarray, shape (n_time, n_states, n_bins, 1)

    Returns
    -------
    causal_posterior : np.ndarray, shape (n_time, n_states, n_bins, 1)

    """
    n_time, n_states, n_bins, _ = likelihood.shape
    path_prob = np.zeros_like(likelihood)
    back_pointer = np.zeros_like(likelihood)

    path_prob[0] = initial_conditions.copy() * likelihood[0]
    path_prob[0] /= np.nansum(path_prob[0])

    for k in np.arange(1, n_time):
        prior = np.zeros((n_states, n_bins, 1))
        for state_k in np.arange(n_states):
            for state_k_1 in np.arange(n_states):
                blah = (
                    discrete_state_transition[state_k_1, state_k] *
                    continuous_state_transition[state_k_1, state_k].T *
                    path_prob[k - 1, state_k_1])
                
                # need argmax/max over all previous state/positions
                back_pointer[state_k, :] = np.argmax(blah, axis=1, keepdims=True)
                prior[state_k, :] = np.max(blah, axis=1, keepdims=True)
                
            path_prob[k] = prior * likelihood[k]
        path_prob[k] /= np.nansum(path_prob[k])
    
    # path back-tracking
    best_path = np.zeros((n_time,), dtype=int)
    best_path[-1] = np.argmax(path_prob[-1])
    
    for k in range(n_time - 1, -1, -1):  
        best_path[k - 1] = back_pointer[k, best_path[k]]

    return np.stack(np.unravel_index(best_path, (n_states, n_bins)))

In [8]:
initial_conditions = classifier.initial_conditions_
continuous_state_transition = classifier.continuous_state_transition_
discrete_state_transition = classifier.discrete_state_transition_
likelihood = results.likelihood.values[..., np.newaxis]

In [50]:
likelihood.shape

(197, 3, 333, 1)

In [97]:
(np.unravel_index(np.asarray([0, 1]), (n_states, n_bins)))

(array([0, 0]), array([0, 1]))

In [78]:
n_time, n_states, n_bins, _ = likelihood.shape

best_path = np.zeros((n_time, 2), dtype=int)

best_path[n_time - 1] = np.unravel_index(np.argmax(likelihood[n_time - 1]), (n_states, n_bins))

In [None]:
# discrete state transition, shape (n_prev_state, n_state)
# continuous state_transition, shape (n_prev_state, n_state, n_prev_position, n_position)
# likelihood, shape (n_time, n_state, n_position, 1)

# discrete state transition * continuous state_transition * likelihood[k, :, :, :]
# take max over all previous states, shape (n_states, n_position, 1)

In [48]:
viterbi(initial_conditions, continuous_state_transition,
            discrete_state_transition, likelihood).shape

(197, 3, 333, 1)

In [9]:
state_probability = initial_conditions.copy() * likelihood[0]
state_probability.shape

(3, 333, 1)

In [19]:
state_k_1 = 0
state_k = 0

blah = discrete_state_transition[state_k_1, state_k] * continuous_state_transition[state_k_1, state_k].T @ state_probability[state_k]
blah.shape

(333, 1)

In [21]:
blah2 = discrete_state_transition[state_k_1, state_k] * continuous_state_transition[state_k_1, state_k].T * state_probability[state_k]
blah2.shape

(333, 333)

In [12]:
for k in np.arange(1, n_time):
    prior = np.zeros((n_states, n_states, 1))
    for state_k in np.arange(n_states):
        for state_k_1 in np.arange(n_states):
            prior[state_k_1, state_k] = np.max(
                discrete_state_transition[state_k_1, state_k] *
                continuous_state_transition[state_k_1, state_k].T @
                posterior[k - 1, state_k_1])
    posterior[k] = np.max(prior, axis=0) * likelihood[k]

[0;31mSignature:[0m [0mnp[0m[0;34m.[0m[0margmax[0m[0;34m([0m[0ma[0m[0;34m,[0m [0maxis[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0mout[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0;34m*[0m[0;34m,[0m [0mkeepdims[0m[0;34m=[0m[0;34m<[0m[0mno[0m [0mvalue[0m[0;34m>[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Returns the indices of the maximum values along an axis.

Parameters
----------
a : array_like
    Input array.
axis : int, optional
    By default, the index is into the flattened array, otherwise
    along the specified axis.
out : array, optional
    If provided, the result will be inserted into this array. It should
    be of the appropriate shape and dtype.
keepdims : bool, optional
    If this is set to True, the axes which are reduced are left
    in the result as dimensions with size one. With this option,
    the result will broadcast correctly against the array.

    .. versionadded:: 1.22.0

Returns
-------
index_array : nda

NameError: name 'posterior' is not defined