In [1]:
import numpy as np # type: ignore
import matplotlib.patches as patches # type: ignore
import matplotlib.pyplot as plt # type: ignore
from matplotlib.collections import LineCollection # type: ignore
from tqdm import tqdm # type: ignore
import heapq
from collections import defaultdict

In [None]:
def is_point_in_zone(point, zone):
    print("Warning - this function only works with the fake city")
    x, y = point
    return (
        x >= zone["x"]
        and x <= zone["x"] + zone["width"]
        and y >= zone["y"]
        and y <= zone["y"] + zone["height"]
    )

# Precompute street points inside each zone
def get_street_points_by_zone(street_graph, zones):
    print("Warning - this function only works with the fake city")
    return {
        zone["id"]: [
            {"x": data["x"], "y": data["y"], "id": node_id}
            for node_id, data in street_graph.nodes(data=True)
            if is_point_in_zone((data["x"], data["y"]), zone)
        ]
        for zone in zones
    }


In [None]:
def get_points_by_zone(G):
    zone_dict = {}
    for node in G.nodes:
        zone = G.nodes[node].get("zone")
        if zone not in zone_dict:
            zone_dict[zone] = []
        zone_dict[zone].append(node)
    return zone_dict

In [None]:
def multi_source_dijkstra_exact_pairs(graph, sources, targets, weight='length'):
    distances = defaultdict(dict)
    targets = set(targets)
    target_lookup = targets.__contains__
    num_remaining = len(targets)

    heap = [(0, source, source) for source in sources]
    heapq.heapify(heap)

    while heap and num_remaining > 0:
        dist_u, u, origin = heapq.heappop(heap)

        # Have we already found a shorter path to u from origin?
        if u in distances[origin]:
            continue

        distances[origin][u] = dist_u

        if target_lookup(u):
            num_remaining -= 1

        for v, edge_data in graph[u].items():
            length = edge_data[0][weight]
            alt = dist_u + length

            if v not in distances[origin]:
                heapq.heappush(heap, (alt, v, origin))

    return distances


In [None]:
# Pre-compute distances betwee edge nodes in different zones

# Find the nodes that are on the "edge" of a zone
# These are nodes which are in a zone but have at least one neighbour that is not in the zone
def find_edge_nodes(street_graph):
    edge_nodes_lookup = {}
    
    points_by_zone = get_points_by_zone(street_graph) # type: ignore
    
    for zone_id, points in points_by_zone.items():
        edge_nodes = set()
        for point in points:
            neighbours = list(street_graph.neighbors(point))
            for neighbour in neighbours:
                if street_graph.nodes[neighbour]["zone"] != zone_id:
                    edge_nodes.add(point)
                    break
        edge_nodes_lookup[zone_id] = edge_nodes
    return edge_nodes_lookup


def generate_edge_distances_lookup(street_graph, edge_nodes_lookup):
    edge_distances = {}

    for start_zone_id, start_edge_nodes in edge_nodes_lookup.items():
        edge_distances[start_zone_id] = {}

        for end_zone_id, end_edge_nodes in tqdm(edge_nodes_lookup.items(), desc=f"Processing start zone {start_zone_id}"):
            if start_zone_id == end_zone_id:
                continue

            # Fast multi-source, tracked by origin
            print(f"Calculating distances from {start_zone_id} to {end_zone_id}...")
            pairwise_dists = multi_source_dijkstra_exact_pairs(
                street_graph,
                sources=start_edge_nodes,
                targets=end_edge_nodes,
                weight='length'
            )

            print(f"Finished calculating distances from {start_zone_id} to {end_zone_id}.")
            edge_distances[start_zone_id][end_zone_id] = {}
            for start_node, end_node_dists in pairwise_dists.items():
                edge_distances[start_zone_id][end_zone_id][start_node] = {
                    end_node: dist for end_node, dist in end_node_dists.items() if end_node in end_edge_nodes
                }

    return edge_distances

# edge_nodes_lookup = find_edge_nodes(street_graph)
# edge_distance_lookup = generate_edge_distances_lookup(street_graph, edge_nodes_lookup)


In [None]:
def draw_on_lines(graph, lines, colour):
    # Fetch all node positions in one go
    pos = {node: (data["x"], data["y"]) for node, data in graph.nodes(data=True)}
    
    # Convert lines into an array of coordinates
    segments = np.array([(pos[u], pos[v]) for u, v in lines])

    # Use LineCollection for faster bulk rendering
    lc = LineCollection(segments, colors=colour, linewidths=2, alpha=1)
    plt.gca().add_collection(lc)

def add_route(graph, route, colour, is_actual_route=False):
    pos = {node: (data["x"], data["y"]) for node, data in graph.nodes(data=True)}

    start, end = pos[route[0]], pos[route[-1]]

    if is_actual_route:
        plt.scatter(*start, color="green", s=100, zorder=5)
        plt.scatter(*end, color="red", s=100, zorder=5)
    else:
        plt.scatter(*start, color="blue", s=100, zorder=5)
        plt.scatter(*end, color="orange", s=100, zorder=5)

    draw_on_lines(graph, zip(route[:-1], route[1:]), colour)

def visualise(street_graph, zones, width, height, trips=[]):
    plt.figure(figsize=(15, 10))

    plt.xlim(0, width)
    plt.ylim(0, height)
    plt.gca().set_facecolor("#f8f9fa")

    draw_on_lines(street_graph, street_graph.edges(), "#8080807f")

    for zone in zones:
        rect = patches.Rectangle(
            (zone["x"], zone["y"]),
            zone["width"],
            zone["height"],
            linewidth=1,
            edgecolor="#333333",
            facecolor=zone["colour"],
            alpha=0.5,
        )
        plt.gca().add_patch(rect)
        plt.text(
            zone["x"] + zone["width"] / 2,
            zone["y"] + zone["height"] / 2,
            zone["name"],
            ha="center",
            va="center",
            fontweight="bold",
        )

    for trip in trips:
        add_route(street_graph, trip["route"], "black", is_actual_route=True)

        if trip["estimated_route"] is not None:
            add_route(street_graph, trip["estimated_route"], "black", is_actual_route=False)

    plt.legend()
    plt.tight_layout()
    plt.show()
