In [0]:
##################################################################################
# Batch Inference Notebook
#
# This notebook is an example of applying a model for batch inference against an input delta table,
# It is configured and can be executed as the batch_inference_job in the batch_inference_job workflow defined under
# ``mlops_dbx/resources/batch-inference-workflow-resource.yml``
#
# Parameters:
#
#  * env (optional)  - String name of the current environment (dev, staging, or prod). Defaults to "dev"
#  * input_table_name (required)  - Delta table name containing your input data.
#  * output_table_name (required) - Delta table name where the predictions will be written to.
#                                   Note that this will create a new version of the Delta table if
#                                   the table already exists
#  * model_name (required) - The name of the model to be used in batch inference.
##################################################################################


# List of input args needed to run the notebook as a job.
# Provide them via DB widgets or notebook arguments.
#
# Name of the current environment
dbutils.widgets.dropdown("env", "dev", ["dev", "staging", "prod"], "Environment Name")
# A Hive-registered Delta table containing the input features.
dbutils.widgets.text("input_table_name", "", label="Input Table Name")
# Delta table to store the output predictions.
dbutils.widgets.text("output_table_name", "", label="Output Table Name")
dbutils.widgets.text("label_table_name", "", label="Label Table Name")
dbutils.widgets.text("inference_table_name", "", label="Inference Table Name")
# Unity Catalog registered model name to use for the trained mode.
dbutils.widgets.text(
    "model_name", "dev.mlops_dbx.mlops_dbx-model", label="Full (Three-Level) Model Name"
)

In [0]:

env = dbutils.widgets.get("env")
input_table_name = dbutils.widgets.get("input_table_name")
output_table_name = dbutils.widgets.get("output_table_name")
label_table_name = dbutils.widgets.get("label_table_name")
inference_table_name = dbutils.widgets.get("inference_table_name")
model_name = dbutils.widgets.get("model_name")
assert input_table_name != "", "input_table_name notebook parameter must be specified"
assert output_table_name != "", "output_table_name notebook parameter must be specified"
assert model_name != "", "model_name notebook parameter must be specified"
alias = "champion"
model_uri = f"models:/{model_name}@{alias}"

In [0]:
from mlflow import MlflowClient

# Get model version from alias
client = MlflowClient(registry_uri="databricks-uc")
model_version = client.get_model_version_by_alias(model_name, alias).version

In [0]:
# Get datetime
from datetime import datetime

ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

In [0]:
import sys
sys.path.append('../..')
from batch_inference.predict import predict_batch
import pyspark.sql.functions as f

input_table = spark.table(input_table_name).select(
                f.col("customerID").alias("customer_id"))

predict_batch(model_uri, input_table, output_table_name, model_version, ts)

df_curr_preds = spark.table(output_table_name).join(spark.table(label_table_name), on='customer_id', how='inner').select('customer_id','model_id',f.col('prediction').cast('integer'),'churn','timestamp')

df_curr_preds.write.mode("append").saveAsTable(inference_table_name)

spark.sql(f"ALTER TABLE {inference_table_name}
  SET TBLPROPERTIES (delta.enableChangeDataFeed = true)")

dbutils.notebook.exit(output_table_name)