# Classification of Merchants

Using K-Means to classify, and evaluate the classification results by Silhouette Score and WCSS.

---

In [56]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, udf, sum
from pyspark.sql.types import DoubleType
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.clustering import KMeans
from pyspark.ml.linalg import Vectors
from pyspark.ml.evaluation import ClusteringEvaluator

In [37]:
spark = (
    SparkSession.builder.appName("Classification")
    .config("spark.sql.repl.eagerEval.enabled", True) 
    .config("spark.sql.parquet.cacheMetadata", "true")
    .config("spark.sql.session.timeZone", "Etc/UTC")
    .config("spark.sql.debug.maxToStringFields", 3000)
    .config("spark.network.timeout", "300s")
    .config("spark.driver.maxResultSize", "4g")
    .config("spark.rpc.askTimeout", "300s")
    .config("spark.driver.memory", "8G")
    .config("spark.executor.memory", "8G")
    .getOrCreate()
)

# Read dataset

In [38]:
merchant_detail = spark.read.parquet('../data/curated/merchant_detail')
full_transaction = spark.read.parquet('../data/curated/full_transaction')

In [39]:
merchant_detail.show(5)

+--------------------+------------+-------------+---------+--------------------+
|                name|merchant_abn|revenue_level|take_rate|  processed_category|
+--------------------+------------+-------------+---------+--------------------+
|fusce aliquet lim...| 17189523131|            c|     1.59|writing paper sup...|
|    cum sociis corp.| 22528859307|            a|      6.4|shop jewelry repa...|
|morbi vehicula li...| 32413511882|            b|     3.35|vehicle new part ...|
|vel nisl incorpor...| 32897338221|            a|      6.2|souvenir card sho...|
|    egestas sed inc.| 34082818630|            a|     6.86|digital book musi...|
+--------------------+------------+-------------+---------+--------------------+
only showing top 5 rows



In [40]:
full_transaction.show(5, truncate=False)

+------------+------------------------------------+---------+--------------------------+-------------------+--------------------+---------------+---------------+---------------+---------------+---------------+----------------+------------+-------------------+--------------+--------------+---------------+--------------+---------------+---------------+-----------------+---------------+----------------+----------------+-------------------+------------+----------------+------------------+-------------+-------------+-------------+---------------+---------------+-------------+--------------+--------------------+--------------+-------------------+-----------------+--------------+-----------------+--------------+-------------------+--------------+--------------+---------------+---------------+-------------+--------------+-------------------+--------------+--------------+----------------+-------------+-------------+---------------+-----------+---------------+---------------+----------------+---

# Create feature vectors

In [41]:
# convert string features to numerical features
indexer = StringIndexer(inputCol="processed_category", outputCol="category_index")
indexed_df = indexer.fit(merchant_detail).transform(merchant_detail)
indexed_df.show(5)

+--------------------+------------+-------------+---------+--------------------+--------------+
|                name|merchant_abn|revenue_level|take_rate|  processed_category|category_index|
+--------------------+------------+-------------+---------+--------------------+--------------+
|fusce aliquet lim...| 17189523131|            c|     1.59|writing paper sup...|          15.0|
|    cum sociis corp.| 22528859307|            a|      6.4|shop jewelry repa...|          11.0|
|morbi vehicula li...| 32413511882|            b|     3.35|vehicle new part ...|          18.0|
|vel nisl incorpor...| 32897338221|            a|      6.2|souvenir card sho...|           5.0|
|    egestas sed inc.| 34082818630|            a|     6.86|digital book musi...|           0.0|
+--------------------+------------+-------------+---------+--------------------+--------------+
only showing top 5 rows



In [42]:
# convert string features to numerical features
indexer = StringIndexer(inputCol="processed_category", outputCol="category_index")
indexed_df = indexer.fit(merchant_detail).transform(merchant_detail)

# create feature vectors
feature_columns = ['category_index']
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
feature_data = assembler.transform(indexed_df)

# Perform K-Means clustering with 5 clusters

In [43]:
kmeans = KMeans(k=5, seed=42)
model = kmeans.fit(feature_data)
predictions = model.transform(feature_data)
predictions.show(5)

+--------------------+------------+-------------+---------+--------------------+--------------+--------+----------+
|                name|merchant_abn|revenue_level|take_rate|  processed_category|category_index|features|prediction|
+--------------------+------------+-------------+---------+--------------------+--------------+--------+----------+
|fusce aliquet lim...| 17189523131|            c|     1.59|writing paper sup...|          15.0|  [15.0]|         4|
|    cum sociis corp.| 22528859307|            a|      6.4|shop jewelry repa...|          11.0|  [11.0]|         0|
|morbi vehicula li...| 32413511882|            b|     3.35|vehicle new part ...|          18.0|  [18.0]|         4|
|vel nisl incorpor...| 32897338221|            a|      6.2|souvenir card sho...|           5.0|   [5.0]|         2|
|    egestas sed inc.| 34082818630|            a|     6.86|digital book musi...|           0.0|   [0.0]|         1|
+--------------------+------------+-------------+---------+-------------

# Inspect the results of K-Means

### Check the number of instances in each class

In [44]:
# view the number of clusters (k) used in the KMeans model
print(f"Number of clusters (k): {model.getK()}")

# view the centers of each cluster
centers = model.clusterCenters()
print("Cluster centers:")
for center in centers:
    print(center)

# calculate the size of each cluster
cluster_sizes = predictions.groupBy("prediction").count()
cluster_sizes.show()

# view the prediction results, including the cluster assigned to each data point
predictions.select("name", "processed_category", "category_index", "prediction").show(5)

Number of clusters (k): 5
Cluster centers:
[11.48435171]
[1.9640592]
[6.98102679]
[21.27830832]
[15.95384615]
+----------+-----+
|prediction|count|
+----------+-----+
|         1|  946|
|         3|  733|
|         4|  780|
|         2|  896|
|         0|  671|
+----------+-----+

+--------------------+--------------------+--------------+----------+
|                name|  processed_category|category_index|prediction|
+--------------------+--------------------+--------------+----------+
|fusce aliquet lim...|writing paper sup...|          15.0|         4|
|    cum sociis corp.|shop jewelry repa...|          11.0|         0|
|morbi vehicula li...|vehicle new part ...|          18.0|         4|
|vel nisl incorpor...|souvenir card sho...|           5.0|         2|
|    egestas sed inc.|digital book musi...|           0.0|         1|
+--------------------+--------------------+--------------+----------+
only showing top 5 rows



### Obtain the category for each cluster

In [45]:
for cluster_id in range(model.getK()):
    print(f"Cluster {cluster_id} categories:")
    cluster = predictions.filter(col("prediction") == cluster_id)
    category_counts = cluster.groupBy("processed_category")\
                                     .count().orderBy("count", ascending=False)
    category_counts.show(truncate=False)

Cluster 0 categories:
+---------------------------------------------+-----+
|processed_category                           |count|
+---------------------------------------------+-----+
|service shop sale and bicycle                |170  |
|shop jewelry repair clock and watch          |170  |
|instrument musical music shop piano sheet and|167  |
|and beauty health spa                        |164  |
+---------------------------------------------+-----+

Cluster 1 categories:
+--------------------------------------------------------------------------+-----+
|processed_category                                                        |count|
+--------------------------------------------------------------------------+-----+
|digital book music movie goods:                                           |195  |
|artist shop supply craft and                                              |193  |
|computer integrated processing design service system data and programming |191  |
|shop shoe               

### Compute Silhouette Score

The closer the Silhouette Score value is to 1, the better the classification effect is:

In [46]:
# evaluate using contour coefficient
evaluator = ClusteringEvaluator(featuresCol='features', 
                                metricName='silhouette', 
                                distanceMeasure='squaredEuclidean')

silhouette_score = evaluator.evaluate(predictions)
print(f"Silhouette Score: {silhouette_score}")

Silhouette Score: 0.7273731394415569


### Compute WCSS

In [47]:
# get cluster centers and broadcast them
centers = model.clusterCenters()
broadcast_centers = spark.sparkContext.broadcast(centers)

# define UDF to compute squared Euclidean distance
def squared_distance(features, center):
    return float(Vectors.squared_distance(features, center))

# define UDF to use the broadcasted centers
squared_distance_udf = udf(lambda features, cluster: squared_distance(features, Vectors.dense(broadcast_centers.value[int(cluster)])), DoubleType())

# add distance column to the DataFrame
distance = predictions.withColumn('distance', squared_distance_udf(col('features'), col('prediction')))

# compute WCSS
wcss = distance.groupBy('prediction').agg({'distance': 'sum'}).agg({'sum(distance)': 'sum'}).collect()[0][0]
print(f"Within-Cluster Sum of Squares (WCSS): {wcss}")

Within-Cluster Sum of Squares (WCSS): 8114.604724895357


# Show the predictions with cluster assignments

In [48]:
predictions.show(truncate=False)

+---------------------------------+------------+-------------+---------+--------------------------------------------------------------------------+--------------+--------+----------+
|name                             |merchant_abn|revenue_level|take_rate|processed_category                                                        |category_index|features|prediction|
+---------------------------------+------------+-------------+---------+--------------------------------------------------------------------------+--------------+--------+----------+
|fusce aliquet limited            |17189523131 |c            |1.59     |writing paper supply stationery office printing and                       |15.0          |[15.0]  |4         |
|cum sociis corp.                 |22528859307 |a            |6.4      |shop jewelry repair clock and watch                                       |11.0          |[11.0]  |0         |
|morbi vehicula limited           |32413511882 |b            |3.35     |vehicle new p

In [49]:
print("Cluster Centers:")
for center in centers:
    print(center)

Cluster Centers:
[11.48435171]
[1.9640592]
[6.98102679]
[21.27830832]
[15.95384615]


# Merge the classification results to `full_transaction` dataset

Drop unused columns:

In [50]:
cols_to_remove = [
    'name', 'category_including', 'category_other', 'category_and', 'category_except',
    'category_al', 'category_shop', 'state'
]
full_transaction = full_transaction.drop(*cols_to_remove)

Merge the `full_transaction` and `predictions` tables based on the `merchant_abn` column:

In [51]:
full_transaction = full_transaction.join(predictions.select("merchant_abn", "prediction"), 
                                         on="merchant_abn", how="left")
full_transaction = full_transaction.withColumnRenamed("prediction", "merchant_segment")
full_transaction.show(5, truncate=False)

+------------+------------------------------------+---------+--------------------------+-------------------+--------------------+---------------+---------------+---------------+---------------+---------------+----------------+------------+-------------------+--------------+--------------+---------------+--------------+---------------+---------------+-----------------+---------------+----------------+----------------+-------------------+------------+----------------+------------------+-------------+-------------+---------------+---------------+-------------+--------------+--------------------+--------------+-------------------+-----------------+--------------+-----------------+--------------+-------------------+--------------+---------------+---------------+-------------+--------------+-------------------+--------------+--------------+----------------+-------------+-------------+---------------+---------------+----------------+----------------+---------------+------------------+--------

Check the shape of the dataset:

In [52]:
num_rows = full_transaction.count()
print(f"Number of rows: {num_rows}")

num_columns = len(full_transaction.columns)
print(f"Number of columns: {num_columns}")

Number of rows: 11372745
Number of columns: 125


In [54]:
full_transaction.printSchema()

root
 |-- merchant_abn: long (nullable = true)
 |-- order_id: string (nullable = true)
 |-- take_rate: float (nullable = true)
 |-- merchant_fraud_probability: double (nullable = true)
 |-- transaction_revenue: double (nullable = true)
 |-- BNLP_revenue: double (nullable = true)
 |-- revenue_level_e: integer (nullable = true)
 |-- revenue_level_d: integer (nullable = true)
 |-- revenue_level_c: integer (nullable = true)
 |-- revenue_level_b: integer (nullable = true)
 |-- revenue_level_a: integer (nullable = true)
 |-- category_jewelry: integer (nullable = true)
 |-- category_art: integer (nullable = true)
 |-- category_television: integer (nullable = true)
 |-- category_watch: integer (nullable = true)
 |-- category_cable: integer (nullable = true)
 |-- category_repair: integer (nullable = true)
 |-- category_stock: integer (nullable = true)
 |-- category_flower: integer (nullable = true)
 |-- category_office: integer (nullable = true)
 |-- category_souvenir: integer (nullable = true)

Confirm there is no nulls:

In [57]:
# create a dictionary with column names and their respective null counts
null_count_dict = {col_name: sum(col(col_name).isNull().cast("int")).alias(col_name) for col_name in full_transaction.columns}

# use agg() to calculate null counts for each column
null_counts_df = full_transaction.agg(*null_count_dict.values())
null_counts_df.show()



+------------+--------+---------+--------------------------+-------------------+------------+---------------+---------------+---------------+---------------+---------------+----------------+------------+-------------------+--------------+--------------+---------------+--------------+---------------+---------------+-----------------+---------------+----------------+----------------+-------------------+------------+----------------+------------------+-------------+-------------+---------------+---------------+-------------+--------------+--------------------+--------------+-------------------+-----------------+--------------+-----------------+--------------+-------------------+--------------+---------------+---------------+-------------+--------------+-------------------+--------------+--------------+----------------+-------------+-------------+---------------+---------------+----------------+----------------+---------------+------------------+----------------+-------------------+-------

                                                                                

In [58]:
# save as a parquet file
full_transaction.write.parquet('../data/curated/full_transaction_with_segments', mode='overwrite')

                                                                                