# NYC Traffic Crash prediction with STM-Graph

This notebook demonstrates how to use the STM-Graph library to analyze New York City traffic crash data and do prediction. We'll go through the complete workflow:

1. Loading and preprocessing the raw data
2. Creating spatial mappings using Grid-based partitioning
3. Extracting OpenStreetMap features / Urban Features Graph Creation
4. Building a graph representation of the data
5. Creating temporal graph dataset
6. Visualizing spatial and temporal patterns
7. Training a GNN model for crash prediction

Let's get started!

In [None]:
stm_graph_path = "/home/ubuntu/STM-Graph/src"

# Import required libraries
import sys
sys.path.append(stm_graph_path)
import stm_graph
import os
import pandas as pd
import numpy as np
from datetime import timedelta
import matplotlib.pyplot as plt
import contextily as ctx

# Define the data and output directories
DATA_DIR = "/mnt/data/nyc_crash_311"
DATASET = "Motor_Vehicle_Collisions_-_Crashes_20241203.csv"
OUTPUT_DIR = "/mnt/data/nyc_crash_311/stm_graph/nyc/crash"

# Define geographic boundaries for NYC
NYC_BOUNDS = {
    "min_lat": 40.4774,  # Southern boundary
    "max_lat": 40.9176,  # Northern boundary
    "min_lon": -74.2591,  # Western boundary
    "max_lon": -73.7004,  # Eastern boundary
}

os.makedirs(OUTPUT_DIR, exist_ok=True)

## 1. Data Loading and Initial Preprocessing

First, we'll load the NYC crash data from a CSV file and perform some initial preprocessing. We'll convert the separate date and time columns into a single datetime column for analysis.

In [None]:
# Step 1: Read the raw data
print(f"Reading raw crash data from {os.path.join(DATA_DIR, DATASET)}")
raw_df = pd.read_csv(os.path.join(DATA_DIR, DATASET), low_memory=False)

# Display basic information about the dataset
print(f"Dataset shape: {raw_df.shape}")
print("\nColumn names:")
print(raw_df.columns.tolist())
print("\nSample data:")
raw_df.head()

In [None]:
# Step 2: Combine date and time to create a datetime column
print("Combining date and time columns...")
raw_df["created_time"] = pd.to_datetime(
    raw_df["CRASH DATE"] + " " + raw_df["CRASH TIME"], errors="coerce"
)

# Create a temporary CSV file with the combined datetime column
temp_csv = os.path.join(DATA_DIR, "temp_combined_crash.csv")
raw_df.to_csv(temp_csv, index=False)

print(f"Time range of data: {raw_df['created_time'].min()} to {raw_df['created_time'].max()}")
raw_df[["CRASH DATE", "CRASH TIME", "created_time"]].head()

## STM-Graph Preprocessing

Now we'll use STM-Graph's preprocessing functionality to clean the data, convert it to a GeoDataFrame, and filter it to a specific time range. This step handles missing coordinates, converts the data to a spatial format, and prepares it for mapping.

In [None]:
# Process with STM-Graph
print("Processing with STM-Graph...")
gdf_crash = stm_graph.preprocess_dataset(
    data_path=DATA_DIR,
    dataset="temp_combined_crash.csv",
    time_col="created_time",
    lat_col="LATITUDE",
    lng_col="LONGITUDE",
    column_mapping={
        "created_time": "created_time",
        "BOROUGH": "borough",
        "ZIP CODE": "zip_code",
        "LATITUDE": "latitude",
        "LONGITUDE": "longitude",
        "LOCATION": "location",
        "COLLISION_ID": "collision_id",
        "NUMBER OF PERSONS INJURED": "injured_count",
        "NUMBER OF PERSONS KILLED": "killed_count",
    },
    filter_dates=(None, "2019-12-31 23:59:59"),
    testing_mode=True,
    test_bounds=NYC_BOUNDS,
    visualize=True,
    fig_format="png",
    output_dir=OUTPUT_DIR,
    show_background_map=True,
)

print(f"Processed dataset shape: {gdf_crash.shape}")
print(f"Time range: {gdf_crash['created_time'].min()} to {gdf_crash['created_time'].max()}")
gdf_crash.head()

## 2. Spatial Mapping

In this step, we'll apply a spatial mapping to divide NYC into grid cells. STM-Graph supports various mapping approaches and will be extended in later releases:

1. Administrative boundaries (such as community districts, census tracts)
2. Regular grid cells
3. Degree-based Voronoi partitioning

For this example, we'll use a Grid-based mapping with a cell size of 1 kilometer.

In [None]:
# Apply grid mapping to the data
print("Applying spatial mapping...")

# Create a grid mapping with 1000-meter cells
mapper = stm_graph.GridMapping(cell_size=1000.0, target_crs="EPSG:32618")

# Apply the mapping to get district geometries and point-to-partition mapping
district_gdf, point_to_partition = mapper.create_mapping(gdf_crash)

print(f"Created mapping with {len(district_gdf)} regions")
print(f"Points with valid mapping: {(point_to_partition >= 0).sum()} of {len(point_to_partition)}")

# Visualize the mapping
mapper.visualize(
    points_gdf=gdf_crash,
    partition_gdf=district_gdf,
    point_to_partition=point_to_partition,
    out_dir=OUTPUT_DIR,
    file_format="png"
)

district_gdf.head()

## 3. OSM Feature Extraction / Urban Features Graph Creation

Next, we'll extract features from OpenStreetMap (OSM) to enrich our model with contextual information about each area. Features can be used as raw normalized features by area or they can be used after embedding. These features include information about:

- Points of interest (POIs)
- Road networks
- Road junctions

These features provide important context about the urban environment that may influence event patterns. These features can get enhanced with later releases to add more features to select by user. 

In [None]:
# Define the feature types to extract
feature_types = ['poi', 'road', 'junction']

# Extract OSM features
osm_cache_dir = os.path.join(OUTPUT_DIR, "osm_cache")
osm_features = stm_graph.extract_osm_features(
    regions_gdf=district_gdf,
    bounds=NYC_BOUNDS,
    cache_dir=osm_cache_dir,
    feature_types=feature_types,
    normalize=True,
    meter_crs="EPSG:32618",
    lat_lon_crs="EPSG:4326"
)

# Print available features
print(f"Extracted {len(osm_features.columns)} OSM features")
print("\nFeature sample:")
osm_features.head()

## 4. Graph Construction

Now we'll build a graph representation of our data. In this graph:
- Nodes represent grid cells
- Edges represent adjacency relationships between cells
- Node features include OSM features and crash statistics

This graph structure allows us to use Graph Neural Networks (GNNs) to model the spatial relationships between different areas of the city.

In [None]:
# Filter points that have valid mappings
gdf_crash_valid = gdf_crash[point_to_partition >= 0].copy()
point_to_partition_valid = point_to_partition[point_to_partition >= 0].copy()

print(f"Using {len(gdf_crash_valid)} valid points for graph construction")

# Build graph with static features
graph_data = stm_graph.build_graph_and_augment(
    grid_gdf=district_gdf,
    points_gdf=gdf_crash_valid,
    point_to_cell=point_to_partition_valid,
    adj_matrix=None,
    remove_empty_nodes=True,
    out_dir=OUTPUT_DIR,
    save_flag=True,
    static_features=osm_features,
)

# Extract graph components
edge_index = graph_data["edge_index"]
edge_weight = graph_data["edge_weight"]
node_features = graph_data["node_features"]
augmented_df = graph_data["augmented_df"]
node_ids = graph_data["node_ids"]

print(f"Built graph with {edge_index.shape[1]} edges and {graph_data['num_nodes']} nodes")
print(f"Node features shape: {node_features.shape}")

# Display augmented dataframe
augmented_df.head()

## 5. Temporal Dataset Creation

With our graph structure in place, we'll now create a temporal dataset for time-aware analysis and prediction. We'll:

1. Bin crashes into daily intervals
2. Create sliding windows of data for training
3. Add time-based features (day of week, hour of day)
4. Normalize the features for better model training

In [None]:
# Create temporal dataset
temporal_dataset, dataset_path, metadata = stm_graph.create_temporal_dataset(
    edge_index=edge_index,
    augmented_df=augmented_df,
    edge_weights=edge_weight,
    node_ids=node_ids,
    static_features=osm_features,
    time_col="created_time",
    cell_col="cell_id",
    bin_type="daily",
    interval_hours=1,
    history_window=3,
    use_time_features=False,
    task="classification",
    horizon=1,
    downsample_factor=1,
    normalize=True,
    scaler_type="minmax",
    dataset_name="nyc_crash_dataset",
    output_format="4d",
)

## 6. Visualization

Now let's create some visualizations to better understand our data. We'll create:

1. Time series plots showing crash trends over time
2. Spatial network visualizations showing crash density across NYC
3. Temporal heatmaps showing patterns across time and space

In [None]:

temporal_dataset_3d = stm_graph.convert_4d_to_3d_dataset(
    temporal_dataset, static_features_count=osm_features.shape[1])

# Plot time series for the most active nodes
stm_graph.plot_node_time_series(
    temporal_dataset_3d,
    num_nodes=5,  # Show 5 nodes
    selection_method="highest_activity",  # Select most active nodes
    feature_idx=0,  # Event count feature
    plot_type="2d",  # 2D line plot
    start_time="2019-11-01",  # Start date for x-axis
    time_delta=timedelta(hours=1),  # Hourly data
    title="Crash Events Over Time (Most Active Nodes)",
    figsize=(15, 8),
    out_dir=OUTPUT_DIR,
    filename="time_series_top",
    file_format="png",
)

In [None]:
# Plot 3D visualization for most active nodes
stm_graph.plot_node_time_series(
    temporal_dataset_3d,
    num_nodes=3,  # Show 3 nodes
    selection_method="highest_activity",  # Select most active nodes
    feature_idx=0,  # Event count feature
    plot_type="3d",  # 3D surface plot
    n_steps=168,  # First week (7 days * 24 hours)
    title="Crashes 3D Visualization Over Time",
    figsize=(15, 10),
    out_dir=OUTPUT_DIR,
    filename="time_series_3d",
    file_format="png",
)

In [None]:
# Extract event counts for each region (node) at a specific time
time_step = 24  # Example: events after 24 hours
node_counts = np.array(
    [
        temporal_dataset_3d.features[time_step][node, 0].item()
        for node in range(graph_data["num_nodes"])
    ]
)

# Plot spatial network with region and edge colors
stm_graph.plot_spatial_network(
    regions_gdf=district_gdf,
    edge_index=edge_index,
    edge_weights=edge_weight,
    node_values=node_counts,
    node_ids=node_ids,
    time_step=time_step,
    title="Crash Density After 24 Hours",
    node_cmap="YlOrRd",  # Red-yellow colormap for heat
    edge_cmap="viridis",  # Blue-green for edges
    map_style=ctx.providers.CartoDB.Positron,
    figsize=(15, 15),
    out_dir=OUTPUT_DIR,
    filename="spatial_network",
    file_format="png",
)

In [None]:
# Create a temporal heatmap to see patterns across time and nodes
stm_graph.plot_temporal_heatmap(
    temporal_dataset_3d,
    num_nodes=10,
    feature_idx=0,  # Event count feature
    selection_method="highest_activity",
    start_time="2019-11-01",
    time_delta=timedelta(hours=1),
    n_steps=168,  # First week
    title="Crash Events Temporal Heatmap (First Week)",
    figsize=(15, 8),
    out_dir=OUTPUT_DIR,
    filename="temporal_heatmap",
    file_format="png",
)

## 7. Model Training

Finally, we'll train a Graph Neural Network (GNN) model to predict crash events. We'll use the DCRNN model.

This model will Predict whether crashes will occur in each area in the next time step. Different designed custom models can be used or any supported model from Torch Geometric Temporal can be used. More custom models can be added. Training logs will be saved in logs folder locally in output directory. [Weights & Biases](https://wandb.ai/) integration is done and you can login and use online dashboard to control the training process and track training metrics and process live.

In [None]:
# Discover available model options
print("Available models in STM-Graph:")
stm_graph.list_available_models()

In [None]:
# STGCN
model = stm_graph.create_model(
    model_name="stgcn",
    source="custom",
    num_nodes=temporal_dataset.features[0].shape[0],
    in_channels=temporal_dataset.features[0].shape[2],
    out_channels=1,
    hidden_dim=64,
    k=3,
    embedding_dimensions=16,
    dropout=0.2,
    task="classification",
)

# Train the model
results = stm_graph.train_model(
    model=model,
    dataset=temporal_dataset,
    optimizer_name="adam",
    learning_rate=0.0001,
    task="classification",
    num_epochs=500,  
    batch_size=10,
    batch_to_device=True,
    test_size=0.15,
    val_size=0.15,
    use_nested_tqdm=True,
    early_stopping=True,
    patience=50,
    scheduler_type="step",
    lr_decay_epochs=50,
    lr_decay_factor=1,
    wandb_project="stm_graph_crash",
    experiment_name="stgcn",
    use_wandb=True, 
    fixed_batch_size=True,
    log_dir=OUTPUT_DIR,
)