In [None]:
import osmnx as ox
from shapely.geometry import Point
import zipfile
import requests
import geopandas as gpd
import random
import os
import pandas as pd

parent_dir = os.path.dirname(os.path.dirname(os.getcwd()))

In [None]:
def fetch_road_graph():
    boroughs = ['Bronx, New York, USA', 'Brooklyn, New York, USA',
                'Manhattan, New York, USA', 'Queens, New York, USA',
                'Staten Island, New York, USA']
    
    print("Getting graph of NYC boroughs")
    
    G = ox.graph_from_place(boroughs, network_type='drive', simplify=True)
    
    G = ox.add_edge_speeds(G)
    G = ox.add_edge_travel_times(G)

    return G

In [None]:
def assign_zones_to_nodes(G, zones):
    nodes = gpd.GeoDataFrame(
        {
            "node": list(G.nodes),
            "geometry": [Point(G.nodes[n]["x"], G.nodes[n]["y"]) for n in G.nodes]
        },
        crs="EPSG:4326"
    )

    # Spatial join
    nodes_with_zones = gpd.sjoin(nodes, zones, how="left", predicate="within")

    for _, row in nodes_with_zones.iterrows():
        G.nodes[row["node"]]["zone"] = row["LocationID"]if not pd.isna(row["LocationID"]) else None

    return G


In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import networkx as nx
import osmnx as ox

def display_real_city(G, zones=None, trips=None, valid_points=None):
    fig, ax = plt.subplots(figsize=(30, 30))

    speeds = nx.get_edge_attributes(G, 'speed_kph')
    min_speed = min(speeds.values())
    max_speed = max(speeds.values())
    norm = mpl.colors.Normalize(vmin=min_speed, vmax=max_speed)
    colormap = plt.cm.viridis
    edge_colors = [colormap(norm(speeds.get(edge, min_speed))) for edge in G.edges]

    # Plot the road network
    ox.plot_graph(G, ax=ax, node_size=0, edge_linewidth=0.1, edge_color=edge_colors, show=False)

    # Plot zones
    if zones is not None:
        if zones.crs != "EPSG:4326":
            zones = zones.to_crs("EPSG:4326")
        zones.plot(ax=ax, column='LocationID', cmap='tab20', alpha=0.5, edgecolor='black')

    # Plot trips
    if trips is not None:
        for trip in trips:
            route_nodes = trip.get("route")
            if not route_nodes:
                continue

            for u, v in zip(route_nodes[:-1], route_nodes[1:]):
                edge_data = G.get_edge_data(u, v)
                geom = edge_data[0].get("geometry")
                if geom:
                    x, y = geom.xy
                    ax.plot(x, y, color="red", linewidth=1.5, linestyle="-", alpha=0.8)
                else:
                    x1, y1 = G.nodes[u]["x"], G.nodes[u]["y"]
                    x2, y2 = G.nodes[v]["x"], G.nodes[v]["y"]
                    ax.plot([x1, x2], [y1, y2], color="red", linewidth=1.5, linestyle="-", alpha=0.8)

            # Mark start and end points of trip
            start_x, start_y = G.nodes[route_nodes[0]]["x"], G.nodes[route_nodes[0]]["y"]
            end_x, end_y = G.nodes[route_nodes[-1]]["x"], G.nodes[route_nodes[-1]]["y"]
            ax.scatter(start_x, start_y, marker="o", color="lime", s=50, label="Trip Start")
            ax.scatter(end_x, end_y, marker="x", color="red", s=50, label="Trip End")

    # Plot valid start/end nodes
    if valid_points:
        valid_start = valid_points.get("valid_start", [])
        valid_end = valid_points.get("valid_end", [])

        if valid_start:
            xs = [G.nodes[n]["x"] for n in valid_start]
            ys = [G.nodes[n]["y"] for n in valid_start]
            ax.scatter(xs, ys, marker="o", color="green", s=40, alpha=0.7, zorder=5, label="Valid Starts")

        if valid_end:
            xs = [G.nodes[n]["x"] for n in valid_end]
            ys = [G.nodes[n]["y"] for n in valid_end]
            ax.scatter(xs, ys, marker="x", color="crimson", s=40, alpha=0.7, zorder=5, label="Valid Ends")

    ax.legend()
    plt.title("NYC Road Network with Taxi Zones, Trips, and Valid Points")
    plt.tight_layout()
    plt.savefig("nyc_roads_map_with_valid_points.png", dpi=600)
    print("Map saved as 'nyc_roads_map_with_valid_points.png'")
    plt.show()


In [None]:
def get_zones(display=False):
    url = "https://d37ci6vzurychx.cloudfront.net/misc/taxi_zones.zip"
    zip_path = os.path.join(parent_dir, "taxi_zones.zip")
    shapefile_dir = os.path.join(parent_dir, "taxi_zones")

    # Check if the shapefile already exists
    if not os.path.exists(shapefile_dir) or not any(file.endswith(".shp") for file in os.listdir(shapefile_dir)):
        response = requests.get(url)
        with open(zip_path, "wb") as f:
            f.write(response.content)

        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall(shapefile_dir)

    # Find the shapefile path
    shapefile_path = [os.path.join(shapefile_dir, file) for file in os.listdir(shapefile_dir) if file.endswith(".shp")][0]
    zones = gpd.read_file(shapefile_path)
    
    if display:
        fig, ax = plt.subplots(figsize=(10, 10))
        zones.plot(ax=ax, column='LocationID', cmap='tab20')
        plt.title("NYC Taxi Zones")
        plt.tight_layout()
        plt.show()

    return zones.to_crs("EPSG:4326")

#get_zones(True)

In [None]:
def get_nodes_in_zone(G, zones, zone_id):
    nodes = gpd.GeoDataFrame(
        {
            "node": list(G.nodes),
            "geometry": [Point(G.nodes[n]["x"], G.nodes[n]["y"]) for n in G.nodes]
        },
        crs="EPSG:4326"
    )

    zone = zones[zones["LocationID"] == zone_id]
    if zone.empty:
        return []

    # Spatial join
    nodes_in_zone = gpd.sjoin(nodes, zone, how="inner", predicate="within")

    return nodes_in_zone["node"].tolist() if not nodes_in_zone.empty else []


def get_random_node_in_zone(G, zones, zone_id):
    nodes_in_zone = get_nodes_in_zone(G, zones, zone_id)

    return random.choice(nodes_in_zone) if nodes_in_zone else None


def fabricate_trips(G, zones, num_trips=10):
    if zones.crs != "EPSG:4326":
        zones = zones.to_crs("EPSG:4326")

    trips = []

    for i in range(num_trips):
        print(f"Fabricating trip {i}")
        start_zone, end_zone = None, None

        # Ensure start zone and end zone are different
        while start_zone is None or end_zone is None or start_zone.equals(end_zone):
            start_zone = zones.sample().iloc[0]
            end_zone = zones.sample().iloc[0]

        # Get random nodes in both start and end zones
        start_point = get_random_node_in_zone(G, zones, start_zone["LocationID"])
        end_point = get_random_node_in_zone(G, zones, end_zone["LocationID"])

        if not start_point or not end_point:
            continue

        try:
            route = nx.shortest_path(G, start_point, end_point, weight="travel_time")

            travel_time = nx.path_weight(G, route, weight="travel_time")
            distance = nx.path_weight(G, route, weight="length")

        except nx.NetworkXNoPath:
            continue

        trips.append({
            "id": i,
            "start_zone_id": start_zone["LocationID"],
            "start_zone_name": start_zone["zone"],
            "end_zone_id": end_zone["LocationID"],
            "end_zone_name": end_zone["zone"],
            "travel_time": travel_time,
            "distance": distance,
            "start_node": start_point,
            "end_node": end_point,
            "route": route
        })
    
    return trips


In [None]:
# Returns zones, street_graph, trips, width, height
def generate_real_city_data(force=False, do_fabricate_trips=True):
    print("Getting graph...")
    # Don't fetch the graph if it already exists
    if not os.path.exists("nyc_roads.graphml") or force:
        street_graph = fetch_road_graph()

        ox.save_graphml(street_graph, filepath="nyc_roads.graphml")
        print("Graph saved as 'nyc_roads.graphml'")
    else:
        street_graph = ox.load_graphml("nyc_roads.graphml")
        print("Graph loaded from 'nyc_roads.graphml'")
    
    print("Getting zones...")
    zones = get_zones()
    
    print("Assigning zones to nodes...")
    assign_zones_to_nodes(street_graph, zones)
    
    trips = []
    if do_fabricate_trips:
        print("Fabricating trips...")
        trips = fabricate_trips(street_graph, zones)
    
    return zones, street_graph, trips

# zones, street_graph, trips = generate_real_city_data()

In [None]:
#print("Displaying city...")
#display_real_city(street_graph, zones, trips)