In [0]:
# Clustering Model to segment customers

# Import necessary libraries
import pandas as pd # Still imported, but not used for data creation in this version
from pyspark.sql.functions import col, when, lit, avg, count, sum
from pyspark.sql.window import Window
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.clustering import KMeans
from pyspark.ml import Pipeline
from pyspark.sql.types import DoubleType # Import DoubleType for casting if needed

# --- 1. Define Database and Table Names, and Read Existing Customer Data ---
# IMPORTANT: Ensure these database and table names match your actual setup in Databricks.
# The 'customers' table is assumed to already exist with the specified schema.

database_name = "fraud_detection_db" # Example: "your_data_warehouse"
customer_table_name = "customers"

# Read existing Customer Data from the Delta Table
try:
    # Expected Customer Table Schema:
    # - Id (LongType)
    # - First Name (StringType)
    # - Last Name (StringType)
    # - Age (IntegerType)
    # - Location (StringType)
    # - Annual Income (DoubleType)
    # - Debt-To-Income Ratio (DTI) (DoubleType)
    # - Loan-to-Value Ratio (LTV) (DoubleType)
    # - Average Monthly Spending (DoubleType)
    # - Credit Score (IntegerType)
    spark_customer_df = spark.read.format("delta").table(f"{database_name}.{customer_table_name}")
    print(f"Customer table '{database_name}.{customer_table_name}' loaded successfully.")
except Exception as e:
    print(f"Error loading customer table: {e}")
    print("Please ensure the database and table exist and are accessible with the correct schema.")
    raise # Re-raise the exception to stop execution if tables are not found

print("\n--- Raw Customer Data (first 5 rows) ---")
spark_customer_df.show(5)

# --- 2. Prepare Data for K-Means Clustering ---
# We will focus on 'Annual Income' and 'Average Monthly Spending' for segmentation.

print("\n--- Starting Customer Segmentation (K-Means Clustering) ---")

# Select relevant features for clustering: Annual Income and Average Monthly Spending
# Ensure these columns are of numeric type for clustering and handle any potential nulls.
clustering_data_df = spark_customer_df.select(
    "Id",
    col("Annual Income").cast(DoubleType()).alias("Annual_Income"),
    col("Average Monthly Spending").cast(DoubleType()).alias("Average_Monthly_Spending")
).na.drop(subset=["Annual_Income", "Average_Monthly_Spending"]) # Drop rows with nulls in these key columns

print("\n--- Data for Clustering (first 5 rows) ---")
clustering_data_df.show(5)

# Assemble features into a single vector. This is a required step for Spark ML models.
clustering_assembler = VectorAssembler(
    inputCols=["Annual_Income", "Average_Monthly_Spending"],
    outputCol="features_clustering",
    handleInvalid="skip" # Skip rows with invalid (e.g., non-numeric) feature values
)

# It's crucial to scale features for K-Means to ensure both 'Annual_Income' and
# 'Average_Monthly_Spending' contribute equally to the distance calculations.
# StandardScaler scales features to have zero mean and unit variance.
scaler = StandardScaler(inputCol="features_clustering", outputCol="scaled_features_clustering",
                        withStd=True, withMean=False) # withMean=False for sparse vectors

# --- 3. Train K-Means Clustering Model ---
# Define the K-Means model with k=4 categories as requested.
kmeans = KMeans(featuresCol="scaled_features_clustering", k=4, seed=42)

# Create a pipeline for clustering: (assembler -> scaler -> kmeans)
# This pipeline will first assemble the features, then scale them, and finally apply K-Means.
clustering_pipeline = Pipeline(stages=[clustering_assembler, scaler, kmeans])

# Train the K-Means model on the prepared data.
print("\n--- Training K-Means Clustering Model ---")
kmeans_model = clustering_pipeline.fit(clustering_data_df)
print("K-Means Model training complete.")

# --- 4. Assign Clusters and Interpret Categories ---
# Make predictions (assign each customer to a cluster)
clustered_customers_df = kmeans_model.transform(clustering_data_df)

print("\n--- Customers with Cluster Assignments (first 10 rows) ---")
clustered_customers_df.select("Id", "Annual_Income", "Average_Monthly_Spending", "prediction").show(10)

# To interpret the clusters, we'll calculate the mean of 'Annual_Income' and 'Average_Monthly_Spending'
# for each cluster. This helps us understand what each cluster represents.
print("\n--- Average Income and Spending per Cluster (for interpretation) ---")
cluster_summary = clustered_customers_df.groupBy("prediction").agg(
    avg("Annual_Income").alias("Avg_Annual_Income"),
    avg("Average_Monthly_Spending").alias("Avg_Monthly_Spending")
).orderBy("prediction")
cluster_summary.show()

# --- Assign Human-Readable Categories based on Cluster Summary ---
# Based on the `cluster_summary` output, you would manually map the cluster IDs
# (0, 1, 2, 3) to the desired categories:
# - High Income High Spenders
# - High Income Low Spenders
# - Low Income High Spenders
# - Low Income Low Spenders

# IMPORTANT: The actual mapping from 'prediction' (cluster ID) to category
# depends entirely on the `cluster_summary` output for your specific data.
# You need to run the code, look at `cluster_summary.show()`, and then
# define the conditions below.

# Example of how you might map based on the cluster summary.
# This is a placeholder. YOU MUST ADJUST THIS LOGIC BASED ON YOUR DATA'S CLUSTER_SUMMARY.
# For instance, if cluster 0 has high average income and high average spending, map it to "High Income High Spenders".

# As a generic illustrative example (you'll replace this with specific cluster IDs):
# Let's assume you determine thresholds based on the overall data distribution
# or by inspecting the cluster centroids.
overall_avg_income = clustered_customers_df.agg(avg("Annual_Income")).collect()[0][0]
overall_avg_spending = clustered_customers_df.agg(avg("Average_Monthly_Spending")).collect()[0][0]

print(f"\nOverall Average Annual Income: {overall_avg_income:.2f}")
print(f"Overall Average Monthly Spending: {overall_avg_spending:.2f}")

customer_segments_df = clustered_customers_df.withColumn(
    "customer_category",
    when((col("Annual_Income") >= overall_avg_income) & (col("Average_Monthly_Spending") >= overall_avg_spending), lit("High Income High Spenders"))
    .when((col("Annual_Income") >= overall_avg_income) & (col("Average_Monthly_Spending") < overall_avg_spending), lit("High Income Low Spenders"))
    .when((col("Annual_Income") < overall_avg_income) & (col("Average_Monthly_Spending") >= overall_avg_spending), lit("Low Income High Spenders"))
    .otherwise(lit("Low Income Low Spenders"))
)

# Alternatively, if you have a clear mapping from 'prediction' to category after reviewing `cluster_summary`:
# customer_segments_df = clustered_customers_df.withColumn(
#     "customer_category",
#     when(col("prediction") == 0, lit("High Income High Spenders")) # Example: if cluster 0 is the high-high group
#     .when(col("prediction") == 1, lit("Low Income Low Spenders"))  # Example: if cluster 1 is the low-low group
#     .when(col("prediction") == 2, lit("High Income Low Spenders")) # Example: if cluster 2 is the high-low group
#     .when(col("prediction") == 3, lit("Low Income High Spenders")) # Example: if cluster 3 is the low-high group
#     .otherwise(lit("Uncategorized")) # Fallback for any unmapped clusters
# )

print("\n--- Customers with Assigned Categories (first 10 rows) ---")
customer_segments_df.select("Id", "Annual_Income", "Average_Monthly_Spending", "prediction", "customer_category").show(10, truncate=False)

# --- 5. Save Customer Segmentation Results to a New Table ---
customer_segmentation_output_table_name = "customer_income_spending_segments"
customer_segments_df.write.format("delta").mode("overwrite").saveAsTable(f"{database_name}.{customer_segmentation_output_table_name}")

print(f"\nCustomer segmentation results saved to '{database_name}.{customer_segmentation_output_table_name}'.")

print("\n--- Verify Customer Segmentation Output Table ---")
spark.sql(f"SELECT * FROM {database_name}.{customer_segmentation_output_table_name}").show(truncate=False)

# --- Clean up temporary views (optional) ---
# It's good practice to drop temporary views if they are no longer needed.
spark.sql("DROP VIEW IF EXISTS all_customers")
# The transactions table was not directly used in this model, so no view for it.
print("\nTemporary views dropped.")