In [30]:
import numpy as np
from non_local_detector.core import filter, smoother

n_time = 10
n_states = 2
initial_distribution = np.ones(n_states) / n_states
# transition_matrix = np.ones((n_states, n_states)) / n_states
transition_matrix = np.array([[0.9, 0.1], [0.1, 0.9]])
log_likelihood = np.random.rand(n_time, n_states)
(log_normalizer, predicted_probs_next), (filtered_probs, predicted_probs) = filter(
    initial_distribution, transition_matrix, log_likelihood
)
smoothed_probs_next, smoothed_probs = smoother(transition_matrix, filtered_probs)

Chunk logic.

1. Figure out length of chunk and index
2. Compute log likelihood of chunk data
3. Run forward filter
4. Get next chunk
5. Compute log likelihood of chunk data
6. Run forward filter
7. Run smoother
8. Get next chunk
9. Run smoother
10. Get previous chunk
11. Run smoother

In [2]:
n_chunks = 2

marginal_likelihood2 = 0.0
filtered_probs2 = []
predicted_probs2 = []

for chunk_id, time_inds in enumerate(np.array_split(np.arange(n_time), n_chunks)):
    print(chunk_id, time_inds)
    if chunk_id == 0:
        initial = initial_distribution
    else:
        initial = predicted_probs_next
    chunk_log_likelihood = log_likelihood[time_inds]
    (marginal_likelihood_chunk, predicted_probs_next), (
        filtered_probs_chunk,
        predicted_probs_chunk,
    ) = filter(initial, transition_matrix, chunk_log_likelihood)
    filtered_probs2.append(filtered_probs_chunk)
    predicted_probs2.append(predicted_probs_chunk)
    marginal_likelihood2 += marginal_likelihood_chunk

filtered_probs2 = np.concatenate(filtered_probs2)
predicted_probs2 = np.concatenate(predicted_probs2)

np.allclose(log_normalizer, marginal_likelihood2), np.allclose(
    filtered_probs, filtered_probs2
), np.allclose(predicted_probs, predicted_probs2)

0 [0 1 2 3 4]
1 [5 6 7 8 9]


(True, True, True)

In [4]:
smoothed_probs2 = []

for chunk_id, time_inds in enumerate(np.array_split(np.arange(n_time), n_chunks)[::-1]):
    print(chunk_id, time_inds)
    if chunk_id == 0:
        initial = filtered_probs[-1]
    else:
        initial = smoothed_probs_chunk[0]
    _, smoothed_probs_chunk = smoother(
        transition_matrix,
        filtered_probs[time_inds],
        initial=initial,
        ind=time_inds,
        n_time=n_time,
    )
    smoothed_probs2.append(smoothed_probs_chunk)

smoothed_probs3 = np.concatenate(smoothed_probs2[::-1])

np.allclose(smoothed_probs, smoothed_probs3)

0 [5 6 7 8 9]
1 [0 1 2 3 4]


True

Things needed to estimate parameters:
- acausal_posterior
- causal_state_probabilities
- predictive_state_probabilities
- acausal_state_probabilities

Things that can be discarded:
- log_likelihood
- causal_posterior


Psuedo code:
```python

filtered_probs = []
predictive_state_probabilities = []
causal_state_probabilities = []
marginal_likelihood = 0.0

time_chunks = np.array_split(np.arange(n_time), n_chunks)

for chunk_id, time_inds in enumerate(time_chunks):
    if is_missing is not None:
        is_missing_chunk = is_missing[time_inds]
    log_likelihood_chunk = compute_log_likelihood(
        time[time_inds],
        position_time,
        position,
        spike_times,
        is_missing=is_missing_chunk,
    )
    initial = initial_distribution if chunk_id == 0 else predicted_probs_next
    
    (marginal_likelihood_chunk, predicted_probs_next), (
        filtered_probs_chunk,
        predicted_probs_chunk,
    ) = filter(initial, transition_matrix, chunk_log_likelihood)

    filtered_probs.append(filtered_probs_chunk)
    
    causal_state_probabilities.append(convert_to_state_probability(filtered_probs_chunk))
    predictive_state_probabilities.append(convert_to_state_probability(predicted_probs_chunk))

    marginal_likelihood += marginal_likelihood_chunk

smoothed_probs = []
acausal_state_probabilities = []

for chunk_id, time_inds in enumerate(time_chunks[::-1]):
    initial = filtered_probs[-1] if chunk_id == 0 else smoothed_probs_chunk[0]
    _, smoothed_probs_chunk = smoother(
        transition_matrix,
        filtered_probs[time_inds],
        initial=initial,
        ind=time_inds,
        n_time=n_time,
    )
    smoothed_probs.append(smoothed_probs_chunk)
    acausal_state_probabilities.append(convert_to_state_probability(smoothed_probs_chunk))


return (
    smoothed_probs,
    acausal_state_probabilities,
    causal_state_probabilities,
    predictive_state_probabilities,
)

```


In [63]:
def compute_log_likelihood(time, data, is_missing=None):
    return data[time]


def chunked_filter_smoother(
    time,
    data,
    state_ind,
    initial_distribution,
    transition_matrix,
    is_missing=None,
    n_chunks=1,
):
    causal_posterior = []
    predictive_state_probabilities = []
    causal_state_probabilities = []
    acausal_posterior = []
    acausal_state_probabilities = []
    marginal_likelihood = 0.0

    n_time = len(time)
    time_chunks = np.array_split(np.arange(n_time), n_chunks)

    n_states = len(np.unique(state_ind))
    state_mask = np.identity(n_states)[state_ind]  # shape (n_state_inds, n_states)

    for chunk_id, time_inds in enumerate(time_chunks):
        is_missing_chunk = is_missing[time_inds] if is_missing is not None else None
        log_likelihood_chunk = compute_log_likelihood(
            time[time_inds],
            data,
            is_missing=is_missing_chunk,
        )

        (marginal_likelihood_chunk, predicted_probs_next), (
            causal_posterior_chunk,
            predicted_probs_chunk,
        ) = filter(
            initial_distribution=(
                initial_distribution if chunk_id == 0 else predicted_probs_next
            ),
            transition_matrix=transition_matrix,
            log_likelihoods=log_likelihood_chunk,
        )

        causal_posterior_chunk = np.asarray(causal_posterior_chunk)
        causal_posterior.append(causal_posterior_chunk)
        causal_state_probabilities.append(causal_posterior_chunk @ state_mask)
        predictive_state_probabilities.append(predicted_probs_chunk @ state_mask)

        marginal_likelihood += marginal_likelihood_chunk

    causal_posterior = np.concatenate(causal_posterior)

    for chunk_id, time_inds in enumerate(reversed(time_chunks)):
        _, acausal_posterior_chunk = smoother(
            transition_matrix=transition_matrix,
            filtered_probs=causal_posterior[time_inds],
            initial=(
                causal_posterior[-1] if chunk_id == 0 else acausal_posterior_chunk[0]
            ),
            ind=time_inds,
            n_time=n_time,
        )
        acausal_posterior_chunk = np.asarray(acausal_posterior_chunk)
        acausal_posterior.append(acausal_posterior_chunk)
        acausal_state_probabilities.append(acausal_posterior_chunk @ state_mask)

    acausal_posterior = np.concatenate(acausal_posterior[::-1])
    acausal_state_probabilities = np.concatenate(acausal_state_probabilities[::-1])

    return (
        acausal_posterior,
        acausal_state_probabilities,
        causal_state_probabilities,
        predictive_state_probabilities,
        marginal_likelihood,
    )

In [65]:
(
    acausal_posterior,
    acausal_state_probabilities,
    causal_state_probabilities,
    predictive_state_probabilities,
    marginal_likelihood,
) = chunked_filter_smoother(
    time=np.arange(n_time),
    data=log_likelihood,
    state_ind=np.arange(n_states),
    initial_distribution=initial_distribution,
    transition_matrix=transition_matrix,
    n_chunks=1,
)

np.allclose(acausal_posterior, smoothed_probs), np.allclose(
    marginal_likelihood, log_normalizer
)

(True, True)