### Import Packages ###

In [0]:
# Comment code below after running the mlflow update
%pip install "mlflow-skinny[databricks]>=2.4.1"
dbutils.library.restartPython()

[43mNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.[0m
[43mNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.[0m


In [0]:
import pandas as pd
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import VectorAssembler
import warnings
warnings.filterwarnings("ignore")
import mlflow
from mlflow.models.signature import infer_signature
from pyspark.ml import Pipeline
from pyspark.sql.functions import col
from mlflow import MlflowClient

### Define Catalog and Schema in Unity Catalog ###

In [0]:
catalog = "data_science"
schema = "models"

### Load Kaggle Data Set ###

https://www.kaggle.com/datasets/fedesoriano/stroke-prediction-dataset

In [0]:
data = pd.read_csv('/Workspace/Users/awnish.choudhary@anthology.ai/healthcare-dataset-stroke-data.csv')

### Clean the Data ###

In [0]:
# Convert string columns to numeric values in data
data['gender'] = data['gender'].map({'Male': 0, 'Female': 1, 'Other': 2})
data['ever_married'] = data['ever_married'].map({'No': 0, 'Yes': 1})
data['work_type'] = data['work_type'].map({'Private': 0, 'Self-employed': 1, 'Govt_job': 2, 'children': 3, 'Never_worked': 4})
data['Residence_type'] = data['Residence_type'].map({'Urban': 0, 'Rural': 1})
data['smoking_status'] = data['smoking_status'].map({'Unknown': 0, 'never smoked': 1, 'formerly smoked': 2, 'smokes': 3})

data

Unnamed: 0,id,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke
0,9046,0,67.0,0,1,1,0,0,228.69,36.6,2,1
1,51676,1,61.0,0,0,1,1,1,202.21,,1,1
2,31112,0,80.0,0,1,1,0,1,105.92,32.5,1,1
3,60182,1,49.0,0,0,1,0,0,171.23,34.4,3,1
4,1665,1,79.0,1,0,1,1,1,174.12,24.0,1,1
...,...,...,...,...,...,...,...,...,...,...,...,...
5105,18234,1,80.0,1,0,1,0,0,83.75,,1,0
5106,44873,1,81.0,0,0,1,1,0,125.20,40.0,1,0
5107,19723,1,35.0,0,0,1,1,1,82.99,30.6,1,0
5108,37544,0,51.0,0,0,1,0,1,166.29,25.6,2,0


In [0]:
# Drop rows with missing values
data.dropna(inplace=True)

In [0]:
data = spark.createDataFrame(data)

In [0]:
# Select features and target
featureCols = [col for col in data.columns if col != "stroke"]
assembler = VectorAssembler(inputCols=featureCols, outputCol="features")
data_prepared = assembler.transform(data).select(col("features"), col("stroke").alias("label"))

# Split the data
(train_data, test_data) = data_prepared.randomSplit([0.7, 0.3])

### Model training ###

In [0]:
# Train a RandomForest model
rf = RandomForestClassifier(labelCol="label", featuresCol="features", numTrees=100)
model = rf.fit(train_data)
model_name = 'random_forest_classifier'
# Make predictions
predictions = model.transform(test_data)
# Evaluate the model
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print(f"Test Accuracy = {accuracy}")


Downloading artifacts:   0%|          | 0/20 [00:00<?, ?it/s]

Uploading artifacts:   0%|          | 0/4 [00:00<?, ?it/s]

Test Accuracy = 0.9583921015514809


### Add model signature which is required to register the model ###

In [0]:
# Convert Spark DataFrame to Pandas DataFrame
train_data_pd = train_data.toPandas()
# Infer model signature using the Pandas DataFrame
signature = infer_signature(train_data_pd, model.transform(train_data).toPandas())

### Log and Register model to MLflow ###

In [0]:
with mlflow.start_run():
    mlflow.spark.log_model(model, model_name, signature=signature)
    uri = mlflow.get_artifact_uri(model_name)
    # Log metrics
    mlflow.log_metric("accuracy_score", accuracy)
    mlflow.set_registry_uri("databricks-uc")
    # Register Model
    mlflow.register_model(
        model_uri=uri,
        name=f"{catalog}.{schema}.{model_name}"
    )

2024/09/12 18:16:00 INFO mlflow.spark: Inferring pip requirements by reloading the logged model from the databricks artifact repository, which can be time-consuming. To speed up, explicitly specify the conda_env or pip_requirements when calling log_model().


Downloading artifacts:   0%|          | 0/20 [00:00<?, ?it/s]

2024/09/12 18:16:03 INFO mlflow.store.artifact.artifact_repo: The progress bar can be disabled by setting the environment variable MLFLOW_ENABLE_ARTIFACTS_PROGRESS_BAR to false


Uploading artifacts:   0%|          | 0/4 [00:00<?, ?it/s]

Successfully registered model 'data_science.models.random_forest_classifier'.


Downloading artifacts:   0%|          | 0/24 [00:00<?, ?it/s]

2024/09/12 18:16:32 INFO mlflow.store.artifact.artifact_repo: The progress bar can be disabled by setting the environment variable MLFLOW_ENABLE_ARTIFACTS_PROGRESS_BAR to false


Uploading artifacts:   0%|          | 0/24 [00:00<?, ?it/s]

2024/09/12 18:16:33 INFO mlflow.store.artifact.cloud_artifact_repo: The progress bar can be disabled by setting the environment variable MLFLOW_ENABLE_ARTIFACTS_PROGRESS_BAR to false
Created version '1' of model 'data_science.models.random_forest_classifier'.


### Create "Champion" alias for latest version of the model currently in production  ###

In [0]:
# Initialize the MLflow client
client = MlflowClient()

# Search for all versions of the model and fetch the latest one
model_version_infos = client.search_model_versions(f"name='{catalog}.{schema}.{model_name}'")
new_model_version = max(model_version_info.version for model_version_info in model_version_infos)

# Set the alias for the latest model version
client.set_registered_model_alias(
    name=f"{catalog}.{schema}.{model_name}",
    alias="Champion",
    version=new_model_version
)

### Load Registered Model for inference ###

In [0]:
model_version_uri = 'models:/'+f"{catalog}.{schema}.{model_name}@Champion"
champion_version = mlflow.spark.load_model(model_version_uri)

2024/09/12 18:16:35 INFO mlflow.spark: 'models:/data_science.models.random_forest_classifier@Champion' resolved as 's3://caden-os-prod-databricks-metastore-storage/7adaa701-ad6e-4968-af13-eddf695e53bc/models/ea2c3592-9769-4b2c-8bd0-240daacff071/versions/fc2ea51a-1b8d-4ded-a58a-cb324b168f57'


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/24 [00:00<?, ?it/s]

2024/09/12 18:16:38 INFO mlflow.store.artifact.artifact_repo: The progress bar can be disabled by setting the environment variable MLFLOW_ENABLE_ARTIFACTS_PROGRESS_BAR to false


In [0]:
display(champion_version.transform(test_data))

features,label,rawPrediction,probability,prediction
"Map(vectorType -> sparse, length -> 11, indices -> List(0, 1, 2, 5, 8, 9), values -> List(7055.0, 1.0, 58.0, 1.0, 80.92, 19.4))",0,"Map(vectorType -> dense, length -> 2, values -> List(97.60272202023313, 2.397277979766864))","Map(vectorType -> dense, length -> 2, values -> List(0.9760272202023313, 0.02397277979766864))",0.0
"Map(vectorType -> sparse, length -> 11, indices -> List(0, 1, 2, 5, 8, 9), values -> List(26134.0, 1.0, 28.0, 1.0, 111.22, 25.5))",0,"Map(vectorType -> dense, length -> 2, values -> List(98.1361388860891, 1.863861113910894))","Map(vectorType -> dense, length -> 2, values -> List(0.981361388860891, 0.01863861113910894))",0.0
"Map(vectorType -> sparse, length -> 11, indices -> List(0, 1, 2, 5, 8, 9), values -> List(32689.0, 1.0, 48.0, 1.0, 84.38, 27.1))",0,"Map(vectorType -> dense, length -> 2, values -> List(98.04020799944801, 1.959792000551984))","Map(vectorType -> dense, length -> 2, values -> List(0.9804020799944801, 0.01959792000551984))",0.0
"Map(vectorType -> sparse, length -> 11, indices -> List(0, 1, 2, 5, 8, 9), values -> List(37413.0, 1.0, 39.0, 1.0, 77.54, 32.7))",0,"Map(vectorType -> dense, length -> 2, values -> List(98.08996268695677, 1.910037313043227))","Map(vectorType -> dense, length -> 2, values -> List(0.9808996268695677, 0.01910037313043227))",0.0
"Map(vectorType -> sparse, length -> 11, indices -> List(0, 1, 2, 5, 8, 9), values -> List(58767.0, 1.0, 37.0, 1.0, 91.45, 25.8))",0,"Map(vectorType -> dense, length -> 2, values -> List(98.1361388860891, 1.863861113910894))","Map(vectorType -> dense, length -> 2, values -> List(0.981361388860891, 0.01863861113910894))",0.0
"Map(vectorType -> sparse, length -> 11, indices -> List(0, 1, 2, 5, 8, 9), values -> List(60491.0, 1.0, 78.0, 1.0, 58.57, 24.2))",1,"Map(vectorType -> dense, length -> 2, values -> List(92.04088658313232, 7.959113416867667))","Map(vectorType -> dense, length -> 2, values -> List(0.9204088658313233, 0.07959113416867668))",0.0
"Map(vectorType -> sparse, length -> 11, indices -> List(0, 1, 2, 5, 8, 9), values -> List(62783.0, 1.0, 76.0, 1.0, 198.02, 38.7))",0,"Map(vectorType -> dense, length -> 2, values -> List(85.21088103059292, 14.789118969407086))","Map(vectorType -> dense, length -> 2, values -> List(0.8521088103059291, 0.14789118969407086))",0.0
"Map(vectorType -> sparse, length -> 11, indices -> List(0, 1, 2, 5, 8, 9), values -> List(66362.0, 1.0, 61.0, 1.0, 129.31, 41.2))",0,"Map(vectorType -> dense, length -> 2, values -> List(97.41252369867509, 2.5874763013249025))","Map(vectorType -> dense, length -> 2, values -> List(0.974125236986751, 0.02587476301324903))",0.0
"Map(vectorType -> sparse, length -> 11, indices -> List(0, 1, 2, 5, 8, 9), values -> List(72836.0, 1.0, 59.0, 1.0, 65.98, 31.1))",0,"Map(vectorType -> dense, length -> 2, values -> List(95.65024901223958, 4.349750987760412))","Map(vectorType -> dense, length -> 2, values -> List(0.9565024901223959, 0.043497509877604125))",0.0
"Map(vectorType -> sparse, length -> 11, indices -> List(0, 1, 2, 6, 8, 9), values -> List(10324.0, 1.0, 5.0, 3.0, 93.88, 14.6))",0,"Map(vectorType -> dense, length -> 2, values -> List(98.5731799846236, 1.4268200153764163))","Map(vectorType -> dense, length -> 2, values -> List(0.9857317998462358, 0.014268200153764161))",0.0
