Skip to content

Commit

Permalink
Merge pull request #561 from ego-thales/develop
Browse files Browse the repository at this point in the history
`FDataGrid.restrict` option `with_bounds`
  • Loading branch information
vnmabus committed Aug 17, 2023
2 parents 6a066f7 + e1dfe7d commit 02d26bc
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 20 deletions.
43 changes: 23 additions & 20 deletions skfda/representation/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,12 +1082,16 @@ def copy( # noqa: WPS211
def restrict(
self: T,
domain_range: DomainRangeLike,
*,
with_bounds: bool = False,
) -> T:
"""
Restrict the functions to a new domain range.
Args:
domain_range: New domain range.
with_bounds: Whether or not to ensure domain boundaries
appear in `grid_points`.
Returns:
Restricted function.
Expand All @@ -1101,30 +1105,29 @@ def restrict(
for ((a, b), (c, d)) in zip(domain_range, self.domain_range)
)

index_list = []
new_grid_points = []

# Eliminate points outside the new range.
for dr, grid_points in zip(
domain_range,
self.grid_points,
):
keep_index = (
(dr[0] <= grid_points)
& (grid_points <= dr[1])
)

index_list.append(keep_index)

new_grid_points.append(
grid_points[keep_index],
)

data_matrix = self.data_matrix[(slice(None),) + tuple(index_list)]
slice_list = []
for (a, b), dim_points in zip(domain_range, self.grid_points):
ia = np.searchsorted(dim_points, a)
ib = np.searchsorted(dim_points, b, 'right')
slice_list.append(slice(ia, ib))
grid_points = [g[s] for g, s in zip(self.grid_points, slice_list)]
data_matrix = self.data_matrix[(slice(None),) + tuple(slice_list)]

# Ensure that boundaries are in grid_points.
if with_bounds:
# Update `grid_points`
for dim, (a, b) in enumerate(domain_range):
dim_points = grid_points[dim]
left = [a] if a < dim_points[0] else []
right = [b] if b > dim_points[-1] else []
grid_points[dim] = np.concatenate((left, dim_points, right))
# Evaluate
data_matrix = self(grid_points, grid=True)

return self.copy(
domain_range=domain_range,
grid_points=new_grid_points,
grid_points=grid_points,
data_matrix=data_matrix,
)

Expand Down
18 changes: 18 additions & 0 deletions skfda/tests/test_grid.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test FDataGrid behaviour."""
import unittest
from typing import Sequence, Tuple

import numpy as np
import scipy.stats.mstats
Expand Down Expand Up @@ -337,6 +338,23 @@ def test_evaluate_grid_unaligned(self) -> None:

np.testing.assert_allclose(res, expected)

def test_restrict(self) -> None:
"""Test FDataGrid.restrict with bounds."""
# Test 1 sample function R^3 -> R^5.
grid_points = ([0, 1], [0, 1, 2], [0, 1, 2, 3])
data_matrix = np.ones((1, 2, 3, 4, 5))
fd = FDataGrid(data_matrix, grid_points)
restricted_domain = ((0, 1), (0.5, 1.5), (0.5, 2))
fd_restricted = fd.restrict(restricted_domain, with_bounds=True)
res = fd_restricted.grid_points
expected: Tuple[Sequence[float], ...] = (
[0, 1],
[0.5, 1, 1.5],
[0.5, 1, 2],
)
for r, e in zip(res, expected):
np.testing.assert_array_equal(r, e)


if __name__ == '__main__':
unittest.main()

0 comments on commit 02d26bc

Please sign in to comment.