<a href="https://colab.research.google.com/github/ananthakrishnagopal/Computational-Geometry/blob/main/Line_Sweep_Algorithm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import unittest
import math
from dataclasses import dataclass
from enum import Enum
from sortedcontainers import SortedDict


'''

Reference:

Mark de Berg, Computational Geometry : Chapter 2

Algorithm to find all intersections amongst line segments

'''

@dataclass
class Point:
    """Represents a 2D point with x and y coordinates."""
    x: float
    y: float

    def __lt__(self, other):
        # Primary sort by y (descending), secondary by x (ascending)
        if not math.isclose(self.y, other.y):
            return self.y > other.y
        return self.x < other.x

    def __hash__(self):
        """Hash based on rounded coordinates to handle floating-point precision."""
        return hash((round(self.x, 10), round(self.y, 10)))

    def __eq__(self, other):
        """Check equality with another point, accounting for floating-point precision."""
        if not isinstance(other, Point):
            return NotImplemented
        return math.isclose(self.x, other.x, abs_tol=1e-10) and math.isclose(self.y, other.y, abs_tol=1e-10)

    def __repr__(self):
        """String representation of the point with 3 decimal places."""
        return f"Point({self.x:.3f}, {self.y:.3f})"


def ccw(A, B, C):
    """Determine if points A, B, C are in counter-clockwise order."""
    return (C.y - A.y) * (B.x - A.x) > (B.y - A.y) * (C.x - A.x)


@dataclass
class Segment:
    """Represents a line segment with two endpoints (u, v) and an ID."""
    u: Point
    v: Point
    id: int = 0

    def __post_init__(self):
        # Ensure u is the upper endpoint (higher y, or same y with smaller x)
        if self.v < self.u:
            self.u, self.v = self.v, self.u

    def __hash__(self):
        """Hash based on segment ID for uniqueness."""
        return hash(self.id)

    def __eq__(self, other):
        """Check equality with another segment based on ID."""
        if not isinstance(other, Segment):
            return NotImplemented
        return self.id == other.id

    def __repr__(self):
        """String representation of the segment with endpoints."""
        return f"Seg{self.id}({self.u} -> {self.v})"

    def x_at_y(self, y):
        """Calculate x-coordinate where segment intersects horizontal line at y."""
        if math.isclose(self.u.y, self.v.y):
            # Horizontal segment: return x-coordinate of endpoint
            return self.u.x

        if math.isclose(self.u.x, self.v.x):
            # Vertical segment: return constant x-coordinate
            return self.u.x

        # Linear interpolation for non-horizontal/vertical segments
        t = (y - self.u.y) / (self.v.y - self.u.y)
        return self.u.x + t * (self.v.x - self.u.x)

    def contains_point_interior(self, p):
        """Check if point p lies strictly inside the segment (excluding endpoints)."""
        if p == self.u or p == self.v:
            return False

        # Check if point lies on the line using cross product
        cross_product = (p.y - self.u.y) * (self.v.x - self.u.x) - (p.x - self.u.x) * (self.v.y - self.u.y)
        if not math.isclose(cross_product, 0, abs_tol=1e-10):
            return False

        # Check if point is between endpoints in x and y ranges
        min_x, max_x = min(self.u.x, self.v.x), max(self.u.x, self.v.x)
        min_y, max_y = min(self.u.y, self.v.y), max(self.u.y, self.v.y)

        return (min_x < p.x < max_x or math.isclose(min_x, max_x)) and \
               (min_y < p.y < max_y or math.isclose(min_y, max_y))

    def intersects(self, other):
        """Find intersection point with another segment, if it exists."""
        if self == other:
            return None

        # Check if segments share an endpoint (no intersection reported)
        if self.u == other.u or self.u == other.v or self.v == other.u or self.v == other.v:
            return None

        # Extract coordinates for intersection calculation
        x1, y1 = self.u.x, self.u.y
        x2, y2 = self.v.x, self.v.y
        x3, y3 = other.u.x, other.u.y
        x4, y4 = other.v.x, other.v.y

        # Calculate denominator for line intersection formula
        denom = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)

        if math.isclose(denom, 0, abs_tol=1e-10):
            return None  # Parallel or collinear

        # Calculate intersection parameters t and u
        t = ((x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4)) / denom
        u = -((x1 - x2) * (y1 - y3) - (y1 - y2) * (x1 - x3)) / denom

        # Check if intersection is strictly inside both segments
        if 0 < t < 1 and 0 < u < 1:
            px = x1 + t * (x2 - x1)
            py = y1 + t * (y2 - y1)
            return Point(px, py)

        return None


class EventType(Enum):
    """Enum for types of events in the Bentley-Ottmann algorithm."""
    START = 1
    END = 2
    INTERSECTION = 3


@dataclass
class Event:
    """Represents an event in the sweep line algorithm with a point, segments, and type."""
    point: Point
    segments: set
    event_type: EventType

    def __lt__(self, other):
        """Compare events based on their points."""
        return self.point < other.point

    def __repr__(self):
        """String representation of the event with point and segment IDs."""
        return f"{self.event_type.name}@{self.point} segs={[s.id for s in self.segments]}"


def find_intersections(segments):
    """Find all intersection points among segments using the Bentley-Ottmann algorithm."""

    # Assign IDs and normalize segments (ensure u is upper endpoint)
    for i, seg in enumerate(segments):
        seg.id = i
        seg.__post_init__()

    # Initialize event queue with start and end points of segments
    event_queue = SortedDict()

    for seg in segments:
        # Add start event
        if seg.u not in event_queue:
            event_queue[seg.u] = Event(seg.u, set(), EventType.START)
        event_queue[seg.u].segments.add(seg)

        # Add end event
        if seg.v not in event_queue:
            event_queue[seg.v] = Event(seg.v, set(), EventType.END)
        event_queue[seg.v].segments.add(seg)

    # Status structure: segments ordered by x-coordinate at current sweep line
    status = SortedDict()
    intersections = set()
    current_y = float('inf')

    def get_status_key(seg, y):
        """Generate sorting key for a segment based on x-coordinate at y."""
        x = seg.x_at_y(y)
        return (x, seg.id)

    def find_neighbors(seg_list):
        """Identify adjacent segment pairs in the status structure."""
        if len(seg_list) < 2:
            return []

        # Sort segments by x-coordinate at current y
        sorted_segs = sorted(seg_list, key=lambda s: s.x_at_y(current_y))
        pairs = []

        for i in range(len(sorted_segs) - 1):
            pairs.append((sorted_segs[i], sorted_segs[i + 1]))

        return pairs

    def add_intersection_event(seg1, seg2, sweep_y):
        """Add an intersection event if segments intersect below the sweep line."""
        intersection = seg1.intersects(seg2)
        if intersection is None:
            return

        # Add event only if intersection is below or at current sweep line
        if intersection.y < sweep_y or (math.isclose(intersection.y, sweep_y) and intersection.x > current_x):
            if intersection not in event_queue:
                event_queue[intersection] = Event(intersection, {seg1, seg2}, EventType.INTERSECTION)
            else:
                event_queue[intersection].segments.update({seg1, seg2})

    # Process events in the queue
    while event_queue:
        current_point, event = event_queue.popitem(0)
        current_y = current_point.y
        current_x = current_point.x

        print(f"\nProcessing {event}")
        print(f"Status before: {[f'Seg{s.id}' for s in status.values()]}")

        # Classify segments based on their role at this event
        starting_segs = set()
        ending_segs = set()
        interior_segs = set()

        for seg in event.segments:
            if seg.u == current_point:
                starting_segs.add(seg)
            elif seg.v == current_point:
                ending_segs.add(seg)
            elif seg.contains_point_interior(current_point):
                interior_segs.add(seg)

        # Mark intersection if multiple segments meet at this point
        all_segs = starting_segs | ending_segs | interior_segs
        if len(all_segs) > 1:
            intersections.add(current_point)
            print(f"Found intersection at {current_point}")

        # Remove segments that end or pass through this point
        to_remove = ending_segs | interior_segs
        keys_to_remove = [k for k, v in status.items() if v in to_remove]
        for key in keys_to_remove:
            status.pop(key)

        # Add segments that start or continue through this point
        to_add = starting_segs | interior_segs
        for seg in to_add:
            key = get_status_key(seg, current_y - 1e-9)  # Slightly below to avoid precision issues
            status[key] = seg

        print(f"Status after: {[f'Seg{s.id}' for s in status.values()]}")

        # Check for new intersections among neighboring segments
        if not to_add:
            # No segments added/continuing, check neighbors at insertion point
            all_segments = list(status.values())
            if len(all_segments) >= 2:
                x_pos = current_x
                left_seg = None
                right_seg = None

                # Find closest segments to the left and right of current x
                for seg in all_segments:
                    seg_x = seg.x_at_y(current_y - 1e-9)
                    if seg_x < x_pos:
                        if left_seg is None or seg_x > left_seg.x_at_y(current_y - 1e-9):
                            left_seg = seg
                    elif seg_x > x_pos:
                        if right_seg is None or seg_x < right_seg.x_at_y(current_y - 1e-9):
                            right_seg = seg

                if left_seg and right_seg:
                    add_intersection_event(left_seg, right_seg, current_y)
        else:
            # Check neighbors of newly added or continuing segments
            current_segments = list(status.values())
            if len(current_segments) >= 2:
                # Sort by x-coordinate at current y
                current_segments.sort(key=lambda s: s.x_at_y(current_y - 1e-9))

                # Identify indices of added/continued segments
                added_indices = []
                for i, seg in enumerate(current_segments):
                    if seg in to_add:
                        added_indices.append(i)

                if added_indices:
                    leftmost_idx = min(added_indices)
                    rightmost_idx = max(added_indices)

                    # Check intersection with left neighbor
                    if leftmost_idx > 0:
                        left_neighbor = current_segments[leftmost_idx - 1]
                        leftmost_added = current_segments[leftmost_idx]
                        add_intersection_event(left_neighbor, leftmost_added, current_y)

                    # Check intersection with right neighbor
                    if rightmost_idx < len(current_segments) - 1:
                        right_neighbor = current_segments[rightmost_idx + 1]
                        rightmost_added = current_segments[rightmost_idx]
                        add_intersection_event(rightmost_added, right_neighbor, current_y)

    return intersections


if __name__ == "__main__":
    # Test Case 1: Two segments forming an X shape
    print("=== Test 1: Simple X intersection ===")
    segments1 = [
        Segment(Point(0, 2), Point(2, 0), 0),  # Top-left to bottom-right
        Segment(Point(0, 0), Point(2, 2), 1)   # Bottom-left to top-right
    ]

    intersections1 = find_intersections(segments1)
    print(f"\nResult: {len(intersections1)} intersections found")
    for p in intersections1:
        print(f"  {p}")
    print("Expected: 1 intersection at (1.0, 1.0)")

    print("\n" + "=" * 50)
    # Test Case 2: Two parallel horizontal segments
    print("=== Test 2: No intersection ===")
    segments2 = [
        Segment(Point(0, 3), Point(1, 3), 0),  # Horizontal line at y=3
        Segment(Point(0, 1), Point(1, 1), 1)   # Horizontal line at y=1
    ]

    intersections2 = find_intersections(segments2)
    print(f"\nResult: {len(intersections2)} intersections found")
    for p in intersections2:
        print(f"  {p}")
    print("Expected: 0 intersections")

    print("\n" + "=" * 50)
    # Test Case 3: Three segments meeting at a single point
    print("=== Test 3: Three segments meeting ===")
    segments3 = [
        Segment(Point(0, 0), Point(2, 2), 0),   # Diagonal up-right
        Segment(Point(2, 0), Point(0, 2), 1),   # Diagonal up-left
        Segment(Point(1, 0), Point(1, 2), 2)    # Vertical through center
    ]

    intersections3 = find_intersections(segments3)
    print(f"\nResult: {len(intersections3)} intersections found")
    for p in sorted(intersections3, key=lambda pt: (pt.y, pt.x)):
        print(f"  {p}")
    print("Expected: 1 intersection at (1.0, 1.0)")

=== Test 1: Simple X intersection ===

Processing START@Point(0.000, 2.000) segs=[0]
Status before: []
Status after: ['Seg0']

Processing START@Point(2.000, 2.000) segs=[1]
Status before: ['Seg0']
Status after: ['Seg0', 'Seg1']

Processing INTERSECTION@Point(1.000, 1.000) segs=[0, 1]
Status before: ['Seg0', 'Seg1']
Found intersection at Point(1.000, 1.000)
Status after: ['Seg1', 'Seg0']

Processing END@Point(0.000, 0.000) segs=[1]
Status before: ['Seg1', 'Seg0']
Status after: ['Seg0']

Processing END@Point(2.000, 0.000) segs=[0]
Status before: ['Seg0']
Status after: []

Result: 1 intersections found
  Point(1.000, 1.000)
Expected: 1 intersection at (1.0, 1.0)

=== Test 2: No intersection ===

Processing START@Point(0.000, 3.000) segs=[0]
Status before: []
Status after: ['Seg0']

Processing END@Point(1.000, 3.000) segs=[0]
Status before: ['Seg0']
Status after: []

Processing START@Point(0.000, 1.000) segs=[1]
Status before: []
Status after: ['Seg1']

Processing END@Point(1.000, 1.000) s