# Model Inference

This notebook holds an example of how the latest CreditKarma Scorer model is retrieved from the registry and used to make inferences.

In [36]:
import os
import glob
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import pyspark
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta

from tqdm import tqdm
import itertools

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, fbeta_score, confusion_matrix, ConfusionMatrixDisplay
import numpy as np

import pickle

import mlflow
from mlflow.models import infer_signature
from mlflow.tracking import MlflowClient

In [37]:
spark = pyspark.sql.SparkSession.builder \
    .appName("model-inference") \
    .master("local[*]") \
    .getOrCreate()

In [38]:
snapshot_date_str = "2024-10-01"
model_name = "creditkarma-scorer-logreg"

# Load model artifact from registry

In [None]:
mlflow.set_tracking_uri(uri="http://mlflow:5001")
model_name = "creditkarma-scorer"
model = mlflow.sklearn.load_model(model_uri=f"models:/{model_name}@champion")

In [51]:
client = MlflowClient()
model_version = client.get_model_version_by_alias(model_name, "champion").tags['train_date']
print(f"Current deployed version: {model_version}")

Current deployed version: 2024-09-01


# Load feature store

In [40]:
def read_gold_table(snapshot_date_str, table, gold_db, spark):
    """
    Helper function to read a specific partition of a gold table
    """
    partition_name = snapshot_date_str.replace("-", "_") + ".parquet"
    folder_path = os.path.join(gold_db, table, partition_name)
    df = spark.read.option("header", "true").option("mergeSchema", "true").parquet(folder_path)
    return df


In [41]:
df_spark = read_gold_table(snapshot_date_str, 'feature_store', 'datamart/gold', spark)
df_spark.show()

+-----------+-------------+----+-------------+---------------------+-----------------+---------------+-------------+-----------+-------------------+----------------------+--------------------+--------------------+----------------+------------------------+-------------------+-----------------------+---------------+------------------------+-------------+---------+-------------------+-------------+-------------+------------+----------------+-----------+-----------------------+--------------------+------------------+-----------------------+--------------------+---------------------+---------------------+-----------------+-------------------+-----------------+------------------+------------------------+--------------------+-------------------+-----------------+-------------------+-------------------------+------------------------+------------------------+-------------------+---------------+--------------+---------------------------+----------------------------+-----------------------------+

## Data preprocessing to feed into model

Transform into pandas

In [42]:
df_pd = df_spark.toPandas().sort_values(by='customer_id')

In [43]:
df_pd

Unnamed: 0,customer_id,snapshot_date,age,annual_income,monthly_inhand_salary,num_bank_accounts,num_credit_card,interest_rate,num_of_loan,delay_from_due_date,...,avg_fe_11,avg_fe_12,avg_fe_13,avg_fe_14,avg_fe_15,avg_fe_16,avg_fe_17,avg_fe_18,avg_fe_19,avg_fe_20
214,CUS_0x102e,2024-04-01,26.0,50807.441406,4197.953125,8.0,4.0,11.0,4.0,12.0,...,87.3125,96.8750,98.1875,61.6250,145.8750,114.0000,74.9375,101.8125,132.3750,112.1875
21,CUS_0x109d,2024-04-01,45.0,49140.871094,4377.072266,10.0,6.0,27.0,6.0,35.0,...,92.2500,58.6875,117.6250,87.0625,65.3125,62.8750,93.1875,148.1875,76.0000,91.3125
38,CUS_0x112e,2024-04-01,32.0,20162.259766,1872.188354,8.0,4.0,11.0,4.0,13.0,...,109.0625,96.5000,65.7500,117.3125,64.6875,132.8125,81.0625,115.4375,117.3125,123.1875
24,CUS_0x1183,2024-04-01,22.0,134393.734375,10944.477539,1.0,3.0,12.0,2.0,13.0,...,39.1250,91.0625,114.0625,155.5000,105.2500,89.3125,122.0625,53.5000,120.8750,81.3750
127,CUS_0x1220,2024-04-01,31.0,128494.523438,10580.876953,4.0,4.0,9.0,2.0,0.0,...,146.8750,136.2500,114.2500,77.1875,87.3750,96.6875,90.5000,123.4375,109.9375,130.6875
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
465,CUS_0xee1,2024-04-01,34.0,17116.550781,1722.379150,7.0,8.0,17.0,7.0,26.0,...,124.5000,152.0625,102.2500,126.2500,120.0625,75.1875,86.1250,97.5000,76.0000,115.5000
393,CUS_0xf42,2024-04-01,45.0,39225.320312,3265.776611,8.0,5.0,23.0,2.0,19.0,...,112.0000,105.6250,89.6250,95.9375,87.1250,108.0625,132.5625,132.6875,135.9375,116.0000
183,CUS_0xf64,2024-04-01,32.0,56125.500000,4875.125000,8.0,3.0,18.0,2.0,30.0,...,115.4375,109.4375,120.3750,118.5000,106.8125,115.8125,72.1875,72.1250,92.0000,76.0000
174,CUS_0xf8b,2024-04-01,40.0,43922.320312,3869.193359,2.0,6.0,6.0,3.0,0.0,...,120.5000,110.0625,86.0000,22.1875,82.0625,105.9375,117.5000,98.3125,123.3750,110.3125


Transform into Numpy array

In [44]:
df_arr = df_pd.drop(columns=['customer_id', 'snapshot_date']).values

# Make Inference

(Note: Preprocessing pipeline (normalization and imputation) was already saved as part of the model artifacts)

In [45]:
model

In [53]:
df_pred_pd = df_pd.copy()[['customer_id', 'snapshot_date']]
pred = model.predict_proba(df_arr)
df_pred_pd['model_version'] = model_version
df_pred_pd['default'] = np.argmax(pred, axis=1)
df_pred_pd['probability_no_default'] = pred[:, 0]
df_pred_pd['probability_default'] = pred[:, 1]

df_pred_pd

Unnamed: 0,customer_id,snapshot_date,model_version,default,probability_no_default,probability_default
214,CUS_0x102e,2024-04-01,2024-09-01,0,1.000000,7.116696e-08
21,CUS_0x109d,2024-04-01,2024-09-01,0,1.000000,1.470068e-25
38,CUS_0x112e,2024-04-01,2024-09-01,0,1.000000,2.262133e-16
24,CUS_0x1183,2024-04-01,2024-09-01,0,0.957406,4.259414e-02
127,CUS_0x1220,2024-04-01,2024-09-01,0,1.000000,7.245425e-25
...,...,...,...,...,...,...
465,CUS_0xee1,2024-04-01,2024-09-01,0,1.000000,7.688612e-20
393,CUS_0xf42,2024-04-01,2024-09-01,0,1.000000,6.343148e-16
183,CUS_0xf64,2024-04-01,2024-09-01,0,1.000000,1.670092e-21
174,CUS_0xf8b,2024-04-01,2024-09-01,0,1.000000,3.926695e-39


# Save to gold table

In [54]:
gold_directory = f"datamart/gold/model_predictions/{model_name}_{model_version}"

if not os.path.exists(gold_directory):
    os.makedirs(gold_directory)

df_pred = spark.createDataFrame(df_pred_pd)
partition_name = snapshot_date_str.replace('-','_') + '.parquet'
filepath = os.path.join(gold_directory, partition_name)
df_pred.write.mode("overwrite").parquet(filepath)

25/06/15 07:19:42 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
25/06/15 07:19:42 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 84.44% for 9 writers
25/06/15 07:19:42 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 76.00% for 10 writers
25/06/15 07:19:42 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 69.09% for 11 writers
25/06/15 07:19:42 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 63.33% for 12 writers
25/06/15 07:19:42 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 58.46% for 13 writers
25/06/15 07:19:42 WARN MemoryManager: Total allocation exceeds 95.

In [55]:
# Check
spark_pred = read_gold_table(snapshot_date_str, f'model_predictions/{model_name}_{model_version}', 'datamart/gold', spark)
spark_pred.show()

[Stage 20:>                                                         (0 + 1) / 1]

+-----------+-------------+-------------+-------+----------------------+--------------------+
|customer_id|snapshot_date|model_version|default|probability_no_default| probability_default|
+-----------+-------------+-------------+-------+----------------------+--------------------+
| CUS_0x41ac|   2024-04-01|   2024-09-01|      0|    0.9999999994244905|5.755094975012801...|
| CUS_0x41be|   2024-04-01|   2024-09-01|      0|                   1.0|5.15441866524182E-25|
| CUS_0x41d3|   2024-04-01|   2024-09-01|      0|                   1.0|6.209628702583512...|
| CUS_0x41eb|   2024-04-01|   2024-09-01|      0|    0.9999999999918484|8.151545762567592...|
| CUS_0x4211|   2024-04-01|   2024-09-01|      0|                   1.0|3.260613491948635...|
| CUS_0x4250|   2024-04-01|   2024-09-01|      0|    0.9854419960486057|0.014558003951394284|
| CUS_0x42ec|   2024-04-01|   2024-09-01|      0|    0.9985190151013671|0.001480984898632...|
| CUS_0x432b|   2024-04-01|   2024-09-01|      0|    0.99999

                                                                                

In [56]:
spark_pred.count()

                                                                                

513