![My Screenshot](enough-is-enough-ending-gun-violence-together-vector.jpg)


In [1]:
from pyspark.ml.feature import VectorAssembler
from pyspark.sql.functions import *

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, to_date, when, year, month, dayofweek, \
    lit, round as spark_round
from pyspark.sql.types import IntegerType, DoubleType

from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, GBTClassifier
from pyspark.sql.functions import monotonically_increasing_id
from pyspark.ml.functions import vector_to_array
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.sql.functions import countDistinct, col, desc, round, sum, count, mean, first, lit, when



#### This section is responsible for initializing the Spark session, which serves as the entry point for using Spark functionality.
# 
##### 1. The code first attempts to **stop any existing SparkSession** by calling `spark.stop()`. This is done inside a try-except block to prevent errors if no Spark session is currently active.
##### 2. Next, a new SparkSession is **configured** and created:
###### -   The application name is set to "PublicSafetySafe", which helps identify the Spark job.
###### -    Certain **configuration options** are set specifically for experimentation and debugging:
             
###### 1.            "spark.sql.adaptive.enabled" and "spark.sql.adaptive.coalescePartitions.enabled" are set to "false" to disable adaptive query execution features. This can help achieve more predictable execution plans and resource allocation during testing.
            
###### 2.           "spark.sql.autoBroadcastJoinThreshold" is set to "-1" to completely disable automatic broadcast joins, preventing Spark from broadcasting small tables, which can be useful for debugging join behavior.
             
######  3.           "spark.sql.shuffle.partitions" is set to "50" to control the number of partitions produced when shuffling data, which can affect parallelism and performance.

##### 3. After creating the SparkSession, the SparkContext's log level is set to **"ERROR" to suppress informational and warning logs**, so the output is less noisy and only shows errors, making the debugging process cleaner.


In [2]:
try:
    spark.stop()
except:
    pass
spark = (
    SparkSession.builder
    .appName("PublicSafetySafe")
    .config("spark.sql.adaptive.enabled", "false")  
    .config("spark.sql.adaptive.coalescePartitions.enabled", "false")
    .config("spark.sql.autoBroadcastJoinThreshold", "-1") 
    .config("spark.sql.shuffle.partitions", "50")
    .getOrCreate()
)

spark.sparkContext.setLogLevel("ERROR")


### **load The Dataset üîÉ**
##### show details about data strucure by `printSchema()`

In [3]:

df = spark.read.option("header", "true").option("inferSchema", "true") \
    .csv(r"C:\Users\Gell G15\.cache\kagglehub\datasets\jameslko\gun-violence-data\versions\1\gun-violence-data_01-2013_03-2018.csv")

print("Number of Rows", df.count())
df.printSchema()
df.show(5)

Number of Rows 246939
root
 |-- incident_id: string (nullable = true)
 |-- date: string (nullable = true)
 |-- state: string (nullable = true)
 |-- city_or_county: string (nullable = true)
 |-- address: string (nullable = true)
 |-- n_killed: string (nullable = true)
 |-- n_injured: string (nullable = true)
 |-- incident_url: string (nullable = true)
 |-- source_url: string (nullable = true)
 |-- incident_url_fields_missing: string (nullable = true)
 |-- congressional_district: string (nullable = true)
 |-- gun_stolen: string (nullable = true)
 |-- gun_type: string (nullable = true)
 |-- incident_characteristics: string (nullable = true)
 |-- latitude: string (nullable = true)
 |-- location_description: string (nullable = true)
 |-- longitude: double (nullable = true)
 |-- n_guns_involved: string (nullable = true)
 |-- notes: string (nullable = true)
 |-- participant_age: string (nullable = true)
 |-- participant_age_group: string (nullable = true)
 |-- participant_gender: string (null

### **Filteration of Columns for Needed Data üßµ**  
#### **Columns Are not used üóëÔ∏è**
| Column Name                     | Description |
|---------------------------------|-------------|
| **incident_id**                  | A unique ID number for each reported gun violence incident. (Not useful for our analysis ü´†) |
| **participant_name**             | Names of the participants (if available; often anonymized). |
| **address**                      | The street address or general location where the event took place. (Not needed ,we selected what is useful) |
| **state_house_district**         | The state‚Äôs house district number for that location. |
| **state_senate_district**        | The state‚Äôs senate district number. |
| **location_description**         | Additional context about where it occurred (e.g., Home, Street, Bar). |
| **congressional_district**      | The U.S. congressional district where the incident occurred. |
| **incident_url**                 | Link to the Gun Violence Archive (GVA) page with detailed info about this incident. |
| **source_url**                   | URL of the news article(s) or reports used as a source. |
| **incident_url_fields_missing**  | A flag indicating if any data fields were missing in the GVA entry. |
| **sources**                      | List of media outlets or reports that confirmed the data. |
| **notes**                        | Free-text field with extra information ‚Äî can include motives, victim names, or event summary. |
| **participant_age_group**        | Age group(s) such as Adult 18+, Teen 12-17, Child 0-11. (we can get it from another one üöÆ)|
| **participant_status**           | Outcome for each participant (Killed, Injured, Arrested, etc.). (we can get it from another one üöÆ) |

##### **Columns May be used in Future üîÆ**

| Column Name                     | Description |
|---------------------------------|-------------|
| **gun_stolen**                   | Whether the gun used was reported stolen (Yes, No, or Unknown). |
| **gun_type**                     | The type(s) of gun(s) involved (e.g., Handgun, Rifle, Shotgun). |
| **n_guns_involved**              | Number of guns used in the incident. |
| **incident_characteristics**     | Descriptive tags about what kind of event it was (e.g., Suicide Attempt, Home Invasion, Gang Involved, Accidental Shooting). |
| **participant_age**              | Age(s) of the participants. |
| **participant_type**             | Whether each person was a Victim, Suspect, or Subject-Suspect. |
| **participant_gender**           | Gender(s) of the participants (Male, Female). |
| **participant_relationship**     | Relationship between participants (e.g., Family, Stranger, Acquaintance). |


##### **Columns Are usedüëå**
| Column Name                     | Description |
|---------------------------------|-------------|
| **date**                         | The date when the incident occurred (e.g., 2015-06-17). |
| **state**                        | The U.S. state where the incident happened (e.g., Texas, California). |
| **city_or_county**               | The city or county of the incident location. |
| **n_killed**                     | Number of people killed in the incident. |
| **n_injured**                    | Number of people injured in the incident. |
| **latitude**                     | Latitude coordinate of the incident. |
| **longitude**                    | Longitude coordinate of the incident. |


In [4]:
df = df.select(
    "date", "state", "city_or_county", "latitude", "longitude", "n_killed", "n_injured"
)
df = df.na.drop()
print(f"Number of rows: {df.count()}")
print(f"Number of columns: {len(df.columns)}")

Number of rows: 231754
Number of columns: 7


### **Preprocessing and Data Cleaning üßπ**
### When converting all data formats, you will encounter two problems ‚¨áÔ∏è‚¨áÔ∏è :


##### 1-**define Date column Format**

In [5]:
#======================================================Check date format============================================================================
# there is a problem with the date format , data in string format so we need to convert it to date format
date_sample = df.select("date").distinct().limit(10).collect()
for row in date_sample:
    print(f"  '{row['date']}' (type: {type(row['date']).__name__})")

  '2/22/2013' (type: str)
  '5/19/2013' (type: str)
  '6/25/2013' (type: str)
  '7/14/2013' (type: str)
  '8/19/2013' (type: str)
  '8/23/2013' (type: str)
  '10/26/2013' (type: str)
  '3/7/2014' (type: str)
  '3/17/2014' (type: str)
  '3/26/2014' (type: str)


##### 2-**there is one record that have url string**

In [6]:
#======================================================Filteration the invalid values============================================================================
# filter the invalid values in the latitude column , because it is url and not a number
df_clean = df.filter(
    col("latitude").rlike("^-?[0-9]*\\.?[0-9]+$")  
)

print(f"Number of rows after removing non-numeric values: {df_clean.count()}")


Number of rows after removing non-numeric values: 231753


##### **Prepare all Data columns format** 
###### after the previous steps , we delete records that not able to convert as a stable format
###### Now, easy convertion using **withColumn**

In [7]:
#================================================================= data format ============================================================================
#all the data is in the correct format 
df_clean = df_clean.withColumn(
    "date", 
    to_date(col("date"), "M/d/yyyy")  #  MM/dd/yyyy
).withColumn("n_killed", col("n_killed").cast(IntegerType())) \
 .withColumn("n_injured", col("n_injured").cast(IntegerType())) \
 .withColumn("latitude", col("latitude").cast(DoubleType())) \
 .withColumn("longitude", col("longitude").cast(DoubleType()))
 
'''print("=== Check date format ===")
df_clean.select("date").show(10)'''

print("check the data format üíØ")
df_clean.printSchema()

check the data format üíØ
root
 |-- date: date (nullable = true)
 |-- state: string (nullable = true)
 |-- city_or_county: string (nullable = true)
 |-- latitude: double (nullable = true)
 |-- longitude: double (nullable = true)
 |-- n_killed: integer (nullable = true)
 |-- n_injured: integer (nullable = true)



### **(Optional)** The next Step for failed data format conversion
- by this step i covered these two columns problems
    - there is a problem with the date format , data in string format so we need to convert it to date format
    - filter the invalid values in the latitude column , because it is url and not a number


In [8]:
#============================================failed in data format conversion==============================
# Check if there are any rows where the date conversion failed (date is null)
# If there are, try an alternative method to convert the date column

if df_clean.filter(col("date").isNull()).count() > 0:
    print("‚ö†Ô∏è  There is an issue converting some dates, trying the alternative method...")
    
    df_clean = df.filter(
        col("latitude").rlike("^-?[0-9]*\\.?[0-9]+$")
    )
    
    # Split the date into parts and convert it
    from pyspark.sql.functions import split, expr
    
    df_clean = df_clean.withColumn("date_parts", split(col("date"), "/")) \
        .withColumn("month", col("date_parts").getItem(0).cast(IntegerType())) \
        .withColumn("day", col("date_parts").getItem(1).cast(IntegerType())) \
        .withColumn("year", col("date_parts").getItem(2).cast(IntegerType())) \
        .withColumn("date_formatted", 
                   expr("make_date(year, month, day)")) \
        .drop("date_parts", "month", "day", "year")
    
# All conversions
    df_clean = df_clean.withColumn("n_killed", col("n_killed").cast(IntegerType())) \
        .withColumn("n_injured", col("n_injured").cast(IntegerType())) \
        .withColumn("latitude", col("latitude").cast(DoubleType())) \
        .withColumn("longitude", col("longitude").cast(DoubleType())) \
        .withColumnRenamed("date_formatted", "date")

#### **Shown** the data after cleaning and Data format Conversion üíØ

In [9]:
# Cell 6: Final Data Verification
print("=== Final Data Verification ===")
print(f"number of row after cleaning :{df_clean.count()}")

print("\n\nsample of data after cleaning")
df_clean.select("date", "latitude", "longitude", "n_killed", "n_injured").show(10)



=== Final Data Verification ===
number of row after cleaning :231753


sample of data after cleaning
+----------+--------+---------+--------+---------+
|      date|latitude|longitude|n_killed|n_injured|
+----------+--------+---------+--------+---------+
|2013-01-01| 40.3467| -79.8559|       0|        4|
|2013-01-01|  33.909| -118.333|       1|        3|
|2013-01-01| 41.4455| -82.1377|       1|        3|
|2013-01-05| 39.6518| -104.802|       4|        0|
|2013-01-07|  36.114| -79.9569|       2|        2|
|2013-01-07| 36.2405| -95.9768|       4|        0|
|2013-01-19| 34.9791| -106.716|       5|        0|
|2013-01-21| 29.9435| -90.0836|       0|        5|
|2013-01-21| 37.9656| -121.718|       0|        4|
|2013-01-23| 39.2899| -76.6412|       1|        6|
+----------+--------+---------+--------+---------+
only showing top 10 rows


### **'''Define target variable (hotspots)'''ü•Ö**
#### Now,
- This code is creating a **new dataset** called df_with_target based on the cleaned data (df_clean).
    - It **creates a new column** called "total_victims", which **sums** up the number of people killed ("n_killed") and the number of people injured ("n_injured") for each incident.

    
    -  It then **creates another column** called **"IsHotspot"**:(Target (when 5->1))
        - If the **total number of victims** (killed + injured) is **greater than or equal to 5**, OR if the number of people **killed alone is at least 2**, **"IsHotspot" will be set to 1**.
        - Otherwise, "IsHotspot" will be set to 0.


In [10]:
# Define Target Variable (Hotspots) 
df_with_target = df_clean.withColumn(
    "total_victims", col("n_killed") + col("n_injured")
).withColumn(
    "IsHotspot", 
    when((col("total_victims") >= 5) | (col("n_killed") >= 2), 1).otherwise(0)
)

#### This code calculates and displays summary statistics for the "hotspots" in the dataset.
 - For each group **(hotspot->1 and non-hotspot->0)**, it computes:
   - The **total** number of **incidents** (count)
   - The **average** number of people **killed per incident** (avg_killed)
   - The **average** number of people **injured per incident** (avg_injured)


# **pumbüí•**
### can you define the problem ?? ü´† 
##### *think ,man*

In [11]:
# Hotspot Statistics

hotspot_stats = df_with_target.groupBy("IsHotspot").agg(
    count("*").alias("count"),
    avg("n_killed").alias("avg_killed"),
    avg("n_injured").alias("avg_injured")
)
print("=== ÿ•ÿ≠ÿµÿßÿ¶Ÿäÿßÿ™ ÿßŸÑŸÖŸÜÿßÿ∑ŸÇ ÿßŸÑÿ≥ÿßÿÆŸÜÿ© ===")
hotspot_stats.show()


=== ÿ•ÿ≠ÿµÿßÿ¶Ÿäÿßÿ™ ÿßŸÑŸÖŸÜÿßÿ∑ŸÇ ÿßŸÑÿ≥ÿßÿÆŸÜÿ© ===
+---------+------+-------------------+------------------+
|IsHotspot| count|         avg_killed|       avg_injured|
+---------+------+-------------------+------------------+
|        0|226090|0.20117209960635146|0.4863328762882038|
|        1|  5663|  2.065159809288363|0.8087586085113897|
+---------+------+-------------------+------------------+



### **Temporal Feature EngineeringüìÖ**
##### What is done:
   - This cell extracts new temporal features from the "date" column, adding columns for year, month, day of the week, and day of the month to the DataFrame.
##### Why is this done:
   - Temporal features help machine learning models learn and utilize seasonal or periodic patterns in the data.


In [12]:
# Temporal Feature Engineering

df_features = df_with_target.withColumn("year", year(col("date"))) \
    .withColumn("month", month(col("date"))) \
    .withColumn("day_of_week", dayofweek(col("date"))) \
    .withColumn("day_of_month", dayofmonth(col("date")))

print('Display a preview to verify the extracted temporal features')
df_features.select("date", "year", "month", "day_of_week", "day_of_month").show(10)


Display a preview to verify the extracted temporal features
+----------+----+-----+-----------+------------+
|      date|year|month|day_of_week|day_of_month|
+----------+----+-----+-----------+------------+
|2013-01-01|2013|    1|          3|           1|
|2013-01-01|2013|    1|          3|           1|
|2013-01-01|2013|    1|          3|           1|
|2013-01-05|2013|    1|          7|           5|
|2013-01-07|2013|    1|          2|           7|
|2013-01-07|2013|    1|          2|           7|
|2013-01-19|2013|    1|          7|          19|
|2013-01-21|2013|    1|          2|          21|
|2013-01-21|2013|    1|          2|          21|
|2013-01-23|2013|    1|          4|          23|
+----------+----+-----+-----------+------------+
only showing top 10 rows


### **LAt and LOn Feature Engineering üó∫Ô∏è**
- **What** it does: Adds **two new columns** (lat_grid, lon_grid) to df_features by **rounding** latitude and longitude to 3 decimal places.
- **Why**: Rounding creates a spatial grid (binning) so **nearby points fall into the same cell**. This **reduces** cardinality and makes aggregations like counts, heatmaps, and clustering **fast and scalable in Spark**.

In [13]:
# Cell 10: Spatial Feature Engineering (Grid Clustering)
df_spatial = df_features.withColumn(
    "lat_grid", round(col("latitude"), 3)
).withColumn(
    "lon_grid", round(col("longitude"), 3)
)

#### Calculate Grid (Spatial) **Statistics**üìè
- Here we are **grouping** the data by spatial grids defined by **latitude and longitude** (lat_grid, lon_grid).
    - For each grid cell, we calculate: **such as hotspot**
       - incident_count: total number of incidents in that grid cell
       - total_killed: total number of people killed in that grid cell
       - total_injured: total number of people injured in that grid cell
       - avg_killed_per_incident: average number of killed per incident in that cell
#### Why? 
- Calculating these statistics per grid cell **helps us identify spatial patterns**, such as accident hotspots or areas with severe outcomes.
- This summary can be used for heatmaps or as **features for further modeling**.

In [14]:
# Calculate Grid (Spatial) Statistics

grid_counts = df_spatial.groupBy("lat_grid", "lon_grid") \
    .agg(
        count("*").alias("incident_count"),
        sum("n_killed").alias("total_killed"),
        sum("n_injured").alias("total_injured"),
        avg("n_killed").alias("avg_killed_per_incident")
    )

print("=== Grid Aggregation Statistics ===")
# Show the top 10 grid cells with the most incidents
grid_counts.orderBy(desc("incident_count")).show(10)


=== Grid Aggregation Statistics ===
+--------+--------+--------------+------------+-------------+-----------------------+
|lat_grid|lon_grid|incident_count|total_killed|total_injured|avg_killed_per_incident|
+--------+--------+--------------+------------+-------------+-----------------------+
|  33.636| -84.433|           252|           0|            0|                    0.0|
|  39.294|  -76.62|           235|           3|           83|    0.01276595744680851|
|  29.987| -95.348|           169|           0|            0|                    0.0|
|  32.898|  -97.04|           160|           0|            1|                    0.0|
|  33.435|-112.006|           159|           0|            0|                    0.0|
|  38.908| -77.018|           134|           6|           11|    0.04477611940298507|
|  29.955| -90.075|           109|           8|           41|    0.07339449541284404|
|  28.436| -81.307|           104|           0|            1|                    0.0|
|   39.85|-104.674

In [15]:
# Merge Grid Statistics
# What: Merge the calculated grid statistics (incident_count, total_killed, total_injured, avg_killed_per_incident) into the main dataframe, 
#       so every incident record gets these aggregate features from its spatial "grid cell".
# Why: This enriches every row with additional spatial context. It's valuable for analysis and modeling, letting us use "grid features" as predictors.
df_final = df_spatial.join(grid_counts, ["lat_grid", "lon_grid"])


# **ÿßŸÑŸÖŸÜÿ≥Ÿä**
### We gave our attention to numerical data and later to categorical data. Don't be sad, now we are giving it all our attention. üëè
### **Handel unkown data if State and city_or_county columns (Categorical Data)** üö¶ü´∑


In [16]:
# Handle Categorical Data

df_final = df_final.fillna({
    "state": "Unknown",
    "city_or_county": "Unknown"
})

# **FINALLy**
###### **Not the end , because X have no endü¶¶**

In [17]:
# Cell 14: Display Final Results

print("=== ÿßŸÑÿ®ŸäÿßŸÜÿßÿ™ ÿßŸÑŸÜŸáÿßÿ¶Ÿäÿ© ÿ®ÿπÿØ ÿßŸÑÿ™ŸÜÿ∏ŸäŸÅ  ===")
df_final.select(
    "date", "latitude", "longitude", "n_killed", "n_injured", 
    "IsHotspot", "lat_grid", "lon_grid", "incident_count", "state"
).show(15)


=== ÿßŸÑÿ®ŸäÿßŸÜÿßÿ™ ÿßŸÑŸÜŸáÿßÿ¶Ÿäÿ© ÿ®ÿπÿØ ÿßŸÑÿ™ŸÜÿ∏ŸäŸÅ  ===
+----------+--------+---------+--------+---------+---------+--------+--------+--------------+-------+
|      date|latitude|longitude|n_killed|n_injured|IsHotspot|lat_grid|lon_grid|incident_count|  state|
+----------+--------+---------+--------+---------+---------+--------+--------+--------------+-------+
|2015-09-04| 19.4475| -155.189|       0|        0|        0|  19.448|-155.189|             1| Hawaii|
|2016-12-20| 25.5399| -80.5126|       0|        1|        0|   25.54| -80.513|             1|Florida|
|2014-01-03| 25.5514| -80.4047|       0|        0|        0|  25.551| -80.405|             1|Florida|
|2014-09-02| 25.6077| -80.3734|       0|        1|        0|  25.608| -80.373|             1|Florida|
|2014-08-31| 25.6864| -80.3804|       0|        1|        0|  25.686|  -80.38|             1|Florida|
|2016-01-26|  25.727| -80.2496|       0|        2|        0|  25.727|  -80.25|             1|Florida|
|2017-01-02| 25.7

# **The same pumb** üí•

In [18]:
# Final Summary Statistics
print("=== Final Data Summary ===")
print(f"Total rows: {df_final.count()}")
print(f"Number of hotspots: {df_final.filter(col('IsHotspot') == 1).count()}")
print(f"Percentage of hotspots: {(df_final.filter(col('IsHotspot') == 1).count() / df_final.count() * 100):.2f}%")

=== Final Data Summary ===
Total rows: 231753
Number of hotspots: 5663
Percentage of hotspots: 2.44%


### **Oversampling** with SMOTE-like approach
#### are you getting it?üëè pravo


In [19]:
# Cell 16: Balance Classes Using Oversampling 
print("===== Imbalanced Data =====")
#==================================================== Show original class distribution ==========================================================
print("\n Original class distribution:‚≠ï")
original_distribution = df_with_target.groupBy("IsHotspot").count().orderBy("IsHotspot")
original_distribution.show()

hotspot_count = df_with_target.filter(col("IsHotspot") == 1).count()
non_hotspot_count = df_with_target.filter(col("IsHotspot") == 0).count()

print(f"Class 0 (Non-Hotspot): {non_hotspot_count:,} ({non_hotspot_count/(non_hotspot_count+hotspot_count)*100:.2f}%)")
print(f"Class 1 (Hotspot): {hotspot_count:,} ({hotspot_count/(non_hotspot_count+hotspot_count)*100:.2f}%)")
print(f"Ratio: {non_hotspot_count/hotspot_count:.1f}:1")
#======================================================== ÿßŸÑŸÑŸä ÿπÿßŸàÿ≤ ÿØÿß Ÿäÿ±Ÿàÿ≠ ŸÉÿØÿß ŸàÿßŸÑŸÑŸä ÿπÿßŸàÿ≤ ÿØÿß Ÿäÿ¨Ÿä ŸÉÿØÿß====================================

# Split by class
hotspot_df = df_with_target.filter(col("IsHotspot") == 1)
non_hotspot_df = df_with_target.filter(col("IsHotspot") == 0)

print("\nüîÑ Applying oversampling to the minority class...")

#===================================================  Calculate oversampling factor for target ratio (e.g., 1:3 or 1:2, better than 1:40)

target_ratio = 3  # Each hotspot will be repeated to achieve a 1:3 ratio
oversample_factor = (non_hotspot_count / hotspot_count) / target_ratio

print(f"Oversampling factor: {oversample_factor:.2f}x")

#=================================================  Perform oversampling with replacement
hotspot_oversampled = hotspot_df.sample(withReplacement=True, fraction=oversample_factor, seed=42)

# Combine to form balanced DataFrame
df_balanced = non_hotspot_df.union(hotspot_oversampled)

#========================================================== Show new distribution
print("\n Distribution after oversampling:‚òëÔ∏è")
balanced_distribution = df_balanced.groupBy("IsHotspot").count().orderBy("IsHotspot")
balanced_distribution.show()

hotspot_count_new = df_balanced.filter(col("IsHotspot") == 1).count()
non_hotspot_count_new = df_balanced.filter(col("IsHotspot") == 0).count()

print(f"Class 0 (Non-Hotspot): {non_hotspot_count_new:,} ({non_hotspot_count_new/(non_hotspot_count_new+hotspot_count_new)*100:.2f}%)")
print(f"Class 1 (Hotspot): {hotspot_count_new:,} ({hotspot_count_new/(non_hotspot_count_new+hotspot_count_new)*100:.2f}%)")
print(f"New Ratio: {non_hotspot_count_new/hotspot_count_new:.1f}:1")

print("\n‚úÖ Balanced dataset created successfully!")
print(f"Total rows: {df_balanced.count():,}")

# Replace df_final with the balanced data
df_final = df_balanced

# print("\n‚ö†Ô∏è Note: Use the balanced df_final in the next modeling steps.")

===== Imbalanced Data =====

 Original class distribution:‚≠ï
+---------+------+
|IsHotspot| count|
+---------+------+
|        0|226090|
|        1|  5663|
+---------+------+

Class 0 (Non-Hotspot): 226,090 (97.56%)
Class 1 (Hotspot): 5,663 (2.44%)
Ratio: 39.9:1

üîÑ Applying oversampling to the minority class...
Oversampling factor: 13.31x

 Distribution after oversampling:‚òëÔ∏è
+---------+------+
|IsHotspot| count|
+---------+------+
|        0|226090|
|        1| 75587|
+---------+------+

Class 0 (Non-Hotspot): 226,090 (74.94%)
Class 1 (Hotspot): 75,587 (25.06%)
New Ratio: 3.0:1

‚úÖ Balanced dataset created successfully!
Total rows: 301,677


### **ÿßÿ™ŸÜŸÉÿ±** ŸÖÿπŸÑÿ¥
#### Solutions to Reduce Overfitting Risk:
- Add Noise/Variation

In [20]:
from pyspark.sql.functions import rand, randn

# Add small random noise to numerical features
hotspot_oversampled = hotspot_df.sample(withReplacement=True, fraction=oversample_factor, seed=42) \
    .withColumn("latitude", col("latitude") + (randn(seed=42) * 0.001)) \
    .withColumn("longitude", col("longitude") + (randn(seed=43) * 0.001))

### **(OPtional)**

In [21]:
# ============================================
# ÿßŸÑÿ¨ÿ≤ÿ° 1: ÿßŸÑÿ™ÿ≠ŸÇŸÇ ŸÖŸÜ ÿßŸÑÿ®ŸäÿßŸÜÿßÿ™ Ÿàÿ•ÿ∂ÿßŸÅÿ© ÿßŸÑŸÖŸäÿ≤ÿßÿ™ ÿßŸÑŸÖŸÅŸÇŸàÿØÿ©
# ============================================

# Cell 1: Check Current Columns
print("=== ŸÅÿ≠ÿµ ÿßŸÑÿ£ÿπŸÖÿØÿ© ÿßŸÑŸÖŸàÿ¨ŸàÿØÿ© ===")
print(f"ÿßŸÑÿ£ÿπŸÖÿØÿ© ÿßŸÑÿ≠ÿßŸÑŸäÿ©: {df_final.columns}")

# Cell 2: Create Temporal Features if Missing
print("\n=== ÿ•ŸÜÿ¥ÿßÿ° ÿßŸÑŸÖŸäÿ≤ÿßÿ™ ÿßŸÑÿ≤ŸÖŸÜŸäÿ© ===")

# ÿßŸÑÿ™ÿ≠ŸÇŸÇ ŸÖŸÜ Ÿàÿ¨ŸàÿØ ÿßŸÑŸÖŸäÿ≤ÿßÿ™
if "month" not in df_final.columns:
    print("‚ö†Ô∏è ÿßŸÑŸÖŸäÿ≤ÿßÿ™ ÿßŸÑÿ≤ŸÖŸÜŸäÿ© ÿ∫Ÿäÿ± ŸÖŸàÿ¨ŸàÿØÿ©ÿå ÿ¨ÿßÿ±Ÿä ÿ•ŸÜÿ¥ÿßÿ¶Ÿáÿß...")
    
    from pyspark.sql.functions import year, month, dayofweek, dayofmonth
    
    df_final = df_final.withColumn("year", year(col("date"))) \
        .withColumn("month", month(col("date"))) \
        .withColumn("day_of_week", dayofweek(col("date"))) \
        .withColumn("day_of_month", dayofmonth(col("date")))
    
    print("‚úÖ ÿ™ŸÖ ÿ•ŸÜÿ¥ÿßÿ° ÿßŸÑŸÖŸäÿ≤ÿßÿ™ ÿßŸÑÿ≤ŸÖŸÜŸäÿ©")
else:
    print("‚úÖ ÿßŸÑŸÖŸäÿ≤ÿßÿ™ ÿßŸÑÿ≤ŸÖŸÜŸäÿ© ŸÖŸàÿ¨ŸàÿØÿ© ÿ®ÿßŸÑŸÅÿπŸÑ")

# Cell 3: Create Spatial Features if Missing
print("\n=== ÿ•ŸÜÿ¥ÿßÿ° ÿßŸÑŸÖŸäÿ≤ÿßÿ™ ÿßŸÑŸÖŸÉÿßŸÜŸäÿ© ===")

if "lat_grid" not in df_final.columns:
    print("‚ö†Ô∏è ÿßŸÑŸÖŸäÿ≤ÿßÿ™ ÿßŸÑŸÖŸÉÿßŸÜŸäÿ© ÿ∫Ÿäÿ± ŸÖŸàÿ¨ŸàÿØÿ©ÿå ÿ¨ÿßÿ±Ÿä ÿ•ŸÜÿ¥ÿßÿ¶Ÿáÿß...")
    
    from pyspark.sql.functions import round as spark_round
    
    df_final = df_final.withColumn("lat_grid", spark_round(col("latitude"), 3)) \
        .withColumn("lon_grid", spark_round(col("longitude"), 3))
    
    print("‚úÖ ÿ™ŸÖ ÿ•ŸÜÿ¥ÿßÿ° ÿßŸÑŸÖŸäÿ≤ÿßÿ™ ÿßŸÑŸÖŸÉÿßŸÜŸäÿ©")
else:
    print("‚úÖ ÿßŸÑŸÖŸäÿ≤ÿßÿ™ ÿßŸÑŸÖŸÉÿßŸÜŸäÿ© ŸÖŸàÿ¨ŸàÿØÿ© ÿ®ÿßŸÑŸÅÿπŸÑ")

# Cell 4: Create Grid Statistics if Missing
print("\n=== ÿ•ŸÜÿ¥ÿßÿ° ÿ•ÿ≠ÿµÿßÿ¶Ÿäÿßÿ™ ÿßŸÑÿ¥ÿ®ŸÉÿ© ===")

if "incident_count" not in df_final.columns:
    print("‚ö†Ô∏è ÿ•ÿ≠ÿµÿßÿ¶Ÿäÿßÿ™ ÿßŸÑÿ¥ÿ®ŸÉÿ© ÿ∫Ÿäÿ± ŸÖŸàÿ¨ŸàÿØÿ©ÿå ÿ¨ÿßÿ±Ÿä ÿ•ŸÜÿ¥ÿßÿ¶Ÿáÿß...")
    
    from pyspark.sql.functions import count, sum as spark_sum, avg
    
    # ÿ≠ÿ≥ÿßÿ® ÿπÿØÿØ ÿßŸÑÿ≠ŸàÿßÿØÿ´ ŸÅŸä ŸÉŸÑ ÿ¥ÿ®ŸÉÿ©
    grid_counts = df_final.groupBy("lat_grid", "lon_grid") \
        .agg(
            count("*").alias("incident_count"),
            spark_sum("n_killed").alias("total_killed"),
            spark_sum("n_injured").alias("total_injured"),
            avg("n_killed").alias("avg_killed_per_incident")
        )
    
    # ÿØŸÖÿ¨ ÿßŸÑÿ•ÿ≠ÿµÿßÿ¶Ÿäÿßÿ™ ŸÖÿπ ÿßŸÑÿ®ŸäÿßŸÜÿßÿ™ ÿßŸÑÿ£ÿµŸÑŸäÿ©
    df_final = df_final.join(grid_counts, ["lat_grid", "lon_grid"])
    
    print("‚úÖ ÿ™ŸÖ ÿ•ŸÜÿ¥ÿßÿ° ÿ•ÿ≠ÿµÿßÿ¶Ÿäÿßÿ™ ÿßŸÑÿ¥ÿ®ŸÉÿ©")
else:
    print("‚úÖ ÿ•ÿ≠ÿµÿßÿ¶Ÿäÿßÿ™ ÿßŸÑÿ¥ÿ®ŸÉÿ© ŸÖŸàÿ¨ŸàÿØÿ© ÿ®ÿßŸÑŸÅÿπŸÑ")

print(f"\n‚úÖ ÿ¨ŸÖŸäÿπ ÿßŸÑŸÖŸäÿ≤ÿßÿ™ ÿ¨ÿßŸáÿ≤ÿ©! ÿßŸÑÿ£ÿπŸÖÿØÿ© ÿßŸÑÿ≠ÿßŸÑŸäÿ©: {df_final.columns}")


=== ŸÅÿ≠ÿµ ÿßŸÑÿ£ÿπŸÖÿØÿ© ÿßŸÑŸÖŸàÿ¨ŸàÿØÿ© ===
ÿßŸÑÿ£ÿπŸÖÿØÿ© ÿßŸÑÿ≠ÿßŸÑŸäÿ©: ['date', 'state', 'city_or_county', 'latitude', 'longitude', 'n_killed', 'n_injured', 'total_victims', 'IsHotspot']

=== ÿ•ŸÜÿ¥ÿßÿ° ÿßŸÑŸÖŸäÿ≤ÿßÿ™ ÿßŸÑÿ≤ŸÖŸÜŸäÿ© ===
‚ö†Ô∏è ÿßŸÑŸÖŸäÿ≤ÿßÿ™ ÿßŸÑÿ≤ŸÖŸÜŸäÿ© ÿ∫Ÿäÿ± ŸÖŸàÿ¨ŸàÿØÿ©ÿå ÿ¨ÿßÿ±Ÿä ÿ•ŸÜÿ¥ÿßÿ¶Ÿáÿß...
‚úÖ ÿ™ŸÖ ÿ•ŸÜÿ¥ÿßÿ° ÿßŸÑŸÖŸäÿ≤ÿßÿ™ ÿßŸÑÿ≤ŸÖŸÜŸäÿ©

=== ÿ•ŸÜÿ¥ÿßÿ° ÿßŸÑŸÖŸäÿ≤ÿßÿ™ ÿßŸÑŸÖŸÉÿßŸÜŸäÿ© ===
‚ö†Ô∏è ÿßŸÑŸÖŸäÿ≤ÿßÿ™ ÿßŸÑŸÖŸÉÿßŸÜŸäÿ© ÿ∫Ÿäÿ± ŸÖŸàÿ¨ŸàÿØÿ©ÿå ÿ¨ÿßÿ±Ÿä ÿ•ŸÜÿ¥ÿßÿ¶Ÿáÿß...
‚úÖ ÿ™ŸÖ ÿ•ŸÜÿ¥ÿßÿ° ÿßŸÑŸÖŸäÿ≤ÿßÿ™ ÿßŸÑŸÖŸÉÿßŸÜŸäÿ©

=== ÿ•ŸÜÿ¥ÿßÿ° ÿ•ÿ≠ÿµÿßÿ¶Ÿäÿßÿ™ ÿßŸÑÿ¥ÿ®ŸÉÿ© ===
‚ö†Ô∏è ÿ•ÿ≠ÿµÿßÿ¶Ÿäÿßÿ™ ÿßŸÑÿ¥ÿ®ŸÉÿ© ÿ∫Ÿäÿ± ŸÖŸàÿ¨ŸàÿØÿ©ÿå ÿ¨ÿßÿ±Ÿä ÿ•ŸÜÿ¥ÿßÿ¶Ÿáÿß...
‚úÖ ÿ™ŸÖ ÿ•ŸÜÿ¥ÿßÿ° ÿ•ÿ≠ÿµÿßÿ¶Ÿäÿßÿ™ ÿßŸÑÿ¥ÿ®ŸÉÿ©

‚úÖ ÿ¨ŸÖŸäÿπ ÿßŸÑŸÖŸäÿ≤ÿßÿ™ ÿ¨ÿßŸáÿ≤ÿ©! ÿßŸÑÿ£ÿπŸÖÿØÿ© ÿßŸÑÿ≠ÿßŸÑŸäÿ©: ['lat_grid', 'lon_grid', 'date', 'state', 'city_or_county', 'latitude', 'longitude', 'n_killed', 'n_injured', 'total_victims', 'IsHotspot', 'year', 'month', 'day_of_week'

### **Stacking**

#### presteps ‚òëÔ∏è
- Enssure all of Featrues are available
- If not , create the columns to handle the error

In [22]:
# step 1: Feature Selection and Assembly Setup
feature_columns = [
    "month", "day_of_week", "day_of_month",
    "lat_grid", "lon_grid", "incident_count"
]

# Add categorical columns if available
# Just ensure
try:
    if "state_encoded" in df_final.columns:
        feature_columns.append("state_encoded")
    if "city_or_county_encoded" in df_final.columns:
        feature_columns.append("city_or_county_encoded")
except:
    print("‚ö†Ô∏è  Categorical columns not found, proceeding without them")

print(f"Features: {feature_columns}")


#   Ÿàÿ¨ŸàÿØ ÿ¨ŸÖŸäÿπ ÿßŸÑŸÖŸäÿ≤ÿßÿ™
missing_features = [f for f in feature_columns if f not in df_final.columns]
if missing_features:
    print(f"‚ö†Ô∏è ÿ™ÿ≠ÿ∞Ÿäÿ±: ÿßŸÑŸÖŸäÿ≤ÿßÿ™ ÿßŸÑÿ™ÿßŸÑŸäÿ© ŸÖŸÅŸÇŸàÿØÿ©: {missing_features}")
    feature_columns = [f for f in feature_columns if f in df_final.columns]
    print(f"ÿßŸÑŸÖŸäÿ≤ÿßÿ™ ÿßŸÑŸÖÿ™ÿßÿ≠ÿ©: {feature_columns}")

Features: ['month', 'day_of_week', 'day_of_month', 'lat_grid', 'lon_grid', 'incident_count']


### **Train & Test Split**

In [23]:
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")

# Cell 8: Data Sampling and Split
data_count = df_final.count()

if data_count > 100000:
    print("‚ö†Ô∏è  ÿßŸÑÿ®ŸäÿßŸÜÿßÿ™ ŸÉÿ®Ÿäÿ±ÿ©ÿå ÿßÿ≥ÿ™ÿÆÿØÿßŸÖ ÿπŸäŸÜÿ© 50% ŸÑŸÑÿ™ÿØÿ±Ÿäÿ®")
    sample_df = df_final.sample(0.5, seed=42)
else:
    sample_df = df_final

#split data set to train and test
train_data, test_data = sample_df.randomSplit([0.7, 0.3], seed=42)

print(f"ÿ®ŸäÿßŸÜÿßÿ™ ÿßŸÑÿ™ÿØÿ±Ÿäÿ®: {train_data.count()}")
print(f"ÿ®ŸäÿßŸÜÿßÿ™ ÿßŸÑÿßÿÆÿ™ÿ®ÿßÿ±: {test_data.count()}")

# ÿßŸÑÿ™ÿ≠ŸÇŸÇ ŸÖŸÜ ÿßŸÑÿ™Ÿàÿßÿ≤ŸÜ ŸÅŸä ŸÖÿ¨ŸÖŸàÿπÿßÿ™ ÿßŸÑÿ™ÿØÿ±Ÿäÿ® ŸàÿßŸÑÿßÿÆÿ™ÿ®ÿßÿ±
print(" ÿ™Ÿàÿ≤Ÿäÿπ ÿ®ŸäÿßŸÜÿßÿ™ ÿßŸÑÿ™ÿØÿ±Ÿäÿ®:")
train_data.groupBy("IsHotspot").count().show()
print(" ÿ™Ÿàÿ≤Ÿäÿπ ÿ®ŸäÿßŸÜÿßÿ™ ÿßŸÑÿßÿÆÿ™ÿ®ÿßÿ±:")
test_data.groupBy("IsHotspot").count().show()

‚ö†Ô∏è  ÿßŸÑÿ®ŸäÿßŸÜÿßÿ™ ŸÉÿ®Ÿäÿ±ÿ©ÿå ÿßÿ≥ÿ™ÿÆÿØÿßŸÖ ÿπŸäŸÜÿ© 50% ŸÑŸÑÿ™ÿØÿ±Ÿäÿ®
ÿ®ŸäÿßŸÜÿßÿ™ ÿßŸÑÿ™ÿØÿ±Ÿäÿ®: 105722
ÿ®ŸäÿßŸÜÿßÿ™ ÿßŸÑÿßÿÆÿ™ÿ®ÿßÿ±: 45028
 ÿ™Ÿàÿ≤Ÿäÿπ ÿ®ŸäÿßŸÜÿßÿ™ ÿßŸÑÿ™ÿØÿ±Ÿäÿ®:
+---------+-----+
|IsHotspot|count|
+---------+-----+
|        0|79017|
|        1|26705|
+---------+-----+

 ÿ™Ÿàÿ≤Ÿäÿπ ÿ®ŸäÿßŸÜÿßÿ™ ÿßŸÑÿßÿÆÿ™ÿ®ÿßÿ±:
+---------+-----+
|IsHotspot|count|
+---------+-----+
|        0|33860|
|        1|11168|
+---------+-----+



### **Feature Assembly**

- Feature assembly is the process of **combining multiple feature columns into a single feature vector column**
- that machine learning algorithms can use for training and prediction. 
- This step is critical because **Spark ML models** require input features to be in **vector form**.


In [24]:
# Cell 9: Apply Feature Assembly
train_data_assembled = assembler.transform(train_data).cache()
test_data_assembled = assembler.transform(test_data).cache()

print("Ÿäÿßÿ®ÿ±ŸÜÿ≥ ü¶¶")
print("‚úÖ ÿ™ŸÖ ÿ™ÿ¨ŸÖŸäÿπ ÿßŸÑŸÖŸäÿ≤ÿßÿ™")
#=================================================================================

# Cell 10: Define Base Models with Stronger Regularization
print("\nüîπ(Base Models)...")

# Model 1: Logistic Regression
lr = LogisticRegression(
    featuresCol="features",
    labelCol="IsHotspot",
    maxIter=50,
    regParam=0.1,
    elasticNetParam=0.5,
    probabilityCol="lr_probability",
    predictionCol="lr_prediction",
    rawPredictionCol="lr_rawPrediction"
)

# Model 2: Decision Tree
dt = DecisionTreeClassifier(
    featuresCol="features",
    labelCol="IsHotspot",
    maxDepth=5,
    minInstancesPerNode=10, # ÿπŸäÿßŸÑŸä ŸàÿßÿÆŸàÿßÿ™ ÿπŸäÿßŸÑŸä
    probabilityCol="dt_probability",
    predictionCol="dt_prediction",
    rawPredictionCol="dt_rawPrediction"
)

# Model 3: Gradient Boosted Trees
gbt = GBTClassifier(
    featuresCol="features",
    labelCol="IsHotspot",
    maxIter=20,
    maxDepth=3,
    stepSize=0.1,#learning rate
    predictionCol="gbt_prediction"
)

# Cell 11: Train Base Models
print("train, Logistic Regression...")
model_lr = lr.fit(train_data_assembled)
print("‚úÖ complete")

print("train, Decision Tree...")
model_dt = dt.fit(train_data_assembled)
print("‚úÖ complete")

print("train, Gradient Boosted Trees...")
model_gbt = gbt.fit(train_data_assembled)
print("‚úÖ complete")

#================================================================== meta model

# Cell 12: Create Level 2 Meta Features
print("\nüîπ Create Level 2 Meta Features")
train_with_id = train_data_assembled.withColumn("row_id", monotonically_increasing_id())#id_name of each record

print("add predictions LR...")
lr_preds = model_lr.transform(train_with_id).select(
    "row_id",
    vector_to_array(col("lr_probability"))[1].alias("lr_prob") #[1] = probability ÿ•ŸÜ ÿßŸÑŸÖŸÜÿ∑ŸÇÿ© Hotspot.
)

print("add predictions DT...")
dt_preds = model_dt.transform(train_with_id).select(
    "row_id",
    vector_to_array(col("dt_probability"))[1].alias("dt_prob")
)

print("add predictions GBT...")
gbt_preds = model_gbt.transform(train_with_id).select(
    "row_id",
    col("gbt_prediction").alias("gbt_pred")
)

# Cell 13: Merge Predictions
print("ÿØŸÖÿ¨ ÿßŸÑpredictions...")
train_meta_features = train_with_id.select("row_id", "IsHotspot") \
    .join(lr_preds, "row_id") \
    .join(dt_preds, "row_id") \
    .join(gbt_preds, "row_id") \
    .drop("row_id")

print(f"‚úÖ ÿ™ŸÖ ÿ•ŸÜÿ¥ÿßÿ° ŸÖŸäÿ≤ÿßÿ™ Meta")

meta_feature_cols = ["lr_prob", "dt_prob", "gbt_pred"]
meta_assembler = VectorAssembler(inputCols=meta_feature_cols, outputCol="meta_features")

train_meta_data = meta_assembler.transform(train_meta_features).cache()


# Cell 14: Train Meta Model
print("train, Meta Model (final model)...")
lr_meta = LogisticRegression(
    featuresCol="meta_features",
    labelCol="IsHotspot",
    maxIter=100,
    regParam=0.1,
    elasticNetParam=0.5
)

model_meta = lr_meta.fit(train_meta_data)
print("‚úÖ complete train, Stacking Ensemble")



Ÿäÿßÿ®ÿ±ŸÜÿ≥ ü¶¶
‚úÖ ÿ™ŸÖ ÿ™ÿ¨ŸÖŸäÿπ ÿßŸÑŸÖŸäÿ≤ÿßÿ™

üîπ(Base Models)...
train, Logistic Regression...
‚úÖ complete
train, Decision Tree...
‚úÖ complete
train, Gradient Boosted Trees...
‚úÖ complete

üîπ Create Level 2 Meta Features
add predictions LR...
add predictions DT...
add predictions GBT...
ÿØŸÖÿ¨ ÿßŸÑpredictions...
‚úÖ ÿ™ŸÖ ÿ•ŸÜÿ¥ÿßÿ° ŸÖŸäÿ≤ÿßÿ™ Meta
train, Meta Model (final model)...
‚úÖ complete train, Stacking Ensemble


In [25]:
# Prepare Test Data
print("\nEvaluating the model...")
test_with_id = test_data_assembled.withColumn("row_id", monotonically_increasing_id())# n

test_lr_preds = model_lr.transform(test_with_id).select(
    "row_id",
    vector_to_array(col("lr_probability"))[1].alias("lr_prob"),
    col("lr_rawPrediction").alias("lr_raw")
)

test_dt_preds = model_dt.transform(test_with_id).select(
    "row_id",
    vector_to_array(col("dt_probability"))[1].alias("dt_prob"),
    col("dt_rawPrediction").alias("dt_raw")
)

test_gbt_preds = model_gbt.transform(test_with_id).select(
    "row_id",
    col("gbt_prediction").alias("gbt_pred")
)

# Make Final Predictions
test_meta_features = test_with_id.select("row_id", "IsHotspot") \
    .join(test_lr_preds, "row_id") \
    .join(test_dt_preds, "row_id") \
    .join(test_gbt_preds, "row_id") \
    .drop("row_id")

test_meta_data = meta_assembler.transform(test_meta_features)
predictions = model_meta.transform(test_meta_data)




Evaluating the model...


In [41]:
print (test_data_assembled.show(10))


+--------+--------+----------+-------+--------------------+--------+---------+--------+---------+-------------+---------+----+-----+-----------+------------+--------------+------------+-------------+-----------------------+--------------------+
|lat_grid|lon_grid|      date|  state|      city_or_county|latitude|longitude|n_killed|n_injured|total_victims|IsHotspot|year|month|day_of_week|day_of_month|incident_count|total_killed|total_injured|avg_killed_per_incident|            features|
+--------+--------+----------+-------+--------------------+--------+---------+--------+---------+-------------+---------+----+-----+-----------+------------+--------------+------------+-------------+-----------------------+--------------------+
|  25.727|  -80.25|2016-01-26|Florida|Miami (Coconut Gr...|  25.727| -80.2496|       0|        2|            2|        0|2016|    1|          3|          26|             1|           0|            2|                    0.0|[1.0,3.0,26.0,25....|
|   25.85| -80.225|2

In [26]:
# Comprehensive Evaluation
evaluator = BinaryClassificationEvaluator(
    labelCol="IsHotspot",
    rawPredictionCol="rawPrediction",
    metricName="areaUnderROC"
)

auc = evaluator.evaluate(predictions)
print(f"\nüìä Model AUC-ROC: {auc:.4f}")

accuracy_evaluator = MulticlassClassificationEvaluator(
    labelCol="IsHotspot",
    predictionCol="prediction",
    metricName="accuracy"
)
accuracy = accuracy_evaluator.evaluate(predictions)
print(f"üéØ Classification Accuracy: {accuracy:.4f}")

f1_evaluator = MulticlassClassificationEvaluator(
    labelCol="IsHotspot",
    predictionCol="prediction",
    metricName="f1"
)
f1_score = f1_evaluator.evaluate(predictions)
print(f"üéØ F1-Score: {f1_score:.4f}")

precision_evaluator = MulticlassClassificationEvaluator(
    labelCol="IsHotspot",
    predictionCol="prediction",
    metricName="weightedPrecision"
)
precision = precision_evaluator.evaluate(predictions)
print(f"üéØ Precision: {precision:.4f}")

recall_evaluator = MulticlassClassificationEvaluator(
    labelCol="IsHotspot",
    predictionCol="prediction",
    metricName="weightedRecall"
)
recall = recall_evaluator.evaluate(predictions)
print(f"üéØ Recall: {recall:.4f}")

print("\nSample predictions:")
predictions.select("IsHotspot", "prediction", "probability").show(10)

# Cell 18: Confusion Matrix
print("\nüìä Confusion Matrix:")
print("="*50)

confusion_matrix = predictions.groupBy("IsHotspot", "prediction").count()
confusion_matrix.orderBy("IsHotspot", "prediction").show()

for class_label in [0, 1]:
    class_predictions = predictions.filter(col("IsHotspot") == class_label)
    total = class_predictions.count()
    correct = class_predictions.filter(col("prediction") == class_label).count()
    class_accuracy = correct / total if total > 0 else 0
    
    print(f"\nClass {class_label} Performance:")
    print(f"  Total samples: {total}")
    print(f"  Correct predictions: {correct}")
    print(f"  Class Accuracy: {class_accuracy:.4f}")




üìä Model AUC-ROC: 0.9794
üéØ Classification Accuracy: 0.9623
üéØ F1-Score: 0.9629
üéØ Precision: 0.9647
üéØ Recall: 0.9623

Sample predictions:
+---------+----------+--------------------+
|IsHotspot|prediction|         probability|
+---------+----------+--------------------+
|        0|       1.0|[0.24602423996822...|
|        0|       0.0|[0.93534119884147...|
|        0|       0.0|[0.93534119884147...|
|        1|       1.0|[0.24602423996822...|
|        0|       0.0|[0.93534119884147...|
|        1|       1.0|[0.24602423996822...|
|        0|       1.0|[0.24602423996822...|
|        0|       0.0|[0.93534119884147...|
|        1|       1.0|[0.24602423996822...|
|        1|       1.0|[0.24602423996822...|
+---------+----------+--------------------+
only showing top 10 rows

üìä Confusion Matrix:
+---------+----------+-----+
|IsHotspot|prediction|count|
+---------+----------+-----+
|        0|       0.0|32491|
|        0|       1.0| 1369|
|        1|       0.0|  328|
|        1

In [57]:
print(predictions.show(10))

+---------+-------------------+--------------------+--------------------+----------------+--------+--------------------+--------------------+--------------------+----------+
|IsHotspot|            lr_prob|              lr_raw|             dt_prob|          dt_raw|gbt_pred|       meta_features|       rawPrediction|         probability|prediction|
+---------+-------------------+--------------------+--------------------+----------------+--------+--------------------+--------------------+--------------------+----------+
|        0| 0.2863982059115956|[0.91294192460826...|  0.9095530368641036|[2402.0,24155.0]|     1.0|[0.28639820591159...|[-1.1199301514580...|[0.24602423996822...|       1.0|
|        0|0.22862181130659587|[1.21610961106219...|0.002718351591594857| [73374.0,200.0]|     0.0|[0.22862181130659...|[2.67178714984083...|[0.93534119884147...|       0.0|
|        0| 0.2322056661255283|[1.19589843196526...|0.002718351591594857| [73374.0,200.0]|     0.0|[0.23220566612552...|[2.6717871

In [74]:
from pyspark.sql.functions import col
final_pr=predictions.select(col("probability"), col("prediction"))
print(final_pr.show(5))


+--------------------+----------+
|         probability|prediction|
+--------------------+----------+
|[0.24602423996822...|       1.0|
|[0.93534119884147...|       0.0|
|[0.93534119884147...|       0.0|
|[0.24602423996822...|       1.0|
|[0.93534119884147...|       0.0|
+--------------------+----------+
only showing top 5 rows
None


In [2]:
'''# Alternative approach: use selectExpr to extract probability value from vector and compare
from pyspark.sql.functions import udf
from pyspark.sql.types import DoubleType

# Assume 'probability' is a vector, so extract probability of class 1 (e.g. probability[1])
extract_prob = udf(lambda v: float(v[1]), DoubleType())
final_pr_with_prob = final_pr.withColumn("prob_1", extract_prob(col("probability")))

high_risk = final_pr_with_prob.filter(
    (col("prediction") == 1) & (col("prob_1") > 0.7)
)'''

'# Alternative approach: use selectExpr to extract probability value from vector and compare\nfrom pyspark.sql.functions import udf\nfrom pyspark.sql.types import DoubleType\n\n# Assume \'probability\' is a vector, so extract probability of class 1 (e.g. probability[1])\nextract_prob = udf(lambda v: float(v[1]), DoubleType())\nfinal_pr_with_prob = final_pr.withColumn("prob_1", extract_prob(col("probability")))\n\nhigh_risk = final_pr_with_prob.filter(\n    (col("prediction") == 1) & (col("prob_1") > 0.7)\n)'

In [27]:
# Cell 19: Compare Models
print("\nüìà Model Comparison:")
print("="*50)

test_predictions_lr = test_with_id.join(test_lr_preds, "row_id").withColumnRenamed("lr_raw", "rawPrediction")
test_predictions_dt = test_with_id.join(test_dt_preds, "row_id").withColumnRenamed("dt_raw", "rawPrediction")

auc_lr = evaluator.evaluate(test_predictions_lr)
acc_lr = MulticlassClassificationEvaluator(labelCol="IsHotspot", predictionCol="lr_prediction").evaluate(
    model_lr.transform(test_with_id)
)
f1_lr = MulticlassClassificationEvaluator(labelCol="IsHotspot", predictionCol="lr_prediction", metricName="f1").evaluate(
    model_lr.transform(test_with_id)
)
print(f"Logistic Regression - AUC: {auc_lr:.4f}, Accuracy: {acc_lr:.4f}, F1: {f1_lr:.4f}")

auc_dt = evaluator.evaluate(test_predictions_dt)
acc_dt = MulticlassClassificationEvaluator(labelCol="IsHotspot", predictionCol="dt_prediction").evaluate(
    model_dt.transform(test_with_id)
)
f1_dt = MulticlassClassificationEvaluator(labelCol="IsHotspot", predictionCol="dt_prediction", metricName="f1").evaluate(
    model_dt.transform(test_with_id)
)
print(f"Decision Tree - AUC: {auc_dt:.4f}, Accuracy: {acc_dt:.4f}, F1: {f1_dt:.4f}")

test_gbt_full = model_gbt.transform(test_with_id)
auc_gbt = evaluator.evaluate(test_gbt_full)
acc_gbt = MulticlassClassificationEvaluator(labelCol="IsHotspot", predictionCol="gbt_prediction").evaluate(test_gbt_full)
f1_gbt = MulticlassClassificationEvaluator(labelCol="IsHotspot", predictionCol="gbt_prediction", metricName="f1").evaluate(test_gbt_full)
print(f"Gradient Boosted Trees - AUC: {auc_gbt:.4f}, Accuracy: {acc_gbt:.4f}, F1: {f1_gbt:.4f}")

print(f"\nüèÜ Stacking Ensemble - AUC: {auc:.4f}, Accuracy: {accuracy:.4f}, F1: {f1_score:.4f}")
print("="*50)

# Cell 20: Memory Cleanup
train_data_assembled.unpersist()
test_data_assembled.unpersist()
train_meta_data.unpersist()

print("\n‚úÖ Complete model evaluation")


üìà Model Comparison:
Logistic Regression - AUC: 0.9557, Accuracy: 0.6424, F1: 0.6424
Decision Tree - AUC: 0.9590, Accuracy: 0.9637, F1: 0.9637
Gradient Boosted Trees - AUC: 0.9876, Accuracy: 0.9629, F1: 0.9629

üèÜ Stacking Ensemble - AUC: 0.9794, Accuracy: 0.9623, F1: 0.9629

‚úÖ Complete model evaluation


In [33]:
df_final.toPandas().to_csv("df_final_export.csv", index=False)


In [None]:
file_path = "df_final_export.csv"  

df = spark.read.csv(file_path, header=True, inferSchema=True)

print(f"‚úÖ ÿ™ŸÖ ÿ™ÿ≠ŸÖŸäŸÑ ÿßŸÑÿ®ŸäÿßŸÜÿßÿ™: {df.count()} ÿµŸÅ")
df.printSchema()
df.show(5)

‚úÖ ÿ™ŸÖ ÿ™ÿ≠ŸÖŸäŸÑ ÿßŸÑÿ®ŸäÿßŸÜÿßÿ™: 301677 ÿµŸÅ
root
 |-- lat_grid: double (nullable = true)
 |-- lon_grid: double (nullable = true)
 |-- date: date (nullable = true)
 |-- state: string (nullable = true)
 |-- city_or_county: string (nullable = true)
 |-- latitude: double (nullable = true)
 |-- longitude: double (nullable = true)
 |-- n_killed: integer (nullable = true)
 |-- n_injured: integer (nullable = true)
 |-- total_victims: integer (nullable = true)
 |-- IsHotspot: integer (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day_of_week: integer (nullable = true)
 |-- day_of_month: integer (nullable = true)
 |-- incident_count: integer (nullable = true)
 |-- total_killed: integer (nullable = true)
 |-- total_injured: integer (nullable = true)
 |-- avg_killed_per_incident: double (nullable = true)

+--------+--------+----------+-------+--------------+--------+---------+--------+---------+-------------+---------+----+-----+-----------+