This is an auto-generated notebook to perform batch inference on a Spark DataFrame using a selected model from the model registry. This feature is in preview, and we would greatly appreciate any feedback through this form: https://databricks.sjc1.qualtrics.com/jfe/form/SV_1H6Ovx38zgCKAR0.

## Instructions:
1. Run the notebook against a cluster with Databricks ML Runtime version 13.3.x-cpu, to best re-create the training environment.
2. Add additional data processing on your loaded table to match the model schema if necessary (see the "Define input and output" section below).
3. "Run All" the notebook.
4. Note: If the `%pip` does not work for your model (i.e. it does not have a `requirements.txt` file logged), modify to use `%conda` if possible.

In [0]:
model_name = "WA_Best_Model"

In [0]:
import mlflow
mlflow.set_registry_uri("databricks-uc")
import mlflow;mlflow.set_registry_uri("databricks")

## Environment Recreation
Run the notebook against a cluster with Databricks ML Runtime version 13.3.x-cpu, to best re-create the training environment.. The cell below downloads the model artifacts associated with your model in the remote registry, which include `conda.yaml` and `requirements.txt` files. In this notebook, `pip` is used to reinstall dependencies by default.

### (Optional) Conda Instructions
Models logged with an MLflow client version earlier than 1.18.0 do not have a `requirements.txt` file. If you are using a Databricks ML runtime (versions 7.4-8.x), you can replace the `pip install` command below with the following lines to recreate your environment using `%conda` instead of `%pip`.
```
conda_yml = os.path.join(local_path, "conda.yaml")
%conda env update -f $conda_yml
```

In [0]:
from mlflow.store.artifact.models_artifact_repo import ModelsArtifactRepository
import os

model_uri = f"models:/{model_name}/1"
local_path = ModelsArtifactRepository(model_uri).download_artifacts("") # download model from remote registry

requirements_path = os.path.join(local_path, "requirements.txt")
if not os.path.exists(requirements_path):
  dbutils.fs.put("file:" + requirements_path, "", True)

## Define input and output
The table path assigned to`input_table_name` will be used for batch inference and the predictions will be saved to `output_table_path`. After the table has been loaded, you can perform additional data processing, such as renaming or removing columns, to ensure the model and table schema matches.

In [0]:
# redefining key variables here because %pip and %conda restarts the Python interpreter
model_name = "WA_Best_Model"
input_table_name = "wildfire.prediction_aus_wildfires.wa_prediction_data"
output_table_path = "/FileStore/batch-inference/WA_Best_Model"

In [0]:
# load table as a Spark DataFrame
table = spark.table(input_table_name)

# optionally, perform additional data processing (may be necessary to conform the schema)


## Load model and run inference
**Note**: If the model does not return double values, override `result_type` to the desired type.

In [0]:
import mlflow
from pyspark.sql.functions import struct

model_uri = f"models:/{model_name}/1"

# create spark user-defined function for model prediction
predict = mlflow.pyfunc.spark_udf(spark, model_uri, result_type="double")

[0;31m---------------------------------------------------------------------------[0m
[0;31mNameError[0m                                 Traceback (most recent call last)
File [0;32m<command-179967625422493>, line 4[0m
[1;32m      1[0m [38;5;28;01mimport[39;00m [38;5;21;01mmlflow[39;00m
[1;32m      2[0m [38;5;28;01mfrom[39;00m [38;5;21;01mpyspark[39;00m[38;5;21;01m.[39;00m[38;5;21;01msql[39;00m[38;5;21;01m.[39;00m[38;5;21;01mfunctions[39;00m [38;5;28;01mimport[39;00m struct
[0;32m----> 4[0m model_uri [38;5;241m=[39m [38;5;124mf[39m[38;5;124m"[39m[38;5;124mmodels:/[39m[38;5;132;01m{[39;00mmodel_name[38;5;132;01m}[39;00m[38;5;124m/1[39m[38;5;124m"[39m
[1;32m      6[0m [38;5;66;03m# create spark user-defined function for model prediction[39;00m
[1;32m      7[0m predict [38;5;241m=[39m mlflow[38;5;241m.[39mpyfunc[38;5;241m.[39mspark_udf(spark, model_uri, result_type[38;5;241m=[39m[38;5;124m"[39m[38;5;124mdouble[39m[38;5;124m"

In [0]:
output_df = table.withColumn("prediction", predict(struct(*table.columns)))

## Save predictions
**The default output path on DBFS is accessible to everyone in this Workspace. If you want to limit access to the output you must change the path to a protected location.**
The cell below will save the output table to the specified FileStore path. `datetime.now()` is appended to the path to prevent overwriting the table in the event that this notebook is run in a batch inference job. To overwrite existing tables at the path, replace the cell below with:
```python
output_df.write.mode("overwrite").save(output_table_path)
```

### (Optional) Write predictions to Unity Catalog
If you have access to any UC catalogs, you can also save predictions to UC by specifying a table in the format `<catalog>.<database>.<table>`.
```python
output_table = "" # Example: "ml.batch-inference.WA_Best_Model"
output_df.write.saveAsTable(output_table)
```

In [0]:
from datetime import datetime

# To write to a unity catalog table, see instructions above
output_df.write.save(f"{output_table_path}_{datetime.now().isoformat()}".replace(":", "."))

In [0]:
output_table ="wildfire.output_wildfires.wa"

output_df.write.mode("overwrite").saveAsTable(output_table)

In [0]:
output_df.display()

Date,Region,Precipitation_Max,Precipitation_Mean,Precipitation_Min,RelativeHumidity_Mean,SoilWaterContent_Max,SoilWaterContent_Min,SolarRadiation_Mean,Temperature_Mean,WindSpeed_Max,WindSpeed_Mean,Year,Month,Shrubs,Cultivated_and_managed_vegetation/agriculture__cropland_,Bare_/_sparse_vegetation,Permanent_water_bodies,Closed_forest__deciduous_broad_leaf,Vegetation_index_mean,Vegetation_index_max,Vegetation_index_min,Vegetation_index_std,prediction
2020-11-01T00:00:00Z,6,28.22734832763672,2.175524320403373,0.0,44.85979814987711,0.330396324396133,8.951561770000001e-07,22.56828239855768,27.6122590716821,9.258060455322266,5.014015765850021,2020,11,31.3,5.6,1.0,0.4,2.4,0.205688176,0.991899967,-0.199999988,0.092951559,43.07334641997653
2020-11-02T00:00:00Z,6,30.36119270324707,1.3862143657180042,0.0,45.89435733226736,0.378079682588577,7.277510009999999e-07,23.340846743129216,25.7086365434748,8.168209075927733,4.997224826194089,2020,11,31.3,5.6,1.0,0.4,2.4,0.205688176,0.991899967,-0.199999988,0.092951559,39.367760669429
2020-11-03T00:00:00Z,6,3.019692659378052,0.142167314238111,0.0,39.881772522584,0.347833096981049,9.28140537e-07,26.315288049742165,26.0035235785804,9.957096099853516,4.304180640610112,2020,11,31.3,5.6,1.0,0.4,2.4,0.205688176,0.991899967,-0.199999988,0.092951559,38.16354196948847
2020-11-04T00:00:00Z,6,4.246981620788573,0.046967263476724,0.0,36.57614505144551,0.3251309990882869,9.82157303e-07,27.468957918133977,26.65371117778029,9.818195343017578,5.296806393800943,2020,11,31.3,5.6,1.0,0.4,2.4,0.205688176,0.991899967,-0.199999988,0.092951559,37.27411275164639
2020-11-05T00:00:00Z,6,6.323696613311768,0.275793056479504,0.0,35.784382644564985,0.3020512461662289,1.437651576e-06,27.82793854813431,27.41218388531529,9.342352867126465,5.166340720129363,2020,11,31.3,5.6,1.0,0.4,2.4,0.205688176,0.991899967,-0.199999988,0.092951559,35.68173122735241
2020-11-06T00:00:00Z,6,8.278383255004883,0.611896216208416,0.0,39.0636770151684,0.28729385137558,8.54333223e-07,26.7431411723662,26.47718080247064,9.184889793395996,5.3532104698610565,2020,11,31.3,5.6,1.0,0.4,2.4,0.205688176,0.991899967,-0.199999988,0.092951559,32.15500346029771
2020-11-07T00:00:00Z,6,12.680755615234377,0.5199550547403401,0.0,40.95983604438168,0.300840735435486,1.544520842e-06,27.56101831186201,26.87634011900948,9.090863227844238,4.536585301831883,2020,11,31.3,5.6,1.0,0.4,2.4,0.205688176,0.991899967,-0.199999988,0.092951559,30.857557101965472
2020-11-08T00:00:00Z,6,19.91455268859864,0.179646255596465,0.0,36.67261591736411,0.2854874730110169,0.0,28.699409193529583,28.268846861907694,7.25596809387207,4.217241318939084,2020,11,31.3,5.6,1.0,0.4,2.4,0.205688176,0.991899967,-0.199999988,0.092951559,29.165757270063807
2020-11-09T00:00:00Z,6,36.8853874206543,0.8930521008309179,0.0,35.10020960821809,0.351660043001175,3.248142775e-06,27.621475577669848,27.651635432111288,11.105335235595703,6.2090673142047415,2020,11,31.3,5.6,1.0,0.4,2.4,0.205688176,0.991899967,-0.199999988,0.092951559,26.782394593176463
2020-11-10T00:00:00Z,6,13.73735523223877,0.1037913325570269,0.0,33.51399955652655,0.323210656642914,9.68342647e-07,27.81783686323569,24.432638971028613,13.934606552124023,5.3908112422328935,2020,11,31.3,5.6,1.0,0.4,2.4,0.205688176,0.991899967,-0.199999988,0.092951559,26.324963605733604


Databricks visualization. Run in Databricks to view.