Skip to content

Commit

Permalink
Merge pull request #252 from ajhynes7/add_cylinder_intersect_line
Browse files Browse the repository at this point in the history
Add `Cylinder.intersect_line`
  • Loading branch information
ajhynes7 committed Feb 26, 2021
2 parents b35482f + 88520ed commit 662a1cb
Show file tree
Hide file tree
Showing 6 changed files with 286 additions and 11 deletions.
1 change: 1 addition & 0 deletions docs/source/api_reference/skspatial.objects.Cylinder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Methods
:toctree: Cylinder/methods

~skspatial.objects.Cylinder.from_points
~skspatial.objects.Cylinder.intersect_line
~skspatial.objects.Cylinder.is_point_within
~skspatial.objects.Cylinder.length
~skspatial.objects.Cylinder.plot_3d
Expand Down
65 changes: 65 additions & 0 deletions src/skspatial/_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Private functions for some spatial computations."""
from __future__ import annotations

import math
from functools import wraps
from typing import Any
Expand Down Expand Up @@ -66,4 +68,67 @@ def wrapper(*args):
return wrapper


def _solve_quadratic(a: float, b: float, c: float, n_digits: int | None = None) -> np.ndarray:
"""
Solve a quadratic equation.
The equation has the form
.. math:: ax^2 + bx + c = 0
Parameters
----------
a, b, c : float
Coefficients of the quadratic equation.
n_digits : int, optional
Additional keyword passed to :func:`round` (default None).
Returns
-------
np.ndarray
Array containing the two solutions to the quadratic.
Raises
------
ValueError
If the discriminant is negative.
Examples
--------
>>> from skspatial._functions import _solve_quadratic
>>> _solve_quadratic(-1, 1, 1).round(3)
array([ 1.618, -0.618])
>>> _solve_quadratic(0, 1, 1)
Traceback (most recent call last):
...
ValueError: The coefficient `a` must be non-zero.
>>> _solve_quadratic(1, 1, 1)
Traceback (most recent call last):
...
ValueError: The discriminant must not be negative.
"""
if n_digits:
a = round(a, n_digits)
b = round(b, n_digits)
c = round(c, n_digits)

if a == 0:
raise ValueError("The coefficient `a` must be non-zero.")

discriminant = b ** 2 - 4 * a * c

if discriminant < 0:
raise ValueError("The discriminant must not be negative.")

pm = np.array([-1, 1]) # Array to compute minus/plus.

X = (-b + pm * math.sqrt(discriminant)) / (2 * a)

return X


_allclose = np.vectorize(math.isclose)
2 changes: 1 addition & 1 deletion src/skspatial/objects/circle.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def intersect_line(self, line: Line) -> Tuple[Point, Point]:

root = math.sqrt(discriminant)

pm = np.array([-1, 1]) # Array to compute plus/minus.
pm = np.array([-1, 1]) # Array to compute minus/plus.
sign = -1 if d_y < 0 else 1

coords_x = (determinant * d_y + pm * sign * d_x * root) / d_r_squared
Expand Down
95 changes: 95 additions & 0 deletions src/skspatial/objects/cylinder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Module for the Cylinder class."""
from __future__ import annotations

from typing import Tuple

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from skspatial._functions import _solve_quadratic
from skspatial.objects._base_spatial import _BaseSpatial
from skspatial.objects._mixins import _ToPointsMixin
from skspatial.objects.line import Line
Expand Down Expand Up @@ -252,6 +255,98 @@ def is_point_within(self, point: array_like) -> bool:

return within_radius and within_planes

def intersect_line(self, line: Line, n_digits: int | None = None) -> Tuple[Point, Point]:
"""
Intersect the cylinder with a 3D line.
This method treats the cylinder as infinite along its axis (i.e., without caps).
Parameters
----------
line : Line
Input 3D line.
n_digits : int, optional
Additional keywords passed to :func:`round`.
This is used to round the coefficients of the quadratic equation.
Returns
-------
point_a, point_b: Point
The two intersection points of the line
with the infinite cylinder, if they exist.
Raises
------
ValueError
If the line is not 3D,
or if it does not intersect the cylinder at one or two points.
References
----------
https://mrl.cs.nyu.edu/~dzorin/rendering/lectures/lecture3/lecture3.pdf
Examples
--------
>>> from skspatial.objects import Line, Cylinder
>>> cylinder = Cylinder([0, 0, 0], [0, 0, 1], 1)
>>> line = Line([0, 0, 0], [1, 0, 0])
>>> cylinder.intersect_line(line)
(Point([-1., 0., 0.]), Point([1., 0., 0.]))
>>> line = Line([1, 2, 3], [1, 2, 3])
>>> point_a, point_b = cylinder.intersect_line(line)
>>> point_a.round(3)
Point([-0.447, -0.894, -1.342])
>>> point_b.round(3)
Point([0.447, 0.894, 1.342])
>>> cylinder = Cylinder([0, 0, 0], [0, 0, 1], 1)
>>> cylinder.intersect_line(Line([0, 0], [1, 2]))
Traceback (most recent call last):
...
ValueError: The line must be 3D.
>>> cylinder.intersect_line(Line([0, 0, 2], [0, 0, 1]))
Traceback (most recent call last):
...
ValueError: The line does not intersect the cylinder.
>>> cylinder.intersect_line(Line([2, 0, 0], [0, 1, 1]))
Traceback (most recent call last):
...
ValueError: The line does not intersect the cylinder.
"""
if line.dimension != 3:
raise ValueError("The line must be 3D.")

p_c = self.point
v_c = self.vector.unit()
r = self.radius

p_l = line.point
v_l = line.vector.unit()

delta_p = Vector.from_points(p_c, p_l)

a = (v_l - v_l.dot(v_c) * v_c).norm() ** 2
b = 2 * (v_l - v_l.dot(v_c) * v_c).dot(delta_p - delta_p.dot(v_c) * v_c)
c = (delta_p - delta_p.dot(v_c) * v_c).norm() ** 2 - r ** 2

try:
X = _solve_quadratic(a, b, c, n_digits=n_digits)
except ValueError:
raise ValueError("The line does not intersect the cylinder.")

point_a, point_b = p_l + X.reshape(-1, 1) * v_l

return point_a, point_b

def to_mesh(self, n_along_axis: int = 100, n_angles: int = 30) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Return coordinate matrices for the 3D surface of the cylinder.
Expand Down
2 changes: 1 addition & 1 deletion src/skspatial/objects/sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def intersect_line(self, line: Line) -> Tuple[Point, Point]:
if discriminant < 0:
raise ValueError("The line does not intersect the sphere.")

pm = np.array([-1, 1]) # Array to compute plus/minus.
pm = np.array([-1, 1]) # Array to compute minus/plus.
distances = -dot + pm * math.sqrt(discriminant)

point_a, point_b = line.point + distances.reshape(-1, 1) * vector_unit
Expand Down
132 changes: 123 additions & 9 deletions tests/unit/test_intersection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import math
from math import sqrt

import numpy as np
import pytest
from skspatial.objects import Circle
from skspatial.objects import Cylinder
from skspatial.objects import Line
from skspatial.objects import Plane
from skspatial.objects import Point
from skspatial.objects import Sphere


Expand Down Expand Up @@ -72,9 +74,21 @@ def test_intersect_line_plane_failure(line, plane):
@pytest.mark.parametrize(
"plane_a, plane_b, line_expected",
[
(Plane([0, 0, 0], [0, 0, 1]), Plane([0, 0, 0], [1, 0, 0]), Line([0, 0, 0], [0, 1, 0])),
(Plane([0, 0, 0], [0, 0, 1]), Plane([0, 0, 1], [1, 0, 1]), Line([1, 0, 0], [0, 1, 0])),
(Plane([0, 0, 0], [-1, 1, 0]), Plane([8, 0, 0], [1, 1, 0]), Line([4, 4, 0], [0, 0, -1])),
(
Plane([0, 0, 0], [0, 0, 1]),
Plane([0, 0, 0], [1, 0, 0]),
Line([0, 0, 0], [0, 1, 0]),
),
(
Plane([0, 0, 0], [0, 0, 1]),
Plane([0, 0, 1], [1, 0, 1]),
Line([1, 0, 0], [0, 1, 0]),
),
(
Plane([0, 0, 0], [-1, 1, 0]),
Plane([8, 0, 0], [1, 1, 0]),
Line([4, 4, 0], [0, 0, -1]),
),
],
)
def test_intersect_planes(plane_a, plane_b, line_expected):
Expand Down Expand Up @@ -104,7 +118,12 @@ def test_intersect_planes_failure(plane_a, plane_b):
(Circle([0, 0], 1), Line([0, 0], [1, 0]), [-1, 0], [1, 0]),
(Circle([0, 0], 1), Line([0, 0], [0, 1]), [0, -1], [0, 1]),
(Circle([0, 0], 1), Line([0, 1], [1, 0]), [0, 1], [0, 1]),
(Circle([0, 0], 1), Line([0, 0.5], [1, 0]), [-math.sqrt(3) / 2, 0.5], [math.sqrt(3) / 2, 0.5]),
(
Circle([0, 0], 1),
Line([0, 0.5], [1, 0]),
[-sqrt(3) / 2, 0.5],
[sqrt(3) / 2, 0.5],
),
(Circle([1, 0], 1), Line([0, 0], [1, 0]), [0, 0], [2, 0]),
],
)
Expand Down Expand Up @@ -140,14 +159,14 @@ def test_intersect_circle_line_failure(circle, line):
(
Sphere([0, 0, 0], 1),
Line([0, 0, 0], [1, 1, 0]),
-math.sqrt(2) / 2 * np.array([1, 1, 0]),
math.sqrt(2) / 2 * np.array([1, 1, 0]),
-sqrt(2) / 2 * np.array([1, 1, 0]),
sqrt(2) / 2 * np.array([1, 1, 0]),
),
(
Sphere([0, 0, 0], 1),
Line([0, 0, 0], [1, 1, 1]),
-math.sqrt(3) / 3 * np.ones(3),
math.sqrt(3) / 3 * np.ones(3),
-sqrt(3) / 3 * np.ones(3),
sqrt(3) / 3 * np.ones(3),
),
(Sphere([1, 0, 0], 1), Line([0, 0, 0], [1, 0, 0]), [0, 0, 0], [2, 0, 0]),
(Sphere([0, 0, 0], 1), Line([1, 0, 0], [0, 0, 1]), [1, 0, 0], [1, 0, 0]),
Expand Down Expand Up @@ -175,3 +194,98 @@ def test_intersect_sphere_line_failure(sphere, line):

with pytest.raises(Exception):
sphere.intersect_line(line)


@pytest.mark.parametrize(
"cylinder, line, array_expected_a, array_expected_b",
[
(
Cylinder([0, 0, 0], [0, 0, 1], 1),
Line([0, 0, 0], [1, 0, 0]),
[-1, 0, 0],
[1, 0, 0],
),
(
Cylinder([0, 0, 0], [0, 0, 1], 1),
Line([0, 0, 0.5], [1, 0, 0]),
[-1, 0, 0.5],
[1, 0, 0.5],
),
(
Cylinder([0, 0, 0], [0, 0, 1], 2),
Line([0, 0, 0], [1, 0, 0]),
[-2, 0, 0],
[2, 0, 0],
),
(
Cylinder([0, 0, 0], [0, 0, 5], 1),
Line([0, 0, 0], [1, 0, 0]),
[-1, 0, 0],
[1, 0, 0],
),
(
Cylinder([0, 0, 0], [0, 0, 1], 1),
Line([0, 0, 0], [1, 1, 0]),
[-sqrt(2) / 2, -sqrt(2) / 2, 0],
[sqrt(2) / 2, sqrt(2) / 2, 0],
),
(
Cylinder([0, 0, 0], [0, 0, 1], 1),
Line([0, 0, 0], [1, 1, 1]),
3 * [-sqrt(2) / 2],
3 * [sqrt(2) / 2],
),
(
Cylinder([0, 0, 0], [0, 0, 1], 1),
Line([0, -1, 0], [1, 0, 0]),
[0, -1, 0],
[0, -1, 0],
),
(
Cylinder([0, 0, 0], [0, 0, 1], 1),
Line([0, 1, 0], [1, 0, 0]),
[0, 1, 0],
[0, 1, 0],
),
(
Cylinder([1, 0, 0], [0, 0, 1], 1),
Line([0, -1, 0], [1, 0, 0]),
[1, -1, 0],
[1, -1, 0],
),
],
)
def test_intersect_cylinder_line(cylinder, line, array_expected_a, array_expected_b):

point_a, point_b = cylinder.intersect_line(line, n_digits=9)

point_expected_a = Point(array_expected_a)
point_expected_b = Point(array_expected_b)

assert point_a.is_close(point_expected_a)
assert point_b.is_close(point_expected_b)


@pytest.mark.parametrize(
"cylinder, line",
[
(
Cylinder([0, 0, 0], [0, 0, 1], 1),
Line([0, -2, 0], [1, 0, 0]),
),
(
Cylinder([0, 0, 0], [0, 0, 1], 1),
Line([0, -2, 0], [1, 0, 1]),
),
(
Cylinder([3, 10, 4], [-1, 2, -3], 3),
Line([0, -2, 0], [1, 0, 1]),
),
],
)
def test_intersect_cylinder_line_failure(cylinder, line):

message_expected = "The line does not intersect the cylinder."

with pytest.raises(ValueError, match=message_expected):
cylinder.intersect_line(line)

0 comments on commit 662a1cb

Please sign in to comment.