In [93]:
import datetime
import logging
import multiprocessing as mp
from functools import partial
from pathlib import Path
from typing import Union

import geopandas as gpd
import pandas as pd
from geopy.distance import great_circle
from gtfsblocks import Feed, filter_blocks_by_route
from mappymatch.constructs.geofence import Geofence
from mappymatch.constructs.trace import Trace
from mappymatch.maps.nx.nx_map import NetworkType, NxMap
from mappymatch.matchers.lcss.lcss import LCSSMatcher

from nrel.routee.transit.prediction.grade.add_grade import run_gradeit_parallel
from nrel.routee.transit.prediction.grade.tile_resolution import TileResolution
from nrel.routee.transit.prediction.create_depot_deadhead_trips import create_depot_deadhead_trips
from nrel.routee.transit.prediction.create_depot_deadhead_stops import create_depot_deadhead_stops
from nrel.routee.transit.prediction.create_betweenTrip_deadhead_trips import create_betweenTrip_deadhead_trips
from nrel.routee.transit.prediction.create_betweenTrip_deadhead_stops import create_betweenTrip_deadhead_stops
from nrel.routee.transit.prediction.add_depot_to_blocks import add_depot_to_blocks  
from nrel.routee.transit.prediction.generate_deadhead_traces import add_deadhead_trips

In [94]:
req_cols = {
        "stop_times": [
            "arrival_time",
            "departure_time",
            "shape_dist_traveled",
            "stop_id",
        ],
        "shapes": ["shape_dist_traveled"],
    }
feed = Feed.from_dir('/Users/yhe/github_repo/routee-transit/sample-inputs/saltlake/gtfs', columns=req_cols)

In [95]:
feed.stops

Unnamed: 0,stop_id,stop_lat,stop_lon
0,391,40.663345,-111.987053
1,506,40.660280,-111.987056
2,870,40.692636,-111.958136
3,887,40.693910,-111.957789
4,1022,40.772535,-111.837594
...,...,...,...
5229,25407,40.718417,-111.983917
5230,25410,40.580596,-111.829459
5231,25421,41.094656,-112.013359
5232,25423,40.983302,-111.911566


In [96]:
feed.trips

Unnamed: 0,trip_id,route_id,service_id,block_id,shape_id
0,5182559,19930,2,1155937,226296
1,5182560,19930,2,1155939,226296
2,5182561,19930,2,1155941,226296
3,5182562,19930,2,1155938,226296
4,5182563,19930,2,1155943,226296
...,...,...,...,...,...
14902,5206452,91962,22,1157708,227290
14903,5206453,91962,22,1157709,227290
14904,5206454,91962,22,1157708,227290
14905,5206455,91962,22,1157709,227290


In [99]:
df_stop_times = feed.stop_times 

In [100]:
df_trip_time = df_stop_times.groupby('trip_id').agg(start_time = ("arrival_time","min"),end_time = ("arrival_time","max")).reset_index()

In [101]:
df_trip_time

Unnamed: 0,trip_id,start_time,end_time
0,5167455,0 days 11:30:00,0 days 13:24:00
1,5167457,0 days 21:00:00,0 days 22:54:00
2,5167458,0 days 16:00:00,0 days 18:03:00
3,5167459,0 days 20:30:00,0 days 22:24:00
4,5167460,0 days 15:30:00,0 days 17:33:00
...,...,...,...
14902,5231656,0 days 22:09:00,0 days 22:18:00
14903,5231657,0 days 22:39:00,0 days 22:48:00
14904,5231658,0 days 23:09:00,0 days 23:18:00
14905,5231659,0 days 05:25:00,0 days 09:35:00


In [102]:
start_hours = df_trip_time['start_time'].dt.total_seconds()/3600

In [103]:
start_hours

0        11.500000
1        21.000000
2        16.000000
3        20.500000
4        15.500000
           ...    
14902    22.150000
14903    22.650000
14904    23.150000
14905     5.416667
14906    16.416667
Name: start_time, Length: 14907, dtype: float64

In [24]:
all_trips = feed.trips
trips_df = all_trips[all_trips['route_id'] == '27614']

In [25]:
date_incl="2023/08/02"
routes_incl=["205"]

if date_incl is not None:
    trips_df = feed.get_trips_from_date(date_incl)
    if len(trips_df) == 0:
        raise ValueError(f"Feed does not contain any trips on {date_incl}")
else:
    trips_df = feed.get_trips_from_sids(feed.trips.service_id.unique().tolist())

if routes_incl is not None:
    trips_df = filter_blocks_by_route(
        trips=trips_df,
        routes=routes_incl,
        route_column="route_short_name",
        route_method="exclusive",
    )

    if len(trips_df) == 0:
        raise ValueError(
            "There are no active trips on your selected routes and date."
        )

shapes_incl = trips_df.shape_id.unique()
shapes_df = feed.shapes[feed.shapes.shape_id.isin(shapes_incl)]
# logger.info(
#     f"Restricted feed to {len(trips_df)} trips and {len(shapes_incl)} shapes"
# )

In [26]:
len(trips_df)

56

In [28]:
trips_df

Unnamed: 0,trip_id,route_id,service_id,block_id,shape_id,route_short_name,route_type,route_desc,agency_id
6065,5170968,27614,4,1155702,226466,205,3,,
6066,5170969,27614,4,1155703,226466,205,3,,
6067,5170970,27614,4,1155704,226466,205,3,,
6069,5170972,27614,4,1155700,226466,205,3,,
6070,5170973,27614,4,1155705,226466,205,3,,
6071,5170974,27614,4,1155701,226466,205,3,,
6072,5170975,27614,4,1155702,226466,205,3,,
6073,5170976,27614,4,1155703,226466,205,3,,
6074,5170977,27614,4,1155704,226466,205,3,,
6076,5170979,27614,4,1155700,226466,205,3,,


In [29]:
deadhead_trips_df = create_depot_deadhead_trips(trips_df)

In [34]:
depot_directory = '/Users/yhe/github_repo/routee-transit/FTA_Depot/Transit_Depot.shp'

In [35]:
first_stops_gdf, last_stops_gdf = add_depot_to_blocks(trips_df, feed, path_to_depots=depot_directory)
deadhead_stop_times_df, deadhead_stops_df = create_depot_deadhead_stops(first_stops_gdf, last_stops_gdf, deadhead_trips_df)
# Generate deadhead trip shapes for trips from depot to first stop
all_points = pd.concat([first_stops_gdf['geometry_origin'], first_stops_gdf['geometry_destination']])
lons = all_points.apply(lambda p: p.x)
lats = all_points.apply(lambda p: p.y)
min_lon, max_lon = lons.min(), lons.max() # Bounding box
min_lat, max_lat = lats.min(), lats.max() # Bounding box
buffer_deg_lat = 0.018     # Roughly 2 km buffer in degrees
buffer_deg_lon = 0.022     # Roughly 2 km buffer in degrees
miny = min_lat - buffer_deg_lat
maxy = max_lat + buffer_deg_lat
minx = min_lon - buffer_deg_lon
maxx = max_lon + buffer_deg_lon
from_depot_deadhead_shapes_df = add_deadhead_trips(
    df = first_stops_gdf,
    n_processes = 1,
    bbox = [minx, miny, maxx, maxy]
    )
from_depot_deadhead_shapes_df['shape_id'] = from_depot_deadhead_shapes_df['shape_id'].apply(lambda x: 'from_depot_' + x)
# Generate deadhead trip shapes for trips from last stop to depot
all_points = pd.concat([last_stops_gdf['geometry_origin'], last_stops_gdf['geometry_destination']])
lons = all_points.apply(lambda p: p.x)
lats = all_points.apply(lambda p: p.y)
min_lon, max_lon = lons.min(), lons.max() # Bounding box
min_lat, max_lat = lats.min(), lats.max() # Bounding box
buffer_deg_lat = 0.018     # Roughly 2 km buffer in degrees
buffer_deg_lon = 0.022     # Roughly 2 km buffer in degrees         
miny = min_lat - buffer_deg_lat
maxy = max_lat + buffer_deg_lat
minx = min_lon - buffer_deg_lon
maxx = max_lon + buffer_deg_lon
to_depot_deadhead_shapes_df = add_deadhead_trips(
    df = last_stops_gdf,
    n_processes = 1,
    bbox = [minx, miny, maxx, maxy]
    )
to_depot_deadhead_shapes_df['shape_id'] = to_depot_deadhead_shapes_df['shape_id'].apply(lambda x: 'to_depot_' + x)
# Combine all deadhead shapes
deadhead_shapes_df = pd.concat([from_depot_deadhead_shapes_df, to_depot_deadhead_shapes_df], ignore_index=True)

# Update trips_df, shapes_df, and feed
# Before updating, update deadhead_trips_df as some blocks may have the same first and last stop therefore won't shown in deadhead_shapes_df
deadhead_trips_df = deadhead_trips_df[deadhead_trips_df['shape_id'].isin(deadhead_shapes_df['shape_id'].unique())]
# Update trips_df, shapes_df, and feed
trips_df_1 = pd.concat([trips_df, deadhead_trips_df], ignore_index=True)
shapes_df = pd.concat([shapes_df, deadhead_shapes_df], ignore_index=True)
feed.trips = pd.concat([feed.trips, deadhead_trips_df], ignore_index=True)
feed.shapes = pd.concat([feed.shapes, deadhead_shapes_df], ignore_index=True)
feed.stop_times = pd.concat([feed.stop_times, deadhead_stop_times_df], ignore_index=True)
feed.stops = pd.concat([feed.stops, deadhead_stops_df], ignore_index=True)

In [36]:
betweenTrip_deadhead_trips_df = create_betweenTrip_deadhead_trips(trips_df,stop_times_df)
# Create between trip deadhead stop_times and stops
betweenTrip_deadhead_stop_times_df, betweenTrip_deadhead_stops_df, betweenTrip_ODs = create_betweenTrip_deadhead_stops(feed, betweenTrip_deadhead_trips_df)
# Generate deadhead trip shapes for trips from depot to first stop
all_points = pd.concat([betweenTrip_ODs['geometry_origin'], betweenTrip_ODs['geometry_destination']])
lons = all_points.apply(lambda p: p.x)
lats = all_points.apply(lambda p: p.y)
min_lon, max_lon = lons.min(), lons.max() # Bounding box
min_lat, max_lat = lats.min(), lats.max() # Bounding box
buffer_deg_lat = 0.018     # Roughly 2 km buffer in degrees
buffer_deg_lon = 0.022     # Roughly 2 km buffer in degrees
miny = min_lat - buffer_deg_lat
maxy = max_lat + buffer_deg_lat
minx = min_lon - buffer_deg_lon
maxx = max_lon + buffer_deg_lon
# Remove ODs with same origin and destination
betweenTrip_ODs = betweenTrip_ODs[betweenTrip_ODs.geometry_origin != betweenTrip_ODs.geometry_destination]
betweenTrip_deadhead_shapes_df = add_deadhead_trips(
    df = betweenTrip_ODs,
    n_processes = 1,
    bbox = [minx, miny, maxx, maxy]
    )

# Update trips_df, shapes_df, and feed
# Before updating, update deadhead_trips_df as some blocks may have the same first and last stop therefore won't shown in deadhead_shapes_df
betweenTrip_deadhead_trips_df = betweenTrip_deadhead_trips_df[betweenTrip_deadhead_trips_df['shape_id'].isin(betweenTrip_deadhead_shapes_df['shape_id'].unique())]
# Update trips_df, shapes_df, and feed
trips_df_2 = pd.concat([trips_df_1, betweenTrip_deadhead_trips_df], ignore_index=True)
shapes_df = pd.concat([shapes_df, betweenTrip_deadhead_shapes_df], ignore_index=True)
feed.trips = pd.concat([feed.trips, betweenTrip_deadhead_trips_df], ignore_index=True)
feed.shapes = pd.concat([feed.shapes, betweenTrip_deadhead_shapes_df], ignore_index=True)
feed.stop_times = pd.concat([feed.stop_times, betweenTrip_deadhead_stop_times_df], ignore_index=True)
feed.stops = pd.concat([feed.stops, betweenTrip_deadhead_stops_df], ignore_index=True)

In [50]:
test = feed.shapes
len(test['shape_id'].unique())

274

In [54]:
len(shapes_df['shape_id'].unique())

38

In [55]:
shapes_df['shape_id'].unique()

array(['226466', '226467', 'from_depot_1155700', 'from_depot_1155701',
       'from_depot_1155702', 'from_depot_1155703', 'from_depot_1155704',
       'from_depot_1155705', 'to_depot_1155700', 'to_depot_1155701',
       'to_depot_1155702', 'to_depot_1155703', 'to_depot_1155704',
       'to_depot_1155705', '5171000_to_5170972', '5171006_to_5170979',
       '5171013_to_5170986', '5171020_to_5170992', '5171026_to_5170998',
       '5171001_to_5170974', '5171008_to_5170981', '5171015_to_5170988',
       '5171022_to_5170994', '5171002_to_5170975', '5171009_to_5170982',
       '5171016_to_5170989', '5171023_to_5170995', '5171003_to_5170976',
       '5171010_to_5170983', '5171017_to_5170990', '5171024_to_5170996',
       '5171004_to_5170977', '5171011_to_5170984', '5171018_to_5170991',
       '5171025_to_5170997', '5171007_to_5170980', '5171014_to_5170987',
       '5171021_to_5170993'], dtype=object)

In [39]:
def upsample_shape(shape_df: pd.DataFrame) -> pd.DataFrame:
    """Upsample a GTFS shape DataFrame to generate a roughly 1 Hz GPS trace.

    Interpolates latitude, longitude, and distance traveled, assuming a constant speed.
    The function performs the following steps:

    * Calculates the distance between consecutive shape points using great-circle distance
    * Computes the cumulative distance traveled along the shape
    * Assigns timestamps based on constant speed (30 km/h)
    * Resamples and interpolates the shape to 1-second intervals
    * Returns DataFrame with interpolated coordinates, timestamps, and distances

    Args:
        shape_df: DataFrame containing GTFS shape points with columns
            'shape_pt_lat', 'shape_pt_lon', and 'shape_id'.

    Returns:
        Upsampled DataFrame with columns 'shape_pt_lat', 'shape_pt_lon',
        'shape_dist_traveled', 'timestamp', and 'shape_id', sampled at 1 Hz.
    """

    # Shift latitude and longitude to get previous point
    shape_df["prev_latitude"] = shape_df["shape_pt_lat"].shift()
    shape_df["prev_longitude"] = shape_df["shape_pt_lon"].shift()

    # Calculate the distance between consecutive points using great_circle
    # TODO: move away from apply() for speed
    shape_df["distance_km"] = shape_df.apply(
        lambda row: great_circle(
            (row["prev_latitude"], row["prev_longitude"]),  # Previous point
            (row["shape_pt_lat"], row["shape_pt_lon"]),  # Current point
        ).kilometers
        if pd.notnull(row["prev_latitude"])
        else 0,
        axis=1,
    )

    # Calculate total distance
    total_distance_km = shape_df["distance_km"].sum()

    # Use calculated total distance instead of shape_dist_traveled
    shape_df["shape_dist_traveled"] = shape_df["distance_km"].cumsum()

    # Speed is assumed to be 30 km/h, which is about 10 (8.33) m per second/node
    shape_df["segment_duration_delta"] = (
        shape_df["shape_dist_traveled"]
        / shape_df["shape_dist_traveled"].max()
        * datetime.timedelta(seconds=round(total_distance_km / 30 * 3600))
    )
    shape_df["segment_duration_delta"] = shape_df["segment_duration_delta"].apply(
        lambda x: datetime.timedelta(seconds=round(x.total_seconds()))
    )
    # Define an arbitrary date to convert from timedelta to datetime
    date_tmp = datetime.datetime(2023, 9, 3)
    shape_df["timestamp"] = (
        datetime.timedelta(seconds=0) + shape_df["segment_duration_delta"] + date_tmp
    )

    # Upsample to 1s
    shape_id_tmp = shape_df.shape_id.iloc[0]
    shape_df = (
        shape_df[["shape_pt_lat", "shape_pt_lon", "timestamp", "shape_dist_traveled"]]
        .drop_duplicates(subset=["timestamp"])
        .set_index("timestamp")
        .resample("1s")
        .interpolate(method="linear")
    )

    # Now we have the 1 Hz gps trace for each trip with timestamp
    shape_df = shape_df.reset_index(drop=True)
    shape_df["shape_id"] = shape_id_tmp

    return shape_df

In [40]:
df_shape_list = [group for _, group in shapes_df.groupby("shape_id")]

In [42]:
upsampled_shapes_list = []
for i in range(len(df_shape_list)):
    shape_df = df_shape_list[i]
    upsampled_shapes = upsample_shape(shape_df)
    upsampled_shapes_list.append(upsampled_shapes)

In [51]:
len(upsampled_shapes_list)

38

In [89]:
upsampled_shapes_list[-20]

Unnamed: 0,shape_pt_lat,shape_pt_lon,shape_dist_traveled,shape_id
0,40.761126,-111.939181,0.0,5171018_to_5170991
1,40.761127,-111.939095,0.007272,5171018_to_5170991
2,40.761127,-111.939008,0.014543,5171018_to_5170991
3,40.761128,-111.938922,0.021815,5171018_to_5170991


In [44]:
def match_shape_to_osm(upsampled_shape_df: pd.DataFrame) -> pd.DataFrame:
    """Match a given GTFS shape DataFrame to the OpenStreetMap (OSM) road network.

    This function uses mappymatch to add OSM network information to the shape trace.
    The trace should be upsampled beforehand to approximately 1 Hz/8 m for the most
    accurate expected mapping performance. The function creates a Trace from the input
    DataFrame, constructs a geofence around the trace, extracts the OSM road network
    within the geofence, and applies the mappymatch LCSS matcher to align the trace to
    the network. The output DataFrame retains the full shape while adding network
    information to each row.

    Args:
        upsampled_shape_df (pd.DataFrame): DataFrame containing the shape points with
            latitude and longitude columns ("shape_pt_lat" and "shape_pt_lon").
    Returns:
        pd.DataFrame: A DataFrame combining the original upsampled shape points with
            their corresponding OSM network matches.
    """
    # Create mappymatch trace
    trace = Trace.from_dataframe(
        upsampled_shape_df, lat_column="shape_pt_lat", lon_column="shape_pt_lon"
    )
    # Create geofence and use it to pull network
    geofence = Geofence.from_trace(trace, padding=1e3)
    nxmap = NxMap.from_geofence(geofence, network_type=NetworkType.DRIVE)
    # Run map matching algorithm
    matcher = LCSSMatcher(nxmap)
    matches = matcher.match_trace(trace).matches_to_dataframe()
    # Combine shape with network details
    df_result = pd.concat([upsampled_shape_df, matches], axis=1)
    return df_result

In [45]:
matched_shapes_list = []
for i in range(len(upsampled_shapes_list)):
    shape_df = upsampled_shapes_list[i]
    matched_shapes = match_shape_to_osm(shape_df)
    matched_shapes_list.append(matched_shapes)

  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)
  df = df.fillna(np.nan)


In [46]:
matched_shapes_df = pd.concat(matched_shapes_list)

In [91]:
matched_shapes_df[matched_shapes_df['shape_id'] == '5171018_to_5170991']

Unnamed: 0,shape_pt_lat,shape_pt_lon,shape_dist_traveled,shape_id,coordinate_id,distance_to_road,road_id,geom,origin_junction_id,destination_junction_id,road_key,kilometers,travel_time
0,40.761126,-111.939181,0.0,5171018_to_5170991,0,,,,,,,,
1,40.761127,-111.939095,0.007272,5171018_to_5170991,1,,,,,,,,
2,40.761127,-111.939008,0.014543,5171018_to_5170991,2,,,,,,,,
3,40.761128,-111.938922,0.021815,5171018_to_5170991,3,,,,,,,,


In [52]:
len(matched_shapes_df['shape_id'].unique())

38

In [56]:
matched_shapes_df['shape_id'].unique()

array(['226466', '226467', '5171000_to_5170972', '5171001_to_5170974',
       '5171002_to_5170975', '5171003_to_5170976', '5171004_to_5170977',
       '5171006_to_5170979', '5171007_to_5170980', '5171008_to_5170981',
       '5171009_to_5170982', '5171010_to_5170983', '5171011_to_5170984',
       '5171013_to_5170986', '5171014_to_5170987', '5171015_to_5170988',
       '5171016_to_5170989', '5171017_to_5170990', '5171018_to_5170991',
       '5171020_to_5170992', '5171021_to_5170993', '5171022_to_5170994',
       '5171023_to_5170995', '5171024_to_5170996', '5171025_to_5170997',
       '5171026_to_5170998', 'from_depot_1155700', 'from_depot_1155701',
       'from_depot_1155702', 'from_depot_1155703', 'from_depot_1155704',
       'from_depot_1155705', 'to_depot_1155700', 'to_depot_1155701',
       'to_depot_1155702', 'to_depot_1155703', 'to_depot_1155704',
       'to_depot_1155705'], dtype=object)

In [57]:
def extend_trip_traces(
    trips_df: pd.DataFrame,
    matched_shapes_df: pd.DataFrame,
    feed: Feed,
    add_stop_flag: bool = False,
    n_processes: int | None = mp.cpu_count(),
) -> pd.DataFrame:
    """Extend trip shapes with stop details and estimated timestamps from GTFS.

    This function processes GTFS trip and shape data to:

    * Summarize stop times for each trip (first/last stop and times)
    * Merge stop time summaries into the trips DataFrame
    * Attach stop coordinates to stop times
    * Merge trip and shape data to create ordered trip traces
    * Optionally, attach stop indicators to shape trace points
    * Estimate timestamps for each trace point based on scheduled trip duration and distance

    Args:
        trips_df: DataFrame containing trip information, including
            'trip_id' and 'shape_id'.
        matched_shapes_df: DataFrame with shape points matched to trips,
            including 'shape_id' and 'shape_dist_traveled'.
        feed: GTFS feed object containing 'stop_times' and 'stops'
            DataFrames.
        add_stop_flag: If True, attaches stop indicators to shape trace
            points. Defaults to False.
        n_processes: Number of processes to run in parallel using
            multiprocessing. Defaults to mp.cpu_count().

    Returns:
        A list of DataFrames, one per trip, with extended trace information
        including estimated timestamps.
    """
    # Start by summarizing stop times: get first and last stop, plus start/end times
    stop_times_by_trip = (
        feed.stop_times.groupby("trip_id")
        .agg(
            {
                "arrival_time": "first",
                "departure_time": "last",
                "stop_id": ["first", "last"],
            }
        )
        .reset_index()
    )
    stop_times_by_trip.columns = [
        "trip_id",
        "o_time",
        "d_time",
        "o_stop_id",
        "d_stop_id",
    ]

    # Add start/end times and stops to trips DF
    # TODO: consider doing this with gtfsblocks add_trip_data()
    trips_df = pd.merge(trips_df, stop_times_by_trip, how="left", on="trip_id")
    trips_df["o_time"] = pd.to_timedelta(trips_df["o_time"])
    trips_df["d_time"] = pd.to_timedelta(trips_df["d_time"])
    trips_df["trip_duration"] = trips_df["d_time"] - trips_df["o_time"]

    # Add stop coordinates to stop_times
    stop_times_ext = feed.stop_times[["trip_id", "stop_sequence", "stop_id"]].merge(
        feed.stops[["stop_id", "stop_lat", "stop_lon"]], on="stop_id"
    )

    # calculate approximate timestamps for each GPS trace
    # TODO: I think this big merge can be avoided
    trip_shape = pd.merge(
        trips_df[["trip_id", "shape_id", "o_time", "d_time"]],
        matched_shapes_df,
        how="left",
        on="shape_id",
    )
    trip_shape = trip_shape.sort_values(
        by=["trip_id", "shape_dist_traveled"]
    ).reset_index(drop=True)
    trip_shapes_list = [item for _, item in trip_shape.groupby("trip_id")]

    # Attach stops to shape traces. Note that this just adds a dummy variable column
    # indicating whether or not a stop is located at a given point on the shape.
    if add_stop_flag:
        attach_stop_partial = partial(
            add_stop_flags_to_shape, stop_times_ext=stop_times_ext
        )
        with mp.Pool(n_processes) as pool:
            trip_shapes_list = pool.map(attach_stop_partial, trip_shapes_list)

    # Attach timestamps to each trip. These are simply based on the scheduled trip
    # duration and shape_dist_traveled, assuming a constant speed for the entire trip.
    # TODO: improve timestamp estimates
    with mp.Pool(n_processes) as pool:
        trips_with_timestamps_list = pool.map(
            estimate_trip_timestamps, trip_shapes_list
        )
    logger.info("Finished attaching timestamps")
    return pd.concat(trips_with_timestamps_list)

In [59]:
def estimate_trip_timestamps(trip_shape_df: pd.DataFrame) -> pd.DataFrame:
    """Estimate timestamps for each shape point of a trip based on distance traveled.

    Args:
        trip_shape_df (pd.DataFrame): DataFrame containing trip shape data with columns:
            - 'shape_dist_traveled': Cumulative distance traveled along the shape.
            - 'o_time': Origin time (datetime) of the trip.
            - 'd_time': Destination time (datetime) of the trip.
    Returns:
        pd.DataFrame: Modified DataFrame with additional columns:
            - 'segment_duration_delta': Estimated duration for each segment as timedelta.
            - 'timestamp': Estimated timestamp for each segment.
            - 'Datetime_nearest5': Timestamp rounded to the nearest 5 minutes.
            - 'hour': Hour component of the rounded timestamp.
            - 'minute': Minute component of the rounded timestamp.
    """
    trip_shape_df["segment_duration_delta"] = (
        trip_shape_df["shape_dist_traveled"]
        / (trip_shape_df["shape_dist_traveled"].max()+0.0001)
        * (trip_shape_df["d_time"] - trip_shape_df["o_time"])
    )
    trip_shape_df["segment_duration_delta"] = trip_shape_df[
        "segment_duration_delta"
    ].apply(lambda x: datetime.timedelta(seconds=round(x.total_seconds())))
    trip_shape_df["timestamp"] = (
        trip_shape_df["o_time"] + trip_shape_df["segment_duration_delta"]
    )

    ## get hour and minute of gps timestamp
    trip_shape_df["Datetime_nearest5"] = trip_shape_df["timestamp"].dt.round("5min")
    trip_shape_df["hour"] = trip_shape_df["Datetime_nearest5"].dt.components["hours"]
    trip_shape_df["minute"] = trip_shape_df["Datetime_nearest5"].dt.components[
        "minutes"
    ]

    return trip_shape_df

In [61]:
# Start by summarizing stop times: get first and last stop, plus start/end times
stop_times_by_trip = (
    feed.stop_times.groupby("trip_id")
    .agg(
        {
            "arrival_time": "first",
            "departure_time": "last",
            "stop_id": ["first", "last"],
        }
    )
    .reset_index()
)
stop_times_by_trip.columns = [
    "trip_id",
    "o_time",
    "d_time",
    "o_stop_id",
    "d_stop_id",
]

# Add start/end times and stops to trips DF
# TODO: consider doing this with gtfsblocks add_trip_data()
trips_df = pd.merge(trips_df_2, stop_times_by_trip, how="left", on="trip_id")
trips_df["o_time"] = pd.to_timedelta(trips_df["o_time"])
trips_df["d_time"] = pd.to_timedelta(trips_df["d_time"])
trips_df["trip_duration"] = trips_df["d_time"] - trips_df["o_time"]

# Add stop coordinates to stop_times
stop_times_ext = feed.stop_times[["trip_id", "stop_sequence", "stop_id"]].merge(
    feed.stops[["stop_id", "stop_lat", "stop_lon"]], on="stop_id"
)

# calculate approximate timestamps for each GPS trace
# TODO: I think this big merge can be avoided
trip_shape = pd.merge(
    trips_df[["trip_id", "shape_id", "o_time", "d_time"]],
    matched_shapes_df,
    how="left",
    on="shape_id",
)
trip_shape = trip_shape.sort_values(
    by=["trip_id", "shape_dist_traveled"]
).reset_index(drop=True)
trip_shapes_list = [item for _, item in trip_shape.groupby("trip_id")]

In [63]:
trips_with_timestamps_list = []
for i in range(len(trip_shapes_list)):
    trip_shapes_df = trip_shapes_list[i]
    trips_with_timestamps_df = estimate_trip_timestamps(trip_shapes_df)
    trips_with_timestamps_list.append(trips_with_timestamps_df)


In [64]:
trips_df_ext = pd.concat(trips_with_timestamps_list)

In [78]:
len(trips_df_ext['trip_id'].unique())

92

In [81]:
trips_df_ext['trip_id'].unique()

array(['5170968', '5170969', '5170970', '5170972', '5170973', '5170974',
       '5170975', '5170976', '5170977', '5170979', '5170980', '5170981',
       '5170982', '5170983', '5170984', '5170986', '5170987', '5170988',
       '5170989', '5170990', '5170991', '5170992', '5170993', '5170994',
       '5170995', '5170996', '5170997', '5170998', '5171000',
       '5171000_to_5170972', '5171001', '5171001_to_5170974', '5171002',
       '5171002_to_5170975', '5171003', '5171003_to_5170976', '5171004',
       '5171004_to_5170977', '5171006', '5171006_to_5170979', '5171007',
       '5171007_to_5170980', '5171008', '5171008_to_5170981', '5171009',
       '5171009_to_5170982', '5171010', '5171010_to_5170983', '5171011',
       '5171011_to_5170984', '5171013', '5171013_to_5170986', '5171014',
       '5171014_to_5170987', '5171015', '5171015_to_5170988', '5171016',
       '5171016_to_5170989', '5171017', '5171017_to_5170990', '5171018',
       '5171018_to_5170991', '5171020', '5171020_to_5170992', 

In [84]:
trips_df_ext[trips_df_ext['trip_id'] == '5171030_to_depot']

Unnamed: 0,trip_id,shape_id,o_time,d_time,shape_pt_lat,shape_pt_lon,shape_dist_traveled,coordinate_id,distance_to_road,road_id,...,origin_junction_id,destination_junction_id,road_key,kilometers,travel_time,segment_duration_delta,timestamp,Datetime_nearest5,hour,minute
169378,5171030_to_depot,to_depot_1155703,0 days 21:37:00,0 days 21:41:43.509442594,40.760645,-111.939030,0.000000,0,0.000000e+00,"(83634074, 10031770361, 0)",...,83634074.0,1.003177e+10,0.0,0.202348,15.088042,0 days 00:00:00,0 days 21:37:00,0 days 21:35:00,21,35
169379,5171030_to_depot,to_depot_1155703,0 days 21:37:00,0 days 21:41:43.509442594,40.760645,-111.938934,0.008056,1,9.306796e-10,"(83634074, 10031770361, 0)",...,83634074.0,1.003177e+10,0.0,0.202348,15.088042,0 days 00:00:01,0 days 21:37:01,0 days 21:35:00,21,35
169380,5171030_to_depot,to_depot_1155703,0 days 21:37:00,0 days 21:41:43.509442594,40.760645,-111.938839,0.016112,2,0.000000e+00,"(83634074, 10031770361, 0)",...,83634074.0,1.003177e+10,0.0,0.202348,15.088042,0 days 00:00:02,0 days 21:37:02,0 days 21:35:00,21,35
169381,5171030_to_depot,to_depot_1155703,0 days 21:37:00,0 days 21:41:43.509442594,40.760645,-111.938735,0.024881,3,2.226966e-03,"(83634074, 10031770361, 0)",...,83634074.0,1.003177e+10,0.0,0.202348,15.088042,0 days 00:00:02,0 days 21:37:02,0 days 21:35:00,21,35
169382,5171030_to_depot,to_depot_1155703,0 days 21:37:00,0 days 21:41:43.509442594,40.760645,-111.938630,0.033650,4,1.781574e-03,"(83634074, 10031770361, 0)",...,83634074.0,1.003177e+10,0.0,0.202348,15.088042,0 days 00:00:03,0 days 21:37:03,0 days 21:35:00,21,35
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
169725,5171030_to_depot,to_depot_1155703,0 days 21:37:00,0 days 21:41:43.509442594,40.762832,-111.911768,2.895650,347,1.363034e-08,"(359380575, 83650176, 0)",...,359380575.0,8.365018e+07,0.0,0.254002,24.350907,0 days 00:04:40,0 days 21:41:40,0 days 21:40:00,21,40
169726,5171030_to_depot,to_depot_1155703,0 days 21:37:00,0 days 21:41:43.509442594,40.762834,-111.911671,2.903815,348,1.422973e-08,"(359380575, 83650176, 0)",...,359380575.0,8.365018e+07,0.0,0.254002,24.350907,0 days 00:04:41,0 days 21:41:41,0 days 21:40:00,21,40
169727,5171030_to_depot,to_depot_1155703,0 days 21:37:00,0 days 21:41:43.509442594,40.762837,-111.911574,2.911981,349,0.000000e+00,"(359380575, 83650176, 0)",...,359380575.0,8.365018e+07,0.0,0.254002,24.350907,0 days 00:04:42,0 days 21:41:42,0 days 21:40:00,21,40
169728,5171030_to_depot,to_depot_1155703,0 days 21:37:00,0 days 21:41:43.509442594,40.762840,-111.911481,2.919825,350,1.576058e-08,"(359380575, 83650176, 0)",...,359380575.0,8.365018e+07,0.0,0.254002,24.350907,0 days 00:04:43,0 days 21:41:43,0 days 21:40:00,21,40


In [68]:
# Aggregate data at road link level to reduce computational burden
trip_links_df = (
    trips_df_ext.groupby(by=["trip_id", "shape_id", "road_id"])
    .agg(
        start_lat=pd.NamedAgg("shape_pt_lat", "first"),
        start_lon=pd.NamedAgg("shape_pt_lon", "first"),
        end_lat=pd.NamedAgg("shape_pt_lat", "last"),
        end_lon=pd.NamedAgg("shape_pt_lon", "last"),
        geom=pd.NamedAgg("geom", "first"),
        start_timestamp=pd.NamedAgg("timestamp", "first"),
        end_timestamp=pd.NamedAgg("timestamp", "last"),
        kilometers=pd.NamedAgg("kilometers", "mean"),
        travel_time_minutes=pd.NamedAgg("travel_time", "mean"),
    )
    .reset_index()
)
trip_links_df["travel_time_minutes"] /= 60
trips_df_list = [t_df for _, t_df in trip_links_df.groupby("trip_id")]

In [79]:
len(trip_links_df['trip_id'].unique())

68

In [75]:
len(trips_df_list)

68

In [71]:
add_road_grade = True
tile_resolution = TileResolution.ONE_THIRD_ARC_SECOND
n_processes = mp.cpu_count()
if add_road_grade:
    result_df = run_gradeit_parallel(
        trip_dfs_list=trips_df_list,
        tile_resolution=tile_resolution,
        n_processes=n_processes,
    )
else:
    result_df = pd.concat(trips_df_list)

INFO:/Users/yhe/github_repo/routee-transit/nrel/routee/transit/prediction/grade/add_grade.py:Running gradeit on 68 trips with 10 processes.
2025-10-22 09:44:44,434 [INFO] - Running gradeit on 68 trips with 10 processes.
INFO:nrel.routee.transit.prediction.grade.download:Downloading 1 USGS tiles at ONE_THIRD_ARC_SECOND resolution.
2025-10-22 09:44:44,667 [INFO] - Downloading 1 USGS tiles at ONE_THIRD_ARC_SECOND resolution.
INFO:nrel.routee.transit.prediction.grade.download:downloading n41w112
2025-10-22 09:44:45,322 [INFO] - downloading n41w112
  grade = d_elev / distances
  grade = d_elev / distances
  grade = d_elev / distances
  grade = d_elev / distances
  grade = d_elev / distances
  grade = d_elev / distances
  grade = d_elev / distances
  grade = d_elev / distances


In [76]:
result_df['trip_id'].unique()

array(['5170968', '5170969', '5170970', '5170972', '5170973', '5170974',
       '5170975', '5170976', '5170977', '5170979', '5170980', '5170981',
       '5170982', '5170983', '5170984', '5170986', '5170987', '5170988',
       '5170989', '5170990', '5170991', '5170992', '5170993', '5170994',
       '5170995', '5170996', '5170997', '5170998', '5171000', '5171001',
       '5171002', '5171003', '5171004', '5171006', '5171007', '5171008',
       '5171009', '5171010', '5171011', '5171013', '5171014', '5171015',
       '5171016', '5171017', '5171018', '5171020', '5171021', '5171022',
       '5171023', '5171024', '5171025', '5171025_to_depot', '5171026',
       '5171026_to_depot', '5171027', '5171027_to_depot', '5171028',
       '5171028_to_depot', '5171029', '5171029_to_depot', '5171030',
       '5171030_to_depot', 'depot_to_5170968', 'depot_to_5170969',
       'depot_to_5170970', 'depot_to_5170972', 'depot_to_5170973',
       'depot_to_5170974'], dtype=object)

In [92]:
result_df

Unnamed: 0,trip_id,shape_id,road_id,start_lat,start_lon,end_lat,end_lon,geom,start_timestamp,end_timestamp,kilometers,travel_time_minutes,grade
0,5170968,226466,"(83521718, 9251870522, 0)",40.782386,-111.908402,40.782375,-111.903465,LINESTRING (-12457592.782659424 4980285.490429...,0 days 06:17:39,0 days 06:18:52,0.425530,0.453279,-0.0137
1,5170968,226466,"(83535259, 9251870518, 0)",40.782383,-111.912357,40.782380,-111.910880,LINESTRING (-12458038.171942087 4980283.564529...,0 days 06:16:39,0 days 06:17:01,0.140789,0.149970,-0.0050
2,5170968,226466,"(83538461, 83549941, 0)",40.782392,-111.922465,40.782400,-111.921244,LINESTRING (-12459159.860527169 4980296.472477...,0 days 06:14:08,0 days 06:14:27,0.118044,0.125741,0.0073
3,5170968,226466,"(83541843, 1585109177, 0)",40.767170,-111.879620,40.767170,-111.876820,LINESTRING (-12454385.378698993 4978059.594665...,0 days 06:34:09,0 days 06:34:51,0.240245,0.298564,-0.0042
4,5170968,226466,"(83542119, 83628841, 0)",40.733442,-111.876760,40.732726,-111.876760,LINESTRING (-12454066.849108037 4973114.850671...,0 days 06:45:56,0 days 06:46:10,0.093838,0.139940,0.0025
...,...,...,...,...,...,...,...,...,...,...,...,...,...
16,depot_to_5170974,from_depot_1155701,"(503841269, 367036112, 0)",40.677987,-111.892865,40.677967,-111.894324,LINESTRING (-12455848.294919202 4964960.263052...,0 days 08:25:50.111650640,0 days 08:26:00.111650640,0.129476,0.193087,-0.0013
17,depot_to_5170974,from_depot_1155701,"(1924089334, 83553370, 0)",40.678813,-111.891248,40.677989,-111.891251,LINESTRING (-12455676.740451941 4965090.682365...,0 days 08:25:31.111650640,0 days 08:25:38.111650640,0.098564,0.104991,-0.0017
18,depot_to_5170974,from_depot_1155701,"(3547578572, 83546067, 0)",40.687010,-111.909511,40.687012,-111.908656,LINESTRING (-12457721.401199088 4966284.676023...,0 days 08:22:14.111650640,0 days 08:22:20.111650640,0.080869,0.075375,-0.0046
19,depot_to_5170974,from_depot_1155701,"(3547578584, 83576238, 0)",40.689563,-111.910744,40.688708,-111.909866,LINESTRING (-12457896.79618878 4966659.7695765...,0 days 08:21:42.111650640,0 days 08:21:56.111650640,0.099764,0.092986,0.0008
