
## Overview

This notebook will show you how to create and query a table or DataFrame that you uploaded to DBFS. [DBFS](https://docs.databricks.com/user-guide/dbfs-databricks-file-system.html) is a Databricks File System that allows you to store data for querying inside of Databricks. This notebook assumes that you have a file already inside of DBFS that you would like to read from.

This notebook is written in **Python** so the default cell type is Python. However, you can use different languages by using the `%LANGUAGE` syntax. Python, Scala, SQL, and R are all supported.

In [None]:
access_key  = "--"
secret_key  = "--"
bucket_name = "--"

spark.conf.set("fs.s3a.access.key", access_key)
spark.conf.set("fs.s3a.secret.key", secret_key)

In [None]:
file_location = "s3://--/"
schema_location = "s3://--/schema/Crimes___2002_20240724_schema"

infer_schema = "true"
first_row_is_header = "true"
delimiter = ","


df = (spark.readStream
      .format("cloudFiles")
      .option("cloudFiles.format", "csv")
      .option("cloudFiles.schemaLocation", schema_location)
      .option("inferSchema", infer_schema)
      .option("header", first_row_is_header)
      .option("sep", delimiter)
      .load(file_location))

display(df)

Case Number,Date,Block,IUCR,Primary Type,Description,Location Description,Arrest,Domestic,Beat,Ward,FBI Code,X Coordinate,Y Coordinate,Year,Latitude,Longitude,Location,_rescued_data
HH109118,01/05/2002 09:24:00 PM,007XX E 103 ST,0820,THEFT,$500 AND UNDER,GAS STATION,True,False,512,,06,,,2002,,,,
HH141345,01/21/2002 12:00:00 PM,013XX W 18 ST,0820,THEFT,$500 AND UNDER,RESTAURANT,False,False,1222,,06,,,2002,,,,
HH160210,01/31/2002 08:00:29 PM,022XX E 80 ST,2022,NARCOTICS,POSS: COCAINE,STREET,True,False,414,,18,,,2002,,,,
HH494434,07/08/2002 12:50:00 AM,035XX N CLARK ST,0560,ASSAULT,SIMPLE,SIDEWALK,False,False,1923,44.0,08A,,,2002,,,,
JG503809,01/14/2002 09:00:00 AM,076XX S CALUMET AVE,1153,DECEPTIVE PRACTICE,FINANCIAL IDENTITY THEFT OVER $ 300,RESIDENCE,False,False,623,6.0,11,,,2002,,,,
HH537836,07/26/2002 06:35:00 PM,004XX W JACKSON BLVD,0890,THEFT,FROM BUILDING,ATHLETIC CLUB,False,False,111,2.0,06,,,2002,,,,
HH703546,10/09/2002 07:05:00 PM,002XX N PINE AVE,2024,NARCOTICS,POSS: HEROIN(WHITE),ALLEY,True,False,1523,28.0,18,,,2002,,,,
HH701602,10/08/2002 08:00:00 PM,056XX W MADISON ST,2027,NARCOTICS,POSS: CRACK,SIDEWALK,True,False,1513,29.0,18,,,2002,,,,
JG439687,09/23/2002 12:00:00 AM,009XX W 36TH ST,0820,THEFT,$500 AND UNDER,STREET,False,False,915,11.0,06,,,2002,,,,
JG238647,01/01/2002 12:00:00 AM,029XX N MELVINA AVE,1754,OFFENSE INVOLVING CHILDREN,AGGRAVATED SEXUAL ASSAULT OF CHILD BY FAMILY MEMBER,RESIDENCE,True,True,2511,30.0,02,,,2002,,,,


In [None]:
%sql
CREATE DATABASE IF NOT EXISTS bronzedb;
CREATE DATABASE IF NOT EXISTS silverdb;
CREATE DATABASE IF NOT EXISTS goldDb;

Changing the INvalid Characters

In [None]:
from pyspark.sql.functions import col

def change_name(df):
    return (df
            .withColumnRenamed("Primary Type", "PrimaryType")
            .withColumnRenamed("Case Number", "CaseNumber")
            .withColumnRenamed("Location Description", "LocationDescription")
            .withColumnRenamed("FBI Code", "FBICode")
            .withColumnRenamed("X Coordinate", "X_Coorddinate")
            .withColumnRenamed("Y Coordinate", "Y_Coorddinate"))
    

df_renamed = change_name(df)

In [None]:
df_renamed.printSchema()

root
 |-- CaseNumber: string (nullable = true)
 |-- Date: string (nullable = true)
 |-- Block: string (nullable = true)
 |-- IUCR: string (nullable = true)
 |-- PrimaryType: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- LocationDescription: string (nullable = true)
 |-- Arrest: string (nullable = true)
 |-- Domestic: string (nullable = true)
 |-- Beat: string (nullable = true)
 |-- Ward: string (nullable = true)
 |-- FBICode: string (nullable = true)
 |-- X_Coorddinate: string (nullable = true)
 |-- Y_Coorddinate: string (nullable = true)
 |-- Year: string (nullable = true)
 |-- Latitude: string (nullable = true)
 |-- Longitude: string (nullable = true)
 |-- Location: string (nullable = true)
 |-- _rescued_data: string (nullable = true)



In [None]:
checkpoint_location = "dbfs:/mnt/delta/crimestats/checkpoint"
query = (df_renamed.writeStream
         .format("delta")
         .option("checkpointLocation", checkpoint_location)
         .table("bronzedb.bronze_table"))

In [None]:
%sql
SELECT count(*) FROM bronzedb.bronze_table

count(1)
486823


Conversion to Silver Layer

Counting the Nulls

In [None]:
def removing_nulls(df1):
    return (df1
             .filter(col("PrimaryType").isNotNull())
             .filter(col("Date").isNotNull())
             .filter(col("LocationDescription").isNotNull())
             .filter(col("Latitude").isNotNull())
             .withColumn("Latitude", col("Latitude").cast("double"))
             .withColumn("Longitude", col("Longitude").cast("double"))
             .withColumn("X_Coorddinate", col("X_Coorddinate").cast("double"))
             .withColumn("Y_Coorddinate", col("Y_Coorddinate").cast("double"))
)
    
silver_df_t = removing_nulls(df_renamed)

In [None]:
silver_df_t.printSchema()

root
 |-- CaseNumber: string (nullable = true)
 |-- Date: string (nullable = true)
 |-- Block: string (nullable = true)
 |-- IUCR: string (nullable = true)
 |-- PrimaryType: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- LocationDescription: string (nullable = true)
 |-- Arrest: string (nullable = true)
 |-- Domestic: string (nullable = true)
 |-- Beat: string (nullable = true)
 |-- Ward: string (nullable = true)
 |-- FBICode: string (nullable = true)
 |-- X_Coorddinate: double (nullable = true)
 |-- Y_Coorddinate: double (nullable = true)
 |-- Year: string (nullable = true)
 |-- Latitude: double (nullable = true)
 |-- Longitude: double (nullable = true)
 |-- Location: string (nullable = true)
 |-- _rescued_data: string (nullable = true)



In [None]:
display(silver_df_t)

CaseNumber,Date,Block,IUCR,PrimaryType,Description,LocationDescription,Arrest,Domestic,Beat,Ward,FBICode,X_Coorddinate,Y_Coorddinate,Year,Latitude,Longitude,Location,_rescued_data
HH588032,08/18/2002 01:10:00 AM,084XX S STATE ST,0281,CRIMINAL SEXUAL ASSAULT,NON-AGGRAVATED,RESIDENCE,False,False,632,6.0,02,1177763.0,1849064.0,2002,41.741151643,-87.624266823,"(41.741151643, -87.624266823)",
HH826321,12/08/2002 10:50:00 PM,008XX N SACRAMENTO AVE,0110,HOMICIDE,FIRST DEGREE MURDER,STREET,True,False,1211,,01A,0.0,0.0,2002,36.619446395,-91.686565684,"(36.619446395, -91.686565684)",
HH215105,02/28/2002 09:59:00 PM,003XX S CENTRAL PARK AV,0110,HOMICIDE,FIRST DEGREE MURDER,AUTO,True,False,1133,,01A,1152444.0,1898129.0,2002,41.87632793,-87.715742766,"(41.87632793, -87.715742766)",
HH367441,05/13/2002 05:00:00 AM,061XX S ARTESIAN ST,0110,HOMICIDE,FIRST DEGREE MURDER,HOUSE,True,False,825,,01A,0.0,0.0,2002,36.619446395,-91.686565684,"(36.619446395, -91.686565684)",
HH460597,06/29/2002 04:36:00 AM,130XX S BUFFALO STREET,0110,HOMICIDE,FIRST DEGREE MURDER,STREET,False,False,433,,01A,0.0,0.0,2002,36.619446395,-91.686565684,"(36.619446395, -91.686565684)",
HH358970,05/09/2002 02:40:00 AM,091XX S RACINE STREET,0110,HOMICIDE,FIRST DEGREE MURDER,HOUSE,True,False,2222,,01A,0.0,0.0,2002,36.619446395,-91.686565684,"(36.619446395, -91.686565684)",
HH447396,06/17/2002 08:04:00 AM,005XX E BROWNING ST.,0110,HOMICIDE,FIRST DEGREE MURDER,CHA GROUNDS,True,False,212,,01A,0.0,0.0,2002,36.619446395,-91.686565684,"(36.619446395, -91.686565684)",
HH479856,07/01/2002 12:25:00 PM,011XX S LARAMIE AVE,0110,HOMICIDE,FIRST DEGREE MURDER,STREET,True,False,1522,,01A,1141835.0,1894683.0,2002,41.867074531,-87.75478119,"(41.867074531, -87.75478119)",
HH347787,06/04/2002 01:05:00 AM,056XX S LOOMIS STREET,0110,HOMICIDE,FIRST DEGREE MURDER,APARTMENT,True,False,713,,01A,0.0,0.0,2002,36.619446395,-91.686565684,"(36.619446395, -91.686565684)",
HH476864,06/29/2002 11:36:00 PM,029XX N ALBANY STREET,0110,HOMICIDE,FIRST DEGREE MURDER,STREET,False,False,1411,,01A,0.0,0.0,2002,36.619446395,-91.686565684,"(36.619446395, -91.686565684)",


In [None]:
# from pyspark.sql.functions import to_timestamp, col
# def columntransformation(df1):
#     return (df1
#             # .withColumn("Date_of_Occurance", to_timestamp("Date", "MM/dd/yyyy hh:mm:ss a"))
#             .withColumn("Latitude", col("Latitude").cast("double"))
#             .withColumn("Longitude", col("Longitude").cast("double"))
#             .withColumn("X_Coorddinate", col("X_Coorddinate").cast("double"))
#             .withColumn("Y_Coorddinate", col("Y_Coorddinate").cast("double"))
#             # .filter(col("Date_of_Occurance").isNotNull())
#             )
    
# silver_cleaned = columntransformation(silver_df)

In [None]:
checkpoint_location = "dbfs:/mnt/delta/silver/checkpoint"
query = (silver_df_t.writeStream
         .format("delta")
         .option("checkpointLocation", checkpoint_location)
         .table("silverdb.silver_layer_table"))

In [None]:
%sql
SELECT count(*) FROM silverdb.silver_layer_table

count(1)
471527


Conversion  to Gold Layer

In [None]:
from pyspark.sql.functions import to_timestamp
gold_df = silver_df_t.withColumn("Date", to_timestamp("Date", "MM/dd/yyyy hh:mm:ss a"))
gold_df = gold_df.drop("Ward")

In [None]:
gold_df.printSchema()

root
 |-- CaseNumber: string (nullable = true)
 |-- Date: timestamp (nullable = true)
 |-- Block: string (nullable = true)
 |-- IUCR: string (nullable = true)
 |-- PrimaryType: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- LocationDescription: string (nullable = true)
 |-- Arrest: string (nullable = true)
 |-- Domestic: string (nullable = true)
 |-- Beat: string (nullable = true)
 |-- FBICode: string (nullable = true)
 |-- X_Coorddinate: double (nullable = true)
 |-- Y_Coorddinate: double (nullable = true)
 |-- Year: string (nullable = true)
 |-- Latitude: double (nullable = true)
 |-- Longitude: double (nullable = true)
 |-- Location: string (nullable = true)
 |-- _rescued_data: string (nullable = true)



In [None]:
import pyspark.sql.functions as F

spark.conf.set("spark.sql.streaming.statefulOperator.checkCorrectness.enabled", "false")

ml_test = gold_df
ml_test_ = gold_df.withColumn("Timestamp", F.current_timestamp())

gold_df_with_watermark = ml_test_.withWatermark("Timestamp", "10 minutes")

gold_df_with_watermark.writeStream \
    .format("delta") \
    .outputMode("append") \
    .option("checkpointLocation", "/mnt/gold/checkpoints/gold_layer_") \
    .table("testing2.goldDb.gold_layer_")

<pyspark.sql.streaming.query.StreamingQuery at 0x7f640567bd00>

In [None]:
from pyspark.sql.functions import count, col
crime_type_dist = gold_df.groupBy(col("PrimaryType")).agg(count("*").alias("CrimeCount"))

In [None]:
from pyspark.sql.functions import year, month
gold_date_df = gold_df.withColumn("Year", year("Date")).withColumn("Month", month("Date"))
monthly_crime_count_df = gold_date_df.groupBy(col("Month")).agg(count("*").alias("CrimeCounts"))

In [None]:
from pyspark.sql.functions import when, col, count, sum

gold_df_arrest = gold_df.withColumn(
    "ArrestBinary", 
    when(col("Arrest") == "true", 1).otherwise(0)
)

arrest_rate_df = gold_df_arrest.groupBy(
    col("PrimaryType")
).agg(
    count("ArrestBinary").alias("TotalCrimes"), 
    sum("ArrestBinary").alias("TotalArrests")
)

In [None]:
display(arrest_rate_df)

PrimaryType,TotalCrimes,TotalArrests
OFFENSE INVOLVING CHILDREN,2397,727
CRIMINAL SEXUAL ASSAULT,23,4
STALKING,192,42
PUBLIC PEACE VIOLATION,2357,898
OBSCENITY,25,24
ARSON,978,153
GAMBLING,947,946
CRIMINAL TRESPASS,13570,10398
ASSAULT,30734,7127
LIQUOR LAW VIOLATION,1368,1367


In [None]:
import pyspark.sql.functions as F

crime_type_dist_ = crime_type_dist.withColumn("Timestamp", F.current_timestamp())
monthly_crime_count_df_ = monthly_crime_count_df.withColumn("Timestamp", F.current_timestamp())
arrest_rate_df_ = arrest_rate_df.withColumn("Timestamp", F.current_timestamp())

In [None]:
# Set the configuration to disable the correctness check
spark.conf.set("spark.sql.streaming.statefulOperator.checkCorrectness.enabled", "false")

import pyspark.sql.functions as F
arrest_rate_df_with_watermark = arrest_rate_df_.withWatermark("Timestamp", "10 minutes")
arrest_rate_df_with_watermark.writeStream \
    .format("delta") \
    .outputMode("complete") \
    .option("checkpointLocation", "/mnt/gold/checkpoints/arrest_rates_per_crime") \
    .table("testing2.goldDb.arrest_rates_per_crime")

<pyspark.sql.streaming.query.StreamingQuery at 0x7f16f81589d0>

In [None]:
crime_type_dist_watermark = crime_type_dist_.withWatermark("Timestamp", "10 minutes")
crime_type_dist_watermark.writeStream \
    .format("delta") \
    .outputMode("complete") \
    .option("checkpointLocation", "/mnt/gold/checkpoints/crime_type_distribution") \
    .table("testing2.goldDb.crime_type_distribution")

<pyspark.sql.streaming.query.StreamingQuery at 0x7f16f85b3700>

In [None]:
monthly_crime_count_df_watermark = monthly_crime_count_df_.withWatermark("Timestamp", "10 minutes")
monthly_crime_count_df_watermark.writeStream \
    .format("delta") \
    .outputMode("complete") \
    .option("checkpointLocation", "/mnt/gold/checkpoints/monthly_crime_distribution") \
    .table("testing2.goldDb.monthly_crime_distribution")

<pyspark.sql.streaming.query.StreamingQuery at 0x7f16f8974f10>

##Visualisations

In [None]:
display(crime_type_dist)

PrimaryType,CrimeCount
OFFENSE INVOLVING CHILDREN,2397
CRIMINAL SEXUAL ASSAULT,23
STALKING,192
PUBLIC PEACE VIOLATION,2357
OBSCENITY,25
ARSON,978
GAMBLING,947
CRIMINAL TRESPASS,13570
ASSAULT,30734
LIQUOR LAW VIOLATION,1368


Databricks visualization. Run in Databricks to view.

In [None]:
display(monthly_crime_count_df)

Month,CrimeCounts
12,37036
1,36450
6,41891
3,36639
5,42045
9,41391
4,38283
8,43261
7,44966
10,41793


Databricks visualization. Run in Databricks to view.

In [None]:
display(arrest_rate_df_)

PrimaryType,TotalCrimes,TotalArrests,Timestamp
OFFENSE INVOLVING CHILDREN,2397,727,2024-07-25T14:41:38.277Z
CRIMINAL SEXUAL ASSAULT,23,4,2024-07-25T14:41:38.277Z
STALKING,192,42,2024-07-25T14:41:38.277Z
PUBLIC PEACE VIOLATION,2357,898,2024-07-25T14:41:38.277Z
OBSCENITY,25,24,2024-07-25T14:41:38.277Z
ARSON,978,153,2024-07-25T14:41:38.277Z
GAMBLING,947,946,2024-07-25T14:41:38.277Z
CRIMINAL TRESPASS,13570,10398,2024-07-25T14:41:38.277Z
ASSAULT,30734,7127,2024-07-25T14:41:38.277Z
LIQUOR LAW VIOLATION,1368,1367,2024-07-25T14:41:38.277Z


Databricks visualization. Run in Databricks to view.

Databricks visualization. Run in Databricks to view.

In [None]:
from pyspark.sql import SparkSession

# Get the current Spark session
spark = SparkSession.builder.getOrCreate()

# Stop all running queries
for query in spark.streams.active:
    query.stop()

Development of ML Model

In [None]:
import mlflow
import mlflow.spark
from pyspark.sql.functions import hour, dayofweek, month, year, to_timestamp
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

In [None]:
def feature_selection(df_new):
    # df_new = df_new.withColumn(
    #     "Timestamp", 
    #     to_timestamp("Date", "MM/dd/yyyy HH:mm:ss")
    # )
    # df_new = (df_new
    #           .withColumn("Hour", hour("Date"))
    #           .withColumn("DayOfWeek", dayofweek("Date"))
    #           .withColumn("Month", month("Date"))
    #           .withColumn("Year", year("Date")))
    
    categorical_cols = ["PrimaryType", "LocationDescription", "Arrest", "Domestic", "Beat"]

    stages = []

    for categorical in categorical_cols:
        stringIndexer = StringIndexer(inputCol=categorical, outputCol=categorical + "Index", handleInvalid="keep")
        encoder = OneHotEncoder(inputCols=[stringIndexer.getOutputCol()], outputCols=[categorical + "classVec"])
        stages += [stringIndexer, encoder]

    assemblerInput = [c + "classVec" for c in categorical_cols] + ['Hour', 'DayOfWeek', "Month", "Year", "Latitude", "Longitude"]
    assembler = VectorAssembler(inputCols=assemblerInput, outputCol="features")
    stages += [assembler]

    pipeline = Pipeline(stages=stages)
    return pipeline

In [None]:
from pyspark.sql.functions import *
static_df = None

def process_batch(batch_df, batch_id):
    global static_df
    
    if static_df is None:
        static_df = batch_df
    else:
        static_df = static_df.union(batch_df)
query = (ml_test
         .writeStream
         .format("memory")
         .queryName("temp_table")
         .trigger(once=True)
         .start())

query.awaitTermination()


static_df = spark.table("temp_table")

In [None]:
display(static_df)

CaseNumber,Date,Block,IUCR,PrimaryType,Description,LocationDescription,Arrest,Domestic,Beat,FBICode,X_Coorddinate,Y_Coorddinate,Year,Latitude,Longitude,Location,_rescued_data
HH588032,2002-08-18T01:10:00Z,084XX S STATE ST,0281,CRIMINAL SEXUAL ASSAULT,NON-AGGRAVATED,RESIDENCE,False,False,632,02,1177763.0,1849064.0,2002,41.741151643,-87.624266823,"(41.741151643, -87.624266823)",
HH826321,2002-12-08T22:50:00Z,008XX N SACRAMENTO AVE,0110,HOMICIDE,FIRST DEGREE MURDER,STREET,True,False,1211,01A,0.0,0.0,2002,36.619446395,-91.686565684,"(36.619446395, -91.686565684)",
HH215105,2002-02-28T21:59:00Z,003XX S CENTRAL PARK AV,0110,HOMICIDE,FIRST DEGREE MURDER,AUTO,True,False,1133,01A,1152444.0,1898129.0,2002,41.87632793,-87.715742766,"(41.87632793, -87.715742766)",
HH367441,2002-05-13T05:00:00Z,061XX S ARTESIAN ST,0110,HOMICIDE,FIRST DEGREE MURDER,HOUSE,True,False,825,01A,0.0,0.0,2002,36.619446395,-91.686565684,"(36.619446395, -91.686565684)",
HH460597,2002-06-29T04:36:00Z,130XX S BUFFALO STREET,0110,HOMICIDE,FIRST DEGREE MURDER,STREET,False,False,433,01A,0.0,0.0,2002,36.619446395,-91.686565684,"(36.619446395, -91.686565684)",
HH358970,2002-05-09T02:40:00Z,091XX S RACINE STREET,0110,HOMICIDE,FIRST DEGREE MURDER,HOUSE,True,False,2222,01A,0.0,0.0,2002,36.619446395,-91.686565684,"(36.619446395, -91.686565684)",
HH447396,2002-06-17T08:04:00Z,005XX E BROWNING ST.,0110,HOMICIDE,FIRST DEGREE MURDER,CHA GROUNDS,True,False,212,01A,0.0,0.0,2002,36.619446395,-91.686565684,"(36.619446395, -91.686565684)",
HH479856,2002-07-01T12:25:00Z,011XX S LARAMIE AVE,0110,HOMICIDE,FIRST DEGREE MURDER,STREET,True,False,1522,01A,1141835.0,1894683.0,2002,41.867074531,-87.75478119,"(41.867074531, -87.75478119)",
HH347787,2002-06-04T01:05:00Z,056XX S LOOMIS STREET,0110,HOMICIDE,FIRST DEGREE MURDER,APARTMENT,True,False,713,01A,0.0,0.0,2002,36.619446395,-91.686565684,"(36.619446395, -91.686565684)",
HH476864,2002-06-29T23:36:00Z,029XX N ALBANY STREET,0110,HOMICIDE,FIRST DEGREE MURDER,STREET,False,False,1411,01A,0.0,0.0,2002,36.619446395,-91.686565684,"(36.619446395, -91.686565684)",


In [None]:
static_df = static_df.withColumn(
        "Timestamp", 
        to_timestamp("Date", "MM/dd/yyyy HH:mm:ss")
    )
static_df = (static_df
          .withColumn("Hour", hour("Date"))
          .withColumn("DayOfWeek", dayofweek("Date"))
          .withColumn("Month", month("Date"))
          .withColumn("Year", year("Date")))

In [None]:
display(static_df)

In [None]:
ml_df_test = ml_test

pipeline = feature_selection(static_df)
pipelineModel = pipeline.fit(static_df)
features = pipelineModel.transform(static_df)


In [None]:
(training_data, test_data) = features.randomSplit([0.7, 0.3])
rf = RandomForestClassifier(labelCol="PrimaryTypeIndex", featuresCol="features")

In [None]:
evaluator = MulticlassClassificationEvaluator(labelCol="PrimaryTypeIndex", predictionCol="prediction", metricName="accuracy")

In [None]:
paramGrid = ParamGridBuilder() \
    .addGrid(rf.numTrees, [10, 50, 100]) \
    .addGrid(rf.maxDepth, [5, 10, 15]) \
    .build()

In [None]:
crossval = CrossValidator(estimator=rf,
                          estimatorParamMaps=paramGrid,
                          evaluator=evaluator,
                          numFolds=3)

In [None]:
with mlflow.start_run(run_name="Random Forest Crime Prediction") as run:
    mlflow.log_param("model_type", "RandomForestClassifier")
    mlflow.log_param("data_version", "gold_crime_data")
    cvModel = crossval.fit(training_data)
    predictions = cvModel.transform(test_data)
    accuracy = evaluator.evaluate(predictions)
    mlflow.log_metric("accuracy", accuracy)
    
    best_rf_model = cvModel.bestModel
    mlflow.spark.log_model(best_rf_model, "random_forest_model")

    feature_imp = best_rf_model.featureImportances
    for i, importance in enumerate(feature_imp):
        mlflow.log_metric(f"feature_importance_{i}", float(importance))

    print(f"Test Accuracy: {accuracy}")
    print(f"Best Model Parameters:")
    print(f"  NumTrees: {best_rf_model.getNumTrees}")
    print(f"  MaxDepth: {best_rf_model.getMaxDepth}")

    run_id = run.info.run_id

model_registered = mlflow.register_model("runs:/" + run_id + "/random_forest_model", "RandomForestCrimePredictor")

print(f"Model registered with name: {model_registered.name}")
print(f"Model version: {model_registered.version}")

2024/07/26 03:27:06 INFO mlflow.spark: Inferring pip requirements by reloading the logged model from the databricks artifact repository, which can be time-consuming. To speed up, explicitly specify the conda_env or pip_requirements when calling log_model().


Test Accuracy: 0.9959861981550595
Best Model Parameters:
  NumTrees: 100
  MaxDepth: <bound method _DecisionTreeParams.getMaxDepth of RandomForestClassificationModel: uid=RandomForestClassifier_3bfa5ac34093, numTrees=100, numClasses=31, numFeatures=452>


Successfully registered model 'testing2.default.randomforestcrimepredictor'.


[0;31m---------------------------------------------------------------------------[0m
[0;31mMlflowException[0m                           Traceback (most recent call last)
File [0;32m<command-3017105006983916>, line 23[0m
[1;32m     19[0m     [38;5;28mprint[39m([38;5;124mf[39m[38;5;124m"[39m[38;5;124m  MaxDepth: [39m[38;5;132;01m{[39;00mbest_rf_model[38;5;241m.[39mgetMaxDepth[38;5;132;01m}[39;00m[38;5;124m"[39m)
[1;32m     21[0m     run_id [38;5;241m=[39m run[38;5;241m.[39minfo[38;5;241m.[39mrun_id
[0;32m---> 23[0m model_registered [38;5;241m=[39m mlflow[38;5;241m.[39mregister_model([38;5;124m"[39m[38;5;124mruns:/[39m[38;5;124m"[39m [38;5;241m+[39m run_id [38;5;241m+[39m [38;5;124m"[39m[38;5;124m/random_forest_model[39m[38;5;124m"[39m, [38;5;124m"[39m[38;5;124mRandomForestCrimePredictor[39m[38;5;124m"[39m)
[1;32m     25[0m [38;5;28mprint[39m([38;5;124mf[39m[38;5;124m"[39m[38;5;124mModel registered with name: [39m[38;5