In [15]:
import sys
from pathlib import Path

sys.path.append(str(Path("../..").resolve()))

from src.data_ingestion import *
from src.data_preprocessing import *
from src.descriptive_analytics import *

from pyspark.sql import DataFrame
from pyspark.sql.functions import col
from pyspark.sql.types import NumericType, StringType
from pyspark.sql import functions as F

import seaborn as sns

import numpy as np

from itertools import combinations

from scipy import stats

import matplotlib.pyplot as plt

import pandas as pd
from pyspark.sql.window import Window

from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier




In [16]:
spark = init_spark()
df = load_data(spark, "../../data/US_Accidents_March23.csv")

In [17]:
# df = preprocess_data(df)

In [20]:
from pyspark.sql import functions as F
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml import Pipeline

state_features = df.groupBy("State").agg(
    F.avg("Visibility(mi)").alias("Avg_Visibility"),
    F.avg(F.when(F.col("Sunrise_Sunset") == "Night", 1).otherwise(0)).alias("Prop_Night_Accidents"),
    F.avg("Precipitation(in)").alias("Avg_Precipitation"),
    F.avg("Temperature(F)").alias("Avg_Temperature"),
    F.avg("Distance(mi)").alias("Avg_Accident_Distance"),
    F.countDistinct("City").alias("Num_Unique_Cities"),
    F.avg(F.unix_timestamp("End_Time") - F.unix_timestamp("Start_Time")).alias("Avg_Accident_Duration_Seconds"),
    F.count("*").alias("Total_Accidents"),
    F.avg("Severity").alias("Avg_Severity")
)

state_features = state_features.withColumn(
    "Risk_Score",
    F.col("Total_Accidents") * F.col("Avg_Severity")
)

state_features = state_features.drop("Total_Accidents", "Avg_Severity")

state_features = state_features.cache()

feature_cols = [
    "Avg_Visibility", "Avg_Precipitation", "Avg_Temperature",
    "Avg_Accident_Distance", "Num_Unique_Cities"
]
state_features = state_features.na.drop(subset=feature_cols + ["Risk_Score"])

assembler = VectorAssembler(
    inputCols=feature_cols,
    outputCol="features",
    handleInvalid="skip"
)

model = LinearRegression(
    featuresCol="features",
    labelCol="Risk_Score",
    maxIter=100,
    regParam=0.0
)

pipeline = Pipeline(stages=[assembler, model])

train_df, test_df = state_features.randomSplit([0.8, 0.2], seed=123)

try:
    model_fitted = pipeline.fit(train_df)
except Exception as e:
    print(f"Error: {e}")
    spark.stop()
    raise e

predictions = model_fitted.transform(test_df)

evaluators = {
    "rmse": RegressionEvaluator(labelCol="Risk_Score", predictionCol="prediction", metricName="rmse"),
    "mae": RegressionEvaluator(labelCol="Risk_Score", predictionCol="prediction", metricName="mae"),
    "r2": RegressionEvaluator(labelCol="Risk_Score", predictionCol="prediction", metricName="r2")
}

print("Regression Metrics:")
for metric, evaluator in evaluators.items():
    value = evaluator.evaluate(predictions)
    print(f"{metric.upper()}: {value:.4f}")

lr_model = model_fitted.stages[-1]
coefficients = lr_model.coefficients.toArray()
abs_coefficients = [abs(coef) for coef in coefficients]
total = sum(abs_coefficients)
importance = [coef / total if total > 0 else 0 for coef in abs_coefficients]
print("\nFeature Importance:")
for feature, imp in zip(feature_cols, importance):
    print(f"{feature}: {imp:.4f}")

predictions.select("State", "Risk_Score", "prediction").show(10, truncate=False)

spark.stop()

Regression Metrics:
RMSE: 317208.9238
MAE: 200413.4663
R2: 0.5838

Feature Importance:
Avg_Visibility: 0.0034
Avg_Precipitation: 0.9957
Avg_Temperature: 0.0007
Avg_Accident_Distance: 0.0001
Num_Unique_Cities: 0.0000
+-----+------------------+------------------+
|State|Risk_Score        |prediction        |
+-----+------------------+------------------+
|MN   |415289.0          |498768.71314730117|
|NJ   |314331.0          |288799.1398893113 |
|DC   |39936.0           |52651.29227274125 |
|NE   |62955.99999999999 |8818.791077522372 |
|NC   |721657.0          |803182.9571533577 |
|MO   |185545.0          |243860.13372003625|
|IL   |402678.00000000006|839219.4194777964 |
|MS   |35589.0           |80191.04793966173 |
|OH   |278007.0          |784559.627471513  |
|NY   |786233.0          |830852.8840830902 |
+-----+------------------+------------------+
only showing top 10 rows

