<a href="https://colab.research.google.com/github/ananthakrishnagopal/Computational-Geometry/blob/main/Line_Sweep_Algorithm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Algorithm Structure:

1. Create events for all segment endpoints
2. Sort events by x-coordinate  
3. Maintain active segments in y-order
4. Process each event:
   - Add/remove segments from active set
   - Check adjacent segments for intersections
   - Add new intersection events as discovered

Data Structures:

- Event queue: Priority queue of x-coordinates (sweep positions)
- Active segments: Balanced BST ordered by y-coordinate at current x
- Result set: Store found intersection points

In [1]:
import math
import heapq
from dataclasses import dataclass
from typing import Optional, List, Set, Tuple
from enum import Enum
from sortedcontainers import SortedList

EPS = 1e-9


@dataclass
class Point:
    x: float
    y: float

    def __hash__(self):
        return hash((round(self.x, 9), round(self.y, 9)))

    def __eq__(self, other):
        return (
            isinstance(other, Point)
            and math.isclose(self.x, other.x, rel_tol=EPS)
            and math.isclose(self.y, other.y, rel_tol=EPS)
        )

    def __repr__(self):
        return f"Point({self.x:.1f}, {self.y:.1f})"


@dataclass
class Segment:
    start: Point
    end: Point
    id: int

    def __post_init__(self):
        if self.start.x > self.end.x:
            self.start, self.end = self.end, self.start

    def y_at_x(self, x: float) -> float:
        if math.isclose(self.start.x, self.end.x, rel_tol=EPS):
            return self.start.y
        t = (x - self.start.x) / (self.end.x - self.start.x)
        return self.start.y + t * (self.end.y - self.start.y)

    def intersects(self, other: "Segment") -> Optional[Point]:
        x1, y1, x2, y2 = self.start.x, self.start.y, self.end.x, self.end.y
        x3, y3, x4, y4 = other.start.x, other.start.y, other.end.x, other.end.y

        denom = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
        if abs(denom) < EPS:
            return None

        t = ((x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4)) / denom
        u = -((x1 - x2) * (y1 - y3) - (y1 - y2) * (x1 - x3)) / denom

        if 0 <= t <= 1 and 0 <= u <= 1:
            ix = x1 + t * (x2 - x1)
            iy = y1 + t * (y2 - y1)
            return Point(ix, iy)
        return None

    def __repr__(self):
        return f"Seg{self.id}({self.start} -> {self.end})"


class EventType(Enum):
    START = 0
    INTERSECTION = 1
    END = 2


@dataclass
class Event:
    x: float
    type: EventType
    segments: List[Segment]
    point: Optional[Point] = None

    def __lt__(self, other):
        if not math.isclose(self.x, other.x, rel_tol=EPS):
            return self.x < other.x
        return self.type.value < other.type.value


class EventQueue:
    def __init__(self):
        self.events = []

    def add_event(self, event: Event):
        heapq.heappush(self.events, event)

    def pop_next(self) -> Optional[Event]:
        return None if self.is_empty() else heapq.heappop(self.events)

    def is_empty(self) -> bool:
        return not self.events


class ActiveSegments:
    def __init__(self):
        self.current_x = 0
        self.segments = SortedList(key=self._sort_key)

    def _sort_key(self, seg: Segment):
        return (seg.y_at_x(self.current_x), seg.id)

    def update_x(self, x: float):
        if not math.isclose(self.current_x, x, rel_tol=EPS):
            old = list(self.segments)
            self.current_x = x
            self.segments.clear()
            self.segments.update(old)
        else:
            self.current_x = x

    def add(self, seg: Segment):
        self.segments.add(seg)

    def remove(self, seg: Segment):
        if seg in self.segments:
            self.segments.remove(seg)

    def get_neighbors(self, seg: Segment) -> Tuple[Optional[Segment], Optional[Segment]]:
        try:
            idx = self.segments.index(seg)
            above = self.segments[idx + 1] if idx + 1 < len(self.segments) else None
            below = self.segments[idx - 1] if idx > 0 else None
            return above, below
        except ValueError:
            return None, None

    def get_topmost_bottommost(self, segs: List[Segment]) -> Tuple[Optional[Segment], Optional[Segment]]:
        if not segs:
            return None, None
        sorted_segs = sorted(segs, key=lambda s: s.y_at_x(self.current_x))
        return sorted_segs[-1], sorted_segs[0]


class LineSweep:
    def __init__(self, segments: List[Segment]):
        self.segments = segments
        self.event_queue = EventQueue()
        self.active = ActiveSegments()
        self.intersections: Set[Point] = set()
        self.processed_pairs = set()

    def _initialize_events(self):
        for seg in self.segments:
            self.event_queue.add_event(Event(seg.start.x, EventType.START, [seg]))
            self.event_queue.add_event(Event(seg.end.x, EventType.END, [seg]))

    def _handle_start(self, event: Event):
        seg = event.segments[0]
        self.active.add(seg)
        above, below = self.active.get_neighbors(seg)
        if above:
            self._check_intersection(seg, above)
        if below:
            self._check_intersection(seg, below)

    def _handle_end(self, event: Event):
        seg = event.segments[0]
        above, below = self.active.get_neighbors(seg)
        self.active.remove(seg)
        if above and below:
            self._check_intersection(above, below)

    def _handle_intersection(self, event: Event):
        if event.point:
            self.intersections.add(event.point)

        for seg in event.segments:
            self.active.remove(seg)

        self.active.update_x(event.x + EPS)
        for seg in event.segments:
            self.active.add(seg)

        top, bottom = self.active.get_topmost_bottommost(event.segments)
        if top:
            above, _ = self.active.get_neighbors(top)
            if above:
                self._check_intersection(top, above)
        if bottom and bottom != top:
            _, below = self.active.get_neighbors(bottom)
            if below:
                self._check_intersection(bottom, below)

    def _check_intersection(self, s1: Segment, s2: Segment):
        pair = tuple(sorted([s1.id, s2.id]))
        if pair in self.processed_pairs:
            return
        inter = s1.intersects(s2)
        if inter and inter.x > self.active.current_x + EPS:
            self.event_queue.add_event(Event(inter.x, EventType.INTERSECTION, [s1, s2], inter))
            self.processed_pairs.add(pair)

    def _process_event(self, event: Event):
        self.active.update_x(event.x)
        if event.type == EventType.START:
            self._handle_start(event)
        elif event.type == EventType.END:
            self._handle_end(event)
        else:
            self._handle_intersection(event)

    def find_intersections(self) -> Set[Point]:
        self._initialize_events()
        while not self.event_queue.is_empty():
            self._process_event(self.event_queue.pop_next())
        return self.intersections


def parse_segments(text: str) -> List[Segment]:
    segs = []
    for i, line in enumerate(text.strip().splitlines(), 1):
        parts = line.split()
        if len(parts) != 4:
            continue
        try:
            x1, y1, x2, y2 = map(float, parts)
            segs.append(Segment(Point(x1, y1), Point(x2, y2), i))
        except ValueError:
            continue
    return segs


def brute_force(segments: List[Segment]) -> Set[Point]:
    result = set()
    n = len(segments)
    for i in range(n):
        for j in range(i + 1, n):
            inter = segments[i].intersects(segments[j])
            if inter:
                result.add(inter)
    return result


if __name__ == "__main__":
    text = """91 179 760 353
              874 890 648 114
              687 715 939 747
              703 692 2 675
              87 616 149 23
              463 450 878 233
              255 695 51 823
              580 716 271 427
              35 318 383 639
              439 750 850 558
              314 491 247 283
              701 107 364 127
              850 538 672 225
              897 914 172 214
              131 747 481 151
              634 896 233 68
              128 290 294 668
              215 444 206 148
              722 649 638 243
              832 799 409 469
              73 914 626 779
              308 601 828 75
              654 626 810 543
              780 139 932 188
              463 58 786 433"""

    segs = parse_segments(text)
    sweep = LineSweep(segs)
    inters = sweep.find_intersections()
    bf = brute_force(segs)

    print(f"Sweep: {len(inters)}, Brute force: {len(bf)}")
    assert len(inters) == len(bf)
    assert inters == bf

Sweep: 60, Brute force: 60
