# Routing Engine Prototype

Python implementation of CH routing algorithms for testing and comparison.

## Algorithms
1. **Classic**: Bidirectional Dijkstra with `inside` filtering
2. **Pruned**: + H3 `parent_check` on popped nodes
3. **Many-to-Many**: Multi-source/target for KNN

In [1]:
import pyarrow.parquet as pq
import pandas as pd
import numpy as np
#import h3.api.basic_int as h3  # Use integer API for H3
import h3
from collections import defaultdict
from heapq import heappush, heappop
from dataclasses import dataclass
from typing import Optional
import time
from dijkstra_general import dijkstra_general


## Configuration

In [2]:
# Data paths
SHORTCUTS_PATH = "../../road-to-shortcut-duckdb/output/Somerset_shortcuts"
EDGES_PATH = "../../osm-to-road/data/output/Somerset/Somerset_driving_simplified_edges_with_h3.csv"

## Data Structures

In [3]:
import numpy as np

@dataclass
class QueryResult:
    distance: float
    path: list
    reachable: bool

@dataclass
class HighCell:
    cell: np.int64
    res: np.int8

@dataclass
class Shortcut:
    from_edge: int
    to_edge: int
    cost: float
    via_edge: int
    cell: np.int64     # 64-bit for H3 cell ID
    inside: np.int8
    cell_res: np.int8      # Small value: -1 to 15

## Load Data

In [4]:
# Load shortcuts with explicit int64 for cell column
shortcuts_df = pq.read_table(SHORTCUTS_PATH).to_pandas()
# CRITICAL: Convert cell to int64 immediately to avoid float precision loss
shortcuts_df['cell'] = pd.to_numeric(shortcuts_df['cell'], errors='coerce').fillna(0).astype('int64')
# After: shortcuts_df['cell'] = pd.to_numeric(...)
# Add:
def get_res(c):
    if c == 0:
        return -1
    return h3.get_resolution(h3.int_to_str(c))
shortcuts_df['cell_res'] = shortcuts_df['cell'].apply(get_res)

print(f"Loaded {len(shortcuts_df):,} shortcuts")
print(f"Cell dtype: {shortcuts_df['cell'].dtype}")
print(shortcuts_df.head())

Loaded 481,812 shortcuts
Cell dtype: int64
   from_edge  to_edge       cost  via_edge  inside                cell  \
0       3566      410  23.910975      6100       1                   0   
1       3566     2780  20.249545         0      -2                   0   
2       3566     5217  11.842502      3569       1                   0   
3       3569      727   9.708133       440       1  608661712568057855   
4       3572     5523   9.837917      3586       0  613165312189136895   

   cell_res  
0        -1  
1        -1  
2        -1  
3         7  
4         8  


In [5]:
# Load edge metadata
edges_df = pd.read_csv(EDGES_PATH)
print(f"Loaded {len(edges_df):,} edges")
print(edges_df.head())

Loaded 6,378 edges
   edge_index   length  maxspeed  \
0           0   80.287      60.0   
1           1  102.097      60.0   
2           2  323.104      50.0   
3           3  111.432      60.0   
4           4   84.317      30.0   

                                            geometry      highway       cost  \
0  LINESTRING (-84.60797882080078 37.091793060302...    secondary   4.817220   
1  LINESTRING (-84.60797882080078 37.091793060302...    secondary   6.125820   
2  LINESTRING (-84.60797882080078 37.091793060302...     tertiary  23.263488   
3  LINESTRING (-84.60697937011719 37.091552734375...    secondary   6.685920   
4  LINESTRING (-84.60697937011719 37.091552734375...  residential  10.118040   

              to_cell           from_cell  lca_res  
0  645224977384028320  645224977383611141        8  
1  645224977383614665  645224977383611141       10  
2  645224977384658840  645224977383611141        8  
3  645224977383653531  645224977383614665        9  
4  645224977383600

In [6]:
# Build edge metadata lookup
edge_meta = {}
for _, row in edges_df.iterrows():
    edge_meta[row['edge_index']] = {
        'to_cell': int(row['to_cell']),
        'from_cell': int(row['from_cell']),
        'lca_res': int(row['lca_res']),
        'length': float(row['length']),
        'cost': float(row['cost'])
    }
print(f"Built metadata for {len(edge_meta):,} edges")

Built metadata for 6,378 edges


In [7]:
# Build adjacency lists
fwd_adj = defaultdict(list)
bwd_adj = defaultdict(list)
# Convert cell column to int64 first to avoid float precision loss
# Build adjacency lists from shortcuts
fwd_adj = defaultdict(list)
bwd_adj = defaultdict(list)
for row in shortcuts_df.itertuples():
    sc = Shortcut(
        from_edge=row.from_edge,
        to_edge=row.to_edge,
        cost=row.cost,
        via_edge=row.via_edge,
        inside=row.inside,
        cell=row.cell,
        cell_res=row.cell_res  # Add this line
    )
    fwd_adj[sc.from_edge].append(sc)
    bwd_adj[sc.to_edge].append(sc)
print(f"Built adjacency: {len(fwd_adj)} forward, {len(bwd_adj)} backward")

Built adjacency: 6378 forward, 6378 backward


## H3 Utility Functions

In [8]:
import h3  # Use the standard h3 import (not h3.api.basic_int)
def safe_cell_to_parent(cell: int, target_res: int) -> int:
    """Get parent cell at target resolution, handling edge cases."""
    if cell == 0 or target_res < 0:
        return 0
    cell_str = h3.int_to_str(cell)
    cell_res = h3.get_resolution(cell_str)
    if target_res > cell_res:
        return cell
    return h3.str_to_int(h3.cell_to_parent(cell_str, target_res))
    
def find_lca(cell1: int, cell2: int) -> int:
    """Find lowest common ancestor of two H3 cells."""
    if cell1 == 0 or cell2 == 0:
        return 0
    cell1_str = h3.int_to_str(cell1)
    cell2_str = h3.int_to_str(cell2)
    min_res = min(h3.get_resolution(cell1_str), h3.get_resolution(cell2_str))
    for res in range(min_res, -1, -1):
        if h3.cell_to_parent(cell1_str, res) == h3.cell_to_parent(cell2_str, res):
            return h3.str_to_int(h3.cell_to_parent(cell1_str, res))
    return 0
    
def parent_check(node_cell: int, high_cell: int, high_res: int) -> bool:
    """Check if node is within high_cell region."""
    if high_cell == 0 or high_res < 0:
        return True
    if node_cell == 0:
        return False
    
    try:
        node_str = h3.int_to_str(node_cell)
        node_res = h3.get_resolution(node_str)
        if high_res > node_res:
            return False
        parent = h3.str_to_int(h3.cell_to_parent(node_str, high_res))
        return parent == high_cell
    except:
        print(f"Error node_cell = {node_cell}, high_cell = {high_cell}, high_res = {high_res}")
        return False

def compute_high_cell(source_edge: int, target_edge: int) -> HighCell:
    """Compute highest common H3 cell for source/target."""
    src_meta = edge_meta[source_edge]
    dst_meta = edge_meta[target_edge]
    
    src_cell = src_meta['to_cell']
    dst_cell = dst_meta['to_cell']
    src_res = src_meta['lca_res']
    dst_res = dst_meta['lca_res']

    
    # Get cells at their LCA resolutions
    src_cell = safe_cell_to_parent(src_cell, src_res)
    dst_cell = safe_cell_to_parent(dst_cell, dst_res)

    if src_cell == 0 or dst_cell == 0:
        return HighCell(0, -1)
        
    lca = find_lca(src_cell, dst_cell)
    if lca != 0:
        lca_str = h3.int_to_str(lca)
        res = h3.get_resolution(lca_str)
    else:
        res = -1
    return HighCell(lca, res)

    
def get_edge_cost(edge_id: int) -> float:
    """Get traversal cost for an edge."""
    return edge_meta[edge_id]['cost']

## Algorithm 1: Classic Bidirectional Dijkstra

In [9]:
def query_classic(source_edge: int, target_edge: int) -> QueryResult:
    """Classic bidirectional Dijkstra with inside filtering only."""
    if source_edge == target_edge:
        return QueryResult(get_edge_cost(source_edge), [source_edge], True)
    
    inf = float('inf')
    dist_fwd = {source_edge: 0.0}
    dist_bwd = {target_edge: get_edge_cost(target_edge)}
    parent_fwd = {source_edge: source_edge}
    parent_bwd = {target_edge: target_edge}
    
    pq_fwd = [(0.0, source_edge)]
    pq_bwd = [(dist_bwd[target_edge], target_edge)]
    
    best = inf
    meeting = None
    
    while pq_fwd or pq_bwd:
        # Forward step
        if pq_fwd:
            d, u = heappop(pq_fwd)
            if d > dist_fwd.get(u, inf):
                pass  # stale
            elif d < best:
                for sc in fwd_adj.get(u, []):
                    if sc.inside != 1:
                        continue
                    v = sc.to_edge
                    nd = d + sc.cost
                    if nd < dist_fwd.get(v, inf):
                        dist_fwd[v] = nd
                        parent_fwd[v] = u
                        heappush(pq_fwd, (nd, v))
                        if v in dist_bwd:
                            total = nd + dist_bwd[v]
                            if total < best:
                                best = total
                                meeting = v
        
        # Backward step
        if pq_bwd:
            d, u = heappop(pq_bwd)
            if d > dist_bwd.get(u, inf):
                pass  # stale
            elif d < best:
                for sc in bwd_adj.get(u, []):
                    if sc.inside not in (-1, 0):
                        continue
                    prev = sc.from_edge
                    nd = d + sc.cost
                    if nd < dist_bwd.get(prev, inf):
                        dist_bwd[prev] = nd
                        parent_bwd[prev] = u
                        heappush(pq_bwd, (nd, prev))
                        if prev in dist_fwd:
                            total = dist_fwd[prev] + nd
                            if total < best:
                                best = total
                                meeting = prev
        
        # Early termination
        if pq_fwd and pq_bwd:
            if pq_fwd[0][0] >= best and pq_bwd[0][0] >= best:
                break
        elif not pq_fwd and not pq_bwd:
            break
    
    if meeting is None or best == inf:
        return QueryResult(-1, [], False)
    
    # Reconstruct path
    path = []
    curr = meeting
    while curr != source_edge:
        path.append(curr)
        curr = parent_fwd[curr]
    path.append(source_edge)
    path.reverse()
    
    curr = meeting
    while curr != target_edge:
        curr = parent_bwd[curr]
        path.append(curr)
    
    return QueryResult(best, path, True)

## Algorithm 2: Pruned Bidirectional Dijkstra

In [10]:
def query_pruned(source_edge: int, target_edge: int) -> QueryResult:
    """Pruned bidirectional Dijkstra with H3 parent_check."""
    if source_edge == target_edge:
        return QueryResult(get_edge_cost(source_edge), [source_edge], True)
    
    high = compute_high_cell(source_edge, target_edge)
    
    inf = float('inf')
    dist_fwd = {source_edge: 0.0}
    dist_bwd = {target_edge: get_edge_cost(target_edge)}
    parent_fwd = {source_edge: source_edge}
    parent_bwd = {target_edge: target_edge}
    
   # Calculate initial edge cells using edge's lca_res
    src_meta = edge_meta[source_edge]
    src_cell = safe_cell_to_parent(src_meta['to_cell'], src_meta['lca_res'])
    tgt_meta = edge_meta[target_edge]
    tgt_cell = safe_cell_to_parent(tgt_meta['to_cell'], tgt_meta['lca_res']) 
    
    # Heap entries: (distance, edge_id, cell)
    pq_fwd = [(0.0, source_edge, src_cell)]
    pq_bwd = [(dist_bwd[target_edge], target_edge, tgt_cell)]
    
    best = inf
    meeting = None
    min_arrival_fwd = inf
    min_arrival_bwd = inf
      
    while pq_fwd or pq_bwd:
        # Forward step
        if pq_fwd:
            d, u, u_cell = heappop(pq_fwd)      
            if u in dist_bwd:
                        min_arrival_fwd = min(dist_fwd[u], min_arrival_fwd)
                        min_arrival_bwd = min(dist_bwd[u], min_arrival_bwd)
                        total = d + dist_bwd[u]
                        if total <= best:
                            best = total
                            meeting = u
                            
            if d >= best:
                continue
            if d > dist_fwd.get(u, inf):
                continue  # stale
            
            # PRUNING: Check popped node's cell
            if not parent_check(u_cell, high.cell, high.res):
                min_arrival_fwd = min(dist_fwd[u], min_arrival_fwd)
                continue

            if u_cell==high.cell:
                min_arrival_fwd = min(dist_fwd[u], min_arrival_fwd)

            for sc in fwd_adj.get(u, []):
                if sc.inside != 1:
                    continue
                v = sc.to_edge
                nd = d + sc.cost
                if nd < dist_fwd.get(v, inf):
                    dist_fwd[v] = nd
                    parent_fwd[v] = u
                    heappush(pq_fwd, (nd, v, sc.cell))  # Use sc.cell directly
                    
        
        # Backward step
        if pq_bwd:
            d, u, u_cell = heappop(pq_bwd)
            if u in dist_fwd:
                min_arrival_fwd = min(dist_fwd[u], min_arrival_fwd)
                min_arrival_bwd = min(dist_bwd[u], min_arrival_bwd)
                total = dist_fwd[u] + d
                if total < best:
                    best = total
                    meeting = u
                    
            if d > dist_bwd.get(u, inf):
                continue  # stale
            if d >= best:
                continue
           
            
            # PRUNING: Check popped node's cell
            check = parent_check(u_cell, high.cell, high.res)
            
            if u_cell==high.cell or (not check):
                min_arrival_bwd = min(dist_bwd[u], min_arrival_bwd)
            
            # Check if at high_cell for lateral edges
            at_high_cell = (u_cell == high.cell)
            
            for sc in bwd_adj.get(u, []):
                if sc.inside == -1 and check :
                    pass
                elif sc.inside == 0 and (at_high_cell or (not check)):
                    pass
                elif sc.inside == -2 and (not check):
                    pass
                else:
                    continue
                
                prev = sc.from_edge
                nd = d + sc.cost
                if nd < dist_bwd.get(prev, inf):
                    dist_bwd[prev] = nd
                    parent_bwd[prev] = u
                    heappush(pq_bwd, (nd, prev, sc.cell))  # Use sc.cell directly
                   
        
        # Early termination
        
        # Check if both directions can improve
        if best < inf:
            bound_fwd = min_arrival_fwd
            bound_bwd = min_arrival_bwd
            if pq_fwd:
                bound_fwd = min(bound_fwd, pq_fwd[0][0])
            if pq_bwd:
                bound_bwd = min(bound_bwd, pq_bwd[0][0])
                
            fwd_good = pq_fwd and (pq_fwd[0][0] + bound_bwd < best)
            bwd_good = pq_bwd and (pq_bwd[0][0] + bound_fwd < best)
            if not fwd_good and not bwd_good:
                break

    
    if meeting is None or best == inf:
        return QueryResult(-1, [], False)
    
    # Reconstruct path
    path = []
    curr = meeting
    while curr != source_edge:
        path.append(curr)
        curr = parent_fwd[curr]
    path.append(source_edge)
    path.reverse()
    
    curr = meeting
    while curr != target_edge:
        curr = parent_bwd[curr]
        path.append(curr)
    
    return QueryResult(best, path, True)

# Algorthm mant-to-many

In [11]:
def query_multi(source_edges: list, target_edges: list,
                source_dists: list, target_dists: list) -> QueryResult:
    """Many-to-many bidirectional Dijkstra for KNN routing."""
    inf = float('inf')
    
    # Initialize forward from all sources
    dist_fwd, parent_fwd, pq_fwd = {}, {}, []
    for src, d in zip(source_edges, source_dists):
        if src in edge_meta:
            dist_fwd[src] = d
            parent_fwd[src] = src
            heappush(pq_fwd, (d, src))
    
    # Initialize backward from all targets
    dist_bwd, parent_bwd, pq_bwd = {}, {}, []
    for tgt, d in zip(target_edges, target_dists):
        if tgt in edge_meta:
            init_dist = get_edge_cost(tgt) + d
            dist_bwd[tgt] = init_dist
            parent_bwd[tgt] = tgt
            heappush(pq_bwd, (init_dist, tgt))
    
    best, meeting = inf, None
    
    while pq_fwd or pq_bwd:
        if pq_fwd:
            d, u = heappop(pq_fwd)
            if u in dist_bwd and d + dist_bwd[u] < best:
                best, meeting = d + dist_bwd[u], u
            if d >= best or d > dist_fwd.get(u, inf): continue
            for sc in fwd_adj.get(u, []):
                if sc.inside == 1:
                    nd = d + sc.cost
                    if nd < dist_fwd.get(sc.to_edge, inf):
                        dist_fwd[sc.to_edge], parent_fwd[sc.to_edge] = nd, u
                        heappush(pq_fwd, (nd, sc.to_edge))
        
        if pq_bwd:
            d, u = heappop(pq_bwd)
            if u in dist_fwd and dist_fwd[u] + d < best:
                best, meeting = dist_fwd[u] + d, u
            if d >= best or d > dist_bwd.get(u, inf): continue
            for sc in bwd_adj.get(u, []):
                if sc.inside in (-1, 0):
                    nd = d + sc.cost
                    if nd < dist_bwd.get(sc.from_edge, inf):
                        dist_bwd[sc.from_edge], parent_bwd[sc.from_edge] = nd, u
                        heappush(pq_bwd, (nd, sc.from_edge))
        
        if best < inf:
            if pq_fwd and pq_fwd[0][0] >= best: pq_fwd = []
            if pq_bwd and pq_bwd[0][0] >= best: pq_bwd = []
    
    if not meeting: return QueryResult(-1, [], False)
    
    path = []
    curr = meeting
    while parent_fwd[curr] != curr:
        path.append(curr); curr = parent_fwd[curr]
    path.append(curr); path.reverse()
    curr = meeting
    while parent_bwd[curr] != curr:
        curr = parent_bwd[curr]; path.append(curr)
    return QueryResult(best, path, True)

## Test & Compare Algorithms

In [12]:
# Get sample edges for testing
all_edges = list(edge_meta.keys())
print(f"Total edges: {len(all_edges)}")

# Pick a random sample
np.random.seed(42)
sample_edges = np.random.choice(all_edges, size=min(1000, len(all_edges)), replace=False)
print(f"Sample edges: {sample_edges[:5]}")

Total edges: 6378
Sample edges: [2718 4275 4141 1199 4546]


In [13]:
def compare_algorithms(source: int, target: int):
    """Compare classic vs pruned for a single query."""
    t0 = time.perf_counter()
    r_classic = query_classic(source, target)
    t_classic = (time.perf_counter() - t0) * 1000
    
    t0 = time.perf_counter()
    r_pruned = query_pruned(source, target)
    t_pruned = (time.perf_counter() - t0) * 1000
    
    match = abs(r_classic.distance - r_pruned.distance) < 0.001 if r_classic.reachable and r_pruned.reachable else r_classic.reachable == r_pruned.reachable
    
    return {
        'source': source,
        'target': target,
        'classic_dist': r_classic.distance,
        'pruned_dist': r_pruned.distance,
        'classic_ms': t_classic,
        'pruned_ms': t_pruned,
        'match': match,
        'speedup': t_classic / t_pruned if t_pruned > 0 else 0
    }

In [14]:
# Run comparison on sample pairs
results = []
for i, src in enumerate(sample_edges[:20]):
    for dst in sample_edges[:20]:
        if src != dst:
            r = compare_algorithms(src, dst)
            results.append(r)
            if not r['match']:
                print(f"MISMATCH: {src} -> {dst}: classic={r['classic_dist']:.3f}, pruned={r['pruned_dist']:.3f}")

results_df = pd.DataFrame(results)
print(f"\nTotal queries: {len(results_df)}")
print(f"Matches: {results_df['match'].sum()} / {len(results_df)}")
print(f"Avg speedup: {results_df['speedup'].mean():.2f}x")
results_df.head(10)


Total queries: 380
Matches: 380 / 380
Avg speedup: 1.78x


Unnamed: 0,source,target,classic_dist,pruned_dist,classic_ms,pruned_ms,match,speedup
0,2718,4275,14.98056,14.98056,1.389876,1.447728,True,0.960039
1,2718,4141,78.076793,78.076793,33.536039,18.185876,True,1.844071
2,2718,1199,89.911853,89.911853,23.365032,4.962678,True,4.70815
3,2718,4546,78.539473,78.539473,32.941605,15.845605,True,2.078911
4,2718,233,152.074028,152.074028,50.827872,41.489881,True,1.225067
5,2718,1618,151.102439,151.102439,46.09241,40.570514,True,1.136106
6,2718,296,170.650151,170.650151,44.369624,42.442859,True,1.045397
7,2718,5407,150.061855,150.061855,53.115001,40.491934,True,1.311743
8,2718,5255,126.527228,126.527228,21.940301,4.686399,True,4.681697
9,2718,1840,101.515996,101.515996,46.600108,106.66097,True,0.436899


In [15]:
# Show any mismatches
mismatches = results_df[~results_df['match']]
if len(mismatches) > 0:
    print(f"\n{len(mismatches)} MISMATCHES FOUND:")
    print(mismatches)
else:
    print("\n✓ All results match!")


✓ All results match!


In [16]:
def query_classic_trace(source_edge: int, target_edge: int):
    r_classic = query_classic(source_edge, target_edge)
    path_edges = set(r_classic.path) if r_classic.reachable else set()
    print(f"query_classic: path={r_classic.path}, dist={r_classic.distance:.4f}")
    
    if source_edge == target_edge:
        return pd.DataFrame()
    
    trace = []
    inf = float('inf')
    dist_fwd = {source_edge: 0.0}
    dist_bwd = {target_edge: get_edge_cost(target_edge)}
    pq_fwd = [(0.0, source_edge)]
    pq_bwd = [(dist_bwd[target_edge], target_edge)]
    best = inf
    meeting = None
    step = 0
    
    while pq_fwd or pq_bwd:
        step += 1
        if pq_fwd:
            d, u = heappop(pq_fwd)
            action = 'STALE' if d > dist_fwd.get(u, inf) else ('d>=best' if d >= best else 'EXPAND')
            expanded = []
            if action == 'EXPAND':
                for sc in fwd_adj.get(u, []):
                    if sc.inside != 1: continue
                    v, nd = sc.to_edge, d + sc.cost
                    if nd < dist_fwd.get(v, inf):
                        dist_fwd[v] = nd
                        heappush(pq_fwd, (nd, v))
                        expanded.append(v)
                        if v in dist_bwd and nd + dist_bwd[v] < best:
                            best, meeting = nd + dist_bwd[v], v
            trace.append({'step': step, 'dir': 'FWD', 'edge': u, 'dist': round(d,2), 
                         'action': action, 'expanded': str(expanded) if expanded else '', 
                         'best': round(best,2) if best < inf else 'inf', 'in_path': u in path_edges})
        
        if pq_bwd:
            d, u = heappop(pq_bwd)
            action = 'STALE' if d > dist_bwd.get(u, inf) else ('d>=best' if d >= best else 'EXPAND')
            expanded = []
            if action == 'EXPAND':
                for sc in bwd_adj.get(u, []):
                    if sc.inside not in (-1, 0): continue
                    prev, nd = sc.from_edge, d + sc.cost
                    if nd < dist_bwd.get(prev, inf):
                        dist_bwd[prev] = nd
                        heappush(pq_bwd, (nd, prev))
                        expanded.append(prev)
                        if prev in dist_fwd and dist_fwd[prev] + nd < best:
                            best, meeting = dist_fwd[prev] + nd, prev
            trace.append({'step': step, 'dir': 'BWD', 'edge': u, 'dist': round(d,2), 
                         'action': action, 'expanded': str(expanded) if expanded else '',
                         'best': round(best,2) if best < inf else 'inf', 'in_path': u in path_edges})
        
        if pq_fwd and pq_bwd and pq_fwd[0][0] >= best and pq_bwd[0][0] >= best:
            break
        if step > 500: break
    
    print(f"Final: best={best:.4f}, meeting={meeting}")
    df = pd.DataFrame(trace)
    
    # Show path edges summary
    print(f"\nPath edges in trace:")
    for e in r_classic.path:
        rows = df[df['edge'] == e]
        if len(rows) > 0:
            print(f"  {e}: dir={rows['dir'].values[0]}, dist={rows['dist'].values[0]}, action={rows['action'].values[0]}")
        else:
            print(f"  {e}: NOT IN TRACE")
    
    return df

In [17]:
def query_pruned_trace(source_edge: int, target_edge: int):
    """Trace version of query_pruned - exact same logic with logging."""
    
    r_pruned = query_pruned(source_edge, target_edge)
    path_edges = set(r_pruned.path) if r_pruned.reachable else set()
    
    high = compute_high_cell(source_edge, target_edge)
    print(f"query_pruned: path={r_pruned.path}, dist={r_pruned.distance:.4f}")
    print(f"High cell: {hex(high.cell)}, res={high.res}")
    
    if source_edge == target_edge:
        return pd.DataFrame()
    
    trace = []
    inf = float('inf')
    dist_fwd = {source_edge: 0.0}
    dist_bwd = {target_edge: get_edge_cost(target_edge)}
    
    src_meta = edge_meta[source_edge]
    src_cell = safe_cell_to_parent(src_meta['to_cell'], src_meta['lca_res'])
    tgt_meta = edge_meta[target_edge]
    tgt_cell = safe_cell_to_parent(tgt_meta['to_cell'], tgt_meta['lca_res'])
    
    pq_fwd = [(0.0, source_edge, src_cell)]
    pq_bwd = [(dist_bwd[target_edge], target_edge, tgt_cell)]
    best = inf
    meeting = None
    min_arrival_fwd = inf
    min_arrival_bwd = inf
    step = 0
    
    # Log initial edges
    trace.append({'step': 0, 'dir': 'FWD', 'edge': source_edge, 'dist': 0.0, 'cell': hex(src_cell)[-8:],
                 'action': 'INIT', 'in_path': source_edge in path_edges})
    trace.append({'step': 0, 'dir': 'BWD', 'edge': target_edge, 'dist': round(dist_bwd[target_edge],2), 
                 'cell': hex(tgt_cell)[-8:], 'action': 'INIT', 'in_path': target_edge in path_edges})
    
    while pq_fwd or pq_bwd:
        step += 1
        
        # Forward step
        if pq_fwd:
            d, u, u_cell = heappop(pq_fwd)
            action = None
            expanded = []
            
            # Meeting check
            if u in dist_bwd:
                min_arrival_fwd = min(dist_fwd[u], min_arrival_fwd)
                min_arrival_bwd = min(dist_bwd[u], min_arrival_bwd)
                total = d + dist_bwd[u]
                if total <= best:
                    best, meeting = total, u
            
            if d >= best:
                action = 'd>=best'
            elif d > dist_fwd.get(u, inf):
                action = 'STALE'
            elif not parent_check(u_cell, high.cell, high.res):
                action = 'PARENT_CHECK'
                min_arrival_fwd = min(dist_fwd[u], min_arrival_fwd)
            else:
                action = 'EXPAND'
                if u_cell == high.cell:
                    min_arrival_fwd = min(dist_fwd[u], min_arrival_fwd)
                for sc in fwd_adj.get(u, []):
                    if sc.inside != 1: continue
                    v, nd = sc.to_edge, d + sc.cost
                    if nd < dist_fwd.get(v, inf):
                        dist_fwd[v] = nd
                        heappush(pq_fwd, (nd, v, sc.cell))
                        expanded.append(v)
            
            trace.append({'step': step, 'dir': 'FWD', 'edge': u, 'dist': round(d,2), 
                         'cell': hex(u_cell)[-8:], 'action': action, 
                         'expanded': str(expanded) if expanded else '',
                         'best': round(best,2) if best < inf else 'inf', 'in_path': u in path_edges})
        
        # Backward step
        if pq_bwd:
            d, u, u_cell = heappop(pq_bwd)
            action = None
            expanded = []
            
            # Meeting check
            if u in dist_fwd:
                min_arrival_fwd = min(dist_fwd[u], min_arrival_fwd)
                min_arrival_bwd = min(dist_bwd[u], min_arrival_bwd)
                total = dist_fwd[u] + d
                if total < best:
                    best, meeting = total, u
            
            if d > dist_bwd.get(u, inf):
                action = 'STALE'
            elif d >= best:
                action = 'd>=best'
            else:
                action = 'EXPAND'
                check = parent_check(u_cell, high.cell, high.res)
                if u_cell == high.cell or (not check):
                    min_arrival_bwd = min(dist_bwd[u], min_arrival_bwd)
                at_high_cell = (u_cell == high.cell)
                
                for sc in bwd_adj.get(u, []):
                    allowed = False
                    if sc.inside == -1 and check:
                        allowed = True
                    elif sc.inside == 0 and (at_high_cell or (not check)):
                        allowed = True
                    elif sc.inside == -2 and (not check):
                        allowed = True
                    if not allowed: continue
                    
                    prev, nd = sc.from_edge, d + sc.cost
                    if nd < dist_bwd.get(prev, inf):
                        dist_bwd[prev] = nd
                        heappush(pq_bwd, (nd, prev, sc.cell))
                        expanded.append(prev)
            
            trace.append({'step': step, 'dir': 'BWD', 'edge': u, 'dist': round(d,2), 
                         'cell': hex(u_cell)[-8:], 'action': action,
                         'expanded': str(expanded) if expanded else '',
                         'best': round(best,2) if best < inf else 'inf', 'in_path': u in path_edges})
        
        # Early termination
        if best < inf:
            bound_fwd = min(min_arrival_fwd, pq_fwd[0][0]) if pq_fwd else min_arrival_fwd
            bound_bwd = min(min_arrival_bwd, pq_bwd[0][0]) if pq_bwd else min_arrival_bwd
            fwd_good = pq_fwd and (pq_fwd[0][0] + bound_bwd < best)
            bwd_good = pq_bwd and (pq_bwd[0][0] + bound_fwd < best)
            if not fwd_good and not bwd_good:
                break
        
        if step > 500: break
    
    print(f"Final: best={best:.4f}, meeting={meeting}")
    return pd.DataFrame(trace)

In [18]:
source_edge, target_edge =5099, 2011
trace_df = query_classic_trace(source_edge, target_edge)

query_classic: path=[5099, 5064, 204, 3549, 2011], dist=100.0369
Final: best=100.0369, meeting=3549

Path edges in trace:
  5099: dir=FWD, dist=0.0, action=EXPAND
  5064: dir=FWD, dist=4.31, action=EXPAND
  204: dir=FWD, dist=15.2, action=EXPAND
  3549: dir=BWD, dist=74.73, action=EXPAND
  2011: dir=BWD, dist=54.25, action=EXPAND


In [19]:
trace_df_2 = query_pruned_trace(source_edge, target_edge)

query_pruned: path=[5099, 5064, 204, 3549, 2011], dist=100.0369
High cell: 0x862669a6fffffff, res=6
Final: best=100.0369, meeting=3549


In [20]:
trace_df_2[trace_df_2["in_path"]==True]

Unnamed: 0,step,dir,edge,dist,cell,action,in_path,expanded,best
0,0,FWD,5099,0.0,6b23ffff,INIT,True,,
1,0,BWD,2011,54.25,69ffffff,INIT,True,,
2,1,FWD,5099,0.0,6b23ffff,EXPAND,True,"[3464, 1084, 5080, 1086, 1085, 5064, 4032, 346...",inf
3,1,BWD,2011,54.25,69ffffff,EXPAND,True,"[582, 1803, 596, 770, 4762, 569, 5631, 2005, 3...",inf
4,2,FWD,5064,4.31,6b3fffff,EXPAND,True,"[147, 157, 532, 4026, 5127, 3074, 5136, 3141, ...",inf
5,2,BWD,3549,74.73,6fffffff,EXPAND,True,"[2593, 2850, 5626, 3543, 121, 574, 317, 2546, ...",inf
36,18,FWD,204,15.2,6bffffff,EXPAND,True,"[1708, 1444, 44, 2207, 1456, 2452, 3549, 1516,...",126.6
90,45,FWD,3549,25.31,6fffffff,EXPAND,True,"[1064, 4022, 1477, 2695, 619, 1492, 4759, 1850...",100.04
94,47,FWD,3549,27.84,6fffffff,STALE,True,,100.04


In [21]:
trace_df[trace_df["in_path"]==True]

Unnamed: 0,step,dir,edge,dist,action,expanded,best,in_path
0,1,FWD,5099,0.0,EXPAND,"[3464, 1084, 5080, 1086, 1085, 5064, 4032, 346...",inf,True
1,1,BWD,2011,54.25,EXPAND,"[582, 1435, 1803, 1697, 2474, 596, 1987, 2624,...",inf,True
2,2,FWD,5064,4.31,EXPAND,"[147, 157, 532, 4026, 5127, 3074, 5136, 3141, ...",inf,True
34,18,FWD,204,15.2,EXPAND,"[1708, 1444, 44, 2207, 1456, 2452, 3549, 1516,...",100.04,True
55,28,BWD,3549,74.73,EXPAND,"[2593, 2850, 5626, 3543, 121, 574, 317, 2546, ...",100.04,True
88,45,FWD,3549,25.31,EXPAND,"[1064, 4022, 1477, 2695, 619, 1492, 4759, 1850...",100.04,True
92,47,FWD,3549,27.84,STALE,,100.04,True


In [22]:
# query_classic: path=[100, 95, 93, 3212, 190, 200], dist=43.2927

# Scipy Path: [100, 95,93, 166, 2407, 3212, 1544, 180, 190, 200], dist=40.17

In [23]:
"""
Resolution-based pruning optimization for query_pruned.

Instead of propagating full 8-byte cell through queue, we use 1-byte resolution.
Parent check becomes a simple integer comparison: node_res >= high_res means "in scope"
"""

# No get_cell_res needed - cell_res is precomputed in Shortcut dataclass

def query_pruned_fast(source_edge: int, target_edge: int):
    """
    Optimized pruned search using resolution instead of cell.
    
    Key optimization: parent_check is now just `node_res >= high_res`
    - 1 byte resolution instead of 8 byte cell
    - Simple integer comparison instead of H3 library calls
    """
    if source_edge == target_edge:
        return QueryResult(get_edge_cost(source_edge), [source_edge], True)
    
    high = compute_high_cell(source_edge, target_edge)
    
    inf = float('inf')
    dist_fwd = {source_edge: 0.0}
    dist_bwd = {target_edge: get_edge_cost(target_edge)}
    parent_fwd = {source_edge: source_edge}
    parent_bwd = {target_edge: target_edge}
    
    # Calculate initial resolutions from edge's lca_res
    src_meta = edge_meta[source_edge]
    src_res = src_meta['lca_res']  # Already a resolution!
    tgt_meta = edge_meta[target_edge]
    tgt_res = tgt_meta['lca_res']
    
    # Heap entries: (distance, edge_id, resolution) - 1 byte vs 8 bytes!
    pq_fwd = [(0.0, source_edge, src_res)]
    pq_bwd = [(dist_bwd[target_edge], target_edge, tgt_res)]
    
    best = inf
    meeting = None
    min_arrival_fwd = inf  # Track min arrival distance in forward direction
    min_arrival_bwd = inf  # Track min arrival distance in backward direction
      
    while pq_fwd or pq_bwd:
        # Forward step
        if pq_fwd:
            d, u, u_res = heappop(pq_fwd)
            
            if u in dist_bwd:
                min_arrival_fwd = min(dist_fwd[u], min_arrival_fwd)
                min_arrival_bwd = min(dist_bwd[u], min_arrival_bwd)
                total = d + dist_bwd[u]
                if total <= best:
                    best = total
                    meeting = u
                    
            if d >= best:
                continue
            if d > dist_fwd.get(u, inf):
                continue  # stale
            
            # FAST PRUNING: simple resolution comparison
            if u_res < high.res:
                min_arrival_fwd = min(dist_fwd[u], min_arrival_fwd)
                continue
            
            # At high resolution level - update min_arrival
            if u_res == high.res:
                min_arrival_fwd = min(dist_fwd[u], min_arrival_fwd)

            for sc in fwd_adj.get(u, []):
                if sc.inside != 1:
                    continue
                v = sc.to_edge
                nd = d + sc.cost
                if nd < dist_fwd.get(v, inf):
                    dist_fwd[v] = nd
                    parent_fwd[v] = u
                    # Use precomputed cell_res from shortcut
                    heappush(pq_fwd, (nd, v, sc.cell_res))
                    
        # Backward step
        if pq_bwd:
            d, u, u_res = heappop(pq_bwd)
            
            if u in dist_fwd:
                min_arrival_fwd = min(dist_fwd[u], min_arrival_fwd)
                min_arrival_bwd = min(dist_bwd[u], min_arrival_bwd)
                total = dist_fwd[u] + d
                if total < best:
                    best = total
                    meeting = u
                    
            if d > dist_bwd.get(u, inf):
                continue  # stale
            if d >= best:
                continue
           
            # FAST PRUNING: check = (u_res >= high.res)
            check = (u_res >= high.res)
            
            # Update min_arrival when at high res or outside scope
            if u_res == high.res or (not check):
                min_arrival_bwd = min(dist_bwd[u], min_arrival_bwd)
            
            for sc in bwd_adj.get(u, []):
                if sc.inside == -1 and check:
                    pass
                elif sc.inside == 0 and (u_res <= high.res):
                    pass
                elif sc.inside == -2 and (not check):
                    pass
                else:
                    continue
                
                prev = sc.from_edge
                nd = d + sc.cost
                if nd < dist_bwd.get(prev, inf):
                    dist_bwd[prev] = nd
                    parent_bwd[prev] = u
                    # Use precomputed cell_res from shortcut
                    heappush(pq_bwd, (nd, prev, sc.cell_res))
                   
        # Early termination - check if both directions can improve
        if best < inf:
            bound_fwd = min_arrival_fwd
            bound_bwd = min_arrival_bwd
            if pq_fwd:
                bound_fwd = min(bound_fwd, pq_fwd[0][0])
            if pq_bwd:
                bound_bwd = min(bound_bwd, pq_bwd[0][0])
                
            fwd_good = pq_fwd and (pq_fwd[0][0] + bound_bwd < best)
            bwd_good = pq_bwd and (pq_bwd[0][0] + bound_fwd < best)
            if not fwd_good and not bwd_good:
                break

    if meeting is None or best == inf:
        return QueryResult(-1, [], False)
    
    # Reconstruct path
    path = []
    curr = meeting
    while curr != source_edge:
        path.append(curr)
        curr = parent_fwd[curr]
    path.append(source_edge)
    path.reverse()
    
    curr = meeting
    while curr != target_edge:
        curr = parent_bwd[curr]
        path.append(curr)
    
    return QueryResult(best, path, True)


# Test the optimization
print("Testing resolution-based pruning optimization...")
print()

# Test cases
test_cases = [(100, 200), (216, 1000), (2006, 1828)]

for src, tgt in test_cases:
    result_orig = query_pruned(src, tgt)
    result_fast = query_pruned_fast(src, tgt)
    
    match = abs(result_orig.distance - result_fast.distance) < 0.001
    print(f"{src} -> {tgt}:")
    print(f"  Original: {result_orig.distance:.4f}")
    print(f"  Fast:     {result_fast.distance:.4f}")
    print(f"  Match: {'✓' if match else 'FAIL'}")
    print()


Testing resolution-based pruning optimization...

100 -> 200:
  Original: 31.6731
  Fast:     31.6731
  Match: ✓

216 -> 1000:
  Original: 73.6716
  Fast:     73.6716
  Match: ✓

2006 -> 1828:
  Original: 59.9668
  Fast:     59.9668
  Match: ✓



In [35]:
def compare_with_dijkstra(source: int, target: int):
    """Compare general Dijkstra with classic and pruned algorithms."""
    r_dijkstra = dijkstra_general(source, target, fwd_adj, get_edge_cost)
    r_classic = query_classic(source, target)
    r_pruned = query_pruned(source, target)
    
    print(f"Query: {source} -> {target}")
    print(f"  Dijkstra: cost={r_dijkstra.distance:.4f}, path={r_dijkstra.path}")
    print(f"  Classic:  cost={r_classic.distance:.4f}, path={r_classic.path}")
    print(f"  Pruned:   cost={r_pruned.distance:.4f}, path={r_pruned.path}")
    
    classic_match = abs(r_dijkstra.distance - r_classic.distance) < 0.001 if r_classic.reachable else False
    pruned_match = abs(r_dijkstra.distance - r_pruned.distance) < 0.001 if r_pruned.reachable else False
    
    print(f"  Classic matches Dijkstra: {'✓' if classic_match else '✗ SUBOPTIMAL'}")
    print(f"  Pruned matches Dijkstra:  {'✓' if pruned_match else '✗ SUBOPTIMAL'}")
    print()
    
    return r_dijkstra, r_classic, r_pruned

In [37]:
compare_with_dijkstra(5099, 2011)

Query: 5099 -> 2011
  Dijkstra: cost=100.0369, path=[5099, 5064, 1078, 1028, 3333, 204, 3549, 2011]
  Classic:  cost=100.0369, path=[5099, 5064, 204, 3549, 2011]
  Pruned:   cost=100.0369, path=[5099, 5064, 204, 3549, 2011]
  Classic matches Dijkstra: ✓
  Pruned matches Dijkstra:  ✓



(QueryResult(distance=100.03691711768761, path=[5099, 5064, 1078, 1028, 3333, 204, 3549, 2011], reachable=True),
 QueryResult(distance=100.03691711768761, path=[5099, 5064, 204, 3549, 2011], reachable=True),
 QueryResult(distance=100.03691711768761, path=[5099, 5064, 204, 3549, 2011], reachable=True))

In [31]:
from dijkstra_general import dijkstra_general
result = dijkstra_general(5064, 204, fwd_adj, get_edge_cost)
print(f"Dijkstra: {result.distance:.4f}")  # Should now show 28.4737

Dijkstra: 28.4737


In [32]:
from dijkstra_general import dijkstra_general
# Test the example
source = 5099
target = 204
result = dijkstra_general(source, target, fwd_adj)
print(f"Path: {result.path}")
print(f"Cost: {result.distance}")
# Or compare with your algorithms
r_dijkstra = dijkstra_general(source, target, fwd_adj)
r_classic = query_classic(source, target)
r_pruned = query_pruned(source, target)
print(f"Dijkstra: {r_dijkstra.distance:.4f}")
print(f"Classic:  {r_classic.distance:.4f}")
print(f"Pruned:   {r_pruned.distance:.4f}")

Path: [5099, 5064, 1078, 1028, 3333, 204]
Cost: 15.20357429218741
Dijkstra: 15.2036
Classic:  32.7797
Pruned:   32.7797


In [27]:
from query_pruned_cpp import query_pruned_cpp_style, HighCell
result = query_pruned_cpp_style(
    source_edge=5099,  # or your test edge
    target_edge=2011,
    fwd_adj=fwd_adj,
    bwd_adj=bwd_adj,
    edge_meta=edge_meta,
    compute_high_cell_fn=compute_high_cell,  # Your existing function
    get_edge_cost_fn=get_edge_cost
)
print(f"C++ style: cost={result.distance:.4f}")

C++ style: cost=100.0369
