Skip to content

Commit

Permalink
[ENH]: allow axis and weights parameters in vonmisesmle
Browse files Browse the repository at this point in the history
Add changes to PR 14533

Move PR changes to right folder

Use ``_length`` to compute the generalized sample length
- This change will allow weighted vonmisesmle and the use of the
  ``axis`` parameter in the correct way

Add test to use axis and weights parameters in `vonmisesmle`
  • Loading branch information
FelipeCybis committed Mar 14, 2023
1 parent 1b41891 commit ddff666
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 10 deletions.
34 changes: 24 additions & 10 deletions astropy/stats/circstats.py
Expand Up @@ -494,15 +494,24 @@ def vtest(data, mu=0.0, axis=None, weights=None):
def _A1inv(x):
# Approximation for _A1inv(x) according R Package 'CircStats'
# See http://www.scienceasia.org/2012.38.n1/scias38_118.pdf, equation (4)
if 0 <= x < 0.53:
return 2.0 * x + x * x * x + (5.0 * x**5) / 6.0
elif x < 0.85:
return -0.4 + 1.39 * x + 0.43 / (1.0 - x)
else:
return 1.0 / (x * x * x - 4.0 * x * x + 3.0 * x)


def vonmisesmle(data, axis=None):
kappa1 = np.where(

Check warning on line 498 in astropy/stats/circstats.py

View check run for this annotation

Codecov / codecov/patch

astropy/stats/circstats.py#L498

Added line #L498 was not covered by tests
np.logical_and(0 <= x, x < 0.53),
2.0 * x + x * x * x + (5.0 * x**5) / 6.0,
0)
kappa2 = np.where(

Check warning on line 502 in astropy/stats/circstats.py

View check run for this annotation

Codecov / codecov/patch

astropy/stats/circstats.py#L502

Added line #L502 was not covered by tests
np.logical_and(0.53 <= x, x < 0.85),
-0.4 + 1.39 * x + 0.43 / (1.0 - x),
0)
kappa3 = np.where(

Check warning on line 506 in astropy/stats/circstats.py

View check run for this annotation

Codecov / codecov/patch

astropy/stats/circstats.py#L506

Added line #L506 was not covered by tests
np.logical_or(x < 0, 0.85 < x),
1.0 / (x * x * x - 4.0 * x * x + 3.0 * x),
0)

return kappa1 + kappa2 + kappa3

Check warning on line 511 in astropy/stats/circstats.py

View check run for this annotation

Codecov / codecov/patch

astropy/stats/circstats.py#L511

Added line #L511 was not covered by tests


def vonmisesmle(data, axis=None, weights=None):
"""Computes the Maximum Likelihood Estimator (MLE) for the parameters of
the von Mises distribution.
Expand All @@ -513,6 +522,11 @@ def vonmisesmle(data, axis=None):
radians whenever ``data`` is ``numpy.ndarray``.
axis : int, optional
Axis along which the mle will be computed.
weights : numpy.ndarray, optional
In case of grouped data, the i-th element of ``weights`` represents a
weighting factor for each group such that ``sum(weights, axis)``
equals the number of observations. See [1]_, remark 1.4, page 22,
for detailed explanation.
Returns
-------
Expand All @@ -538,7 +552,7 @@ def vonmisesmle(data, axis=None):
Circular Statistics (2001)'". 2015.
<https://cran.r-project.org/web/packages/CircStats/CircStats.pdf>
"""
mu = circmean(data, axis=None)
mu = circmean(data, axis=axis, weights=weights)

Check warning on line 555 in astropy/stats/circstats.py

View check run for this annotation

Codecov / codecov/patch

astropy/stats/circstats.py#L555

Added line #L555 was not covered by tests

kappa = _A1inv(np.mean(np.cos(data - mu), axis))
kappa = _A1inv(_length(data, p=1, phi=0., axis=axis, weights=weights))

Check warning on line 557 in astropy/stats/circstats.py

View check run for this annotation

Codecov / codecov/patch

astropy/stats/circstats.py#L557

Added line #L557 was not covered by tests
return mu, kappa
36 changes: 36 additions & 0 deletions astropy/stats/tests/test_circstats.py
Expand Up @@ -146,3 +146,39 @@ def test_vonmisesmle():
data = np.rad2deg(data) * u.deg
answer = np.rad2deg(3.006514) * u.deg
assert_equal(np.around(answer, 3), np.around(vonmisesmle(data)[0], 3))

# testing for weighted vonmisesmle
data = np.array(
[
np.pi/2, np.pi, np.pi/2,
]
) # this data has twice more np.pi/2 than np.pi
# get answer using astropy vonmisesmle to test
answer = vonmisesmle(data)
data_to_weigh = np.array(
[np.pi/2, np.pi]
)
weights = [2, 1]
assert_allclose(answer[0], vonmisesmle(data_to_weigh, weights=weights)[0], atol=1e-5)
assert_allclose(answer[1], vonmisesmle(data_to_weigh, weights=weights)[1], atol=1e-5)

# testing for axis argument (stacking the data from the first test)
data = np.array(
[
[3.3699057, 4.0411630, 0.5014477, 2.6223103, 3.7336524,
1.8136389, 4.1566039, 2.7806317, 2.4672173, 2.8493644],
[3.3699057, 4.0411630, 0.5014477, 2.6223103, 3.7336524,
1.8136389, 4.1566039, 2.7806317, 2.4672173, 2.8493644]
]
)
# answer should be duplicated
answer = (np.array([3.006514, 3.006514]), np.array([1.474132, 1.474132]))
assert_allclose(answer[0], vonmisesmle(data, axis=1)[0], atol=1e-5)
assert_allclose(answer[1], vonmisesmle(data, axis=1)[1], atol=1e-5)

# same test for Quantity
data = np.rad2deg(data) * u.deg
answer = (np.rad2deg(answer[0]) * u.deg, answer[1])
assert_allclose(answer[0], vonmisesmle(data, axis=1)[0], atol=1e-5)
assert_allclose(answer[1], vonmisesmle(data, axis=1)[1], atol=1e-5)

3 changes: 3 additions & 0 deletions docs/changes/stats/14533.feature.rst
@@ -0,0 +1,3 @@
This pull request is to address the lack of "weights" and "axis" parameters on the circstats function ``vonmisesmle``.
The "axis" parameter that existed but was not actually being used allows for fast usage of 2D numpy arrays.
The "weights" parameter is very convenient since it is very common to have data for binned angles.

0 comments on commit ddff666

Please sign in to comment.