In [None]:
import pandas as pd
import numpy as np
import shap
from snowflake.snowpark.functions import col, when
from snowflake.ml.modeling.ensemble import RandomForestClassifier
from snowflake.ml.modeling.metrics import accuracy_score

from sklearn.model_selection import train_test_split
from sklearn.model_selection import train_test_split
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.types import FloatType, DateType
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col, when, regexp_replace
from datetime import datetime
from snowflake.snowpark.functions import lit, current_date, current_timestamp
from snowflake.snowpark.types import IntegerType
from snowflake.ml.feature_store import (
    FeatureStore,
    FeatureView,
    Entity,
    CreationMode
)
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    average_precision_score
)
from snowflake.ml.modeling.pipeline import Pipeline
from snowflake.ml.modeling.xgboost import XGBClassifier
from snowflake.ml.registry import Registry
from snowflake.ml.model import model_signature
# from snowflake.ml.modeling.model_selection import train_test_split
from datetime import date

In [None]:
def preprocess_after_feature_store(df):
    # Convert target to binary label
    df = df.with_column("CHURN", when(col("CHURN") == "Yes", 1).otherwise(0))

    # Categorical columns to one-hot encode
    onehot_columns = [
        "GENDER", "PARTNER", "DEPENDENTS", "PHONESERVICE", "MULTIPLELINES", "INTERNETSERVICE",
        "ONLINESECURITY", "ONLINEBACKUP", "DEVICEPROTECTION", "TECHSUPPORT",
        "STREAMINGTV", "STREAMINGMOVIES", "CONTRACT", "PAPERLESSBILLING", "PAYMENTMETHOD"
    ]

    # One-hot encoding manually
    for column_name in onehot_columns:
        unique_vals = df.select(col(column_name)).distinct().collect()
        for row in unique_vals:
            value = row[column_name]
            if value is not None:
                safe_value = str(value).replace(' ', '_').replace('-', '_').replace('(', '').replace(')', '')
                new_col = f"{column_name}_{safe_value}"
                df = df.with_column(new_col, when(col(column_name) == value, 1).otherwise(0))

    # Drop original categorical columns
    df = df.drop(*onehot_columns)
    return df

def clean_total_charges_column(df):
    """
    Cleans the TOTALCHARGES column:
    - Replaces known invalid values ('', 'No', 'N/A', etc.) with NULL
    - Removes non-numeric characters from valid entries
    - Casts to FloatType
    """
    df = df.with_column(
        "TOTALCHARGES_CLEANED",
        when(
            (col("TOTALCHARGES").is_null()) |
            (col("TOTALCHARGES") == '') |
            (col("TOTALCHARGES").isin("No", "N/A", "null", "None")),
            None
        ).otherwise(
            regexp_replace(col("TOTALCHARGES"), r"[^0-9.]", "")
        ).cast(FloatType())
    )
    # Optionally drop original and rename
    df = df.drop("TOTALCHARGES").with_column_renamed("TOTALCHARGES_CLEANED", "TOTALCHARGES")
    return df

In [None]:
session = get_active_session()

# Add a query tag to the session. This helps with debugging and performance monitoring.
session.query_tag = session.query_tag = {
    "origin": "churn_data_science",
    "name": "churn_prediction_model_training",
    "version": {"major": 1, "minor": 0},
    "attributes": {"training": 1, "source": "notebook"}
}

# Set session context 
session.use_role("DEV_DATA_SCIENCE") 
# Set the compute warehouse (used to run queries and transformations)
session.use_warehouse("DEV_DATA_SCIENCE_WH")

# Set the active database (logical container for schemas and tables)
session.use_database("DEV_DATA_SCIENCE_DB")

# Set the active schema (namespace within the database)
session.use_schema("CHURN_MODEL_SCHEMA")

# Print the current role, warehouse, and database/schema
print(f"Session role: {session.get_current_role()} \nSession WH: {session.get_current_warehouse()} \nSession DB.SCHEMA: {session.get_fully_qualified_current_schema()}")

# Feature Store schema reference
FEATURE_STORE_SCHEMA = "DEV_DATA_SCIENCE_DB.FEATURE_STORE_SCHEMA"

# Model Registry schema reference
MODEL_REGISTRY_SCHEMA = "DEV_DATA_SCIENCE_DB.MODEL_REGISTRY_SCHEMA"

# Print the current role, warehouse, and database/schema
print(f"FEATURE_STORE_SCHEMA: {FEATURE_STORE_SCHEMA} \nMODEL_REGISTRY_SCHEMA: {MODEL_REGISTRY_SCHEMA}")


In [None]:
# Read the SQL query string
sql_file_path = "TRAIN_CHURN_MODEL_V2/features_ingest.sql"
with open(sql_file_path, "r") as file:
    sql_query = file.read()

# Execute the SQL query using Snowpark
df = session.sql(sql_query)

# Apply preprocessing to Snowpark DataFrame
df = preprocess_after_feature_store(df)
df.show()

In [None]:
# -----------------------------
# Exclude non-numeric or irrelevant columns
excluded_cols = {"CUSTOMERID", "CHURN", "SNAPSHOT_DATE"}

# Select only numeric/one-hot columns for model training
feature_cols = [
    col_name for col_name in df.columns 
    if col_name.upper() not in excluded_cols
]

# -----------------------------
# Split the data into training and testing sets
train_df, test_df = df.randomSplit([0.8, 0.2], seed=42)

# -----------------------------
# Initialize and train the model
rf_model = RandomForestClassifier(
    input_cols=feature_cols,
    label_cols=["CHURN"],
    output_cols=["PREDICTION"],
    n_estimators=100,
    random_state=42
)

rf_model.fit(train_df)

In [None]:
today_str = date.today().strftime("%Y_%m_%d")

# Initialize Model Registry
registry = Registry(
    session=session,
    database_name="DEV_DATA_SCIENCE_DB", 
    schema_name="MODEL_REGISTRY_SCHEMA"
)

# -----------------------------
# Sample input for registry
sample_input_data = train_df.select(feature_cols).limit(10)

# -----------------------------
# Define model name and version
version_str = f"V_{today_str}"
model_name = f"CHURN_RF_V1_{version_str}"
dev_comment = "Random Forest model for churn prediction"
# -----------------------------
# Log model to registry
model_version = registry.log_model(
    model=rf_model,
    model_name=model_name,
    version_name=version_str,
    comment=dev_comment,
    sample_input_data=sample_input_data,
)

# -----------------------------
# Fetch versioned model from registry
model_version = registry.get_model(model_name).version(version_str)

# -----------------------------
# Get predicted probabilities
proba_df = model_version.run(test_df, function_name="predict_proba")


# -----------------------------
# Apply threshold manually
threshold = 0.5  # you can change this
# Assumes output column is a 2-element array: [prob_class_0, prob_class_1]
proba_df = proba_df.with_column(
    "PREDICTION",
    when(col('"PREDICT_PROBA_1"') >= threshold, 1).otherwise(0)
)
# -----------------------------
# Evaluate metrics
true_vals = test_df.select("CHURN").to_pandas()["CHURN"].astype(int).values
pred_vals = proba_df.select("PREDICTION").to_pandas()["PREDICTION"].astype(int).values
prob_vals = proba_df.select('"PREDICT_PROBA_1"').to_pandas()["PREDICT_PROBA_1"].values

acc = accuracy_score(true_vals, pred_vals)
prec = precision_score(true_vals, pred_vals, zero_division=0)
rec = recall_score(true_vals, pred_vals, zero_division=0)
f1 = f1_score(true_vals, pred_vals, zero_division=0)
roc_auc = roc_auc_score(true_vals, prob_vals)
pr_auc = average_precision_score(true_vals, prob_vals)

print(f"Accuracy: {acc:.4f}")
print(f"Precision: {prec:.4f}")
print(f"Recall: {rec:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"ROC AUC: {roc_auc:.4f}")
print(f"PR AUC: {pr_auc:.4f}")

# -----------------------------
# Log metrics to model registry
model_version.set_metric("accuracy", acc)
model_version.set_metric("precision", prec)
model_version.set_metric("recall", rec)
model_version.set_metric("f1_score", f1)
model_version.set_metric("roc_auc", roc_auc)
model_version.set_metric("pr_auc", pr_auc)

print("Metrics logged to model registry.")




In [None]:
# Define model metrics and metadata
training_date = datetime.today().date()

# Example metric values
precision = 0.81
recall = 0.76
f1_score = 0.78
accuracy = 0.84
roc_auc = 0.89
pr_auc = 0.75
threshold = 0.5

# Create a DataFrame to insert the values
metrics_df = session.create_dataframe([[
    model_name,
    training_date,
    prec,
    rec,
    f1,
    acc,
    roc_auc,
    pr_auc,
    threshold,
    dev_comment
]], schema=[
    "MODEL_ID",
    "TRAINING_DATE",
    "PRECISION",
    "RECALL",
    "F1_SCORE",
    "ACCURACY",
    "ROC_AUC",
    "PR_AUC",
    "THRESHOLD",
    "COMMENT"
])

# Insert into the target table
metrics_df.write.mode("append").save_as_table("DEV_DATA_SCIENCE_DB.CHURN_MODEL_SCHEMA.MODEL_PERFORMANCE_METRICS")

In [None]:
# Get feature importances and feature names
rf_model_sk = rf_model.to_sklearn()
importances = rf_model_sk.feature_importances_
feature_names = feature_cols  # or manually provided list

# Pair and sort
sorted_importances = sorted(
    zip(feature_names, importances),
    key=lambda x: x[1],
    reverse=True
)

# Create rows with rank (CREATED_AT is auto-handled by Snowflake)
feature_data = [
    (model_name, fname, round(score, 5), rank + 1, training_date)
    for rank, (fname, score) in enumerate(sorted_importances)
]

# Create Snowpark DataFrame
importance_df = session.create_dataframe(
    feature_data,
    schema=["MODEL_ID", "FEATURE_NAME", "IMPORTANCE_SCORE", "RANK", "TRAINING_DATE"]
)

# Write to Snowflake table
importance_df.write.mode("append").save_as_table("DEV_DATA_SCIENCE_DB.CHURN_MODEL_SCHEMA.FEATURE_IMPORTANCE")