In [3]:
# pip install model_inference

In [1]:
import os
import glob
import pandas as pd
import pickle
import matplotlib.pyplot as plt
import numpy as np
import random
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
import pprint
import pyspark
import pyspark.sql.functions as F

from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import StringType, IntegerType, FloatType, DateType

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from sklearn.model_selection import RandomizedSearchCV
from sklearn.metrics import make_scorer, f1_score, roc_auc_score
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

import joblib
from joblib import load
# import model_inference

In [5]:
# Build a .py script that takes a snapshot date, loads a model artefact and make an inference and save to datamart

## set up pyspark session

In [2]:
# Initialize SparkSession
spark = pyspark.sql.SparkSession.builder \
    .appName("gold_model_prediction") \
    .master("local[*]") \
    .getOrCreate()

# Set log level to ERROR to hide warnings
spark.sparkContext.setLogLevel("ERROR")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/06/24 06:45:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/06/24 06:45:07 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


## set up config

In [3]:
snapshot_date_str = "2024-07-01"
model_name = "log_reg_churn_model.joblib"

In [4]:
config = {}
config["snapshot_date_str"] = snapshot_date_str
config["snapshot_date"] = datetime.strptime(config["snapshot_date_str"], "%Y-%m-%d")
config["model_name"] = model_name
config["model_directory"] = "model_artifacts/"
config["model_artefact_filepath"] = config["model_directory"] + config["model_name"]

pprint.pprint(config)

{'model_artefact_filepath': 'model_artifacts/log_reg_churn_model.joblib',
 'model_directory': 'model_artifacts/',
 'model_name': 'log_reg_churn_model.joblib',
 'snapshot_date': datetime.datetime(2024, 7, 1, 0, 0),
 'snapshot_date_str': '2024-07-01'}


## load model from model artifacts

In [5]:
# Load the model from the .joblib file
with open(config["model_artefact_filepath"], 'rb') as file:
    model_artefact = joblib.load(file)

print("Model loaded successfully! " + config["model_artefact_filepath"])

Model loaded successfully! model_artifacts/log_reg_churn_model.joblib


## load feature store

In [6]:
gold_feature_store_path = "datamart/gold/feature_store/"

config["train_test_start_date_str"] = "2024-04-01"
config["oot_end_date_str"] = "2024-07-01"
config["train_test_start_date"] = datetime.strptime(config["train_test_start_date_str"], "%Y-%m-%d")
config["oot_end_date"] = datetime.strptime(config["oot_end_date_str"], "%Y-%m-%d")

available_files = os.listdir(gold_feature_store_path)
target_dates = ["2024-04-01", "2024-05-01", "2024-06-01", "2024-07-01"]

target_files = [
    os.path.join(gold_feature_store_path, f"gold_feature_store_{d}.parquet")
    for d in target_dates if f"gold_feature_store_{d}.parquet" in available_files
]

if not target_files:
    raise FileNotFoundError("No matching Parquet files found for the given date range.")

features_sdf = spark.read.parquet(*target_files)
features_sdf = features_sdf.filter(
    (col("snapshot_date") >= config["train_test_start_date"]) &
    (col("snapshot_date") <= config["oot_end_date"])
)

print("Extracted features_sdf:", features_sdf.count(), config["train_test_start_date"], config["oot_end_date"])

# Extract IDs before dropping them
id_cols_pdf = features_sdf.select("customerID", "snapshot_date").toPandas()

features_pdf = features_sdf.toPandas()
features_pdf.drop(columns=[c for c in ["customerID", "snapshot_date"] if c in features_pdf.columns], inplace=True)

                                                                                

Extracted features_sdf: 25124 2024-04-01 00:00:00 2024-07-01 00:00:00


## preprocess data for modeling

In [11]:
# print(type(model_artefact))

In [7]:
# Drop identifiers 
features_pdf.drop(columns=[c for c in ["customerID", "snapshot_date"] if c in features_pdf.columns], inplace=True)

# Recreate tenure groups
def create_tenure_groups(df):
    df = df.copy()
    df["tenure_group"] = pd.cut(
        df["tenure"],
        bins=[0, 12, 24, 36, 48, 60, 72, np.inf],
        labels=["0-1yr", "1-2yr", "2-3yr", "3-4yr", "4-5yr", "5-6yr", "6+yr"]
    )
    return df

features_pdf = create_tenure_groups(features_pdf)

In [8]:
# Handle missing values
numerical_cols = features_pdf.select_dtypes(include=["float64", "int64"]).columns
categorical_cols = features_pdf.select_dtypes(include=["object", "category"]).columns

for col in numerical_cols:
    median_val = features_pdf[col].median()
    features_pdf[col] = features_pdf[col].fillna(median_val)

for col in categorical_cols:
    mode_val = features_pdf[col].mode()[0] if not features_pdf[col].mode().empty else "Unknown"
    features_pdf[col] = features_pdf[col].fillna(mode_val)

## model prediction inference

In [9]:
# Load model and preprocessor 
pipeline = load("model_artifacts/model_pipeline.pkl")
preprocessor = pipeline["preprocessor"]
model = pipeline["model"]

In [10]:
# Transform and predict
X_inference = preprocessor.transform(features_pdf)
y_inference = model.predict(X_inference)

In [11]:
y_inference_pdf = id_cols_pdf.copy()
y_inference_pdf["model_name"] = config["model_name"]
y_inference_pdf["model_predictions"] = y_inference

print(y_inference_pdf.head())
print("Number of predictions:", len(y_inference_pdf))

   customerID snapshot_date                  model_name  model_predictions
0  7590-VHVEG    2024-06-01  log_reg_churn_model.joblib                  0
1  5575-GNVDE    2024-06-01  log_reg_churn_model.joblib                  0
2  3668-QPYBK    2024-06-01  log_reg_churn_model.joblib                  0
3  7795-CFOCW    2024-06-01  log_reg_churn_model.joblib                  0
4  9237-HQITU    2024-06-01  log_reg_churn_model.joblib                  0
Number of predictions: 25124


## save model inference to datamartgold table

In [12]:
snapshot_range_str = f"{config['train_test_start_date_str']}_to_{config['oot_end_date_str']}".replace("-", "_")
gold_directory = f"datamart/gold/model_predictions/{config['model_name'][:-4]}/"

if not os.path.exists(gold_directory):
    os.makedirs(gold_directory)

partition_name = f"{config['model_name'][:-4]}_predictions_{snapshot_range_str}.parquet"
filepath = os.path.join(gold_directory, partition_name)

spark.createDataFrame(y_inference_pdf).write.mode("overwrite").parquet(filepath)
print("Saved inference results to:", filepath)

                                                                                

Saved inference results to: datamart/gold/model_predictions/log_reg_churn_model.jo/log_reg_churn_model.jo_predictions_2024_04_01_to_2024_07_01.parquet


In [13]:
# Save as CSV
csv_filepath = filepath.replace(".parquet", ".csv")  # Change the extension to .csv

# Save the predictions as CSV using pandas 
y_inference_pdf.to_csv(csv_filepath, index=False)
print("Saved inference results to CSV:", csv_filepath)

Saved inference results to CSV: datamart/gold/model_predictions/log_reg_churn_model.jo/log_reg_churn_model.jo_predictions_2024_04_01_to_2024_07_01.csv


## Check datamart

In [14]:
# Initialize Spark session 
spark = SparkSession.builder \
    .appName("lr_model_checker") \
    .master("local[*]") \
    .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")  # Clean output

# Define model folder and prediction file name pattern 
model_name = "log_reg_churn_model.jo"
snapshot_range = "2024_04_01_to_2024_07_01"

folder_path = f"datamart/gold/model_predictions/{model_name}/"

# Find matching Parquet files
files_list = [
    f for f in glob.glob(os.path.join(folder_path, "*.parquet"))
    if snapshot_range in os.path.basename(f)
]

if not files_list:
    print("No matching prediction files found for:", snapshot_range)
    print("Available files in folder:", os.listdir(folder_path))
else:
    # Load data into Spark DataFrame and inspect
    df = spark.read.option("header", "true").parquet(*files_list)
    print(f"Loaded {df.count()} rows from:")
    for f in files_list:
        print(" -", f)
    
    df.show(10)
    df.printSchema()

Loaded 25124 rows from:
 - datamart/gold/model_predictions/log_reg_churn_model.jo/log_reg_churn_model.jo_predictions_2024_04_01_to_2024_07_01.parquet
+----------+-------------+--------------------+-----------------+
|customerID|snapshot_date|          model_name|model_predictions|
+----------+-------------+--------------------+-----------------+
|4526-ZJJTM|   2024-07-01|log_reg_churn_mod...|                0|
|8384-FZBJK|   2024-07-01|log_reg_churn_mod...|                0|
|3750-RNQKR|   2024-07-01|log_reg_churn_mod...|                0|
|0962-CQPWQ|   2024-07-01|log_reg_churn_mod...|                0|
|3096-YXENJ|   2024-07-01|log_reg_churn_mod...|                0|
|1265-BCFEO|   2024-07-01|log_reg_churn_mod...|                0|
|5837-LXSDN|   2024-07-01|log_reg_churn_mod...|                0|
|5945-AZYHT|   2024-07-01|log_reg_churn_mod...|                0|
|8325-QRPZR|   2024-07-01|log_reg_churn_mod...|                0|
|6384-VMJHP|   2024-07-01|log_reg_churn_mod...|           