In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.types import FloatType, DateType
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col, when, regexp_replace
from datetime import datetime
from snowflake.snowpark.functions import lit, current_date, current_timestamp
from snowflake.snowpark.types import IntegerType
from snowflake.ml.feature_store import (
    FeatureStore,
    FeatureView,
    Entity,
    CreationMode
)
from snowflake.ml.modeling.pipeline import Pipeline
from snowflake.ml.modeling.xgboost import XGBClassifier
from snowflake.ml.registry import Registry
from snowflake.ml.model import model_signature
# from snowflake.ml.modeling.model_selection import train_test_split

In [None]:
def clean_column_names(df):
    cleaned_cols = {col_name: col_name.strip().replace('"', '') for col_name in df.columns}
    return df.select([col(c).alias(cleaned_cols[c]) for c in df.columns])

def preprocess_before_feature_store(df):
    # Clean column names
    df = df.select([col(c).alias(c.strip().replace('"', '')) for c in df.columns])

    # Cast numeric fields properly
    df = df.with_column("SeniorCitizen", col("SeniorCitizen").cast(IntegerType()))
    df = df.with_column("tenure", col("tenure").cast(IntegerType()))
    df = df.with_column("MonthlyCharges", col("MonthlyCharges").cast("float"))
    df = df.with_column(
        "TotalCharges",
        when(col("TotalCharges") == ' ', None).otherwise(col("TotalCharges")).cast("float")
    )

    # Drop rows with nulls in critical numeric fields
    df = df.filter(col("TotalCharges").is_not_null())

    return df

def preprocess_after_feature_store(df):
    # Convert target to binary label
    df = df.with_column("CHURN", when(col("CHURN") == "Yes", 1).otherwise(0))

    # Categorical columns to one-hot encode
    onehot_columns = [
        "GENDER", "PARTNER", "DEPENDENTS", "PHONESERVICE", "MULTIPLELINES", "INTERNETSERVICE",
        "ONLINESECURITY", "ONLINEBACKUP", "DEVICEPROTECTION", "TECHSUPPORT",
        "STREAMINGTV", "STREAMINGMOVIES", "CONTRACT", "PAPERLESSBILLING", "PAYMENTMETHOD"
    ]

    # One-hot encoding manually
    for column_name in onehot_columns:
        unique_vals = df.select(col(column_name)).distinct().collect()
        for row in unique_vals:
            value = row[column_name]
            if value is not None:
                safe_value = str(value).replace(' ', '_').replace('-', '_').replace('(', '').replace(')', '')
                new_col = f"{column_name}_{safe_value}"
                df = df.with_column(new_col, when(col(column_name) == value, 1).otherwise(0))

    # Drop original categorical columns
    df = df.drop(*onehot_columns)
    return df

def clean_total_charges_column(df):
    """
    Cleans the TOTALCHARGES column:
    - Replaces known invalid values ('', 'No', 'N/A', etc.) with NULL
    - Removes non-numeric characters from valid entries
    - Casts to FloatType
    """
    df = df.with_column(
        "TOTALCHARGES_CLEANED",
        when(
            (col("TOTALCHARGES").is_null()) |
            (col("TOTALCHARGES") == '') |
            (col("TOTALCHARGES").isin("No", "N/A", "null", "None")),
            None
        ).otherwise(
            regexp_replace(col("TOTALCHARGES"), r"[^0-9.]", "")
        ).cast(FloatType())
    )
    # Optionally drop original and rename
    df = df.drop("TOTALCHARGES").with_column_renamed("TOTALCHARGES_CLEANED", "TOTALCHARGES")
    return df

In [None]:
session = get_active_session()

# Add a query tag to the session. This helps with debugging and performance monitoring.
session.query_tag = session.query_tag = {
    "origin": "churn_data_science",
    "name": "churn_prediction_model_training",
    "version": {"major": 1, "minor": 0},
    "attributes": {"training": 1, "source": "notebook"}
}

# Set session context 
session.use_role("DEV_DATA_SCIENCE") 
# Set the compute warehouse (used to run queries and transformations)
session.use_warehouse("DEV_DATA_SCIENCE_WH")

# Set the active database (logical container for schemas and tables)
session.use_database("DEV_DATA_SCIENCE_DB")

# Set the active schema (namespace within the database)
session.use_schema("CHURN_MODEL_SCHEMA")

# Print the current role, warehouse, and database/schema
print(f"Session role: {session.get_current_role()} \nSession WH: {session.get_current_warehouse()} \nSession DB.SCHEMA: {session.get_fully_qualified_current_schema()}")

# Feature Store schema reference
FEATURE_STORE_SCHEMA = "DEV_DATA_SCIENCE_DB.FEATURE_STORE_SCHEMA"

# Model Registry schema reference
MODEL_REGISTRY_SCHEMA = "DEV_DATA_SCIENCE_DB.MODEL_REGISTRY_SCHEMA"

# Print the current role, warehouse, and database/schema
print(f"FEATURE_STORE_SCHEMA: {FEATURE_STORE_SCHEMA} \nMODEL_REGISTRY_SCHEMA: {MODEL_REGISTRY_SCHEMA}")


In [None]:
# Read the SQL query string
sql_file_path = "TRAIN_CHURN_MODEL/features_ingest.sql"
with open(sql_file_path, "r") as file:
    sql_query = file.read()

# Execute the SQL query using Snowpark
df = session.sql(sql_query)

# Clean column names: remove quotes and standardize to uppercase
cleaned_columns = {col_name: col_name.strip('"').upper() for col_name in df.columns}
df = df.select([col(c).alias(cleaned_columns[c]) for c in df.columns])

# Handle TOTALCHARGES: replace blank with NULL and cast to float
# from snowflake.snowpark.functions import col, when, regexp_replace
# from snowflake.snowpark.types import DoubleType

# df = df.with_column(
#     "TOTALCHARGES",
#     when(
#         # Handle NULLs and non-numeric garbage strings
#         (col("TOTALCHARGES").is_null()) |
#         (col("TOTALCHARGES") == '') |
#         (col("TOTALCHARGES").isin("No", "N/A", "null", "None")),
#         None
#     ).otherwise(
#         # Remove all characters except digits and dot, then cast
#         regexp_replace(col("TOTALCHARGES"), r"[^0-9.]", "").cast(DoubleType())
#     )
# )


# Add SNAPSHOT_DATE column with current date
df = df.with_column("SNAPSHOT_DATE", current_date())



In [None]:
# Write to Snowflake table
df.write.mode("append").save_as_table("RAW_FEATURES_DATA")

In [None]:
df = clean_column_names(df)
df = preprocess_before_feature_store(df)
label_df = df.select(col("CUSTOMERID"), col("SNAPSHOT_DATE"), col("CHURN"))
df.show()

In [None]:
from snowflake.snowpark.functions import col, current_timestamp
from snowflake.ml.feature_store import FeatureStore, CreationMode, FeatureView, Entity
import pandas as pd
from datetime import datetime

# Set session context
session.use_database("DEV_DATA_SCIENCE_DB")
session.use_schema("FEATURE_STORE_SCHEMA")

# Initialize Feature Store
fs = FeatureStore(
    session=session,
    database="DEV_DATA_SCIENCE_DB",
    name="FEATURE_STORE_SCHEMA",
    default_warehouse=session.get_current_warehouse(),
    creation_mode=CreationMode.CREATE_IF_NOT_EXIST
)

# -----------------------------
# Register Entity
from snowflake.ml.feature_store import Entity

customer_entity = Entity(
    name="customer_entity",
    join_keys=["SNAPSHOT_DATE", "CUSTOMERID"],
    desc="Entity representing customer identifier"
)
fs.register_entity(customer_entity)


# -----------------------------
# Track ingestion metadata manually
def log_feature_view_registration(name, version, desc, features):
    metadata_df = pd.DataFrame([{
        "FEATURE_VIEW": name,
        "VERSION": version,
        "DESC": desc,
        "FEATURES": features,
        "INGESTED_AT": datetime.now()
    }])
    session.write_pandas(metadata_df, "FEATURE_VIEW_HISTORY", mode="append")

# Make sure SNAPSHOT_DATE exists
df = df.with_column("SNAPSHOT_DATE", current_timestamp())

# -----------------------------
# Demographic Features
demographic_features = FeatureView(
    name="demographic_features",
    entities=[customer_entity],
    feature_df=df.select(
        col("SNAPSHOT_DATE"),
        col("CUSTOMERID"),
        col("GENDER"),
        col("SENIORCITIZEN"),
        col("PARTNER"),
        col("DEPENDENTS")
    ),
    desc="Demographic-related features"
)
fs.register_feature_view(demographic_features, version="1.0")
log_feature_view_registration(
    "demographic_features", "1.0", "Demographic-related features", demographic_features.feature_df.schema.names
)

# -----------------------------
# Financial Features
financial_features = FeatureView(
    name="financial_features",
    entities=[customer_entity],
    feature_df=df.select(
        col("SNAPSHOT_DATE"),
        col("CUSTOMERID"),
        col("MONTHLYCHARGES"),
        col("TOTALCHARGES"),
        col("PAYMENTMETHOD"),
        col("PAPERLESSBILLING")
    ),
    desc="Financial behavior and billing features"
)
fs.register_feature_view(financial_features, version="1.0")
log_feature_view_registration(
    "financial_features", "1.0", "Financial behavior and billing features", financial_features.feature_df.schema.names
)

# -----------------------------
# Usage Features
usage_features = FeatureView(
    name="usage_features",
    entities=[customer_entity],
    feature_df=df.select(
        col("SNAPSHOT_DATE"),
        col("CUSTOMERID"),
        col("PHONESERVICE"),
        col("MULTIPLELINES"),
        col("INTERNETSERVICE"),
        col("ONLINESECURITY"),
        col("ONLINEBACKUP"),
        col("DEVICEPROTECTION"),
        col("TECHSUPPORT"),
        col("STREAMINGTV"),
        col("STREAMINGMOVIES"),
        col("CONTRACT")
    ),
    desc="Usage and service interaction features"
)
fs.register_feature_view(usage_features, version="1.0")
log_feature_view_registration(
    "usage_features", "1.0", "Usage and service interaction features", usage_features.feature_df.schema.names
)

# -----------------------------
# Spine Definition
spine_df = df.select(col("CUSTOMERID"), col("SNAPSHOT_DATE"))

# Retrieve views for joining
demographic_fv = fs.get_feature_view(name="demographic_features", version="1.0")
financial_fv = fs.get_feature_view(name="financial_features", version="1.0")
usage_fv = fs.get_feature_view(name="usage_features", version="1.0")


In [None]:
# from snowflake.snowpark.functions import col
# from snowflake.ml.feature_store import FeatureStore, CreationMode, FeatureView, Entity

# # Set session context
# session.use_database("DEV_DATA_SCIENCE_DB")
# session.use_schema("FEATURE_STORE_SCHEMA")

# # Initialize Feature Store
# fs = FeatureStore(
#     session=session,
#     database="DEV_DATA_SCIENCE_DB",
#     name="FEATURE_STORE_SCHEMA",
#     default_warehouse=session.get_current_warehouse(),
#     creation_mode=CreationMode.CREATE_IF_NOT_EXIST
# )

# # Register entity for CUSTOMERID, SNAPSHOT_DATE 
# # (required by Feature Store)
# customer_entity = Entity(
#     name="customer_entity",
#     join_keys=["CUSTOMERID", "SNAPSHOT_DATE"],
#     desc="Entity representing customer identifier"
# )

# fs.register_entity(customer_entity)

# # Use the object, not its name string
# demographic_features = FeatureView(
#     name="demographic_features",
#     entities=[customer_entity],
#     feature_df=df.select(
#         col("CUSTOMERID"),
#         col("SNAPSHOT_DATE"),
#         col("GENDER"),
#         col("SENIORCITIZEN"),
#         col("PARTNER"),
#         col("DEPENDENTS")
#     ),
#     desc="Demographic-related features"
# )

# fs.register_feature_view(demographic_features, version='1.0')

# # Financial Features
# financial_features = FeatureView(
#     name="financial_features",
#     entities=[customer_entity],
#     feature_df=df.select(
#         col("CUSTOMERID"),
#         col("SNAPSHOT_DATE"),
#         col("MONTHLYCHARGES"),
#         col("TOTALCHARGES"),
#         col("PAYMENTMETHOD"),
#         col("PAPERLESSBILLING")
#     ),
#     desc="Financial behavior and billing features"
# )
# fs.register_feature_view(financial_features, version='1.0')

# # Usage Features
# usage_features = FeatureView(
#     name="usage_features",
#     entities=[customer_entity],
#     feature_df=df.select(
#         col("CUSTOMERID"),
#         col("SNAPSHOT_DATE"),
#         col("PHONESERVICE"),
#         col("MULTIPLELINES"),
#         col("INTERNETSERVICE"),
#         col("ONLINESECURITY"),
#         col("ONLINEBACKUP"),
#         col("DEVICEPROTECTION"),
#         col("TECHSUPPORT"),
#         col("STREAMINGTV"),
#         col("STREAMINGMOVIES"),
#         col("CONTRACT")
#     ),
#     desc="Usage and service interaction features"
# )
# fs.register_feature_view(usage_features, version='1.0')

# # Define Spine DataFrame
# # Spine will include CUSTOMERID and the SNAPSHOT_DATE
# spine_df = df.select(col("CUSTOMERID"), col("SNAPSHOT_DATE"))
# # Get registered feature views
# demographic_fv = fs.get_feature_view(name="demographic_features", version="1.0")
# financial_fv = fs.get_feature_view(name="financial_features", version="1.0")
# usage_fv = fs.get_feature_view(name="usage_features", version="1.0")

# #---------------------------------------
# # Deleting DataSet
# from snowflake.ml.dataset import Dataset
# # Load the existing Dataset
# dataset = Dataset.load(session, name="churn_training_dataset")
# # Delete the specific version
# dataset.delete_version("1.0")
# #---------------------------------------

# # Generate training dataset
# dataset = fs.generate_dataset(
#     name="churn_training_dataset",
#     spine_df=spine_df,
#     features=[demographic_fv, financial_fv, usage_fv],
#     version="1.0",
#     desc="Churn training dataset with historical snapshots"
# )

# # Convert Snowpark DataFrame to Pandas for ML
# df = dataset.read.to_snowpark_dataframe()
# df = df.join(label_df, on=["CUSTOMERID", "SNAPSHOT_DATE"], how="inner")


In [None]:
from snowflake.snowpark.functions import col, when
from snowflake.ml.modeling.ensemble import RandomForestClassifier
from snowflake.ml.modeling.metrics import accuracy_score

# -----------------------------
# Apply preprocessing to Snowpark DataFrame
df = preprocess_after_feature_store(df)
df.show()

In [None]:
# Write Post Processed Data to Snowflake table
df.write.mode("append").save_as_table("POST_PROCESSED_FEATURES_DATA")

In [None]:
# -----------------------------
# Exclude non-numeric or irrelevant columns
excluded_cols = {"CUSTOMERID", "CHURN", "SNAPSHOT_DATE"}

# Select only numeric/one-hot columns for model training
feature_cols = [col_name for col_name in df.columns if col_name not in excluded_cols]

# -----------------------------
# Split the data into training and testing sets
train_df, test_df = df.randomSplit([0.8, 0.2], seed=42)

# -----------------------------
# Initialize and train the model
rf_model = RandomForestClassifier(
    input_cols=feature_cols,
    label_cols=["CHURN"],
    output_cols=["PREDICTION"],
    n_estimators=100,
    random_state=42
)

rf_model.fit(train_df)

# -----------------------------
# Make predictions
predictions = rf_model.predict(test_df)

# -----------------------------
# Evaluate accuracy
accuracy = accuracy_score(
    df=predictions,
    y_true_col_names=["CHURN"],
    y_pred_col_names=["PREDICTION"]
)

print(f"Model Accuracy: {accuracy:.4f}")


In [None]:
# Initialize Model Registry
registry = Registry(
    session=session,
    database_name="DEV_DATA_SCIENCE_DB",       # Replace with your database name
    schema_name="MODEL_REGISTRY_SCHEMA"    # Replace with your schema name
)

In [None]:
sample_input_data = train_df.select(feature_cols).limit(10)
model_name = "CHURN_PRED_MODEL"
model_version = registry.log_model(
    model=rf_model,
    model_name="CHURN_PRED_MODEL",
    version_name="v4",
    comment="Random Forest model for churn prediction",
    sample_input_data=sample_input_data,
)

model_version = registry.get_model("CHURN_PRED_MODEL").version("v4")
predictions = model_version.run(test_df, function_name="predict")
predictions.show()
from snowflake.ml.modeling.metrics import accuracy_score

accuracy = accuracy_score(
    df=predictions,
    y_true_col_names=["CHURN"],
    y_pred_col_names=["PREDICTION"]
)
print(f"Model Accuracy: {accuracy:.4f}")