Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
edeno committed Dec 12, 2017
1 parent 2c46f09 commit 55cc4bb
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 31 deletions.
5 changes: 3 additions & 2 deletions replay_classification/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def poisson_log_likelihood(is_spike, conditional_intensity=None,
'''
probability_no_spike = -conditional_intensity * time_bin_size
is_spike = atleast_kd(is_spike, conditional_intensity.ndim)
eps = np.spacing(1)
return (np.log(conditional_intensity + eps) * is_spike +
conditional_intensity[
np.isclose(conditional_intensity, 0.0)] = np.spacing(1)
return (np.log(conditional_intensity) * is_spike +
probability_no_spike)
10 changes: 5 additions & 5 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ def test_normalize_to_probability():


@mark.parametrize('data, exponent, expected', [
(np.arange(1, 9), 1, np.nanprod(np.arange(1, 9))),
(np.arange(1, 9), 2, np.nanprod(np.arange(1, 9) ** 2)), # test kwarg
(np.arange(1, 9).reshape(2, 4), 2, # test product along 1st dimension
np.nanprod(np.arange(1, 9).reshape(2, 4) ** 2, axis=0)),
(2, 2, 4), # test single data point
(np.arange(1, 9), 1, 1),
(np.arange(1, 9), 2, 1), # test kwarg
(np.array([[0.2, 0.4], [0.1, 0.2]]), 2,
np.array([np.exp(-0.15), 1])),
(2, 2, 1), # test single data point
])
def test_combined_likelihood(data, exponent, expected):
def likelihood_function(x, exponent=1):
Expand Down
37 changes: 13 additions & 24 deletions tests/test_sorted_spikes.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,23 @@
import numpy as np
from pytest import mark

from replay_classification.sorted_spikes import poisson_likelihood
from replay_classification.sorted_spikes import poisson_log_likelihood

CONDITIONAL_INTENSITY = np.atleast_2d(np.array([0.20, 0.50, 0.20, 0.25]))
SPIKE_LOG_LIKELIHOOD = np.log(CONDITIONAL_INTENSITY) - CONDITIONAL_INTENSITY


@mark.parametrize('is_spike, expected_likelihood', [
(np.zeros(3,), np.array([[5, 2, 5, 4],
[5, 2, 5, 4],
[5, 2, 5, 4],
])),
(np.array([0, 1, 0]), np.array([[5, 2, 5, 4],
[5 * np.log(0.2), 2 * np.log(0.5),
5 * np.log(0.2), 4 * np.log(0.25)],
[5, 2, 5, 4],
])),
(np.ones(3,), np.array([[5 * np.log(0.2), 2 * np.log(0.5),
5 * np.log(0.2), 4 * np.log(0.25)],
[5 * np.log(0.2), 2 * np.log(0.5),
5 * np.log(0.2), 4 * np.log(0.25)],
[5 * np.log(0.2), 2 * np.log(0.5),
5 * np.log(0.2), 4 * np.log(0.25)],
])),
(np.zeros(3,), -CONDITIONAL_INTENSITY),
(np.array([0, 1, 0]), np.concatenate(
(-CONDITIONAL_INTENSITY,
SPIKE_LOG_LIKELIHOOD,
-CONDITIONAL_INTENSITY))),
(np.ones(3,), SPIKE_LOG_LIKELIHOOD * np.ones((3, 1))),
])
def test_poisson_likelihood_is_spike(is_spike, expected_likelihood):
conditional_intensity = np.array(
[[np.log(0.2), np.log(0.5), np.log(0.2), np.log(0.25)],
[np.log(0.2), np.log(0.5), np.log(0.2), np.log(0.25)],
[np.log(0.2), np.log(0.5), np.log(0.2), np.log(0.25)]
])
likelihood = poisson_likelihood(
is_spike, conditional_intensity=conditional_intensity,
ci = CONDITIONAL_INTENSITY * np.ones((3, 1))
likelihood = poisson_log_likelihood(
is_spike, conditional_intensity=ci,
time_bin_size=1)
assert np.allclose(likelihood, expected_likelihood)

0 comments on commit 55cc4bb

Please sign in to comment.