In [None]:
"""
GNN Traffic Network Graph Construction

This Python script constructs a Graph Neural Network (GNN) structure to represent traffic points within a road network. It leverages spatial data processing, shortest path analysis, and spatial indexing techniques. The core goal is to generate a graph suitable for use in subsequent traffic pattern analyses using a Graph Neural Network (GNN).

The processing workflow includes the following steps:

Step 1: Loading Preprocessed Data
-----------------------------------
- Load previously prepared data from a pickle file (`preprocessed_data.pkl`).
- This data includes:
  - `traffic_node_mapping`: Mapping of unique traffic point IDs to their spatial coordinates (longitude, latitude).
  - `G_road`: A NetworkX graph representing the detailed road network of the study area.
  - `global_distance_threshold`: Threshold distance (in meters) used to verify the validity of edges between traffic points.

Step 2: Spatial Indexing Preparation
-----------------------------------
- Construct a list of Shapely Point geometries (`global_points`) representing the traffic nodes.
- Create a spatial index (`global_STRtree`) using Shapely’s STRtree for efficient spatial queries.
- Construct a dictionary (`global_point_to_site`) mapping each Point geometry's WKT (Well-Known Text) representation to its corresponding site ID for quick identification.

Step 3: Candidate Edges Generation using KDTree
-----------------------------------
- Use SciPy’s KDTree to efficiently identify candidate edges between traffic points based on proximity.
- For each traffic point, identify its nearest 100 neighbors to create candidate edges.
- Avoid duplicate candidate edges by enforcing an order condition (`site < neighbor`).

Step 4: Candidate Edges Verification
-----------------------------------
- For each candidate edge between two traffic points:
  - Compute the shortest path along the road network (`G_road`) between the two points using NetworkX.
  - Convert this shortest path into a continuous LineString geometry.
  - Create a buffer zone around the path geometry using the predefined `global_distance_threshold`.
  - Query the spatial index (`global_STRtree`) to detect if any other traffic points (excluding the two endpoints) lie within this buffer zone.
  - If another traffic point is found within this threshold, discard this candidate edge (invalid edge).
  - If no other points are detected, the candidate edge is valid. Calculate the edge weight as the inverse of the path length to reflect proximity.

Step 5: GNN Graph Construction
-----------------------------------
- Initialize an empty NetworkX Graph (`G_gnn_parallel`) to represent the final GNN graph structure.
- Add all traffic points as nodes, embedding their spatial coordinates as node attributes.
- Add all verified, valid candidate edges to the graph, storing their computed inverse distance as edge weights, and including their geometry as an attribute.

Final Result:
-----------------------------------
- A comprehensive NetworkX graph (`G_gnn_parallel`) ready for further traffic pattern analysis with GNN models.
- Each node corresponds to a traffic monitoring site.
- Each edge accurately reflects direct traffic interaction based on real-world road connectivity and spatial proximity.

Dependencies:
-----------------------------------
- geopandas
- shapely
- networkx
- numpy
- scipy
- tqdm (for progress monitoring)

Note:
-----------------------------------
- This script processes candidate edges sequentially to ensure compatibility and stability in environments where parallel computing might pose issues (e.g., Jupyter Notebooks).

Author: Peter Guo
Date: 3.27.2025
"""


"\nGNN Traffic Network Graph Construction\n\nThis Python script constructs a Graph Neural Network (GNN) structure to represent traffic points within a road network. It leverages spatial data processing, shortest path analysis, and spatial indexing techniques. The core goal is to generate a graph suitable for use in subsequent traffic pattern analyses using a Graph Neural Network (GNN).\n\nThe processing workflow includes the following steps:\n\nStep 1: Loading Preprocessed Data\n-----------------------------------\n- Load previously prepared data from a pickle file (`preprocessed_data.pkl`).\n- This data includes:\n  - `traffic_node_mapping`: Mapping of unique traffic point IDs to their spatial coordinates (longitude, latitude).\n  - `G_road`: A NetworkX graph representing the detailed road network of the study area.\n  - `global_distance_threshold`: Threshold distance (in meters) used to verify the validity of edges between traffic points.\n\nStep 2: Spatial Indexing Preparation\n----

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import pandas as pd
import numpy as np
import osmnx as ox
import networkx as nx
import os
import geopandas as gpd
from shapely.geometry import Point, LineString
from shapely.ops import nearest_points
from scipy.spatial import KDTree
from shapely.strtree import STRtree
import pickle
from tqdm.notebook import tqdm

In [None]:
pip install osmnx

Collecting osmnx
  Downloading osmnx-2.0.2-py3-none-any.whl.metadata (4.9 kB)
Downloading osmnx-2.0.2-py3-none-any.whl (99 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m99.9/99.9 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: osmnx
Successfully installed osmnx-2.0.2


In [None]:
# ---------------------------
# 1. Data merging
# ---------------------------

In [None]:
df1 = pd.read_csv("/content/drive/MyDrive/BAGFormer/all_data_combined.csv")

  df1 = pd.read_csv("/content/drive/MyDrive/BAGFormer/all_data_combined.csv")


In [None]:
df1.head()

Unnamed: 0,Site_ID,Bound_Category,Time,Date,MeltedValue
0,1215468,All Northbound,00:00:00,2023-05-01,1063.0
1,1215468,All Northbound,00:15:00,2023-05-01,936.0
2,1215468,All Northbound,00:30:00,2023-05-01,728.0
3,1215468,All Northbound,00:45:00,2023-05-01,749.0
4,1215468,All Northbound,01:00:00,2023-05-01,661.0


In [None]:
df1.isnull().sum()

Unnamed: 0,0
Site_ID,0
Bound_Category,0
Time,0
Date,0
MeltedValue,2016012


In [None]:
df1.describe()

Unnamed: 0,MeltedValue
count,35612820.0
mean,185.3307
std,332.5297
min,0.0
25%,12.0
50%,58.0
75%,195.0
max,3616.0


In [None]:
unique_ids = df1['Site_ID'].unique()

In [None]:
# Read data (note: use the original file path and prefix with raw string r"")
df1 = pd.read_csv("/content/drive/MyDrive/BAGFormer/all_data_combined.csv")
df2 = pd.read_csv("/content/drive/MyDrive/BAGFormer/gdot_sites_information.csv")

# Convert data types to strings and strip whitespace from both ends
df1['Site_ID'] = df1['Site_ID'].astype(str).str.strip()
df2['cosit'] = df2['cosit'].astype(str).str.strip()

# Create a mapping dictionary: keys are 'cosit', values are (latitude, longitude)
mapping_dict = df2.set_index('cosit')[['latitude', 'longitude']].to_dict(orient='index')
# Example: {'000000010183': {'latitude': xxx, 'longitude': yyy}, ...}

# Define a function to match Site_ID to latitude and longitude using substring containment
def get_lat_lon(site_id):
    for key, latlon in mapping_dict.items():
        if site_id in key:
            return latlon['latitude'], latlon['longitude']
    return None, None

# Extract all unique Site_IDs and build a mapping to avoid redundant computation
unique_ids = df1['Site_ID'].unique()
siteid_to_latlon = {}
for sid in unique_ids:
    lat, lon = get_lat_lon(sid)
    siteid_to_latlon[sid] = (lat, lon)

# Map the matched results back to df1 and add latitude and longitude columns
df1['latitude'] = df1['Site_ID'].map(lambda x: siteid_to_latlon.get(x, (None, None))[0])
df1['longitude'] = df1['Site_ID'].map(lambda x: siteid_to_latlon.get(x, (None, None))[1])

# Check the result
print(df1.head())

# Save the result to a new file
df1.to_csv(r"/content/drive/MyDrive/BAGFormer/all_dot_with_lat_lon.csv", index=False)

  df1 = pd.read_csv("/content/drive/MyDrive/BAGFormer/all_data_combined.csv")


   Site_ID  Bound_Category      Time        Date  MeltedValue   latitude  \
0  1215468  All Northbound  00:00:00  2023-05-01       1063.0  33.708828   
1  1215468  All Northbound  00:15:00  2023-05-01        936.0  33.708828   
2  1215468  All Northbound  00:30:00  2023-05-01        728.0  33.708828   
3  1215468  All Northbound  00:45:00  2023-05-01        749.0  33.708828   
4  1215468  All Northbound  01:00:00  2023-05-01        661.0  33.708828   

   longitude  
0 -84.402372  
1 -84.402372  
2 -84.402372  
3 -84.402372  
4 -84.402372  


In [None]:
unique_ids.shape

(16071,)

In [None]:
import pandas as pd

In [None]:
# Read the data
traffic_df = pd.read_csv("/content/drive/MyDrive/BAGFormer/all_dot_with_lat_lon.csv")

# Check the data types
traffic_df['latitude'] = pd.to_numeric(traffic_df['latitude'], errors='coerce')
traffic_df['longitude'] = pd.to_numeric(traffic_df['longitude'], errors='coerce')


  traffic_df = pd.read_csv("/content/drive/MyDrive/BAGFormer/all_dot_with_lat_lon.csv")


In [None]:
traffic_df = traffic_df.drop_duplicates(subset=["Site_ID"])

In [None]:
traffic_df

Unnamed: 0,Site_ID,Bound_Category,Time,Date,MeltedValue,latitude,longitude
0,1215468,All Northbound,00:00:00,2023-05-01,1063.0,33.708828,-84.402372
5952,899171,All Northbound,00:00:00,2024-01-01,28.0,33.672880,-84.332530
8928,899611,All Southbound,00:00:00,2023-11-01,104.0,33.893670,-84.259550
11808,1219804,All Westbound,00:00:00,2024-05-01,218.0,33.766420,-84.498400
14784,830209,All Eastbound,00:00:00,2024-02-01,127.0,34.977550,-85.456550
...,...,...,...,...,...,...,...
37625376,0000217_8062,All Northbound,00:00:00,2023-03-14,,33.716388,-83.898008
37625952,0000199_0132,All Northbound,00:00:00,2023-03-21,,33.016570,-84.709718
37626528,0000137_0229,All Northbound,00:00:00,2023-07-25,,34.658077,-83.536567
37627680,0000311_0143,All Northbound,00:00:00,2024-07-09,,34.596278,-83.772855


In [None]:
traffic_df.to_csv(r"/content/drive/MyDrive/BAGFormer/site_id_with_lat_lon.csv", index=False)

In [None]:

unique_location_ids = traffic_df.drop_duplicates(subset=["Site_ID", "latitude", "longitude"])
print("具有唯一经纬度的 Site_ID 数量：", unique_location_ids["Site_ID"].nunique())

具有唯一经纬度的 Site_ID 数量： 16085


In [None]:
# ---------------------------
# 2. Data Loading and Preprocessing
# ---------------------------

In [None]:
# Read the GeoPackage data (already merged with CSV and road data, includes projected point information)
gdf_traffic = gpd.read_file("/content/drive/MyDrive/BAGFormer/finalmergedata.gpkg")
# Remove duplicate traffic points (using Site_ID as unique identifier)
gdf_traffic = gdf_traffic.drop_duplicates(subset=["Site_ID"])

# Construct projected point geometry using existing nearest_x and nearest_y (ensure projection is in meters, otherwise reproject first)
gdf_traffic["geometry"] = gdf_traffic.apply(lambda row: Point(row["nearest_x"], row["nearest_y"]), axis=1)

# Read the road network data (Shapefile)
gdf_roads = gpd.read_file("/content/drive/MyDrive/BAGFormer/gis_osm_roads_free_1.shp")

In [None]:
# Print a sample of traffic data for verification
print("部分交通数据样本:")
print(gdf_traffic.head())

print("道路网络数据记录数:", len(gdf_roads))
print("部分道路数据样本:")
print(gdf_roads.head())

部分交通数据样本:
   Site_ID  Bound_Category      Time       Date  MeltedValue   latitude  \
0  1215468  All Northbound  00:00:00 2023-05-01       1063.0  33.708828   
1   899171  All Northbound  00:00:00 2024-01-01         28.0  33.672880   
2   899611  All Southbound  00:00:00 2023-11-01        104.0  33.893670   
3  1219804   All Westbound  00:00:00 2024-05-01        218.0  33.766420   
4   830209   All Eastbound  00:00:00 2024-02-01        127.0  34.977550   

   longitude  feature_x  feature_y     osm_id  ...  layer bridge tunnel  n  \
0 -84.402372 -84.402372  33.708828  879776187  ...      0      F      F  1   
1 -84.332530 -84.332530  33.672880   41312116  ...      0      F      F  1   
2 -84.259550 -84.259550  33.893670    9164185  ...      0      F      F  1   
3 -84.498400 -84.498400  33.766420  437378654  ...      0      F      F  1   
4 -85.456550 -85.456550  34.977550    9162308  ...      0      F      F  1   

   distance  feature_x_2  feature_y_2  nearest_x  nearest_y  \
0  0.00

In [None]:
# ---------------------------
# 3. Build the road network graph
# ---------------------------
# Use NetworkX to build an undirected graph, where each road forms an edge between its endpoints, and the edge weight is the road length


In [None]:
G_road = nx.Graph()

for idx, row in gdf_roads.iterrows():
    geom = row.geometry
    if geom is None:
        continue
    # Handle LineString and MultiLineString
    if geom.geom_type == "LineString":
        coords = list(geom.coords)
        start, end = coords[0], coords[-1]
        G_road.add_node(start, pos=start)
        G_road.add_node(end, pos=end)
        G_road.add_edge(start, end, weight=geom.length, geometry=geom, osm_id=row["osm_id"])
    elif geom.geom_type == "MultiLineString":
        for line in geom:
            coords = list(line.coords)
            start, end = coords[0], coords[-1]
            G_road.add_node(start, pos=start)
            G_road.add_node(end, pos=end)
            G_road.add_edge(start, end, weight=line.length, geometry=line, osm_id=row["osm_id"])

In [None]:
print("G_road节点数:", G_road.number_of_nodes())
print("G_road边数:", G_road.number_of_edges())

G_road节点数: 2269386
G_road边数: 1452202


In [None]:
# ---------------------------
# 4. Embed traffic points into the road network
# ---------------------------
# Define a function to insert traffic points onto corresponding road edges


In [None]:
def insert_traffic_point(G, traffic_row, roads_gdf):
    """
    Find the road edge using the traffic point's corresponding osm_id and insert the traffic point onto that edge
    Return: the coordinates of the inserted node (tuple) or None
    """
    osm_id = traffic_row["osm_id"]
    pt = traffic_row.geometry
    # Find the corresponding road from the roads data
    road_match = roads_gdf[roads_gdf["osm_id"] == osm_id]
    if road_match.empty:
        return None
    road_geom = road_match.iloc[0].geometry
    # Obtain the start and end points of the road
    coords = list(road_geom.coords)
    start, end = coords[0], coords[-1]

    # Check whether the edge exists in G_road
    if G.has_edge(start, end):
        # Coordinates of the new node
        pt_coord = (pt.x, pt.y)
        # To avoid duplicate insertions (e.g., the same traffic point appearing multiple times), check if it already exists first
        if pt_coord in G.nodes:
            return pt_coord
        # Retrieve the original edge data
        edge_data = G.get_edge_data(start, end)
        # Remove the original edge and insert a new node along that edge
        G.remove_edge(start, end)
        # Split the original edge into two segments
        line1 = LineString([start, pt_coord])
        line2 = LineString([pt_coord, end])
        G.add_node(pt_coord, pos=pt_coord)
        G.add_edge(start, pt_coord, weight=line1.length, geometry=line1, osm_id=osm_id)
        G.add_edge(pt_coord, end, weight=line2.length, geometry=line2, osm_id=osm_id)
        return pt_coord
    else:
        # If the edge is not found, alternative methods can be tried (e.g., finding the nearest edge), but this is omitted here
        return None

# Use the above function to insert all traffic points into the road network
traffic_node_mapping = {}  # Site_ID -> Node coordinates
for idx, row in gdf_traffic.iterrows():
    node = insert_traffic_point(G_road, row, gdf_roads)
    if node is not None:
        traffic_node_mapping[row["Site_ID"]] = node

In [None]:
print("成功插入的交通点数量:", len(traffic_node_mapping))
print("部分交通点映射关系（Site_ID : 节点坐标）:")
print(list(traffic_node_mapping.items())[:5])

成功插入的交通点数量: 13649
部分交通点映射关系（Site_ID : 节点坐标）:
[('630096', (-84.4207659, 33.5278901)), ('0000039_7048', (-81.74085551472646, 30.95978038437354)), ('0000039_7053', (-81.5519744876392, 30.751729703356318)), ('0000039_7056', (-81.65938318533466, 30.76374932136274)), ('0000039_7058', (-81.63024015933692, 30.810615791621217))]


In [None]:
# ---------------------------
# 5. Save the road network graph
# ---------------------------

In [None]:
import pickle
from shapely.geometry import Point

global_distance_threshold = 10

# Generate global_points and global_point_to_site
global_points = []
global_point_to_site = {}
for site, coord in traffic_node_mapping.items():
    pt = Point(coord)
    global_points.append(pt)
    global_point_to_site[id(pt)] = site

data_to_save = {
    "traffic_node_mapping": traffic_node_mapping,
    "G_road": G_road,
    "global_points": global_points,
    "global_point_to_site": global_point_to_site,
    "global_distance_threshold": global_distance_threshold,
}

with open("/content/drive/MyDrive/BAGFormer/preprocessed_data.pkl", "wb") as f:
    pickle.dump(data_to_save, f)

print("前三部分的结果已保存到 preprocessed_data.pkl")

前三部分的结果已保存到 preprocessed_data.pkl


In [None]:
# ---------------------------
# 6. Build the GNN graph structure
# ---------------------------
# In the final GNN graph, each node represents a traffic point. An edge connects two traffic points if and only if:
# 1) The two traffic points are connected by the shortest path in the road network;
# 2) No other traffic points lie along this shortest path (determined by checking that all other traffic points are more than 10 meters away from the path)

In [None]:
# 1. Load the preprocessed data
with open("/content/drive/MyDrive/BAGFormer/preprocessed_data.pkl", "rb") as f:
    saved_data = pickle.load(f)

traffic_node_mapping = saved_data["traffic_node_mapping"]
G_road = saved_data["G_road"]
global_distance_threshold = saved_data["global_distance_threshold"]

# Rebuild global_points and global_point_to_site, using the WKT (Well-Known Text) of each point as the mapping key
global_points = []
global_point_to_site = {}
for site, coord in traffic_node_mapping.items():
    pt = Point(coord)
    global_points.append(pt)
    global_point_to_site[pt.wkt] = site

# Build a global spatial index
global_STRtree = STRtree(global_points)

print("加载并重建预处理数据完成。")
print("交通点数量：", len(traffic_node_mapping))
print("道路网络节点数：", G_road.number_of_nodes())
print("道路网络边数：", G_road.number_of_edges())


# 2. Utility function definitions
def get_path_geometry(G, path):
    """
    将道路网络图中节点路径拼接成 LineString 对象。
    """
    coords = []
    for i in range(len(path) - 1):
        edge_data = G.get_edge_data(path[i], path[i+1])
        if edge_data is not None and "geometry" in edge_data:
            geom = edge_data["geometry"]
            if i == 0:
                coords.extend(list(geom.coords))
            else:
                coords.extend(list(geom.coords)[1:])
    return LineString(coords)

def process_candidate(candidate):
    """
    For each candidate edge (site, neighbor):
      1. Compute the shortest path between the two points in the road network;
      2. Merge the path geometries;
      3. Query traffic points within a buffer of the path (using STRtree); if any point other than the endpoints lies within the buffer and is within a threshold distance from the path, the edge is considered invalid;
      4. If the edge is valid, return (site, neighbor, inv_weight, path_geometry), where inv_weight is the inverse of the path length.
    """
    site, neighbor = candidate
    node_coord = traffic_node_mapping[site]
    neighbor_coord = traffic_node_mapping[neighbor]

    try:
        path = nx.shortest_path(G_road, source=node_coord, target=neighbor_coord, weight="weight")
        path_geom = get_path_geometry(G_road, path)
    except nx.NetworkXNoPath:
        return None
    except Exception:
        return None

    buffer_geom = path_geom.buffer(global_distance_threshold)
    candidate_points = global_STRtree.query(buffer_geom)

    for candidate_pt in candidate_points:
        # Objects returned by STRtree may be different instances, so match using WKT
        if not hasattr(candidate_pt, "wkt"):
            continue
        candidate_site = global_point_to_site.get(candidate_pt.wkt)
        if candidate_site in (site, neighbor):
            continue
        if candidate_pt.distance(path_geom) <= global_distance_threshold:
            return None

    length = path_geom.length
    inv_weight = 1 / length if length > 0 else float('inf')
    return (site, neighbor, inv_weight, path_geom)


# 3. Construct candidate edges (using KDTree to filter nearest neighbors)
site_ids = list(traffic_node_mapping.keys())
points_list = [traffic_node_mapping[site] for site in site_ids]
tree = KDTree(points_list)

candidate_pairs = []
k = 101  # For each point, find the 101 nearest neighbors (including itself), resulting in approximately 100 actual candidate neighbors
for idx, site in enumerate(site_ids):
    node_coord = traffic_node_mapping[site]
    distances, indices = tree.query(node_coord, k=k)
    for d, neighbor_index in zip(distances, indices):
        if neighbor_index == idx:
            continue
        neighbor_site = site_ids[neighbor_index]
        # Only process cases where site < neighbor to avoid duplicates
        if site >= neighbor_site:
            continue
        candidate_pairs.append((site, neighbor_site))

print("候选边总数：", len(candidate_pairs))


# 4. Process candidate edges sequentially (use tqdm to show progress)
results = []
for candidate in tqdm(candidate_pairs, total=len(candidate_pairs)):
    result = process_candidate(candidate)
    results.append(result)

valid_edges = [res for res in results if res is not None]
print("有效候选边数：", len(valid_edges))


# 5. Build the final GNN graph
G_gnn_parallel = nx.Graph()
for site, node in traffic_node_mapping.items():
    G_gnn_parallel.add_node(site, pos=node)
for site, neighbor, inv_weight, geom in valid_edges:
    G_gnn_parallel.add_edge(site, neighbor, weight=inv_weight, geometry=geom)

print("GNN 图构建完成。")
print("图节点数：", G_gnn_parallel.number_of_nodes())
print("图边数：", G_gnn_parallel.number_of_edges())


加载并重建预处理数据完成。
交通点数量： 13649
道路网络节点数： 2284165
道路网络边数： 1466978
候选边总数： 687511


  0%|          | 0/687511 [00:00<?, ?it/s]

In [None]:
# Save the constructed GNN graph to a pickle file
with open('/content/drive/MyDrive/BAGFormer/gnn_traffic_graph.pkl', 'wb') as f:
    pickle.dump(G_gnn_parallel, f)

print("The GNN graph has been successfully saved as 'gnn_traffic_graph.pkl'.")

In [None]:
with open('/content/drive/MyDrive/BAGFormer/gnn_traffic_graph.pkl', 'rb') as f:
    G_loaded = pickle.load(f)

print("The GNN graph has been successfully loaded.")
print("Number of nodes:", G_loaded.number_of_nodes())
print("Number of edges:", G_loaded.number_of_edges())