In [1]:
import os
import glob
import pandas as pd
import pickle
import matplotlib.pyplot as plt
import numpy as np
import random
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
import pprint
import pyspark
import pyspark.sql.functions as F

from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import StringType, IntegerType, FloatType, DateType

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from sklearn.model_selection import RandomizedSearchCV
from sklearn.metrics import make_scorer, f1_score, roc_auc_score
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

import joblib
from joblib import load

In [5]:
# Build a .py script that takes a snapshot date, loads a model artefact and make an inference and save to datamart

## set up pyspark session

In [2]:
# Initialize SparkSession
spark = pyspark.sql.SparkSession.builder \
    .appName("gold_model_prediction") \
    .master("local[*]") \
    .getOrCreate()

# Set log level to ERROR to hide warnings
spark.sparkContext.setLogLevel("ERROR")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/06/25 06:44:02 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/06/25 06:44:03 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


## set up config

In [3]:
snapshot_date_str = "2024-06-01"
model_name = "log_reg_churn_model.joblib"

In [4]:
config = {}
config["snapshot_date_str"] = snapshot_date_str
config["snapshot_date"] = datetime.strptime(config["snapshot_date_str"], "%Y-%m-%d")
config["model_name"] = model_name
config["model_directory"] = "model_artifacts/"
config["model_artefact_filepath"] = config["model_directory"] + config["model_name"]

pprint.pprint(config)

{'model_artefact_filepath': 'model_artifacts/log_reg_churn_model.joblib',
 'model_directory': 'model_artifacts/',
 'model_name': 'log_reg_churn_model.joblib',
 'snapshot_date': datetime.datetime(2024, 6, 1, 0, 0),
 'snapshot_date_str': '2024-06-01'}


## load model from model artifacts

In [5]:
# Load the model from the .joblib file
with open(config["model_artefact_filepath"], 'rb') as file:
    model_artefact = joblib.load(file)

print("Model loaded successfully! " + config["model_artefact_filepath"])

Model loaded successfully! model_artifacts/log_reg_churn_model.joblib


In [6]:
# Load the preprocessor from the model_artifacts directory
preprocessor_filepath = "model_artifacts/preprocessor.joblib"
with open(preprocessor_filepath, 'rb') as file:
    preprocessor = joblib.load(file)

print("Scaler loaded successfully!")

Scaler loaded successfully!


## load feature store

In [7]:
features_store_sdf = spark.read.parquet("datamart/gold/feature_store/gold_feature_store_2024-06-01.parquet")

# extract feature store
features_sdf = features_store_sdf.filter((col("snapshot_date") == config["snapshot_date"]))
print("extracted features_sdf", features_sdf.count(), config["snapshot_date"])

# Extract IDs before dropping them
id_cols_pdf = features_sdf.select("customerID", "snapshot_date").toPandas()

features_pdf = features_sdf.toPandas()
features_pdf.drop(columns=[c for c in ["customerID", "snapshot_date"] if c in features_pdf.columns], inplace=True)

                                                                                

extracted features_sdf 6948 2024-06-01 00:00:00


## preprocess data for modeling

In [8]:
# Drop identifiers 
features_pdf.drop(columns=[c for c in ["customerID", "snapshot_date"] if c in features_pdf.columns], inplace=True)

# Recreate tenure groups
def create_tenure_groups(df):
    df = df.copy()
    df["tenure_group"] = pd.cut(
        df["tenure"],
        bins=[0, 12, 24, 36, 48, 60, 72, np.inf],
        labels=["0-1yr", "1-2yr", "2-3yr", "3-4yr", "4-5yr", "5-6yr", "6+yr"]
    )
    return df

features_pdf = create_tenure_groups(features_pdf)

In [9]:
# Handle missing values
numerical_cols = features_pdf.select_dtypes(include=["float64", "int64"]).columns
categorical_cols = features_pdf.select_dtypes(include=["object", "category"]).columns

for col in numerical_cols:
    median_val = features_pdf[col].median()
    features_pdf[col] = features_pdf[col].fillna(median_val)

for col in categorical_cols:
    mode_val = features_pdf[col].mode()[0] if not features_pdf[col].mode().empty else "Unknown"
    features_pdf[col] = features_pdf[col].fillna(mode_val)

## model prediction inference

In [10]:
# Load pipeline
import pickle

with open('model_artifacts/model_pipeline.pkl', 'rb') as f:
    pipeline = pickle.load(f)

preprocessor = pipeline['preprocessor']
model = pipeline['model']

In [11]:
# Transform and predict
X_inference = preprocessor.transform(features_pdf)
y_inference = model.predict(X_inference)

In [12]:
y_inference_pdf = id_cols_pdf.copy()
y_inference_pdf["model_name"] = config["model_name"]
y_inference_pdf["model_predictions"] = y_inference

print(y_inference_pdf.head())
print("Number of predictions:", len(y_inference_pdf))

   customerID snapshot_date                  model_name  model_predictions
0  7590-VHVEG    2024-06-01  log_reg_churn_model.joblib                  0
1  5575-GNVDE    2024-06-01  log_reg_churn_model.joblib                  0
2  3668-QPYBK    2024-06-01  log_reg_churn_model.joblib                  0
3  7795-CFOCW    2024-06-01  log_reg_churn_model.joblib                  0
4  9237-HQITU    2024-06-01  log_reg_churn_model.joblib                  0
Number of predictions: 6948


## save model inference to datamartgold table

In [13]:
gold_directory = f"datamart/gold/model_predictions/{config['model_name'][:-4]}/"

if not os.path.exists(gold_directory):
    os.makedirs(gold_directory)

partition_name = f"{config['model_name'][:-4]}_predictions_{config['snapshot_date_str'].replace('-', '_')}.parquet"
filepath = os.path.join(gold_directory, partition_name)

spark.createDataFrame(y_inference_pdf).write.mode("overwrite").parquet(filepath)
print("Saved inference results to:", filepath)

                                                                                

Saved inference results to: datamart/gold/model_predictions/log_reg_churn_model.jo/log_reg_churn_model.jo_predictions_2024_06_01.parquet


In [14]:
# Save as CSV
csv_filepath = filepath.replace(".parquet", ".csv")  # Change the extension to .csv

# Save the predictions as CSV using pandas 
y_inference_pdf.to_csv(csv_filepath, index=False)
print("Saved inference results to CSV:", csv_filepath)

Saved inference results to CSV: datamart/gold/model_predictions/log_reg_churn_model.jo/log_reg_churn_model.jo_predictions_2024_06_01.csv


## Check datamart

In [15]:
# Initialize Spark session 
spark = SparkSession.builder \
    .appName("lr_model_checker") \
    .master("local[*]") \
    .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")  # Clean output

In [16]:
folder_path = "datamart/gold/model_predictions/"

# Recursively list all files in the subfolders of the folder_path
files_list = glob.glob(os.path.join(folder_path, '**', '*'), recursive=True)

# Filter out only Parquet and CSV files
parquet_files = [f for f in files_list if f.endswith(".parquet")]
csv_files = [f for f in files_list if f.endswith(".csv")]

# Read the Parquet files if they exist
if parquet_files:
    df = spark.read.option("header", "true").parquet(*parquet_files)
    print("Read Parquet files")
elif csv_files:
    # Read the CSV files if they exist
    df = spark.read.option("header", "true").csv(*csv_files)
    print("Read CSV files")
else:
    print("No valid Parquet or CSV files found.")

# Show row count and schema of the DataFrame
print("row_count:", df.count())
df.show(5) 

Read Parquet files


                                                                                

row_count: 78040
+----------+-------------+--------------------+-----------------+
|customerID|snapshot_date|          model_name|model_predictions|
+----------+-------------+--------------------+-----------------+
|4526-ZJJTM|   2024-07-01|log_reg_churn_mod...|                0|
|8384-FZBJK|   2024-07-01|log_reg_churn_mod...|                0|
|3750-RNQKR|   2024-07-01|log_reg_churn_mod...|                0|
|0962-CQPWQ|   2024-07-01|log_reg_churn_mod...|                0|
|3096-YXENJ|   2024-07-01|log_reg_churn_mod...|                0|
+----------+-------------+--------------------+-----------------+
only showing top 5 rows

