# Model Inference

This notebook holds an example of how the latest CreditKarma Scorer model is retrieved from the registry and used to make inferences.

In [1]:
import os
import glob
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import pyspark
from pyspark.sql.functions import col
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta

from tqdm import tqdm
import itertools

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, fbeta_score, confusion_matrix, ConfusionMatrixDisplay
import numpy as np

from utils.data_processing_gold_table import build_feature_store

import mlflow
from mlflow.models import infer_signature
from mlflow.tracking import MlflowClient

In [2]:
spark = pyspark.sql.SparkSession.builder \
    .appName("model-inference") \
    .master("local[*]") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/06/21 13:20:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/06/21 13:20:07 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [82]:
snapshot_date_str = "2024-06-01"
snapshot_date = datetime.strptime(snapshot_date_str, "%Y-%m-%d").date()
model_name = "creditkarma-scorer"

# Load model artifact from registry

In [4]:
mlflow.set_tracking_uri(uri="http://mlflow:5001")

In [None]:
client = MlflowClient()
model_version = client.get_model_version_by_alias(model_name, "champion").tags['train_date']
print(f"Current deployed version: {model_version}")

model = mlflow.sklearn.load_model(model_uri=f"models:/{model_name}@champion")

Current deployed version: 2024-09-01


In [44]:
client = MlflowClient()
versions = client.search_model_versions(f"name='{model_name}'")

# Retrieve model version from MLFlow
model_version = next((v for v in versions if v.tags.get("train_date") == snapshot_date_str), None)
model_uri = f"models:/{model_name}/{model_version.version}"
model = mlflow.sklearn.load_model(model_uri=model_uri)

In [None]:
# Get threshold
run = client.get_run(model_version.run_id) 
best_threshold = float(run.data.params['best_fb_threshold'])

# Load feature store

In [83]:
def read_silver_table(table, silver_db, spark):
    """
    Helper function to read all partitions of a silver table
    """
    folder_path = os.path.join(silver_db, table)
    files_list = [os.path.join(folder_path, os.path.basename(f)) for f in glob.glob(os.path.join(folder_path, '*'))]
    df = spark.read.option("header", "true").parquet(*files_list)
    return df

In [84]:
def read_gold_table(snapshot_date_str, table, gold_db, spark):
    """
    Helper function to read gold table features from this day
    """
    snapshot_date = datetime.strptime(snapshot_date_str, "%Y-%m-%d").date()
    folder_path = os.path.join(gold_db, table)
    files_list = [os.path.join(folder_path, os.path.basename(f)) for f in glob.glob(os.path.join(folder_path, '*'))]
    df = spark.read.option("header", "true").option("mergeSchema", "true").parquet(*files_list)
    df = df.filter(col("snapshot_date")==snapshot_date)
    return df


In [85]:
def read_online_feature_store(gold_db, spark):
    """
    Helper function to read online feature store
    """
    folder_path = os.path.join(gold_db, 'feature_store', 'online.parquet')
    df = spark.read.option("header", "true").option("mergeSchema", "true").parquet(folder_path)
    df = df.filter(col("snapshot_date")==snapshot_date)
    return df

In [86]:
def create_online_feature_store(snapshot_date_str, gold_db, silver_db, spark):
    print("Trying to retrieve records from gold table...")
    df_gold_online = read_gold_table(snapshot_date_str, 'feature_store', gold_db, spark)

    if df_gold_online.count() == 0:
        print("Building online feature store...")
        df_attributes = read_silver_table('attributes', silver_db, spark).filter(col("snapshot_date")==snapshot_date)
        df_clickstream = read_silver_table('clickstream', silver_db, spark)
        df_financials = read_silver_table('financials', silver_db, spark).filter(col("snapshot_date")==snapshot_date)
        df_loan_type = read_silver_table('loan_type', silver_db, spark).filter(col("snapshot_date")==snapshot_date)
        
        # create online feature store
        df_gold_online = build_feature_store(df_attributes, df_financials, df_loan_type, df_clickstream)

    # Save into online feature store
    filename = 'online.parquet'
    feature_filepath = os.path.join(gold_db, 'feature_store', filename)
    df_gold_online.write.mode('overwrite').parquet(feature_filepath)

In [88]:
create_online_feature_store(snapshot_date_str, 'datamart/gold', 'datamart/silver', spark)

Trying to retrieve records from gold table...


                                                                                

In [89]:
df_spark = read_online_feature_store('datamart/gold', spark)
df_spark.show()

+-----------+-------------+----+-------------+---------------------+-----------------+---------------+-------------+-----------+-------------------+----------------------+--------------------+--------------------+----------------+------------------------+-------------------+-----------------------+---------------+------------------------+-------------+---------+-------------------+-------------+-------------+------------+----------------+-----------+-----------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+---------------------+-----------------+------------------------+-----------------+--------------------+--------------------+-----------------

## Data preprocessing to feed into model

Transform into pandas

In [77]:
df_pd = df_spark.toPandas().sort_values(by='customer_id')

In [78]:
df_pd

Unnamed: 0,customer_id,snapshot_date,age,annual_income,monthly_inhand_salary,num_bank_accounts,num_credit_card,interest_rate,num_of_loan,delay_from_due_date,...,avg_fe_11,avg_fe_12,avg_fe_13,avg_fe_14,avg_fe_15,avg_fe_16,avg_fe_17,avg_fe_18,avg_fe_19,avg_fe_20
297,CUS_0x1038,2024-10-01,28.0,129473.156250,10959.429688,3.0,4.0,10.0,3.0,28.0,...,,,,,,,,,,
428,CUS_0x104f,2024-10-01,20.0,11336.834961,992.736267,4.0,6.0,14.0,6.0,30.0,...,,,,,,,,,,
350,CUS_0x10a9,2024-10-01,41.0,17603.265625,1324.945923,8.0,6.0,4.0,0.0,9.0,...,,,,,,,,,,
60,CUS_0x111b,2024-10-01,40.0,70277.218750,5633.435059,3.0,6.0,7.0,3.0,5.0,...,,,,,,,,,,
200,CUS_0x1123,2024-10-01,15.0,15264.980469,1266.081665,6.0,5.0,32.0,3.0,23.0,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
336,CUS_0xd50,2024-10-01,41.0,43771.851562,3600.654053,4.0,3.0,9.0,1.0,,...,,,,,,,,,,
436,CUS_0xde4,2024-10-01,30.0,75402.523438,6221.543457,2.0,4.0,3.0,3.0,0.0,...,,,,,,,,,,
358,CUS_0xe16,2024-10-01,18.0,37976.679688,2997.723389,9.0,9.0,18.0,7.0,56.0,...,,,,,,,,,,
197,CUS_0xe7a,2024-10-01,32.0,7711.970215,395.664154,10.0,5.0,29.0,7.0,57.0,...,,,,,,,,,,


Transform into Numpy array

In [79]:
df_arr = df_pd.drop(columns=['customer_id', 'snapshot_date']).values

# Make Inference

(Note: Preprocessing pipeline (normalization and imputation) was already saved as part of the model artifacts)

In [80]:
model

In [81]:
df_pred_pd = df_pd.copy()[['customer_id', 'snapshot_date']]
pred = model.predict_proba(df_arr)
model_version_name = model_version.tags.get('model_type') + '_' + model_version.tags.get('train_date')
df_pred_pd['model_version'] = model_version_name
df_pred_pd['default'] = pred[: , 1] > best_threshold
df_pred_pd['probability_no_default'] = pred[:, 0]
df_pred_pd['probability_default'] = pred[:, 1]

df_pred_pd

Unnamed: 0,customer_id,snapshot_date,model_version,default,probability_no_default,probability_default
297,CUS_0x1038,2024-10-01,rf_2024-12-01,0,0.580000,0.420000
428,CUS_0x104f,2024-10-01,rf_2024-12-01,0,0.580000,0.420000
350,CUS_0x10a9,2024-10-01,rf_2024-12-01,0,0.553056,0.446944
60,CUS_0x111b,2024-10-01,rf_2024-12-01,0,0.660000,0.340000
200,CUS_0x1123,2024-10-01,rf_2024-12-01,1,0.320000,0.680000
...,...,...,...,...,...,...
336,CUS_0xd50,2024-10-01,rf_2024-12-01,0,0.577500,0.422500
436,CUS_0xde4,2024-10-01,rf_2024-12-01,0,0.537500,0.462500
358,CUS_0xe16,2024-10-01,rf_2024-12-01,0,0.580000,0.420000
197,CUS_0xe7a,2024-10-01,rf_2024-12-01,1,0.460000,0.540000


# Save to gold table

In [60]:
gold_directory = f"datamart/gold/model_predictions/{model_name}_{model_version_name}"

if not os.path.exists(gold_directory):
    os.makedirs(gold_directory)

df_pred = spark.createDataFrame(df_pred_pd)
partition_name = snapshot_date_str.replace('-','_') + '.parquet'
filepath = os.path.join(gold_directory, partition_name)
df_pred.write.mode("overwrite").parquet(filepath)

In [61]:
# Check
spark_pred = read_gold_table(snapshot_date_str, f'model_predictions/{model_name}_{model_version_name}', 'datamart/gold', spark)
spark_pred.show()

+-----------+-------------+-------------+-------+----------------------+-------------------+
|customer_id|snapshot_date|model_version|default|probability_no_default|probability_default|
+-----------+-------------+-------------+-------+----------------------+-------------------+
| CUS_0xbe46|   2024-12-01|rf_2024-12-01|      0|                  0.62|               0.38|
| CUS_0xbe4d|   2024-12-01|rf_2024-12-01|      0|                  0.56|               0.44|
| CUS_0xbe62|   2024-12-01|rf_2024-12-01|      0|                0.5375|             0.4625|
| CUS_0xbe9a|   2024-12-01|rf_2024-12-01|      0|    0.6130555555555556| 0.3869444444444444|
| CUS_0xbebf|   2024-12-01|rf_2024-12-01|      0|                0.6175|             0.3825|
| CUS_0xbf3a|   2024-12-01|rf_2024-12-01|      0|                  0.68|               0.32|
| CUS_0xbf60|   2024-12-01|rf_2024-12-01|      1|                  0.46|               0.54|
| CUS_0xbf7f|   2024-12-01|rf_2024-12-01|      0|                0.617

In [62]:
spark_pred.count()

515