Skip to content

Commit

Permalink
Feat: Cylinder.best_fit (#326)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Andrew Hynes <andrewjhynes@gmail.com>
  • Loading branch information
3 people committed Mar 25, 2023
1 parent 9412103 commit fc9c7c9
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8, 3.9, "3.10"]
python-version: [3.8, 3.9, "3.10"]

steps:
- uses: actions/checkout@v2
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ include = ["tests"]


[tool.poetry.dependencies]
python = "^3.7"
python = ">=3.8,<3.12"

numpy = "^1.16"
numpy = "^1.17.3"
matplotlib = "^3"
scipy = "^1.8"

# To keep __version__ in sync with the version in this file.
importlib-metadata = { version = "~1", python = "<3.8" }
Expand Down
12 changes: 0 additions & 12 deletions src/skspatial/objects/_base_array.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Private base classes for arrays."""
import warnings
from typing import Type
from typing import TypeVar

Expand All @@ -22,17 +21,6 @@ class _BaseArray(np.ndarray, _BaseSpatial):

def __new__(cls: Type[Array], array: array_like) -> Array:

with warnings.catch_warnings():

warnings.filterwarnings("error")

try:
np.array(array)

except np.VisibleDeprecationWarning as error:
if str(error).startswith("Creating an ndarray from ragged nested sequences"):
raise ValueError("The array must not contain sequences with different lengths.")

if np.size(array) == 0:
raise ValueError("The array must not be empty.")

Expand Down
2 changes: 1 addition & 1 deletion src/skspatial/objects/circle.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def intersect_line(self, line: Line) -> Tuple[Point, Point]:
@classmethod
def best_fit(cls, points: array_like) -> Circle:
"""
Return the sphere of best fit for a set of 2D points.
Return the circle of best fit for a set of 2D points.
Parameters
----------
Expand Down
184 changes: 183 additions & 1 deletion src/skspatial/objects/cylinder.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
"""Module for the Cylinder class."""
from __future__ import annotations

import math
from dataclasses import dataclass
from typing import List
from typing import Optional
from typing import Tuple

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from scipy.optimize import minimize

from skspatial._functions import _solve_quadratic
from skspatial.objects._base_spatial import _BaseSpatial
from skspatial.objects._mixins import _ToPointsMixin
from skspatial.objects.line import Line
from skspatial.objects.plane import Plane
from skspatial.objects.point import Point
from skspatial.objects.points import Points
from skspatial.objects.vector import Vector
from skspatial.typing import array_like

Expand Down Expand Up @@ -476,6 +481,166 @@ def to_mesh(self, n_along_axis: int = 100, n_angles: int = 30) -> Tuple[np.ndarr

return X, Y, Z

@classmethod
def best_fit(cls, points: array_like) -> Cylinder:
"""
Return the cylinder of best fit for a set of 3D points.
The points are assumed to lie close to the cylinder surface. The algorithm is not guaranteed to produce a
meaningful solution with random points.
Parameters
----------
points : array_like
Input 3D points. At least six points must be provided.
Returns
-------
Cylinder
The cylinder of best fit.
Raises
------
ValueError
If the points are not 3D.
If there are fewer than six points.
If the points are coplanar.
References
----------
https://www.geometrictools.com/Documentation/LeastSquaresFitting.pdf
https://github.com/xingjiepan/cylinder_fitting
https://github.com/CristianoPizzamiglio/py-cylinder-fitting
Examples
--------
>>> from skspatial.objects import Cylinder
>>> points = [[0, 2, 0], [0, -2, 0], [0, 0, 2], [5, 2, 0], [5, -2, 0], [5, 0, 2]]
>>> cylinder = Cylinder.best_fit(points)
>>> cylinder.point.round()
Point([0., 0., 0.])
>>> cylinder.vector.round()
Vector([5., 0., 0.])
>>> cylinder.radius
2.0
"""

def _best_fit(points_centered: Points, centroid: Point) -> Tuple[Vector, Point, float, float]:
"""Return the cylinder of best fit for a set of 3D points."""
best_fit = minimize(
lambda x: _compute_g(_spherical_to_cartesian(_SphericalCoordinates(x[0], x[1])), points_centered),
x0=_compute_initial_direction(points_centered),
method="Powell",
)
direction = _spherical_to_cartesian(_SphericalCoordinates(best_fit.x[0], best_fit.x[1]))
center = _compute_center(direction, points_centered) + centroid
return direction, center, _compute_radius(direction, points_centered), best_fit.fun

def _compute_initial_direction(points: Points) -> np.ndarray:
"""Compute the initial direction as the best fit line."""
initial_direction = Line.best_fit(points).vector.unit()
spherical_coordinates = _cartesian_to_spherical(*initial_direction)
return np.array([spherical_coordinates.theta, spherical_coordinates.phi])

def _compute_projection_matrix(direction: Vector) -> np.ndarray:

return np.identity(3) - np.dot(np.reshape(direction, (3, 1)), np.reshape(direction, (1, 3)))

def _compute_skew_matrix(direction: Vector) -> np.ndarray:

return np.array(
[
[0.0, -direction[2], direction[1]],
[direction[2], 0.0, -direction[0]],
[-direction[1], direction[0], 0.0],
],
)

def _compute_a_matrix(input_samples: List[np.ndarray]) -> np.ndarray:

return sum(np.dot(np.reshape(sample, (3, 1)), np.reshape(sample, (1, 3))) for sample in input_samples)

def _compute_a_hat_matrix(a_matrix: np.ndarray, skew_matrix: np.ndarray) -> np.ndarray:

return np.dot(skew_matrix, np.dot(a_matrix, np.transpose(skew_matrix)))

def _compute_g(direction: Vector, points: Points) -> float:

projection_matrix = _compute_projection_matrix(direction)
skew_matrix = _compute_skew_matrix(direction)
input_samples = [np.dot(projection_matrix, x) for x in points]
a_matrix = _compute_a_matrix(input_samples)
a_hat_matrix = _compute_a_hat_matrix(a_matrix, skew_matrix)

u = sum(np.dot(sample, sample) for sample in input_samples) / len(points)
v = np.dot(a_hat_matrix, sum(np.dot(sample, sample) * sample for sample in input_samples)) / np.trace(
np.dot(a_hat_matrix, a_matrix),
)
return sum((np.dot(sample, sample) - u - 2 * np.dot(sample, v)) ** 2 for sample in input_samples)

def _compute_center(direction: Vector, points: Points) -> Point:

projection_matrix = _compute_projection_matrix(direction)
skew_matrix = _compute_skew_matrix(direction)
input_samples = [np.dot(projection_matrix, x) for x in points]
a_matrix = _compute_a_matrix(input_samples)
a_hat_matrix = _compute_a_hat_matrix(a_matrix, skew_matrix)

return np.dot(a_hat_matrix, sum(np.dot(sample, sample) * sample for sample in input_samples)) / np.trace(
np.dot(a_hat_matrix, a_matrix),
)

def _compute_radius(direction: Vector, points) -> float:

projection_matrix = _compute_projection_matrix(direction)
center = _compute_center(direction, points)
return np.sqrt(
sum(np.dot(center - point, np.dot(projection_matrix, center - point)) for point in points)
/ len(points),
)

def _cartesian_to_spherical(x: float, y: float, z: float) -> _SphericalCoordinates:
"""Convert cartesian to spherical coordinates."""
theta = np.arccos(z / np.sqrt(x**2 + y**2 + z**2))

if math.isclose(x, 0.0, abs_tol=1e-9) and math.isclose(y, 0.0, abs_tol=1e-9):
phi = 0.0
else:
phi = np.sign(y) * np.arccos(x / np.sqrt(x**2 + y**2))
return _SphericalCoordinates(theta, phi)

def _spherical_to_cartesian(spherical_coordinates: _SphericalCoordinates) -> Vector:
"""Convert spherical to cartesian coordinates."""
theta = spherical_coordinates.theta
phi = spherical_coordinates.phi
return Vector([np.cos(phi) * np.sin(theta), np.sin(phi) * np.sin(theta), np.cos(theta)])

points = Points(points)

if points.dimension != 3:
raise ValueError("The points must be 3D.")

if points.shape[0] < 6:
raise ValueError("There must be at least 6 points.")

if points.are_coplanar():
raise ValueError("The points must not be coplanar.")

points_centered, centroid = points.mean_center(return_centroid=True)
unit_vector, center, radius, _ = _best_fit(points_centered, centroid)
axis = Line(point=center, direction=unit_vector)
points_1d = axis.transform_points(points)
point_a = axis.project_point(points[np.argmin(points_1d)])
length = point_a.distance_point(center) * 2
vector_ab = unit_vector * length

return cls(point_a, vector_ab, radius)

def plot_3d(self, ax_3d: Axes3D, n_along_axis: int = 100, n_angles: int = 30, **kwargs) -> None:
"""
Plot a 3D cylinder.
Expand Down Expand Up @@ -515,6 +680,24 @@ def plot_3d(self, ax_3d: Axes3D, n_along_axis: int = 100, n_angles: int = 30, **
ax_3d.plot_surface(X, Y, Z, **kwargs)


@dataclass
class _SphericalCoordinates:
"""
Spherical coordinates.
Attributes
----------
theta : float
Inclination in radians.
phi : float
Azimuth in radians.
"""

theta: float
phi: float


def _between_cap_planes(cylinder: Cylinder, point: array_like) -> bool:
"""Check if a point lies between the cylinder cap planes."""
plane_base = Plane(cylinder.point, cylinder.vector)
Expand All @@ -528,7 +711,6 @@ def _intersect_line_with_infinite_cylinder(
line: Line,
n_digits: Optional[int],
) -> Tuple[Point, Point]:

p_c = cylinder.point
v_c = cylinder.vector.unit()
r = cylinder.radius
Expand Down
4 changes: 1 addition & 3 deletions tests/unit/objects/test_base_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
)
def test_failure_from_different_lengths(class_spatial, array):

message_expected = "The array must not contain sequences with different lengths."

with pytest.raises(ValueError, match=message_expected):
with pytest.raises(ValueError): # noqa: PT011
class_spatial(array)


Expand Down
33 changes: 33 additions & 0 deletions tests/unit/objects/test_cylinder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from math import isclose
from math import pi
from math import sqrt
Expand All @@ -8,6 +9,7 @@
from skspatial.objects import Line
from skspatial.objects import Point
from skspatial.objects import Points
from skspatial.objects import Vector

LINE_DOES_NOT_INTERSECT_CYLINDER = "The line does not intersect the cylinder."
LINE_MUST_BE_3D = "The line must be 3D."
Expand Down Expand Up @@ -321,3 +323,34 @@ def test_to_points(cylinder, n_along_axis, n_angles, points_expected):
points_unique = Points(array_rounded).unique()

assert points_unique.is_close(points_expected)


@pytest.mark.parametrize(
("points", "vector_expected", "radius_expected"),
[
([[2, 0, 0], [0, 2, 0], [0, -2, 0], [2, 0, 4], [0, 2, 4], [0, -2, 4]], Vector([0, 0, 4]), 2.0),
([[-2, 0, 1], [-2, 1, 0], [-2, -1, 0], [3, 0, 1], [3, 1, 0], [3, -1, 0]], Vector([5, 0, 0]), 1.0),
([[-3, 3, 0], [0, 3, 3], [0, 3, -3], [-3, -12, 0], [0, -12, 3], [0, -12, -3]], Vector([0, -15, 0]), 3.0),
],
)
def test_best_fit(points, vector_expected, radius_expected):

cylinder = Cylinder.best_fit(points)

assert isclose(cylinder.vector.norm(), vector_expected.norm())
assert cylinder.vector.is_parallel(vector_expected)
assert math.isclose(cylinder.radius, radius_expected)


@pytest.mark.parametrize(
("points", "message_expected"),
[
([[1, 0], [-1, 0], [0, 1]], "The points must be 3D."),
([[2, 0, 1], [-2, 0, -3]], "There must be at least 6 points."),
([[0, 0, 1], [1, 1, 1], [2, 1, 1], [3, 3, 1], [4, 4, 1], [5, 5, 1]], "The points must not be coplanar."),
],
)
def test_best_fit_failure(points, message_expected):

with pytest.raises(ValueError, match=message_expected):
Cylinder.best_fit(points)

0 comments on commit fc9c7c9

Please sign in to comment.