In [1]:
# Install Java
!apt-get install openjdk-11-jdk-headless -qq > /dev/null

# Download Spark 3.3.2 (confirmed working)
!wget https://archive.apache.org/dist/spark/spark-3.3.2/spark-3.3.2-bin-hadoop3.tgz -O spark.tgz

# Extract Spark
!tar -xvzf spark.tgz

# Install findspark
!pip install -q findspark


--2025-04-16 03:11:28--  https://archive.apache.org/dist/spark/spark-3.3.2/spark-3.3.2-bin-hadoop3.tgz
Resolving archive.apache.org (archive.apache.org)... 65.108.204.189, 2a01:4f9:1a:a084::2
Connecting to archive.apache.org (archive.apache.org)|65.108.204.189|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 299360284 (285M) [application/x-gzip]
Saving to: ‘spark.tgz’


2025-04-16 03:12:05 (7.97 MB/s) - ‘spark.tgz’ saved [299360284/299360284]

spark-3.3.2-bin-hadoop3/
spark-3.3.2-bin-hadoop3/LICENSE
spark-3.3.2-bin-hadoop3/NOTICE
spark-3.3.2-bin-hadoop3/R/
spark-3.3.2-bin-hadoop3/R/lib/
spark-3.3.2-bin-hadoop3/R/lib/SparkR/
spark-3.3.2-bin-hadoop3/R/lib/SparkR/DESCRIPTION
spark-3.3.2-bin-hadoop3/R/lib/SparkR/INDEX
spark-3.3.2-bin-hadoop3/R/lib/SparkR/Meta/
spark-3.3.2-bin-hadoop3/R/lib/SparkR/Meta/Rd.rds
spark-3.3.2-bin-hadoop3/R/lib/SparkR/Meta/features.rds
spark-3.3.2-bin-hadoop3/R/lib/SparkR/Meta/hsearch.rds
spark-3.3.2-bin-hadoop3/R/lib/SparkR/Meta/links.rd

In [2]:
import os
import findspark

os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.3.2-bin-hadoop3"

findspark.init()


In [3]:
from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("ChicagoCrimeAnalysis") \
    .getOrCreate()


In [4]:
# Load the dataset
data = spark.read.csv("/content/Crimes_-_2001_to_Present.csv", header=True, inferSchema=True)

# Preview the data
data.show(5)
data.printSchema()


+--------+-----------+--------------------+--------------------+----+------------+--------------------+--------------------+------+--------+----+--------+----+--------------+--------+------------+------------+----+--------------------+------------+-------------+--------------------+
|      ID|Case Number|                Date|               Block|IUCR|Primary Type|         Description|Location Description|Arrest|Domestic|Beat|District|Ward|Community Area|FBI Code|X Coordinate|Y Coordinate|Year|          Updated On|    Latitude|    Longitude|            Location|
+--------+-----------+--------------------+--------------------+----+------------+--------------------+--------------------+------+--------+----+--------+----+--------------+--------+------------+------------+----+--------------------+------------+-------------+--------------------+
|10224738|   HY411648|09/05/2015 01:30:...|     043XX S WOOD ST|0486|     BATTERY|DOMESTIC BATTERY ...|           RESIDENCE| false|    true| 924|   

Data Cleaning


In [5]:
data = data.drop("ID", "Case Number", "IUCR", "Updated On", "X Coordinate", "Y Coordinate", "Location")

In [6]:
from pyspark.sql.functions import col, when, count

null_counts = data.select([
    count(when(col(c).isNull(), c)).alias(c)
    for c in data.columns
])
null_counts.show()


+----+-----+------------+-----------+--------------------+------+--------+----+--------+----+--------------+--------+----+--------+---------+
|Date|Block|Primary Type|Description|Location Description|Arrest|Domestic|Beat|District|Ward|Community Area|FBI Code|Year|Latitude|Longitude|
+----+-----+------------+-----------+--------------------+------+--------+----+--------+----+--------------+--------+----+--------+---------+
|   0|    0|           0|          0|                 747|     0|       0|   0|       0|  11|            12|       0|   0|    3419|     3419|
+----+-----+------------+-----------+--------------------+------+--------+----+--------+----+--------------+--------+----+--------+---------+



In [7]:
data = data.fillna({'Location Description': 'UNKNOWN'})
data = data.na.drop(subset=["District", "Latitude", "Longitude"])
data = data.drop("Community Area", "Ward")

In [8]:
null_counts = data.select([
    count(when(col(c).isNull(), c)).alias(c)
    for c in data.columns
])
null_counts.show()


+----+-----+------------+-----------+--------------------+------+--------+----+--------+--------+----+--------+---------+
|Date|Block|Primary Type|Description|Location Description|Arrest|Domestic|Beat|District|FBI Code|Year|Latitude|Longitude|
+----+-----+------------+-----------+--------------------+------+--------+----+--------+--------+----+--------+---------+
|   0|    0|           0|          0|                   0|     0|       0|   0|       0|       0|   0|       0|        0|
+----+-----+------------+-----------+--------------------+------+--------+----+--------+--------+----+--------+---------+



In [9]:
data.groupBy(data.columns).count().filter("count > 1").show()

+--------------------+--------------------+--------------------+--------------------+--------------------+------+--------+----+--------+--------+----+------------+-------------+-----+
|                Date|               Block|        Primary Type|         Description|Location Description|Arrest|Domestic|Beat|District|FBI Code|Year|    Latitude|    Longitude|count|
+--------------------+--------------------+--------------------+--------------------+--------------------+------+--------+----+--------+--------+----+------------+-------------+-----+
|10/08/2015 11:30:...|047XX W ARTHINGTO...|        PROSTITUTION|SOLICIT ON PUBLIC...|              STREET|  true|   false|1131|      11|      16|2015|41.869559676|-87.743762659|    2|
|11/19/2015 08:20:...|  039XX W GLADYS AVE|           NARCOTICS|ATTEMPT POSSESSIO...|            SIDEWALK|  true|   false|1132|      11|      18|2015|41.876196845|-87.724082848|    3|
|12/18/2015 02:30:...|  023XX W LOGAN BLVD|               THEFT|      $500 AND U

In [10]:
data = data.dropDuplicates()

In [11]:
data.printSchema()

root
 |-- Date: string (nullable = true)
 |-- Block: string (nullable = true)
 |-- Primary Type: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Location Description: string (nullable = false)
 |-- Arrest: boolean (nullable = true)
 |-- Domestic: boolean (nullable = true)
 |-- Beat: integer (nullable = true)
 |-- District: integer (nullable = true)
 |-- FBI Code: string (nullable = true)
 |-- Year: integer (nullable = true)
 |-- Latitude: double (nullable = true)
 |-- Longitude: double (nullable = true)



In [12]:
spark.conf.set("spark.sql.legacy.timeParserPolicy", "LEGACY")

In [13]:
from pyspark.sql.functions import to_timestamp

data = data.withColumn("Date", to_timestamp("Date", "MM/dd/yyyy hh:mm:ss a"))

In [14]:
from pyspark.sql.functions import hour, dayofweek, month

data = data.withColumn("Hour", hour("Date")) \
           .withColumn("Weekday", dayofweek("Date")) \
           .withColumn("Month", month("Date"))


In [15]:
data.groupBy("Hour").count().orderBy("Hour").show()

+----+-----+
|Hour|count|
+----+-----+
|   0|12621|
|   1| 7865|
|   2| 6921|
|   3| 5943|
|   4| 4475|
|   5| 3886|
|   6| 4617|
|   7| 6273|
|   8| 8986|
|   9|11792|
|  10|11420|
|  11|11661|
|  12|14802|
|  13|12530|
|  14|12859|
|  15|13874|
|  16|13790|
|  17|13938|
|  18|15267|
|  19|14844|
+----+-----+
only showing top 20 rows



In [16]:
data.groupBy("Weekday").count().orderBy("Weekday").show()

+-------+-----+
|Weekday|count|
+-------+-----+
|      1|36789|
|      2|37278|
|      3|36820|
|      4|37549|
|      5|37497|
|      6|38193|
|      7|36698|
+-------+-----+



In [17]:
data.groupBy("Month").count().orderBy("Month").show()

+-----+-----+
|Month|count|
+-----+-----+
|    1|20021|
|    2|18136|
|    3|22347|
|    4|20778|
|    5| 4795|
|    6|10141|
|    7|37144|
|    8|42025|
|    9|22143|
|   10|23581|
|   11|19700|
|   12|20013|
+-----+-----+



Univariate Analysis

In [18]:
# Top 10 Most Common Crime Types
data.groupBy("Primary Type").count().orderBy("count", ascending=False).show(10)

+-------------------+-----+
|       Primary Type|count|
+-------------------+-----+
|              THEFT|58469|
|            BATTERY|48955|
|    CRIMINAL DAMAGE|30480|
|            ASSAULT|17374|
|      OTHER OFFENSE|17222|
|          NARCOTICS|17089|
| DECEPTIVE PRACTICE|15969|
|           BURGLARY|13955|
|            ROBBERY|11059|
|MOTOR VEHICLE THEFT|10702|
+-------------------+-----+
only showing top 10 rows



In [19]:
# Top 10 Locations Where Crimes Happen
data.groupBy("Location Description").count().orderBy("count", ascending=False).show(10)

+--------------------+-----+
|Location Description|count|
+--------------------+-----+
|              STREET|58760|
|           RESIDENCE|42288|
|           APARTMENT|34799|
|            SIDEWALK|25827|
|               OTHER|10651|
|PARKING LOT/GARAG...| 7714|
|RESIDENTIAL YARD ...| 5721|
|  SMALL RETAIL STORE| 5704|
|          RESTAURANT| 5338|
|               ALLEY| 5220|
+--------------------+-----+
only showing top 10 rows



In [20]:
# Top 10 Crime Types by Hour (Most Active Time for Each Type)
data.groupBy("Primary Type", "Hour").count().orderBy("count", ascending=False).show(10)

+------------+----+-----+
|Primary Type|Hour|count|
+------------+----+-----+
|       THEFT|  18| 3895|
|       THEFT|  12| 3836|
|       THEFT|  17| 3753|
|       THEFT|  15| 3729|
|       THEFT|  16| 3687|
|       THEFT|  14| 3609|
|       THEFT|  13| 3442|
|       THEFT|  19| 3324|
|       THEFT|  20| 2966|
|     BATTERY|  22| 2805|
+------------+----+-----+
only showing top 10 rows



Bivariate Analysis

In [21]:
# Arrest rates per Primary Type
from pyspark.sql.functions import avg, col

data.groupBy("Primary Type") \
    .agg(avg(col("Arrest").cast("int")).alias("Arrest Rate")) \
    .orderBy("Arrest Rate", ascending=False) \
    .show(10, truncate=False)


+---------------------------------+------------------+
|Primary Type                     |Arrest Rate       |
+---------------------------------+------------------+
|HOMICIDE                         |1.0               |
|PUBLIC INDECENCY                 |1.0               |
|GAMBLING                         |1.0               |
|PROSTITUTION                     |1.0               |
|NARCOTICS                        |0.9995318626016736|
|LIQUOR LAW VIOLATION             |0.9924528301886792|
|CONCEALED CARRY LICENSE VIOLATION|0.9736842105263158|
|INTERFERENCE WITH PUBLIC OFFICER |0.9506057781919851|
|WEAPONS VIOLATION                |0.7556492919554083|
|OTHER NARCOTIC VIOLATION         |0.75              |
+---------------------------------+------------------+
only showing top 10 rows



In [22]:
# Crime trends by Hour + Location Description
data.groupBy("Hour", "Location Description") \
    .count() \
    .orderBy("Hour", "count", ascending=[True, False]) \
    .show(20, truncate=False)


+----+------------------------------+-----+
|Hour|Location Description          |count|
+----+------------------------------+-----+
|0   |STREET                        |3081 |
|0   |RESIDENCE                     |2771 |
|0   |APARTMENT                     |1853 |
|0   |SIDEWALK                      |1027 |
|0   |OTHER                         |469  |
|0   |RESIDENCE-GARAGE              |309  |
|0   |ALLEY                         |266  |
|0   |VEHICLE NON-COMMERCIAL        |262  |
|0   |RESIDENTIAL YARD (FRONT/BACK) |261  |
|0   |PARKING LOT/GARAGE(NON.RESID.)|257  |
|0   |BAR OR TAVERN                 |246  |
|0   |RESIDENCE PORCH/HALLWAY       |209  |
|0   |RESTAURANT                    |175  |
|0   |GAS STATION                   |139  |
|0   |SMALL RETAIL STORE            |85   |
|0   |HOTEL/MOTEL                   |69   |
|0   |CONVENIENCE STORE             |61   |
|0   |BANK                          |60   |
|0   |ATM (AUTOMATIC TELLER MACHINE)|57   |
|0   |PARK PROPERTY             

In [23]:
# Domestic vs. Non-Domestic patterns by Crime Type
data.groupBy("Primary Type", "Domestic") \
    .count() \
    .orderBy("Primary Type", "count", ascending=[True, False]) \
    .show(20, truncate=False)


+---------------------------------+--------+-----+
|Primary Type                     |Domestic|count|
+---------------------------------+--------+-----+
|ARSON                            |false   |472  |
|ARSON                            |true    |20   |
|ASSAULT                          |false   |12895|
|ASSAULT                          |true    |4479 |
|BATTERY                          |false   |24569|
|BATTERY                          |true    |24386|
|BURGLARY                         |false   |13804|
|BURGLARY                         |true    |151  |
|CONCEALED CARRY LICENSE VIOLATION|false   |38   |
|CRIM SEXUAL ASSAULT              |false   |1111 |
|CRIM SEXUAL ASSAULT              |true    |248  |
|CRIMINAL DAMAGE                  |false   |27205|
|CRIMINAL DAMAGE                  |true    |3275 |
|CRIMINAL SEXUAL ASSAULT          |false   |19   |
|CRIMINAL SEXUAL ASSAULT          |true    |4    |
|CRIMINAL TRESPASS                |false   |5853 |
|CRIMINAL TRESPASS             

In [24]:
#  Geographic trends (District vs. Crime Type)
data.groupBy("District", "Primary Type") \
    .count() \
    .orderBy("District", "count", ascending=[True, False]) \
    .show(20, truncate=False)


+--------+--------------------------------+-----+
|District|Primary Type                    |count|
+--------+--------------------------------+-----+
|1       |THEFT                           |6289 |
|1       |DECEPTIVE PRACTICE              |1572 |
|1       |BATTERY                         |1280 |
|1       |CRIMINAL DAMAGE                 |759  |
|1       |ASSAULT                         |580  |
|1       |OTHER OFFENSE                   |556  |
|1       |CRIMINAL TRESPASS               |537  |
|1       |ROBBERY                         |411  |
|1       |MOTOR VEHICLE THEFT             |213  |
|1       |BURGLARY                        |188  |
|1       |NARCOTICS                       |184  |
|1       |PUBLIC PEACE VIOLATION          |104  |
|1       |CRIM SEXUAL ASSAULT             |47   |
|1       |SEX OFFENSE                     |46   |
|1       |OFFENSE INVOLVING CHILDREN      |38   |
|1       |WEAPONS VIOLATION               |27   |
|1       |INTERFERENCE WITH PUBLIC OFFICER|13   |


Time Series Analysis

In [25]:
# Monthly/Yearly crime trends
from pyspark.sql.functions import year, month

data.groupBy(year("Date").alias("Year"), month("Date").alias("Month")) \
    .count() \
    .orderBy("Year", "Month") \
    .show(30)


+----+-----+-----+
|Year|Month|count|
+----+-----+-----+
|2001|    9|    1|
|2007|    6|    1|
|2008|    6|    1|
|2008|    9|    1|
|2009|    9|    1|
|2010|    1|   10|
|2010|    2|    3|
|2010|    3|    3|
|2010|    4|    6|
|2010|    5|    2|
|2010|    6|    4|
|2010|    7|    7|
|2010|    8|    3|
|2010|    9|    4|
|2010|   10|    5|
|2010|   11|    4|
|2010|   12|    6|
|2011|    1|   19|
|2011|    2|    8|
|2011|    3|    7|
|2011|    4|    3|
|2011|    5|    3|
|2011|    6|    5|
|2011|    7|    1|
|2011|    8|    6|
|2011|    9|   11|
|2011|   10|    6|
|2011|   11|    7|
|2011|   12|    2|
|2012|    1|   20|
+----+-----+-----+
only showing top 30 rows



In [26]:
# Seasonal Patterns by Month
data.groupBy(month("Date").alias("Month")) \
    .count() \
    .orderBy("Month") \
    .show()

+-----+-----+
|Month|count|
+-----+-----+
|    1|20021|
|    2|18136|
|    3|22347|
|    4|20778|
|    5| 4795|
|    6|10141|
|    7|37144|
|    8|42025|
|    9|22143|
|   10|23581|
|   11|19700|
|   12|20013|
+-----+-----+



In [27]:
# Yearly Trends for THEFT
data.filter(col("Primary Type") == "THEFT") \
    .groupBy(year("Date").alias("Year")) \
    .count() \
    .orderBy("Year") \
    .show()

+----+-----+
|Year|count|
+----+-----+
|2010|    2|
|2011|    1|
|2012|    8|
|2013|    9|
|2014|   54|
|2015|28481|
|2016|29223|
|2017|    2|
|2018|    3|
|2019|  415|
|2020|  249|
|2021|   17|
|2023|    5|
+----+-----+



Geospatial Analysis

In [28]:
# Crime Hotspots
from pyspark.sql.functions import round

data.withColumn("LatRound", round("Latitude", 2)) \
    .withColumn("LongRound", round("Longitude", 2)) \
    .groupBy("LatRound", "LongRound") \
    .count() \
    .orderBy("count", ascending=False) \
    .show(10, truncate=False)


+--------+---------+-----+
|LatRound|LongRound|count|
+--------+---------+-----+
|41.88   |-87.63   |4381 |
|41.89   |-87.63   |3195 |
|41.88   |-87.75   |2262 |
|41.88   |-87.73   |2040 |
|41.9    |-87.63   |2024 |
|41.87   |-87.72   |1833 |
|41.88   |-87.76   |1752 |
|41.89   |-87.62   |1691 |
|41.88   |-87.74   |1680 |
|41.9    |-87.72   |1577 |
+--------+---------+-----+
only showing top 10 rows



In [29]:
data.show(5)
data.printSchema()

+-------------------+--------------------+------------+--------------------+--------------------+------+--------+----+--------+--------+----+------------+-------------+----+-------+-----+
|               Date|               Block|Primary Type|         Description|Location Description|Arrest|Domestic|Beat|District|FBI Code|Year|    Latitude|    Longitude|Hour|Weekday|Month|
+-------------------+--------------------+------------+--------------------+--------------------+------+--------+----+--------+--------+----+------------+-------------+----+-------+-----+
|2015-09-05 02:00:00| 009XX W BELMONT AVE|       THEFT|           OVER $500|            SIDEWALK| false|   false|1924|      19|      06|2015|41.939918899|-87.653369734|   2|      7|    9|
|2015-09-06 16:39:00|    073XX S STATE ST|       THEFT|       FROM BUILDING|  GROCERY FOOD STORE| false|   false| 323|       3|      06|2015|41.761010056|-87.624814287|  16|      1|    9|
|2015-04-09 01:00:00|     005XX E 43RD ST|     ASSAULT|     

In [30]:
from pyspark.sql.functions import col, when
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler, MinMaxScaler

In [31]:
# Convert boolean to integers
data = data.withColumn("ArrestInt", when(col("Arrest") == True, 1).otherwise(0))
data = data.withColumn("DomesticInt", when(col("Domestic") == True, 1).otherwise(0))

# Ensure Year and Longitude are numeric
data = data.withColumn("Year", col("Year").cast("int"))
data = data.withColumn("Longitude", col("Longitude").cast("double"))

# 3. Encode categorical features
primary_indexer = StringIndexer(
    inputCol="Primary Type",
    outputCol="PrimaryIndex",
    handleInvalid="keep"
)

location_indexer = StringIndexer(
    inputCol="Location Description",
    outputCol="LocationIndex",
    handleInvalid="keep"
)

primary_encoder = OneHotEncoder(inputCol="PrimaryIndex", outputCol="PrimaryVec")
location_encoder = OneHotEncoder(inputCol="LocationIndex", outputCol="LocationVec")


In [32]:
# 4. Assemble and scale numeric features
numeric_assembler = VectorAssembler(
    inputCols=["Hour", "Weekday", "Month", "District", "Beat", "Latitude", "Longitude"],
    outputCol="numeric_features"
)
scaler = MinMaxScaler(inputCol="numeric_features", outputCol="scaled_numeric")

In [33]:
# 5. Final assembler to combine everything
final_assembler = VectorAssembler(
    inputCols=["PrimaryVec", "LocationVec", "scaled_numeric", "DomesticInt"],
    outputCol="features"
)

In [34]:
from pyspark.sql.functions import col, when, lit

# Count total records and Arrest = 1
total = data.count()
pos_count = data.filter(col("ArrestInt") == 1).count()

# Compute balancing ratio
balancing_ratio = total / (2.0 * pos_count)

# Add weight column
data = data.withColumn(
    "classWeightCol",
    when(col("ArrestInt") == 1, lit(balancing_ratio)).otherwise(lit(1.0))
)


In [35]:
from pyspark.ml.classification import LogisticRegression

lr = LogisticRegression(
    featuresCol="features",
    labelCol="ArrestInt",
    weightCol="classWeightCol",
    maxIter=10
)


In [36]:
from pyspark.ml import Pipeline

lr_pipeline = Pipeline(stages=[
    primary_indexer,
    location_indexer,
    primary_encoder,
    location_encoder,
    numeric_assembler,
    scaler,
    final_assembler,
    lr  # Logistic Regression
])

In [37]:
# Split into train and test sets (80% training, 20% testing)
train_data, test_data = data.randomSplit([0.8, 0.2], seed=42)

In [38]:
lr_model = lr_pipeline.fit(train_data)

In [54]:
lr_predictions = lr_model.transform(test_data)
lr_predictions.select("features", "ArrestInt", "prediction", "probability").show(10, truncate=False)


+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------+----------+------------------------------------------+
|features                                                                                                                                                                                  |ArrestInt|prediction|probability                               |
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------+----------+------------------------------------------+
|(165,[6,35,158,160,161,162,163],[1.0,1.0,0.8333333333333333,0.2,0.2528877887788779,0.954702622809411,0.9689706948057357])                                                                 |0        |0.0       |[0.9472556909465072,0.0527443090

In [55]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Accuracy
evaluator = MulticlassClassificationEvaluator(
    labelCol="ArrestInt",
    predictionCol="prediction",
    metricName="accuracy"
)
accuracy = evaluator.evaluate(lr_predictions)
print(f"Accuracy: {accuracy:.4f}")

# Precision
precision_evaluator = MulticlassClassificationEvaluator(
    labelCol="ArrestInt", predictionCol="prediction", metricName="weightedPrecision"
)
precision = precision_evaluator.evaluate(lr_predictions)
print(f"Precision: {precision:.4f}")

# Recall
recall_evaluator = MulticlassClassificationEvaluator(
    labelCol="ArrestInt", predictionCol="prediction", metricName="weightedRecall"
)
recall = recall_evaluator.evaluate(lr_predictions)
print(f"Recall: {recall:.4f}")

# F1 Score
f1_evaluator = MulticlassClassificationEvaluator(
    labelCol="ArrestInt", predictionCol="prediction", metricName="f1"
)
f1 = f1_evaluator.evaluate(lr_predictions)
print(f"F1 Score: {f1:.4f}")


Accuracy: 0.8591
Precision: 0.8516
Recall: 0.8591
F1 Score: 0.8518


In [42]:
data.write.csv("cleaned_chicago_crime.csv", header=True, mode="overwrite")

In [43]:
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml import Pipeline

rf = RandomForestClassifier(
    featuresCol="features",
    labelCol="ArrestInt",
    weightCol="classWeightCol",
    numTrees=100,
    maxDepth=10,
    seed=42
)

In [44]:
rf_pipeline = Pipeline(stages=[
    primary_indexer,
    location_indexer,
    primary_encoder,
    location_encoder,
    numeric_assembler,
    scaler,
    final_assembler,
    rf  # Random Forest
])

In [57]:
rf_model = rf_pipeline.fit(train_data)
rf_predictions = rf_model.transform(test_data)

In [58]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, BinaryClassificationEvaluator

# Accuracy
evaluator = MulticlassClassificationEvaluator(
    labelCol="ArrestInt",
    predictionCol="prediction",
    metricName="accuracy"
)
accuracy = evaluator.evaluate(rf_predictions)

# Precision
precision_evaluator = MulticlassClassificationEvaluator(
    labelCol="ArrestInt",
    predictionCol="prediction",
    metricName="precisionByLabel"
)
precision = precision_evaluator.evaluate(rf_predictions)

# Recall
recall_evaluator = MulticlassClassificationEvaluator(
    labelCol="ArrestInt",
    predictionCol="prediction",
    metricName="recallByLabel"
)
recall = recall_evaluator.evaluate(rf_predictions)

# F1 Score
f1_evaluator = MulticlassClassificationEvaluator(
    labelCol="ArrestInt",
    predictionCol="prediction",
    metricName="f1"
)
f1 = f1_evaluator.evaluate(rf_predictions)

# Output
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")


Accuracy: 0.8735
Precision: 0.8727
Recall: 0.9801
F1 Score: 0.8599


In [47]:
from pyspark.ml.classification import GBTClassifier
# Define the GBTClassifier
gbt = GBTClassifier(
    featuresCol="features",
    labelCol="ArrestInt",
    weightCol="classWeightCol",
    maxIter=20
)

In [48]:
gbt_pipeline = Pipeline(stages=[
    primary_indexer,
    location_indexer,
    primary_encoder,
    location_encoder,
    numeric_assembler,
    scaler,
    final_assembler,
    gbt  # Gradient Boosted Trees
])

In [59]:
gbt_model = gbt_pipeline.fit(train_data)
gbt_predictions = gbt_model.transform(test_data)

In [60]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator(
    labelCol="ArrestInt",
    predictionCol="prediction",
    metricName="accuracy"
)
accuracy = evaluator.evaluate(gbt_predictions)

precision_eval = MulticlassClassificationEvaluator(
    labelCol="ArrestInt", predictionCol="prediction", metricName="precisionByLabel")
recall_eval = MulticlassClassificationEvaluator(
    labelCol="ArrestInt", predictionCol="prediction", metricName="recallByLabel")
f1_eval = MulticlassClassificationEvaluator(
    labelCol="ArrestInt", predictionCol="prediction", metricName="f1")

precision = precision_eval.evaluate(gbt_predictions)
recall = recall_eval.evaluate(gbt_predictions)
f1 = f1_eval.evaluate(gbt_predictions)

print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")


Accuracy: 0.8776
Precision: 0.8859
Recall: 0.9670
F1 Score: 0.8687


In [66]:
from pyspark.ml.functions import vector_to_array
from pyspark.sql.functions import col

# Convert SparseVector to Dense Array
gbt_predictions_dense = gbt_predictions.withColumn("probability_array", vector_to_array("probability"))

# Extract class 1 probability
gbt_predictions_final = gbt_predictions_dense.withColumn("prob_class_1", col("probability_array")[1])


gbt_predictions_final.select("Date", "District", "ArrestInt", "prediction", "prob_class_1") \
    .write.csv("gbt_predictions_output.csv", header=True)


In [71]:
!cat gbt_predictions_output.csv/part* > gbt_predictions_output_combined.csv


In [72]:
from google.colab import files
files.download("gbt_predictions_output_combined.csv")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>