## Payment Delay Inference
This notebook will guide you through the process of setting up the prediction for the payment delays. To understand the factors influencing each delay prediction, we will use SHAP values [(SHapley Additive exPlanations)](https://shap.readthedocs.io/en/latest/index.html). SHAP values (SHapley Additive exPlanations) are a method used in machine learning to explain the output of predictive models. They help you understand how much each feature contributes to a particular prediction.

This are the main steps of this exercise:
1. Install and Import Packages
2. Predict the Payment Delays
3. Derive key drivers for payment delay prediction
4. Merge prediction data and SHAP values into a single result set
5. Persist the result set into a table

## 1. Install and Import Packages

In [0]:
%pip install xgboost
%pip install shap
%restart_python

In [0]:
import mlflow
from pyspark.sql import SparkSession
from pyspark.sql.functions import row_number, lit, col, dateadd, monotonically_increasing_id
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType, StructField, StructType, MapType, StringType
from pathlib import Path
import pandas as pd
import xgboost as xgb
import shap
from tqdm.auto import tqdm
import databricks.connect as db_connect
import mlflow.tracking._model_registry.utils

## 2. Predict the Payment Delays

#### Set Parameters
Please replace the values `<CATALOG_NAME>` and `<SCHEMA_NAME>` with the specific values that match our use case and group. You can find the correct names by checking the **Unity Catalog** and look for the specific catalog and schema names: `uc_XXX`, `grpX`.

In [0]:
%sql
-- CREATE CATALOG IF NOT EXISTS <CATALOG_NAME>;
SET CATALOG <CATALOG_NAME>;
CREATE SCHEMA IF NOT EXISTS <SCHEMA_NAME>;
USE SCHEMA <SCHEMA_NAME>;

#### Load data model
Replace the value `<PREPARED_TABLE_NAME>` with the prepared table name. We will load a randomized sample dataset by providing the variable `<SEED_PARAMETER>` with a random value, e.g. `42`.

In [0]:
mlflow.set_tracking_uri("databricks")
mlflow.set_registry_uri("databricks-uc")

In [0]:
inference_data_model = (spark.read.table("<PREPARED_TABLE_NAME>")
                        .drop("CompanyCode", "AccountingDocument", "FiscalYear", "AccountingDocumentItem", "delay", "NetDueDate", "ClearingDate")
                        .sample(0.01, seed=<SEED_PARAMETER>)
                        .toPandas())

In [0]:
inference_data_model.count()

In [0]:
def infer_column_dtype(series):
    # Try to convert to numeric
    try:
        pd.to_numeric(series.dropna())
        return 'numeric'
    except:
        pass

    # Try to convert to datetime
    try:
        pd.to_datetime(series.dropna(), errors='raise', infer_datetime_format=True)
        return 'datetime'
    except:
        pass

    # If all unique values are 'True' or 'False' like
    lower_vals = set(str(v).strip().lower() for v in series.dropna().unique())
    if lower_vals <= {'true', 'false', '1', '0'}:
        return 'boolean'
    
    return 'string'

In [0]:
for col in tqdm(inference_data_model.columns):
    inferred = infer_column_dtype(inference_data_model[col])
    if inferred == 'numeric':
        inference_data_model[col] = pd.to_numeric(inference_data_model[col], errors='coerce')
    elif inferred == 'datetime':
        inference_data_model[col] = pd.to_datetime(inference_data_model[col], errors='coerce')
    elif inferred == 'boolean':
        inference_data_model[col] = inference_data_model[col].astype('bool')
    else:
        inference_data_model[col] = inference_data_model[col].astype('category')

In [0]:
# Convert the __TIMESTAMP column to a supported type
inference_data_model['__TIMESTAMP'] = inference_data_model['__TIMESTAMP'].astype('int64')


#### Load Trained Model
Load the trained model by replacing the variable `<TRAINED_MODEL_NAME>` with the trained model name from previous exercise.
Hint: You can also find the name by checking for `models` in the **Unity Catalog** in the corresponding catalog and schema. 

In [0]:
# Set the registry URI manually
mlflow.tracking._model_registry.utils._get_registry_uri_from_spark_session = lambda: "databricks-uc"
mlflow.login()

In [0]:
model_delay = mlflow.xgboost.load_model("models:/<TRAINED_MODEL_NAME>@prod")

#### Run Prediction
We then perform inference using the trained model on our prepared data.

In [0]:
prediction = model_delay.predict(inference_data_model, output_margin=True)

## 3. Derive key drivers for payment delay prediction

To determine the key influecing factors on payment delays, we will use SHAP TreeExplainer and XGBoost (Extreme Gradient Boosting) algorithm.

- [**SHAP TreeExplainer**]((https://shap.readthedocs.io/en/latest/tabular_examples.html#tree-based-models) is a specialized explainer framework, designed specifically for tree-based machine learning models. It provides fast and exact explanations for how each feature contributes to a prediction.
- [**XGBoost**]((https://xgboost.ai/about) is a powerful and efficient machine learning algorithm based on gradient boosting, to predict the payment delays. 

In [0]:
explainer = shap.TreeExplainer(model_delay)

Converting data into xgb.DMatrix data structure, which is used in XGBoost to efficiently store and process input data for training and prediction. It’s optimized for performance and memory usage, especially with large datasets.

In [0]:

Xd = xgb.DMatrix(inference_data_model, enable_categorical=True)

In [0]:
explanation = explainer(Xd)

In [0]:
explanation.feature_names = Xd.feature_names

#### Visualizing the SHAP values using beeswarm plot

The [beeswarm plot](https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/beeswarm.html#A-simple-beeswarm-summary-plot) is designed to display an information-dense summary of how the top features in a dataset impact the model’s output. Each instance the given explanation is represented by a single dot on each feature row. The x position of the dot is determined by the SHAP value (`shap_values.value[instance,feature]`) of that feature, and dots “pile up” along each feature row to show density. Color is used to display the original value of a feature (`shap_values.data[instance,feature]`).

In [0]:
shap.plots.beeswarm(explanation)

Sometimes it is helpful to transform the SHAP values before we plots them. Below we plot the absolute value and fix the color to be red. This creates a richer parallel to the standard shap_values.abs.mean(0) bar plot, since the bar plot just plots the mean value of the dots in the beeswarm plot.

In [0]:
shap.plots.bar(explanation)

## 4. Merge prediction data and SHAP values into a single result set

#### Convert the SHAP values into a Spark DataFrame

In [0]:
# Extract SHAP values from the explanation object
shap_values = explanation.values

# Get the list of columns from the inference data model
column_list = list(inference_data_model.columns)

# Create a Spark DataFrame from the SHAP values with appropriate schema
spark_shap_df = spark.createDataFrame(
    shap_values, 
    schema=StructType([StructField(f"SHAP_Values_{column}", DoubleType()) for column in column_list]))

spark_shap_df.display()

#### Calculate Min/Max Values for each column

In [0]:
# Calculate the minimum and maximum values for each column in the SHAP values DataFrame
mins = spark_shap_df.agg(*[F.min(c).alias(c) for c in spark_shap_df.columns])
maxs = spark_shap_df.agg(*[F.max(c).alias(c) for c in spark_shap_df.columns])

# Transpose the min/max values DataFrame for easier access
mins_transposed = mins.pandas_api().transpose()
maxs_transposed = maxs.pandas_api().transpose()

# Identify columns where the minimum/maximum value is zero
min_column_set = set(mins_transposed.loc[(mins_transposed[0] == 0)].index.tolist())
max_column_set = set(maxs_transposed.loc[(maxs_transposed[0] == 0)].index.tolist())

# Find and drop columns where both the minimum and maximum values are zero
zero_only = list(min_column_set.intersection(max_column_set))
spark_shap_df = spark_shap_df.drop(*zero_only)

# Extract the original column names from the remaining SHAP value columns
selected_columns = [column.split("SHAP_Values_")[-1] for column in spark_shap_df.columns]

# Add a unique ID column to the SHAP values DataFrame
spark_shap_df = spark_shap_df.withColumn("ID", F.monotonically_increasing_id())

#### Retrieve additional semantics
In this step, we enrich the table meta data with additional business semantics, by providing addtional comments / descriptions per column.

> [IMPORTANT]
> This step is required as a preparation for later exercise, when we will interprete the result set using LLM! 

1. Navigate to the Unity Catalog and find in your `<SCHEMA>` the prepared table `prepared_accounting_document`. 
2. Right-click on the table to open the `Catalog Explorer`
3. Go the tab `Overview` to see the list of all columns
4. Click on the `AI generate` button on the top right
5. Confirm the dialog with the AI generated comments

After successful generation of the column comments, proceed with the exercise.

In [0]:
column_comments = spark.sql("DESCRIBE TABLE prepared_accounting_document").filter(F.col("col_name").isin(selected_columns))

In [0]:
display(column_comments)

Replace the value `<PREPARED_TABLE_NAME>` with the appropriate table name from our preprocessed data. 


In [0]:
dataset_primary_keys = spark.read.table("<PREPARED_TABLE_NAME>").select("CompanyCode", "AccountingDocument", "FiscalYear", "AccountingDocumentItem", "NetDueDate", *selected_columns).sample(0.01, seed=42)

# Add a unique identifier column "ID" to the dataset
dataset_primary_keys = dataset_primary_keys.withColumn("ID", monotonically_increasing_id())

Replace the value `<PRIMARY_KEYS>` with the variable that we defined before to display a preview of the primary keys.

In [0]:
display(<PRIMARY_KEYS>)

Set the `<JOIN_TYPE>` to the value `inner`.

In [0]:
joined_spark_shap_df = spark_shap_df.join(
    dataset_primary_keys, on="ID", how="<JOIN_TYPE>"
)

Next, select the following columns for the table by replacing the value `<COL_NAME>`:
- ID
- CompanyCode
- AccountingDocument
- FiscalYear
- AccountingDocumentItem
- NetDueDate



In [0]:
spark_shap_structure = (joined_spark_shap_df.select(
                F.col("<COL_NAME>"),
                F.col("<COL_NAME>"),
                F.col("<COL_NAME>"),
                F.col("<COL_NAME>"),
                F.col("<COL_NAME>"),
                F.col("<COL_NAME>"),
                F.array([
                   F.struct(F.lit(column).alias("column_name"),
                            F.col(f"SHAP_Values_{column}").alias("shap_value"),
                            F.lit(column_comments.filter(column_comments.col_name == column).select("comment").collect()[0][0]).alias("column_description"),
                            F.col(column).cast("string").alias("column_value")
                            )
                             for column in selected_columns]).alias("shap_array")
                ))
display(spark_shap_structure)

In [0]:
spark_prediction_df = spark.createDataFrame(prediction, schema=StructType([StructField("delay_prediction", DoubleType())])).select(monotonically_increasing_id().alias("ID"), F.round("delay_prediction").alias("delay_prediction"))
display(spark_prediction_df)

In [0]:
spark_shap_data = (spark_shap_structure
                   .join(spark_prediction_df, on="ID", how="inner")
                   .drop("ID")
                   .select("CompanyCode", "AccountingDocument", "FiscalYear", "AccountingDocumentItem", "shap_array", "delay_prediction", "NetDueDate"))
display(spark_shap_data)

## 5. Persist the result set into a table
We create a prediction table that contains a constraint key consisting of the columns AccountingDocument, CompanyCode, FiscalYear, AccountingDocumentItem. As the constraint key is unique over the complete Databricks catalog, please replace the constant `<CONSTRAINT_NAME>` with an appropriate name for the constraint key.

In the following code, set the parameters `<BOOL>` as follows to make sure we prepare the table correctly for sharing it back to BDC:
- enableChangeDataFeed = true
- enableDeletionVectors = false 

In [0]:
spark_shap_data.write.format("delta").\
    mode("overwrite").\
    option("delta.enableChangeDataFeed", "<BOOL>").\
    option("delta.enableDeletionVectors", "<BOOL>").\
    saveAsTable("delay_prediction_shap")

In [0]:
%sql
ALTER TABLE delay_prediction_shap SET TBLPROPERTIES ('delta.columnMapping.mode' = 'name');

ALTER TABLE delay_prediction_shap ALTER COLUMN CompanyCode SET NOT NULL;
ALTER TABLE delay_prediction_shap ALTER COLUMN AccountingDocument SET NOT NULL;
ALTER TABLE delay_prediction_shap ALTER COLUMN FiscalYear SET NOT NULL;
ALTER TABLE delay_prediction_shap ALTER COLUMN AccountingDocumentItem SET NOT NULL;
ALTER TABLE delay_prediction_shap ADD CONSTRAINT delay_shap_pk2 PRIMARY KEY (CompanyCode, AccountingDocument, FiscalYear, AccountingDocumentItem);

ALTER TABLE delay_prediction_shap SET TBLPROPERTIES (
  delta.enableChangeDataFeed = true,
  delta.enableDeletionVectors = false
);