# Predicting Train Accidents Based on Train Volumes in Indiana
This notebook will guide you through the process of predicting train accidents based on train volumes using GeoPandas, Folium, and other necessary Python libraries.

### Table of Contents
1. Import Libraries
2. Load and Prepare Accident Data
3. Load and Prepare Train Volume Data
4. Spatial Join of Accidents to Track Segments
5. Data Aggregation
6. Statistical Modeling
7. Prediction
8. Conclusion

In [301]:
# pip install geopandas --upgrade 

1. Import Libraries <a id="import-libraries"></a>

In [302]:
# Import necessary libraries
import pandas as pd
import numpy as np
import geopandas as gpd
from shapely.geometry import Point, MultiLineString
from shapely import wkt
import matplotlib.pyplot as plt
import statsmodels.api as sm
import statsmodels.formula.api as smf
import warnings

# Suppress warnings
warnings.filterwarnings('ignore')

# Enable inline plotting
%matplotlib inline

2. Load and Prepare Accident Data <a id="load-accident-data"></a>

In [303]:
# Load accident data from CSV file
accidents_df = pd.read_csv('Indiana_Accidents_Since_2011_v2.csv')  # Replace with your actual file path

# Display first few rows
accidents_df.head()


Unnamed: 0,Reporting Railroad Code,Reporting Railroad Name,Year,Accident Number,PDF Link,Accident Year,Accident Month,Other Railroad Code,Other Railroad Name,Other Accident Number,...,Other Parent Railroad Code,Other Parent Railroad Name,Other Railroad Holding Company,Maintenance Railroad Company Grouping,Maintenance Railroad Class,Maintenance Railroad SMT Grouping,Maintenance Parent Railroad Code,Maintenance Parent Railroad Name,Maintenance Railroad Holding Company,Location
0,NS,Norfolk Southern Railway Company,2016,122005,https://safetydata.fra.dot.gov/Officeofsafety/...,16,9,,,,...,,,,Shippers,,,ACA,Aluminum Company Of America,Not Assigned,POINT (-87.318454 37.91326)
1,ATK,Amtrak (National Railroad Passenger Corporation),2019,158162,https://safetydata.fra.dot.gov/Officeofsafety/...,19,3,,,,...,,,,Passenger (Formerly Commuter),,"SMT-1 - Amtrak, Commuter East",ATK,Amtrak (National Railroad Passenger Corporation),Amtrak,POINT (-86.077778 39.719722)
2,ATK,Amtrak (National Railroad Passenger Corporation),2019,158327,https://safetydata.fra.dot.gov/Officeofsafety/...,19,4,,,,...,,,,Passenger (Formerly Commuter),,"SMT-1 - Amtrak, Commuter East",ATK,Amtrak (National Railroad Passenger Corporation),Amtrak,POINT (-86.077778 39.719722)
3,ATK,Amtrak (National Railroad Passenger Corporation),2017,146245,https://safetydata.fra.dot.gov/Officeofsafety/...,17,2,,,,...,,,,Passenger (Formerly Commuter),,"SMT-1 - Amtrak, Commuter East",ATK,Amtrak (National Railroad Passenger Corporation),Amtrak,POINT (-86.080662 39.721159)
4,ATK,Amtrak (National Railroad Passenger Corporation),2018,151649,https://safetydata.fra.dot.gov/Officeofsafety/...,18,1,,,,...,,,,Passenger (Formerly Commuter),,"SMT-1 - Amtrak, Commuter East",ATK,Amtrak (National Railroad Passenger Corporation),Amtrak,POINT (-86.078521 39.719528)


In [304]:
# Remove ' 0:00' from the 'Date' column if present
accidents_df['Date'] = accidents_df['Date'].str.replace(' 0:00', '')

# Combine 'Date' and 'Time' columns to create a 'Datetime' column
accidents_df['Datetime'] = pd.to_datetime(accidents_df['Date'] + ' ' + accidents_df['Time'], errors='coerce')

# Display the new 'Datetime' column
accidents_df[['Date', 'Time', 'Datetime']].head()



Unnamed: 0,Date,Time,Datetime
0,09/09/2016 12:00:00 AM,11:15 AM,2016-09-09 11:15:00
1,03/21/2019 12:00:00 AM,1:38 PM,2019-03-21 13:38:00
2,04/03/2019 12:00:00 AM,9:55 AM,2019-04-03 09:55:00
3,02/03/2017 12:00:00 AM,12:30 PM,2017-02-03 12:30:00
4,01/23/2018 12:00:00 AM,1:30 AM,2018-01-23 01:30:00


In [305]:
# Remove ' 0:00' from the 'Date' column if present
accidents_df['Date'] = accidents_df['Date'].str.replace(' 0:00', '')

# Combine 'Date' and 'Time' columns to create a 'Datetime' column
accidents_df['Datetime'] = pd.to_datetime(accidents_df['Date'] + ' ' + accidents_df['Time'], errors='coerce')

# Display the new 'Datetime' column
accidents_df[['Date', 'Time', 'Datetime']].head()


Unnamed: 0,Date,Time,Datetime
0,09/09/2016 12:00:00 AM,11:15 AM,2016-09-09 11:15:00
1,03/21/2019 12:00:00 AM,1:38 PM,2019-03-21 13:38:00
2,04/03/2019 12:00:00 AM,9:55 AM,2019-04-03 09:55:00
3,02/03/2017 12:00:00 AM,12:30 PM,2017-02-03 12:30:00
4,01/23/2018 12:00:00 AM,1:30 AM,2018-01-23 01:30:00


In [306]:
# Extract year, month, and day from 'Datetime'
accidents_df['Year'] = accidents_df['Datetime'].dt.year
accidents_df['Month'] = accidents_df['Datetime'].dt.month
accidents_df['Day'] = accidents_df['Datetime'].dt.day

# Compare extracted 'Year' and 'Month' with 'Accident Year' and 'Accident Month'
print(accidents_df[['Accident Year', 'Year', 'Accident Month', 'Month']].head())


   Accident Year  Year  Accident Month  Month
0             16  2016               9      9
1             19  2019               3      3
2             19  2019               4      4
3             17  2017               2      2
4             18  2018               1      1


In [307]:
# Create geometry column from Latitude and Longitude
accidents_df['geometry'] = accidents_df.apply(lambda row: Point(row['Longitude'], row['Latitude']), axis=1)


In [308]:
# Convert to GeoDataFrame
accidents_gdf = gpd.GeoDataFrame(accidents_df, geometry='geometry')

# Set Coordinate Reference System (CRS) to WGS84 Latitude/Longitude
accidents_gdf.set_crs(epsg=4326, inplace=True)


Unnamed: 0,Reporting Railroad Code,Reporting Railroad Name,Year,Accident Number,PDF Link,Accident Year,Accident Month,Other Railroad Code,Other Railroad Name,Other Accident Number,...,Maintenance Railroad Company Grouping,Maintenance Railroad Class,Maintenance Railroad SMT Grouping,Maintenance Parent Railroad Code,Maintenance Parent Railroad Name,Maintenance Railroad Holding Company,Location,Datetime,Month,geometry
0,NS,Norfolk Southern Railway Company,2016,122005,https://safetydata.fra.dot.gov/Officeofsafety/...,16,9,,,,...,Shippers,,,ACA,Aluminum Company Of America,Not Assigned,POINT (-87.318454 37.91326),2016-09-09 11:15:00,9,POINT (-87.31845 37.91326)
1,ATK,Amtrak (National Railroad Passenger Corporation),2019,158162,https://safetydata.fra.dot.gov/Officeofsafety/...,19,3,,,,...,Passenger (Formerly Commuter),,"SMT-1 - Amtrak, Commuter East",ATK,Amtrak (National Railroad Passenger Corporation),Amtrak,POINT (-86.077778 39.719722),2019-03-21 13:38:00,3,POINT (-86.07778 39.71972)
2,ATK,Amtrak (National Railroad Passenger Corporation),2019,158327,https://safetydata.fra.dot.gov/Officeofsafety/...,19,4,,,,...,Passenger (Formerly Commuter),,"SMT-1 - Amtrak, Commuter East",ATK,Amtrak (National Railroad Passenger Corporation),Amtrak,POINT (-86.077778 39.719722),2019-04-03 09:55:00,4,POINT (-86.07778 39.71972)
3,ATK,Amtrak (National Railroad Passenger Corporation),2017,146245,https://safetydata.fra.dot.gov/Officeofsafety/...,17,2,,,,...,Passenger (Formerly Commuter),,"SMT-1 - Amtrak, Commuter East",ATK,Amtrak (National Railroad Passenger Corporation),Amtrak,POINT (-86.080662 39.721159),2017-02-03 12:30:00,2,POINT (-86.08066 39.72116)
4,ATK,Amtrak (National Railroad Passenger Corporation),2018,151649,https://safetydata.fra.dot.gov/Officeofsafety/...,18,1,,,,...,Passenger (Formerly Commuter),,"SMT-1 - Amtrak, Commuter East",ATK,Amtrak (National Railroad Passenger Corporation),Amtrak,POINT (-86.078521 39.719528),2018-01-23 01:30:00,1,POINT (-86.07852 39.71953)
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
850,WC,WISCONSIN CENTRAL LTD.,2014,820641,https://safetydata.fra.dot.gov/Officeofsafety/...,14,7,,,,...,Switching and Terminal,Class 1,SMT-4 - CPKC/CP/CN/CCD,CN,Canadian National - North America,Canadian National - North America,POINT (-87.352193 41.610313),2014-07-21 05:33:00,7,POINT (-87.35219 41.61031)
851,WC,WISCONSIN CENTRAL LTD.,2013,791340,https://safetydata.fra.dot.gov/Officeofsafety/...,13,9,,,,...,Switching and Terminal,Class 1,SMT-4 - CPKC/CP/CN/CCD,CN,Canadian National - North America,Canadian National - North America,POINT (-87.370083 41.618603),2013-09-10 09:00:00,9,POINT (-87.37008 41.6186)
852,WC,WISCONSIN CENTRAL LTD.,2022,1120068,https://safetydata.fra.dot.gov/Officeofsafety/...,22,10,,,,...,Switching and Terminal,Class 1,SMT-4 - CPKC/CP/CN/CCD,CN,Canadian National - North America,Canadian National - North America,POINT (-87.41272 41.629558),2022-10-26 09:06:00,10,POINT (-87.41272 41.62956)
853,IHB,Indiana Harbor Belt Railroad Company,2016,2265,https://safetydata.fra.dot.gov/Officeofsafety/...,16,11,,,,...,,,,XLSR,LAKESHORE RAIL SERVICES LLC,Not Assigned,POINT (-87.456689 41.632231),2016-11-18 20:00:00,11,POINT (-87.45669 41.63223)


In [309]:
# Convert to GeoDataFrame
accidents_gdf = gpd.GeoDataFrame(accidents_df, geometry='geometry')

# Set Coordinate Reference System (CRS) to WGS84 Latitude/Longitude
accidents_gdf.set_crs(epsg=4326, inplace=True)


Unnamed: 0,Reporting Railroad Code,Reporting Railroad Name,Year,Accident Number,PDF Link,Accident Year,Accident Month,Other Railroad Code,Other Railroad Name,Other Accident Number,...,Maintenance Railroad Company Grouping,Maintenance Railroad Class,Maintenance Railroad SMT Grouping,Maintenance Parent Railroad Code,Maintenance Parent Railroad Name,Maintenance Railroad Holding Company,Location,Datetime,Month,geometry
0,NS,Norfolk Southern Railway Company,2016,122005,https://safetydata.fra.dot.gov/Officeofsafety/...,16,9,,,,...,Shippers,,,ACA,Aluminum Company Of America,Not Assigned,POINT (-87.318454 37.91326),2016-09-09 11:15:00,9,POINT (-87.31845 37.91326)
1,ATK,Amtrak (National Railroad Passenger Corporation),2019,158162,https://safetydata.fra.dot.gov/Officeofsafety/...,19,3,,,,...,Passenger (Formerly Commuter),,"SMT-1 - Amtrak, Commuter East",ATK,Amtrak (National Railroad Passenger Corporation),Amtrak,POINT (-86.077778 39.719722),2019-03-21 13:38:00,3,POINT (-86.07778 39.71972)
2,ATK,Amtrak (National Railroad Passenger Corporation),2019,158327,https://safetydata.fra.dot.gov/Officeofsafety/...,19,4,,,,...,Passenger (Formerly Commuter),,"SMT-1 - Amtrak, Commuter East",ATK,Amtrak (National Railroad Passenger Corporation),Amtrak,POINT (-86.077778 39.719722),2019-04-03 09:55:00,4,POINT (-86.07778 39.71972)
3,ATK,Amtrak (National Railroad Passenger Corporation),2017,146245,https://safetydata.fra.dot.gov/Officeofsafety/...,17,2,,,,...,Passenger (Formerly Commuter),,"SMT-1 - Amtrak, Commuter East",ATK,Amtrak (National Railroad Passenger Corporation),Amtrak,POINT (-86.080662 39.721159),2017-02-03 12:30:00,2,POINT (-86.08066 39.72116)
4,ATK,Amtrak (National Railroad Passenger Corporation),2018,151649,https://safetydata.fra.dot.gov/Officeofsafety/...,18,1,,,,...,Passenger (Formerly Commuter),,"SMT-1 - Amtrak, Commuter East",ATK,Amtrak (National Railroad Passenger Corporation),Amtrak,POINT (-86.078521 39.719528),2018-01-23 01:30:00,1,POINT (-86.07852 39.71953)
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
850,WC,WISCONSIN CENTRAL LTD.,2014,820641,https://safetydata.fra.dot.gov/Officeofsafety/...,14,7,,,,...,Switching and Terminal,Class 1,SMT-4 - CPKC/CP/CN/CCD,CN,Canadian National - North America,Canadian National - North America,POINT (-87.352193 41.610313),2014-07-21 05:33:00,7,POINT (-87.35219 41.61031)
851,WC,WISCONSIN CENTRAL LTD.,2013,791340,https://safetydata.fra.dot.gov/Officeofsafety/...,13,9,,,,...,Switching and Terminal,Class 1,SMT-4 - CPKC/CP/CN/CCD,CN,Canadian National - North America,Canadian National - North America,POINT (-87.370083 41.618603),2013-09-10 09:00:00,9,POINT (-87.37008 41.6186)
852,WC,WISCONSIN CENTRAL LTD.,2022,1120068,https://safetydata.fra.dot.gov/Officeofsafety/...,22,10,,,,...,Switching and Terminal,Class 1,SMT-4 - CPKC/CP/CN/CCD,CN,Canadian National - North America,Canadian National - North America,POINT (-87.41272 41.629558),2022-10-26 09:06:00,10,POINT (-87.41272 41.62956)
853,IHB,Indiana Harbor Belt Railroad Company,2016,2265,https://safetydata.fra.dot.gov/Officeofsafety/...,16,11,,,,...,,,,XLSR,LAKESHORE RAIL SERVICES LLC,Not Assigned,POINT (-87.456689 41.632231),2016-11-18 20:00:00,11,POINT (-87.45669 41.63223)


In [310]:
# Select only the necessary columns
accidents_gdf = accidents_gdf[['Accident Number', 'Year', 'Month', 'Day', 'geometry']]

# Display the GeoDataFrame
accidents_gdf.head()


Unnamed: 0,Accident Number,Year,Month,Day,geometry
0,122005,2016,9,9,POINT (-87.31845 37.91326)
1,158162,2019,3,21,POINT (-86.07778 39.71972)
2,158327,2019,4,3,POINT (-86.07778 39.71972)
3,146245,2017,2,3,POINT (-86.08066 39.72116)
4,151649,2018,1,23,POINT (-86.07852 39.71953)


In [311]:
# Select only the necessary columns
accidents_gdf = accidents_gdf[['Accident Number', 'Year', 'Month', 'Day', 'geometry']]

# Display the GeoDataFrame
accidents_gdf.head()


Unnamed: 0,Accident Number,Year,Month,Day,geometry
0,122005,2016,9,9,POINT (-87.31845 37.91326)
1,158162,2019,3,21,POINT (-86.07778 39.71972)
2,158327,2019,4,3,POINT (-86.07778 39.71972)
3,146245,2017,2,3,POINT (-86.08066 39.72116)
4,151649,2018,1,23,POINT (-86.07852 39.71953)


In [312]:
# Check for missing values
print(accidents_gdf.isnull().sum())

# Drop rows with missing geometry or date information
accidents_gdf.dropna(subset=['geometry', 'Year', 'Month', 'Day'], inplace=True)


Accident Number    0
Year               0
Month              0
Day                0
geometry           0
dtype: int64


3. Load and Prepare Train Volume Data <a id="load-train-volume-data"></a>

In [313]:
# Load train volume data from CSV file
track_volumes_df = pd.read_csv('track_segment_volumes.csv')  # Replace with your actual file path

# Display first few rows
track_volumes_df.head()


Unnamed: 0,segment_id,rrowner1,yardname,miles,geometry,annual_estimate
0,299,LIRC,,0.619007,MULTILINESTRING ((-85.75219166999995 38.406361...,3276
1,300,LIRC,,0.613696,MULTILINESTRING ((-85.75235658999998 38.415246...,3276
2,309,KBSR,,2.689108,MULTILINESTRING ((-86.95838045399995 40.446606...,260
3,310,KBSR,,4.374098,MULTILINESTRING ((-87.02970563399998 40.474318...,260
4,333,INRD,,2.625358,MULTILINESTRING ((-87.18613253499996 39.067629...,208


In [314]:
# 3.2 Parse Geometry Column
# Parse the 'geometry' column from WKT to geometry objects
track_volumes_df['geometry'] = track_volumes_df['geometry'].apply(wkt.loads)


In [315]:
## 3.3 Convert to GeoDataFrame

# Convert to GeoDataFrame
track_volumes_gdf = gpd.GeoDataFrame(track_volumes_df, geometry='geometry')

# Set CRS to WGS84 Latitude/Longitude
track_volumes_gdf.set_crs(epsg=4326, inplace=True)


Unnamed: 0,segment_id,rrowner1,yardname,miles,geometry,annual_estimate
0,299,LIRC,,0.619007,"MULTILINESTRING ((-85.75219 38.40636, -85.7520...",3276
1,300,LIRC,,0.613696,"MULTILINESTRING ((-85.75236 38.41525, -85.7523...",3276
2,309,KBSR,,2.689108,"MULTILINESTRING ((-86.95838 40.44661, -86.9493...",260
3,310,KBSR,,4.374098,"MULTILINESTRING ((-87.02971 40.47432, -87.0270...",260
4,333,INRD,,2.625358,"MULTILINESTRING ((-87.18613 39.06763, -87.1859...",208
...,...,...,...,...,...,...
7360,5818,IN,,0.110315,"MULTILINESTRING ((-84.93208 41.72458, -84.9318...",20904
7361,1137,IN,,0.136433,"MULTILINESTRING ((-84.93208 41.72458, -84.9302...",20904
7362,1082,IN,,0.342116,"MULTILINESTRING ((-84.93683 41.72033, -84.9344...",20904
7363,2143,IN,,0.286396,"MULTILINESTRING ((-84.93683 41.72033, -84.9366...",20904


In [316]:
## 3.4 Handle Missing Values

# Check for missing values
print(track_volumes_gdf.isnull().sum())

# Drop rows with missing geometry or annual_estimate
track_volumes_gdf.dropna(subset=['geometry', 'annual_estimate'], inplace=True)


segment_id            0
rrowner1             65
yardname           6390
miles                 0
geometry              0
annual_estimate       0
dtype: int64


4. Spatial Join of Accidents to Track Segments <a id="spatial-join"></a>

In [317]:
# Confirm both are in EPSG:4326
print(f"Accidents CRS: {accidents_gdf.crs}")
print(f"Track Volumes CRS: {track_volumes_gdf.crs}")


Accidents CRS: EPSG:4326
Track Volumes CRS: EPSG:4326


In [318]:
## 4.2 Spatial Join
# Perform spatial join to associate accidents with track segments
accidents_on_tracks = gpd.sjoin(accidents_gdf, track_volumes_gdf, how='inner', predicate='intersects')

# Display result
accidents_on_tracks.head()


Unnamed: 0,Accident Number,Year,Month,Day,geometry,index_right,segment_id,rrowner1,yardname,miles,annual_estimate


In [319]:
# 4.3 Handle Accidents Not Directly on Tracks (Optional)
# If some accidents are not intersecting with the tracks due to precision issues, you can buffer the accidents or tracks slightly.

# Buffer the accidents by a small distance (e.g., 50 meters)
accidents_gdf_buffered = accidents_gdf.copy()
accidents_gdf_buffered['geometry'] = accidents_gdf_buffered.geometry.buffer(0.0005)  # Approx. 50 meters

# Perform spatial join with buffered accidents
accidents_on_tracks = gpd.sjoin(accidents_gdf_buffered, track_volumes_gdf, how='inner', predicate='intersects')


5. Data Aggregation <a id="data-aggregation"></a>

In [320]:
#  5.1 Calculate Accidents per Segment per Year

# Group by segment_id and Year to get accident counts
accidents_per_segment = accidents_on_tracks.groupby(['segment_id', 'Year']).size().reset_index(name='accident_count')

# Display result
accidents_per_segment.head()


Unnamed: 0,segment_id,Year,accident_count
0,14,2021,1
1,20,2012,1
2,31,2023,1
3,40,2020,1
4,54,2016,1


In [321]:
#  5.2 Merge with Train Volumes

# Merge accident counts with train volume data
merged_data = accidents_per_segment.merge(track_volumes_df[['segment_id', 'annual_estimate', 'miles']], on='segment_id', how='left')

# Display merged data
merged_data.head()


Unnamed: 0,segment_id,Year,accident_count,annual_estimate,miles
0,14,2021,1,1040,0.478878
1,20,2012,1,2184,0.703769
2,31,2023,1,1040,0.205927
3,40,2020,1,2184,0.502552
4,54,2016,1,16016,0.371608


In [322]:
# Export to Parquet file
merged_data.to_parquet('merged_accident_volume_data.parquet')

In [323]:
# 5.3 Prepare Data for Modeling

# Calculate log of annual_estimate to use in modeling
merged_data['log_annual_estimate'] = np.log(merged_data['annual_estimate'])

# Handle any infinite values resulting from log(0)
merged_data.replace([np.inf, -np.inf], np.nan, inplace=True)
merged_data.dropna(subset=['log_annual_estimate'], inplace=True)

# Check for missing values
print(merged_data.isnull().sum())


segment_id             0
Year                   0
accident_count         0
annual_estimate        0
miles                  0
log_annual_estimate    0
dtype: int64


6. Statistical Modeling <a id="statistical-modeling"></a>

In [324]:
#  6.1 Split Data into Training and Testing Sets

from sklearn.model_selection import train_test_split

# Split data: 80% training, 20% testing
train_data, test_data = train_test_split(merged_data, test_size=0.2, random_state=42)


In [325]:
#  6.2 Fit Poisson Regression Model

# Define the formula for Poisson regression
formula = 'accident_count ~ log_annual_estimate + miles'

# Fit the model using the training data
poisson_model = smf.glm(formula=formula, data=train_data, family=sm.families.Poisson()).fit()

# Print model summary
print(poisson_model.summary())


                 Generalized Linear Model Regression Results                  
Dep. Variable:         accident_count   No. Observations:                 1445
Model:                            GLM   Df Residuals:                     1442
Model Family:                 Poisson   Df Model:                            2
Link Function:                    Log   Scale:                          1.0000
Method:                          IRLS   Log-Likelihood:                -2104.0
Date:                Mon, 18 Nov 2024   Deviance:                       1073.2
Time:                        14:32:27   Pearson chi2:                 2.04e+03
No. Iterations:                     5   Pseudo R-squ. (CS):            0.02697
Covariance Type:            nonrobust                                         
                          coef    std err          z      P>|z|      [0.025      0.975]
---------------------------------------------------------------------------------------
Intercept              -0.2780    

In [326]:
## 6.3 Check for Overdispersion

# Calculate the ratio of the deviance to the degrees of freedom
deviance = poisson_model.deviance
degrees_of_freedom = poisson_model.df_resid
overdispersion_ratio = deviance / degrees_of_freedom
print(f'Overdispersion ratio: {overdispersion_ratio:.2f}')

# If the ratio is significantly greater than 1, consider Negative Binomial regression
if overdispersion_ratio > 1.5:
    print("Overdispersion detected. Considering Negative Binomial regression.")


Overdispersion ratio: 0.74


In [327]:
# 6.4 Fit Negative Binomial Regression Model (If Necessary)

from statsmodels.genmod.families import NegativeBinomial

# Fit Negative Binomial model
nb_model = smf.glm(formula=formula, data=train_data, family=NegativeBinomial()).fit()

# Print model summary
print(nb_model.summary())


                 Generalized Linear Model Regression Results                  
Dep. Variable:         accident_count   No. Observations:                 1445
Model:                            GLM   Df Residuals:                     1442
Model Family:        NegativeBinomial   Df Model:                            2
Link Function:                    Log   Scale:                          1.0000
Method:                          IRLS   Log-Likelihood:                -2378.7
Date:                Mon, 18 Nov 2024   Deviance:                       321.13
Time:                        14:32:27   Pearson chi2:                     825.
No. Iterations:                     7   Pseudo R-squ. (CS):           0.009886
Covariance Type:            nonrobust                                         
                          coef    std err          z      P>|z|      [0.025      0.975]
---------------------------------------------------------------------------------------
Intercept              -0.1917    

7. Prediction <a id="prediction"></a>

In [328]:
# 7.1 Predict on Test Data

# Predict accident counts using the test data
test_data['predicted_accidents'] = nb_model.predict(test_data)

# Ensure predicted accidents are non-negative
test_data['predicted_accidents'] = test_data['predicted_accidents'].clip(lower=0)


In [329]:
#  7.2 Evaluate Model Performance

from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

# Calculate Mean Squared Error
mse = mean_squared_error(test_data['accident_count'], test_data['predicted_accidents'])
print(f'Mean Squared Error: {mse:.2f}')

# Calculate Root Mean Squared Error
rmse = np.sqrt(mse)
print(f'Root Mean Squared Error: {rmse:.2f}')

# Calculate Mean Absolute Error
mae = mean_absolute_error(test_data['accident_count'], test_data['predicted_accidents'])
print(f'Mean Absolute Error: {mae:.2f}')

# Calculate R-squared
r2 = r2_score(test_data['accident_count'], test_data['predicted_accidents'])
print(f'R-squared: {r2:.2f}')


Mean Squared Error: 2.77
Root Mean Squared Error: 1.66
Mean Absolute Error: 0.75
R-squared: 0.03


In [330]:
# 7.3 Add Predictions to GeoDataFrame

# Since test_data contains 'segment_id' and 'Year', we'll aggregate predictions per segment
predictions_per_segment = test_data.groupby('segment_id')['predicted_accidents'].sum().reset_index()

# Merge predictions back to the track volumes GeoDataFrame
track_volumes_gdf = track_volumes_gdf.merge(predictions_per_segment, on='segment_id', how='left')

# Fill NaN predictions with zeros (segments not in test_data)
track_volumes_gdf['predicted_accidents'].fillna(0, inplace=True)


8. Conclusion <a id="conclusion"></a>
We have successfully:

- Adjusted the data parsing to account for the actual date and time formats in your dataset.
- Loaded and prepared the accident and train volume datasets.
- Performed a spatial join to associate accidents with track segments.
- Aggregated the data to calculate accidents per segment per year.
- Built a statistical model to predict accident counts based on train volume and segment length.
- Evaluated the model and made predictions on test data.
- Merged the predictions back into the track volumes GeoDataFrame for visualization.

# Visualization with Folium

9. Visualization
9.1 Prepare Data for Visualization

In [331]:
# !pip install folium
#!pip install --upgrade branca

In [332]:
# Import Folium
import folium
from folium.plugins import HeatMap
from branca.colormap import linear

# Create a copy of the GeoDataFrame for visualization
viz_gdf = track_volumes_gdf.copy()


In [333]:
# Define a color scale using 'YlOrRd_09' colormap
max_pred = viz_gdf['predicted_accidents'].max()
color_scale = linear.YlOrRd_09.scale(0, max_pred)
color_scale.caption = 'Predicted Accident Counts'



In [334]:
# from branca.colormap import linear

# # Define a color scale using 'Blues' colormap
# max_pred = viz_gdf['predicted_accidents'].max()
# color_scale = linear.Blues_09.scale(0, max_pred)
# color_scale.caption = 'Predicted Accident Counts'



In [335]:
# Initialize Folium Map centered on Indiana
m = folium.Map(location=[39.8283, -86.2790], zoom_start=7)


In [336]:
# Function to style each feature
def style_function(feature):
    predicted_accidents = feature['properties']['predicted_accidents']
    return {
        'color': color_scale(predicted_accidents),
        'weight': 3,
        'opacity': 1
    }

# Add GeoJson layer to the map
folium.GeoJson(
    data=viz_gdf,
    style_function=style_function,
    tooltip=folium.GeoJsonTooltip(
        fields=['segment_id', 'annual_estimate', 'predicted_accidents'],
        aliases=['Segment ID:', 'Annual Train Volume:', 'Predicted Accidents:'],
        localize=True
    )
).add_to(m)

# Add color scale to the map
color_scale.add_to(m)

# Display the map
m


In [337]:
# # Save the map to an HTML file
# m.save('predicted_accidents_map.html')


In [339]:
# Calculate centroids of track segments
viz_gdf['centroid'] = viz_gdf.geometry.centroid

# Filter segments with predicted accidents greater than zero
segments_with_accidents = viz_gdf[viz_gdf['predicted_accidents'] > 0].copy()

# Ensure centroids are valid
segments_with_accidents = segments_with_accidents[
    ~segments_with_accidents['centroid'].is_empty &
    segments_with_accidents['centroid'].notnull()
]

# Initialize Folium Map
m = folium.Map(location=[39.8283, -86.2790], zoom_start=7, tiles='cartodbpositron')

# Add bubbles for predicted accidents
for idx, row in segments_with_accidents.iterrows():
    centroid = row['centroid']
    predicted_accidents = float(row['predicted_accidents'])
    segment_id = str(row['segment_id'])
    
    folium.CircleMarker(
        location=[centroid.y, centroid.x],
        radius=5 + predicted_accidents * 2,
        color='blue',
        fill=True,
        fill_color='blue',
        fill_opacity=0.6,
        tooltip=folium.Tooltip(
            f"<b>High Risk Area</b><br>"
            f"Segment ID: {segment_id}<br>"
            f"Predicted Accidents: {predicted_accidents}"
        )
    ).add_to(m)

# Add color scale to the map (if you're using one)
color_scale.add_to(m)

# Display the map
m
