In [None]:
import os 
os.chdir("/home/canyon/Bus-Weather-Impacts")
from src.utils import *
import pandas as pd
import os
import osmnx as ox
import numpy as np
import geopandas as gpd
import networkx as nx
from sklearn.neighbors import KDTree
import numpy as np
pd.options.mode.chained_assignment = None
pd.set_option('display.float_format', '{:.02f}'.format)
from geopy.distance import geodesic
from shapely.geometry import Point
calculated_pair_path = "data/node_pairs.parquet"
pd.set_option('display.max_columns', None)


In [None]:
import numpy as np
import pandas as pd
import geopandas as gpd
from shapely.geometry import Point
import networkx as nx
from typing import Dict, List, Tuple, Union
from functools import lru_cache

# Assuming these functions are imported from elsewhere in the codebase
# from utils import read_parquet_from_tar_gz, ox, KDTree

@lru_cache(maxsize=None)
def compute_distance(prev_osmid: int, osmid: int, graph: nx.Graph) -> Union[float, None]:
    """Compute the shortest path length between two nodes."""
    try:
        return nx.shortest_path_length(graph, prev_osmid, osmid, weight='travel_time')
    except nx.NetworkXNoPath:
        return None

@lru_cache(maxsize=None)
def compute_path(prev_osmid: int, osmid: int, graph: nx.Graph) -> Union[List[int], None]:
    """Compute the shortest path between two nodes."""
    try:
        return nx.shortest_path(graph, prev_osmid, osmid, weight='travel_time')
    except nx.NetworkXNoPath:
        return None

def compute_euclid_dists(node_pairs: pd.DataFrame, nodes_points: gpd.GeoDataFrame) -> pd.Series:
    """Compute Euclidean distances between node pairs."""
    nodes_points = nodes_points.to_crs(2263)
    nodes_points_xy = nodes_points[['osmid', 'geometry']].copy()
    nodes_points_xy['x'] = nodes_points_xy['geometry'].x
    nodes_points_xy['y'] = nodes_points_xy['geometry'].y

    merged = node_pairs.merge(nodes_points_xy, left_on="osmid", right_on="osmid", how="left")
    merged = merged.merge(nodes_points_xy, right_on="osmid", left_on="prev_osmid", how="left", suffixes=["_curr", "_prev"])
    
    x_diff_sq = (merged['x_curr'] - merged['x_prev'])**2
    y_diff_sq = (merged['y_curr'] - merged['y_prev'])**2

    return np.sqrt(x_diff_sq + y_diff_sq) / 3.28

def precalculate_node_pair_distances(node_pair_df: pd.DataFrame, calculated_pair_path: str, G: nx.Graph, nodes: gpd.GeoDataFrame) -> None:
    """Precalculate and save node pair distances."""
    try:
        calculated_pairs = pd.read_parquet(calculated_pair_path)
    except Exception:
        print("No pre-calculated pairs found")
        calculated_pairs = pd.DataFrame(columns=["osmid", "prev_osmid", "distance_osm", "distance_euclid", "shortest_path", "dist_ratio"])

    node_pair_df = node_pair_df.drop_duplicates().dropna()
    node_pair_df = node_pair_df.merge(calculated_pairs, on=["osmid", "prev_osmid"], how="outer")
    pairs_to_calc = node_pair_df[node_pair_df["distance_euclid"].isna()].reset_index(drop=True)
    print(f"Pairs to calculate: {pairs_to_calc.shape[0]}")
    
    if not pairs_to_calc.empty:
        pairs_to_calc["distance_osm"] = pairs_to_calc.apply(lambda row: compute_distance(row['prev_osmid'], row['osmid'], G), axis=1)
        pairs_to_calc["distance_euclid"] = compute_euclid_dists(pairs_to_calc, nodes)
        pairs_to_calc["shortest_path"] = pairs_to_calc.apply(lambda row: compute_path(row['prev_osmid'], row['osmid'], G), axis=1)
        pairs_to_calc["dist_ratio"] = pairs_to_calc["distance_euclid"] / pairs_to_calc["distance_osm"]

        calculated_pairs = pd.concat([calculated_pairs, pairs_to_calc])
        calculated_pairs.to_parquet(calculated_pair_path)
        print(f"Wrote calculated pairs to {calculated_pair_path}")
    else:
        print("No new pairs to calculate")

def prep_buses_nodes(buses_with_nodes: gpd.GeoDataFrame, max_distance_to_node: float) -> pd.DataFrame:
    """Prepare bus node data."""
    buses_with_nodes = (
        buses_with_nodes.sort_values(["trip_id", "timestamp"])
        .drop_duplicates(subset=["trip_id", "osmid"], keep="first")
        .to_crs(2263)
    )
    
    columns = ["route_short", "timestamp", "trip_id", "next_stop_id", "osmid", "vehicle_id", "distance_to_node", "geometry"]
    buses_with_nodes = buses_with_nodes[columns].copy()
    
    buses_with_nodes["prev_stop_id"] = buses_with_nodes.groupby("trip_id")["next_stop_id"].shift(1)
    buses_with_nodes["prev_osmid"] = buses_with_nodes.groupby("trip_id")["osmid"].shift(1)
    buses_with_nodes["next_osmid"] = buses_with_nodes.groupby("trip_id")["osmid"].shift(-1)

    buses_with_nodes["prev_osmid"] = buses_with_nodes["prev_osmid"].astype(float)
    buses_with_nodes["osmid"] = buses_with_nodes["osmid"].astype(float)

    return buses_with_nodes[buses_with_nodes["distance_to_node"] < max_distance_to_node]

def tag_feed_with_nodes(buses: pd.DataFrame, tree: KDTree, nodes: gpd.GeoDataFrame, types_to_include: List[Union[str, float]] = [np.NaN, "traffic_signals", "stop"]) -> gpd.GeoDataFrame:
    """Tag bus feed with nearest nodes."""
    nodes = nodes[nodes["highway"].isin(types_to_include)]
    nearest_nodes = tree.query(np.array(buses[['lat', 'lon']]), k=1, return_distance=False)
    buses['nearest_node'] = nearest_nodes.flatten()

    buses['nearest_osm_id'] = buses['nearest_node'].map(nodes['osmid'])
    buses = buses.merge(nodes, left_on="nearest_osm_id", right_on="osmid")
    return gpd.GeoDataFrame(buses, geometry='geometry')

def get_node_data(place: str = "New York City, New York, USA") -> Tuple[KDTree, gpd.GeoDataFrame, nx.Graph]:
    """Get node data for a specified place."""
    G = ox.graph_from_place(place, network_type='drive')
    G = ox.add_edge_speeds(G)
    G = ox.add_edge_travel_times(G)
    
    nodes = ox.graph_to_gdfs(G, edges=False).reset_index()
    tree = KDTree(nodes[['y', 'x']], metric='euclidean')

    return tree, nodes, G

def prep_coords(df: pd.DataFrame, lat_col: str, lon_col: str) -> gpd.GeoDataFrame:
    """Prepare coordinates for geospatial analysis."""
    gdf = gpd.GeoDataFrame(
        df,
        geometry=gpd.points_from_xy(df[lon_col], df[lat_col]),
        crs=4326
    )

    gdf_projected = gdf.to_crs(2263)
    gdf["planar_x"] = gdf_projected.geometry.x
    gdf["planar_y"] = gdf_projected.geometry.y
    
    return gdf.drop('geometry', axis=1)

def calculate_distance_to_node(buses_with_nodes: gpd.GeoDataFrame) -> pd.Series:
    """Calculate distance from buses to their nearest nodes."""
    buses_projected = buses_with_nodes.to_crs(2263)
    return np.sqrt((buses_projected.geometry.x - buses_with_nodes["planar_x"])**2 + 
                   (buses_projected.geometry.y - buses_with_nodes["planar_y"])**2)

def calculate_speeds(prepped_trips: pd.DataFrame, calculated_pair_path: str = "data/node_pairs.parquet", minimum_time_diff: int = 45) -> pd.DataFrame:
    """Calculate bus speeds based on node pair distances."""
    node_pair_dists = pd.read_parquet(calculated_pair_path)

    prepped_trips["time_diff_seconds"] = prepped_trips.groupby("trip_id")["timestamp"].diff().dt.total_seconds()
    prepped_trips = prepped_trips[prepped_trips["time_diff_seconds"] >= minimum_time_diff]

    buses_with_distances = prepped_trips.merge(node_pair_dists)
    buses_with_distances["speed_osm"] = (buses_with_distances["distance_osm"] / 1609) / (buses_with_distances["time_diff_seconds"] / 3600)
    buses_with_distances["speed_euclid"] = (buses_with_distances["distance_euclid"] / 1609) / (buses_with_distances["time_diff_seconds"] / 3600)
    
    return buses_with_distances

def explode_edges(row: pd.Series) -> pd.DataFrame:
    """Explode edges from a single row into multiple rows."""
    try:
        nodes = row['shortest_path']
        return pd.DataFrame({
            'idx': [int(row['index'])] * (len(nodes) - 1),
            'from': nodes[:-1],
            'to': nodes[1:]
        })
    except:
        return pd.DataFrame({'idx': [pd.NA], 'from': [pd.NA], 'to': [pd.NA]})

def get_bus_stops(path: str = "/home/data/test/cities/C3562/stops.geojson") -> gpd.GeoDataFrame:
    """Get bus stop data from a GeoJSON file."""
    bus_stops = gpd.read_file(path)
    bus_agencies = ["MTA NYCT", "MTABC", "MTA NYCT,MTABC"]
    bus_stops = bus_stops[bus_stops["agency_ids_serviced"].isin(bus_agencies)][["stop_id", "stop_name", "stop_lat", "stop_lon", "geometry"]]
    bus_stops = bus_stops.rename({"stop_lat": "lat", "stop_lon": "lon"}, axis=1)
    bus_stops["stop_id"] = "MTA_" + bus_stops["stop_id"]
    return prep_coords(bus_stops, "lat", "lon")

def bus_stops_nodes(bus_stops: gpd.GeoDataFrame, tree: KDTree, nodes: gpd.GeoDataFrame) -> pd.DataFrame:
    """Tag bus stops with nearest nodes."""
    stops_with_nodes = tag_feed_with_nodes(bus_stops, tree, nodes)
    stops_with_nodes["dist_to_node"] = calculate_distance_to_node(stops_with_nodes)
    stops_with_nodes = stops_with_nodes[stops_with_nodes["dist_to_node"] < 200]
    return stops_with_nodes[["stop_id", "stop_name", "osmid", "dist_to_node"]]

def get_stop_pairs(bus_stops: pd.DataFrame, raw_GTFS_path: str) -> pd.DataFrame:
    """Get pairs of consecutive bus stops from GTFS data."""
    if raw_GTFS_path.endswith('.gz'):
        gtfs_rt = read_parquet_from_tar_gz(raw_GTFS_path)
    else:
        col_remappings = {
            "vehicle.trip.trip_id": "trip_id",
            "vehicle.timestamp": "timestamp",
            "vehicle.position.latitude": "lat",
            "vehicle.position.longitude": "lon",
            "vehicle.trip.route_id": "route_short",
            "vehicle.stop_id": "next_stop_id",
            "vehicle.vehicle.id": "vehicle_id"
        }
        gtfs_rt = pd.read_parquet(raw_GTFS_path).rename(columns=col_remappings)
        gtfs_rt["next_stop_id"] = "MTA_" + gtfs_rt["next_stop_id"]

    
    gtfs_rt = gtfs_rt.merge(bus_stops, left_on="next_stop_id", right_on="stop_id", how="left")
    gtfs_rt = gtfs_rt[["trip_id", "route_short", "timestamp", "next_stop_id"]].sort_values(["trip_id", "timestamp"]).drop_duplicates(["trip_id", "next_stop_id"]).dropna()
    gtfs_rt["prev_stop_id"] = gtfs_rt.groupby("trip_id")["next_stop_id"].shift(1)

    stop_pairs = gtfs_rt[["prev_stop_id", "next_stop_id"]]
    stop_pairs = stop_pairs.merge(bus_stops[["stop_id", "stop_name", "osmid"]], left_on="prev_stop_id", right_on="stop_id")
    stop_pairs = stop_pairs.merge(bus_stops[["stop_id", "stop_name", "osmid"]], left_on="next_stop_id", right_on="stop_id", suffixes=["_prev", "_next"])
    stop_pairs = stop_pairs.rename(columns={"osmid_next": "osmid", "osmid_prev": "prev_osmid"})
    stop_pairs["osmid"] = stop_pairs["osmid"].astype(int)
    stop_pairs["prev_osmid"] = stop_pairs["prev_osmid"].astype(int)

    return stop_pairs.drop_duplicates()

def get_pair_paths(stop_pairs: pd.DataFrame, G: nx.Graph, nodes: gpd.GeoDataFrame, calculated_pair_path: str = "data/node_pairs.parquet") -> pd.DataFrame:
    """Get paths between pairs of stops."""
    precalculate_node_pair_distances(stop_pairs[["osmid", "prev_osmid"]], calculated_pair_path, G, nodes)
    node_pair_dists = pd.read_parquet(calculated_pair_path)
    return stop_pairs[["osmid", "prev_osmid", "next_stop_id", "prev_stop_id", "stop_name_prev", "stop_name_next"]].merge(node_pair_dists)

def full_process_stops(tree: KDTree, nodes: gpd.GeoDataFrame, G: nx.Graph, GTFS_PATH: str, calculated_pair_path: str = "data/node_pairs.parquet", stops_path: str = "/home/data/test/cities/C3562/stops.geojson") -> pd.DataFrame:
    """Fully process bus stops data."""
    bus_stops = get_bus_stops(stops_path)
    bus_stops = bus_stops_nodes(bus_stops, tree, nodes)
    stop_pairs = get_stop_pairs(bus_stops, GTFS_PATH)
    stop_pairs = get_pair_paths(stop_pairs, G, nodes, calculated_pair_path)

    return stop_pairs[["next_stop_id", "prev_stop_id", "stop_name_prev", "stop_name_next", "shortest_path"]].rename(columns={"shortest_path": "shortest_path_stops"})

def check_in_bus_path(row: pd.Series) -> bool:
    """Check if a node is in the bus path."""
    return isinstance(row["shortest_path_stops"], (list, np.ndarray)) and row["osmid"] in row["shortest_path_stops"]

def process_gtfs_rt_main(tree: KDTree, nodes: gpd.GeoDataFrame, G: nx.Graph, gtfs_path: str, calculated_pair_path: str, out_path: str, stops_with_paths: pd.DataFrame = None, max_distance_to_node: float = 100) -> None:
    """Main function to process GTFS realtime data."""
    print("Preprocessing bus data")
    if gtfs_path.endswith('.gz'):
        buses = read_parquet_from_tar_gz(gtfs_path)
    else:
        col_remappings = {
            "vehicle.trip.trip_id": "trip_id",
            "vehicle.timestamp": "timestamp",
            "vehicle.position.latitude": "lat",
            "vehicle.position.longitude": "lon",
            "vehicle.trip.route_id": "route_short",
            "vehicle.stop_id": "next_stop_id",
            "vehicle.vehicle.id": "vehicle_id"
        }
        buses = pd.read_parquet(gtfs_path).rename(columns=col_remappings)
        buses["next_stop_id"] = "MTA_" + buses["next_stop_id"]
    buses = prep_coords(buses, 'lat', 'lon')

    print("Tagging bus locations with nodes")
    buses_with_nodes = tag_feed_with_nodes(buses, tree, nodes)
    buses_with_nodes["distance_to_node"] = calculate_distance_to_node(buses_with_nodes)

    print("Calculating distance pairs")
    prepped_trips = prep_buses_nodes(buses_with_nodes, max_distance_to_node)

    if stops_with_paths is not None:
        print(f"Initial shape: {prepped_trips.shape}")
        prepped_trips = prepped_trips.merge(stops_with_paths)
        prepped_trips["in_bus_path"] = prepped_trips.apply(check_in_bus_path, axis=1)
        prepped_trips = prepped_trips[prepped_trips["in_bus_path"]]
        print(f"Shape after filtering: {prepped_trips.shape}")

    node_pair_df = prepped_trips[["osmid", "prev_osmid"]]

    precalculate_node_pair_distances(node_pair_df, calculated_pair_path, G, nodes)
    buses_with_speeds = calculate_speeds(prepped_trips).reset_index()

    print("Exploding edges")
    segment_speeds = pd.concat([explode_edges(row) for _, row in buses_with_speeds.iterrows()])
    bus_speed_segmented = buses_with_speeds.merge(segment_speeds, left_on="index", right_on="idx").drop(columns=["index"])

    print("Writing to parquet")
    bus_speed_segmented.to_parquet(out_path)



In [None]:
read_parquet_from_tar_gz("https://urbantech-public.s3.amazonaws.com/DO-NOT-DELETE-BUSOBSERVATORY-PUBLIC-DATASET/one-system-day.tar.gz")

In [None]:
#GTFS_PATH = "https://urbantech-public.s3.amazonaws.com/DO-NOT-DELETE-BUSOBSERVATORY-PUBLIC-DATASET/one-system-day.tar.gz"
GTFS_PATH = "/home/data/bus-weather/raw_bus_gtfs_rt_202230917_20230930.parquet"
CALCULATED_PAIR_PATH = "data/node_pairs.parquet"
OUT_PATH = "data/buses_with_segmented_storm.parquet"
#OUT_PATH = "data/buses_test.parquet"
STOPS_PATH = "/home/data/test/cities/C3562/stops.geojson"

In [None]:
read_parquet_from_tar_gz("https://urbantech-public.s3.amazonaws.com/DO-NOT-DELETE-BUSOBSERVATORY-PUBLIC-DATASET/one-system-day.tar.gz")

In [None]:
pd.read_parquet(GTFS_PATH)

In [None]:
tree, nodes, G = get_node_data()

In [None]:
stops_with_paths = full_process_stops(tree, nodes, G, GTFS_PATH, calculated_pair_path = CALCULATED_PAIR_PATH, stops_path = "/home/data/test/cities/C3562/stops.geojson")

In [None]:
stops_with_paths

In [None]:
process_gtfs_rt_main(tree, nodes, G, GTFS_PATH, CALCULATED_PAIR_PATH, OUT_PATH, stops_with_paths)

In [None]:
gpd.read_parquet(OUT_PATH)

In [None]:
new_buses = gpd.read_parquet(OUT_PATH)
old_buses = gpd.read_parquet("/home/canyon/Bus-Weather-Impacts/data/buses_with_segmented.parquet")

In [None]:
quantiles = [0.01, 0.1, 0.25, 0.5, 0.75, 0.9, .99]


In [None]:
new_buses[["speed_osm", "speed_euclid", "dist_ratio"]].describe(percentiles=quantiles)

In [None]:
old_buses[["speed_osm", "speed_euclid", "dist_ratio"]].describe(percentiles=quantiles)

In [None]:
bus_speed_segemented.query("speed_osm < 70").query("`from` == 4209661118 & to == 4209661121.00")['speed_osm'].hist()

In [None]:
segment_speeds.to_parquet("segments_test.parquet")

In [None]:
prepped_trips["time_diff_seconds"] = prepped_trips.groupby("trip_id")["timestamp"].diff().dt.total_seconds()