# Priority Search Tree — Construction Animation

**Priority Search Tree** combines a **BST on x-coordinates** with a **min-heap on y-coordinates**:

1. Among all points in the current subset, extract the one with the **smallest y** → it becomes the node's stored point (heap property).
2. The **remaining** points are split by their **median x-coordinate** ($x_{med}$) into left ($x < x_{med}$) and right ($x \ge x_{med}$) subsets.
3. Recurse on each subset.

Result: root-to-leaf paths are sorted by y (min-heap), and left/right splits partition x-space (BST).

In [1]:
# Import libraries
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import warnings
warnings.filterwarnings('ignore')
np.random.seed(42)

In [2]:
# Generate random points (unique x-values for clean BST splits)
def generate_random_points(n, x_range=(5, 95), y_range=(5, 95)):
    xs = np.round(np.random.uniform(*x_range, size=n * 2), 1)
    xs = list(set(xs))[:n]  # ensure unique x
    ys = np.round(np.random.uniform(*y_range, size=len(xs)), 1)
    return list(zip(xs, ys))

In [3]:
# ── Priority Search Tree Data Structure ────────────────────────────

class PSTNode:
    def __init__(self, point, split_x=None):
        self.point = point        # the min-y point (heap winner)
        self.split_x = split_x    # median x used to partition remaining points
        self.left = None
        self.right = None

class PrioritySearchTree:
    def __init__(self, points):
        self.build_steps = []      # record construction steps for animation
        self.root = self._build(sorted(points, key=lambda p: p[0]),
                                depth=0, x_lo=0, x_hi=100)

    def _build(self, pts, depth, x_lo, x_hi):
        if not pts:
            return None

        # 1. Extract point with smallest y (heap property)
        min_idx = min(range(len(pts)), key=lambda i: pts[i][1])
        winner = pts[min_idx]
        remaining = pts[:min_idx] + pts[min_idx + 1:]

        # 2. Split remaining by median x
        split_x = None
        left_pts, right_pts = [], []
        if remaining:
            remaining_sorted = sorted(remaining, key=lambda p: p[0])
            mid = len(remaining_sorted) // 2
            split_x = remaining_sorted[mid][0]
            left_pts  = remaining_sorted[:mid]
            right_pts = remaining_sorted[mid:]

        node = PSTNode(winner, split_x)

        # Record this construction step
        self.build_steps.append(dict(
            winner=winner,
            remaining=list(remaining),
            split_x=split_x,
            depth=depth,
            x_lo=x_lo,
            x_hi=x_hi,
            left_pts=list(left_pts),
            right_pts=list(right_pts),
        ))

        # 3. Recurse
        if split_x is not None:
            node.left  = self._build(left_pts,  depth + 1, x_lo, split_x)
            node.right = self._build(right_pts, depth + 1, split_x, x_hi)

        return node

print("Priority Search Tree defined ✓")

Priority Search Tree defined ✓


In [4]:
# ── Helper: assign (x, y) positions for tree nodes for drawing ─────

def layout_tree(node, x=0.5, y=0, dx=0.25, depth=0):
    """Assign drawing positions to each node in the tree (top-down)."""
    positions = {}
    edges = []
    if node is None:
        return positions, edges
    positions[id(node)] = (x, -depth, node.point, node.split_x)
    if node.left:
        child_x = x - dx
        edges.append(((x, -depth), (child_x, -(depth + 1))))
        p, e = layout_tree(node.left, child_x, 0, dx / 2, depth + 1)
        positions.update(p)
        edges += e
    if node.right:
        child_x = x + dx
        edges.append(((x, -depth), (child_x, -(depth + 1))))
        p, e = layout_tree(node.right, child_x, 0, dx / 2, depth + 1)
        positions.update(p)
        edges += e
    return positions, edges

# Map each node to its insertion order so we can reveal progressively
def order_tree_nodes(node, order=None):
    """DFS pre-order to match build_steps order."""
    if order is None:
        order = []
    if node is None:
        return order
    order.append(id(node))
    order_tree_nodes(node.left, order)
    order_tree_nodes(node.right, order)
    return order

print("Tree layout helpers defined")

Tree layout helpers defined


In [5]:
# ── Animated PST Construction ──────────────────────────────────────

def animate_pst(points):
    pst = PrioritySearchTree(points)
    steps = pst.build_steps
    n_steps = len(steps)

    # Pre-compute tree layout
    positions, edges = layout_tree(pst.root)
    node_order = order_tree_nodes(pst.root)  # id list in DFS pre-order

    all_x = [p[0] for p in points]
    all_y = [p[1] for p in points]
    pad = 6
    x_lo_g, x_hi_g = min(all_x) - pad, max(all_x) + pad
    y_lo_g, y_hi_g = min(all_y) - pad, max(all_y) + pad

    # Colors per depth
    cmap = plt.cm.tab10
    max_depth = max(s['depth'] for s in steps)

    # ── Figure: left = 2D space, right = tree diagram ──
    fig, (ax_pts, ax_tree) = plt.subplots(1, 2, figsize=(16, 7),
                                           gridspec_kw={'width_ratios': [1, 1.1]})
    fig.subplots_adjust(wspace=0.3)

    # ── Static background on left panel ──
    ax_pts.scatter(all_x, all_y, c='lightgray', s=50, zorder=1,
                   edgecolors='gray', linewidths=0.5)
    for p in points:
        ax_pts.annotate(f'({p[0]},{p[1]})', xy=p, fontsize=6,
                        color='gray', ha='left', va='bottom')
    ax_pts.set_xlim(x_lo_g, x_hi_g)
    ax_pts.set_ylim(y_lo_g, y_hi_g)
    ax_pts.set_xlabel('x')
    ax_pts.set_ylabel('y')
    ax_pts.set_title('2D Point Space')
    ax_pts.grid(True, alpha=0.2)

    def draw_tree_up_to(ax, frame):
        """Draw the tree diagram showing only nodes revealed up to `frame`."""
        ax.cla()
        ax.set_title('Priority Search Tree', fontsize=11, fontweight='bold')
        ax.set_xlim(-0.05, 1.05)
        min_y_tree = -(max_depth + 1)
        ax.set_ylim(min_y_tree - 0.5, 0.8)
        ax.axis('off')

        revealed = set(node_order[:frame + 1])

        # Draw edges (only if both ends revealed)
        for (x1, y1), (x2, y2) in edges:
            src_ids = [nid for nid, (px, py, _, _) in positions.items()
                       if (px, py) == (x1, y1)]
            dst_ids = [nid for nid, (px, py, _, _) in positions.items()
                       if (px, py) == (x2, y2)]
            if (src_ids and dst_ids and
                    src_ids[0] in revealed and dst_ids[0] in revealed):
                ax.plot([x1, x2], [y1, y2], 'k-', linewidth=1, alpha=0.4)

        # Draw nodes
        for nid, (nx, ny, pt, sx) in positions.items():
            if nid not in revealed:
                continue
            is_current = (nid == node_order[frame])
            depth_of_node = int(-ny)
            color = cmap(depth_of_node / max(max_depth, 1))
            size = 18 if is_current else 14
            edge_w = 2.5 if is_current else 1
            ax.scatter(nx, ny, s=size**2, c=[color], edgecolors='black',
                       linewidths=edge_w, zorder=5)
            label = f'({pt[0]},{pt[1]})'
            if sx is not None:
                label += f'\nxm={sx}'
            ax.text(nx, ny + 0.3, label, fontsize=7, ha='center',
                    va='bottom',
                    fontweight='bold' if is_current else 'normal',
                    color=color)

    def update(frame):
        step = steps[frame]
        winner = step['winner']
        remaining = step['remaining']
        split_x = step['split_x']
        depth = step['depth']
        x_lo = step['x_lo']
        x_hi = step['x_hi']
        color = cmap(depth / max(max_depth, 1))

        # ── Left panel: highlight winner with star ──
        ax_pts.scatter([winner[0]], [winner[1]], c=[color], s=200, zorder=5,
                       marker='*', edgecolors='black', linewidths=1)
        ax_pts.annotate(f'min-y\nd={depth}', xy=winner, fontsize=7,
                        color=color, fontweight='bold',
                        xytext=(5, 8), textcoords='offset points')

        # Highlight remaining points in this partition
        if remaining:
            rx = [p[0] for p in remaining]
            ry = [p[1] for p in remaining]
            ax_pts.scatter(rx, ry, c=[color], s=60, zorder=3,
                           edgecolors='black', linewidths=0.6, alpha=0.6)

        # Draw the x_med split line
        if split_x is not None:
            ax_pts.plot([split_x, split_x], [y_lo_g, y_hi_g],
                        color=color, linewidth=1.5, linestyle='--', alpha=0.5)
            ax_pts.text(split_x, y_hi_g + 0.5, f'xm={split_x}',
                        fontsize=7, ha='center', color=color,
                        fontweight='bold')

        # Light shading for the x-partition region
        rect = patches.Rectangle((x_lo, y_lo_g), x_hi - x_lo,
                                  y_hi_g - y_lo_g,
                                  linewidth=0, facecolor=color, alpha=0.04)
        ax_pts.add_patch(rect)

        # ── Right panel: tree so far ──
        draw_tree_up_to(ax_tree, frame)

        fig.suptitle(
            f"PST Construction  step {frame+1}/{n_steps}  |  "
            f"depth={depth}  |  winner=({winner[0]}, {winner[1]})"
            + (f"  |  xm={split_x}" if split_x else "  |  leaf"),
            fontsize=12, fontweight='bold')

    anim = FuncAnimation(fig, update, frames=n_steps,
                         interval=1200, repeat=False)
    plt.close(fig)
    return HTML(anim.to_jshtml())

# ── Run ──
pts = generate_random_points(12)
print("Points:", pts)
animate_pst(pts)

Points: [(np.float64(6.9), np.float64(46.0)), (np.float64(10.2), np.float64(75.7)), (np.float64(17.6), np.float64(23.0)), (np.float64(19.0), np.float64(51.3)), (np.float64(21.4), np.float64(58.3)), (np.float64(21.5), np.float64(9.2)), (np.float64(24.1), np.float64(59.7)), (np.float64(31.2), np.float64(20.3)), (np.float64(32.4), np.float64(10.9)), (np.float64(31.3), np.float64(90.4)), (np.float64(38.7), np.float64(91.9)), (np.float64(38.0), np.float64(77.8))]
