# Classification of Merchants

This notebook involves machine learning classification (K-Means) and manually classification on merchant segments. Then will evaluate the K-Means classification results by Silhouette Score and WCSS.

---

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, udf, sum, when
from pyspark.sql.types import DoubleType
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.clustering import KMeans
from pyspark.ml.linalg import Vectors
from pyspark.ml.evaluation import ClusteringEvaluator
import pyspark.sql.functions as F

In [2]:
spark = (
    SparkSession.builder.appName("Classification")
    .config("spark.sql.repl.eagerEval.enabled", True) 
    .config("spark.sql.parquet.cacheMetadata", "true")
    .config("spark.sql.session.timeZone", "Etc/UTC")
    .config("spark.sql.debug.maxToStringFields", 3000)
    .config("spark.network.timeout", "300s")
    .config("spark.driver.maxResultSize", "4g")
    .config("spark.rpc.askTimeout", "300s")
    .config("spark.driver.memory", "8G")
    .config("spark.executor.memory", "8G")
    .getOrCreate()
)

24/09/20 01:55:30 WARN Utils: Your hostname, Cocos-MacBook-Air.local resolves to a loopback address: 127.0.0.1; using 172.16.33.67 instead (on interface en0)
24/09/20 01:55:30 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/09/20 01:55:30 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


# Read dataset

In [3]:
merchant_detail = spark.read.parquet('../data/curated/merchant_detail')
full_transaction = spark.read.parquet('../data/curated/full_transaction')

                                                                                

In [4]:
merchant_detail.show(5)

                                                                                

+--------------------+------------+-------------+---------+--------------------+
|                name|merchant_abn|revenue_level|take_rate|  processed_category|
+--------------------+------------+-------------+---------+--------------------+
|fusce aliquet lim...| 17189523131|            c|     1.59|writing paper sup...|
|    cum sociis corp.| 22528859307|            a|      6.4|shop jewelry repa...|
|morbi vehicula li...| 32413511882|            b|     3.35|vehicle new part ...|
|vel nisl incorpor...| 32897338221|            a|      6.2|souvenir card sho...|
|    egestas sed inc.| 34082818630|            a|     6.86|digital book musi...|
+--------------------+------------+-------------+---------+--------------------+
only showing top 5 rows



In [5]:
full_transaction.show(5, truncate=False)

+------------+------------------------------------+---------+--------------------------+-------------------+-------------------+---------------+---------------+---------------+---------------+---------------+----------------+------------+-------------------+--------------+--------------+---------------+--------------+---------------+---------------+-----------------+---------------+----------------+----------------+-------------------+------------+----------------+------------------+-------------+-------------+---------------+---------------+-------------+--------------+--------------------+--------------+-------------------+-----------------+--------------+-----------------+--------------+-------------------+--------------+---------------+---------------+-------------+--------------+-------------------+--------------+--------------+----------------+-------------+-------------+---------------+---------------+----------------+----------------+---------------+------------------+---------

In [6]:
full_transaction.count()

11304463

# K-Means classification

## Create feature vectors

In [7]:
# convert string features to numerical features
indexer = StringIndexer(inputCol="processed_category", outputCol="category_index")
indexed_df = indexer.fit(merchant_detail).transform(merchant_detail)
indexed_df.show(5)

+--------------------+------------+-------------+---------+--------------------+--------------+
|                name|merchant_abn|revenue_level|take_rate|  processed_category|category_index|
+--------------------+------------+-------------+---------+--------------------+--------------+
|fusce aliquet lim...| 17189523131|            c|     1.59|writing paper sup...|          15.0|
|    cum sociis corp.| 22528859307|            a|      6.4|shop jewelry repa...|          11.0|
|morbi vehicula li...| 32413511882|            b|     3.35|vehicle new part ...|          18.0|
|vel nisl incorpor...| 32897338221|            a|      6.2|souvenir card sho...|           5.0|
|    egestas sed inc.| 34082818630|            a|     6.86|digital book musi...|           0.0|
+--------------------+------------+-------------+---------+--------------------+--------------+
only showing top 5 rows



In [8]:
# convert string features to numerical features
indexer = StringIndexer(inputCol="processed_category", outputCol="category_index")
indexed_df = indexer.fit(merchant_detail).transform(merchant_detail)

# create feature vectors
feature_columns = ['category_index']
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
feature_data = assembler.transform(indexed_df)

## Perform K-Means clustering with 5 clusters

In [9]:
kmeans = KMeans(k=5, seed=42)
model = kmeans.fit(feature_data)
predictions = model.transform(feature_data)
predictions.show(5)

24/09/20 01:55:39 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS


+--------------------+------------+-------------+---------+--------------------+--------------+--------+----------+
|                name|merchant_abn|revenue_level|take_rate|  processed_category|category_index|features|prediction|
+--------------------+------------+-------------+---------+--------------------+--------------+--------+----------+
|fusce aliquet lim...| 17189523131|            c|     1.59|writing paper sup...|          15.0|  [15.0]|         4|
|    cum sociis corp.| 22528859307|            a|      6.4|shop jewelry repa...|          11.0|  [11.0]|         0|
|morbi vehicula li...| 32413511882|            b|     3.35|vehicle new part ...|          18.0|  [18.0]|         4|
|vel nisl incorpor...| 32897338221|            a|      6.2|souvenir card sho...|           5.0|   [5.0]|         2|
|    egestas sed inc.| 34082818630|            a|     6.86|digital book musi...|           0.0|   [0.0]|         1|
+--------------------+------------+-------------+---------+-------------

## Inspect the results of K-Means

#### Check the number of instances in each class

In [10]:
# view the number of clusters (k) used in the KMeans model
print(f"Number of clusters (k): {model.getK()}")

# view the centers of each cluster
centers = model.clusterCenters()
print("Cluster centers:")
for center in centers:
    print(center)

# calculate the size of each cluster
cluster_sizes = predictions.groupBy("prediction").count()
cluster_sizes.show()

# view the prediction results, including the cluster assigned to each data point
predictions.select("name", "processed_category", "category_index", "prediction").show(5)

Number of clusters (k): 5
Cluster centers:
[11.48435171]
[1.9640592]
[6.98102679]
[21.27830832]
[15.95384615]
+----------+-----+
|prediction|count|
+----------+-----+
|         1|  946|
|         3|  733|
|         4|  780|
|         2|  896|
|         0|  671|
+----------+-----+

+--------------------+--------------------+--------------+----------+
|                name|  processed_category|category_index|prediction|
+--------------------+--------------------+--------------+----------+
|fusce aliquet lim...|writing paper sup...|          15.0|         4|
|    cum sociis corp.|shop jewelry repa...|          11.0|         0|
|morbi vehicula li...|vehicle new part ...|          18.0|         4|
|vel nisl incorpor...|souvenir card sho...|           5.0|         2|
|    egestas sed inc.|digital book musi...|           0.0|         1|
+--------------------+--------------------+--------------+----------+
only showing top 5 rows



#### Obtain the category for each cluster

In [11]:
for cluster_id in range(model.getK()):
    print(f"Cluster {cluster_id} categories:")
    cluster = predictions.filter(col("prediction") == cluster_id)
    category_counts = cluster.groupBy("processed_category")\
                                     .count().orderBy("count", ascending=False)
    category_counts.show(truncate=False)

Cluster 0 categories:
+---------------------------------------------+-----+
|processed_category                           |count|
+---------------------------------------------+-----+
|service shop sale and bicycle                |170  |
|shop jewelry repair clock and watch          |170  |
|instrument musical music shop piano sheet and|167  |
|and beauty health spa                        |164  |
+---------------------------------------------+-----+

Cluster 1 categories:
+--------------------------------------------------------------------------+-----+
|processed_category                                                        |count|
+--------------------------------------------------------------------------+-----+
|digital book music movie goods:                                           |195  |
|artist shop supply craft and                                              |193  |
|computer integrated processing design service system data and programming |191  |
|shop shoe               

#### Compute Silhouette Score

The closer the Silhouette Score value is to 1, the better the classification effect is:

In [12]:
# evaluate using contour coefficient
evaluator = ClusteringEvaluator(featuresCol='features', 
                                metricName='silhouette', 
                                distanceMeasure='squaredEuclidean')

silhouette_score = evaluator.evaluate(predictions)
print(f"Silhouette Score: {silhouette_score}")

Silhouette Score: 0.7273731394415569


#### Compute WCSS

In [13]:
# get cluster centers and broadcast them
centers = model.clusterCenters()
broadcast_centers = spark.sparkContext.broadcast(centers)

# define UDF to compute squared Euclidean distance
def squared_distance(features, center):
    return float(Vectors.squared_distance(features, center))

# define UDF to use the broadcasted centers
squared_distance_udf = udf(lambda features, cluster: squared_distance(features, Vectors.dense(broadcast_centers.value[int(cluster)])), DoubleType())

# add distance column to the DataFrame
distance = predictions.withColumn('distance', 
                                  squared_distance_udf(col('features'), col('prediction')))

# compute WCSS
wcss = distance.groupBy('prediction')\
               .agg({'distance': 'sum'})\
               .agg({'sum(distance)': 'sum'}).collect()[0][0]
print(f"Within-Cluster Sum of Squares (WCSS): {wcss}")

Within-Cluster Sum of Squares (WCSS): 8114.604724895357


                                                                                

## Show the predictions with cluster assignments

In [14]:
predictions.show(truncate=False)

+---------------------------------+------------+-------------+---------+--------------------------------------------------------------------------+--------------+--------+----------+
|name                             |merchant_abn|revenue_level|take_rate|processed_category                                                        |category_index|features|prediction|
+---------------------------------+------------+-------------+---------+--------------------------------------------------------------------------+--------------+--------+----------+
|fusce aliquet limited            |17189523131 |c            |1.59     |writing paper supply stationery office printing and                       |15.0          |[15.0]  |4         |
|cum sociis corp.                 |22528859307 |a            |6.4      |shop jewelry repair clock and watch                                       |11.0          |[11.0]  |0         |
|morbi vehicula limited           |32413511882 |b            |3.35     |vehicle new p

In [15]:
print("Cluster Centers:")
for center in centers:
    print(center)

Cluster Centers:
[11.48435171]
[1.9640592]
[6.98102679]
[21.27830832]
[15.95384615]


**Since we found the K-Means classification is not reasonable, we decide to classify the merchants manually.**

# Manually classification

Find how many caterories in total:

In [16]:
df = spark.createDataFrame(merchant_detail.select("processed_category").rdd,
                           ['processed_category'])

# Get distinct values
unique_descriptions_df = df.distinct()

# Count the number of unique values
unique_count = unique_descriptions_df.count()

# Display all unique values
unique_descriptions_df.show(truncate=False)

print(f"Total number of categories: {unique_count}")

+--------------------------------------------------------------------------+
|processed_category                                                        |
+--------------------------------------------------------------------------+
|service shop sale and bicycle                                             |
|and beauty health spa                                                     |
|nursery lawn supply outlet and including garden                           |
|instrument musical music shop piano sheet and                             |
|and gallery art dealer                                                    |
|equipment furnishing manufacturer shop furniture appliance home and except|
|digital book music movie goods:                                           |
|shop restoration service repair sale and antique                          |
|computer integrated processing design service system data and programming |
|equipment peripheral computer computer and software                       |

In [17]:
# define the mapping for merchant_segment
category_to_segment = {
    # Personal Care and Repair Services
    "service shop sale and bicycle": 1,
    "shop jewelry repair clock and watch": 1,
    "and beauty health spa": 1,
    "optician good optical and eyeglass": 1,
    "shop shoe": 1,
    "vehicle new part supply and motor": 1,
    
    # Arts
    "artist shop supply craft and": 2,
    "instrument musical music shop piano sheet and": 2,
    "and gallery art dealer": 2,
    "and newspaper book periodical": 2,
    
    # Home and Furniture
    "equipment furnishing manufacturer shop furniture appliance home and except": 3,
    "equipment rent furniture al leasing and tool appliance": 3,
    "shop restoration service repair sale and antique": 3,
    "nursery lawn supply outlet and including garden": 3,
    "and shop awning tent": 3,
    
    # Gifts and Souvenirs
    "souvenir card shop and novelty gift": 4,
    "nursery florist flower supply stock and": 4,
    "silverware shop jewelry clock and watch": 4,
    "writing paper supply stationery office printing and": 4,
    "game shop hobby toy and":4,
    
    # Technology and Electronic Equipment
    "radio other service satellite cable television and pay": 5,
    "digital book music movie goods:": 5,
    "computer integrated processing design service system data and programming": 5,
    "equipment peripheral computer computer and software": 5,
    "telecom": 5
}

# add a new column merchant_segment based on processed_category
merchant_detail = merchant_detail.withColumn(
    "merchant_segment",
    F.when(
        F.col("processed_category").isin(list(category_to_segment.keys())),
        F.coalesce(*[F.when(F.col("processed_category") == key, 
                            F.lit(value)) for key, value in category_to_segment.items()])
    )
)

merchant_detail.show(5, truncate=False)

+----------------------+------------+-------------+---------+---------------------------------------------------+----------------+
|name                  |merchant_abn|revenue_level|take_rate|processed_category                                 |merchant_segment|
+----------------------+------------+-------------+---------+---------------------------------------------------+----------------+
|fusce aliquet limited |17189523131 |c            |1.59     |writing paper supply stationery office printing and|4               |
|cum sociis corp.      |22528859307 |a            |6.4      |shop jewelry repair clock and watch                |1               |
|morbi vehicula limited|32413511882 |b            |3.35     |vehicle new part supply and motor                  |1               |
|vel nisl incorporated |32897338221 |a            |6.2      |souvenir card shop and novelty gift                |4               |
|egestas sed inc.      |34082818630 |a            |6.86     |digital book music mov

Count the number of sub-category in each segment:

In [18]:
for cluster_id in range(1, 6, 1):
    if cluster_id == 1:
        segment = "Personal Care and Repair Services"
    elif cluster_id == 2:
        segment = "Arts"
    elif cluster_id == 3:
        segment = "Home and Furniture"
    elif cluster_id == 4:
        segment = "Gifts and Souvenirs"
    elif cluster_id == 5:
        segment = "Technology and Electronic Equipment"
        
    print(f"Segment {cluster_id} - {segment}:")
    cluster = merchant_detail.filter(col("merchant_segment") == cluster_id)
    category_counts = cluster.groupBy("processed_category")\
                             .count().orderBy("count", ascending=False)
    category_counts.show(truncate=False)

Segment 1 - Personal Care and Repair Services:
+-----------------------------------+-----+
|processed_category                 |count|
+-----------------------------------+-----+
|shop shoe                          |185  |
|service shop sale and bicycle      |170  |
|shop jewelry repair clock and watch|170  |
|and beauty health spa              |164  |
|optician good optical and eyeglass |151  |
|vehicle new part supply and motor  |151  |
+-----------------------------------+-----+

Segment 2 - Arts:
+---------------------------------------------+-----+
|processed_category                           |count|
+---------------------------------------------+-----+
|artist shop supply craft and                 |193  |
|instrument musical music shop piano sheet and|167  |
|and newspaper book periodical                |164  |
|and gallery art dealer                       |112  |
+---------------------------------------------+-----+

Segment 3 - Home and Furniture:
+----------------------------

# Merge the classification results to `full_transaction` dataset

Merge the `full_transaction` and `predictions` tables based on the `merchant_abn` column:

In [19]:
full_transaction = full_transaction.join(merchant_detail.select("merchant_abn",
                                                                 "merchant_segment"), 
                                         on="merchant_abn", how="left")
full_transaction.show(5, truncate=False)

+------------+------------------------------------+---------+--------------------------+-------------------+-------------------+---------------+---------------+---------------+---------------+---------------+----------------+------------+-------------------+--------------+--------------+---------------+--------------+---------------+---------------+-----------------+---------------+----------------+----------------+-------------------+------------+----------------+------------------+-------------+-------------+---------------+---------------+-------------+--------------+--------------------+--------------+-------------------+-----------------+--------------+-----------------+--------------+-------------------+--------------+---------------+---------------+-------------+--------------+-------------------+--------------+--------------+----------------+-------------+-------------+---------------+---------------+----------------+----------------+---------------+------------------+---------

# Apply One-Hot Encoding on `merchant_segment` column

In [20]:
segments = full_transaction.select("merchant_segment").distinct().rdd.flatMap(lambda x: x).collect()

for segment in segments:
    full_transaction = full_transaction.withColumn(f"merchant_segment_{segment}", when(col("merchant_segment") == segment, 1).otherwise(0))

full_transaction = full_transaction.drop("merchant_segment")
full_transaction.show(5)

+------------+--------------------+---------+--------------------------+-------------------+-------------------+---------------+---------------+---------------+---------------+---------------+----------------+------------+-------------------+--------------+--------------+---------------+--------------+---------------+---------------+-----------------+---------------+----------------+----------------+-------------------+------------+----------------+------------------+-------------+-------------+---------------+---------------+-------------+--------------+--------------------+--------------+-------------------+-----------------+--------------+-----------------+--------------+-------------------+--------------+---------------+---------------+-------------+--------------+-------------------+--------------+--------------+----------------+-------------+-------------+---------------+---------------+----------------+----------------+---------------+------------------+----------------+--------

Check the shape of the dataset:

In [21]:
num_rows = full_transaction.count()
print(f"Number of rows: {num_rows}")

num_columns = len(full_transaction.columns)
print(f"Number of columns: {num_columns}")

Number of rows: 11304463
Number of columns: 132


In [22]:
full_transaction.printSchema()

root
 |-- merchant_abn: long (nullable = true)
 |-- order_id: string (nullable = true)
 |-- take_rate: float (nullable = true)
 |-- merchant_fraud_probability: double (nullable = true)
 |-- transaction_revenue: double (nullable = true)
 |-- BNPL_revenue: double (nullable = true)
 |-- revenue_level_e: integer (nullable = true)
 |-- revenue_level_d: integer (nullable = true)
 |-- revenue_level_c: integer (nullable = true)
 |-- revenue_level_b: integer (nullable = true)
 |-- revenue_level_a: integer (nullable = true)
 |-- category_jewelry: integer (nullable = true)
 |-- category_art: integer (nullable = true)
 |-- category_television: integer (nullable = true)
 |-- category_watch: integer (nullable = true)
 |-- category_cable: integer (nullable = true)
 |-- category_repair: integer (nullable = true)
 |-- category_stock: integer (nullable = true)
 |-- category_flower: integer (nullable = true)
 |-- category_office: integer (nullable = true)
 |-- category_souvenir: integer (nullable = true)

Confirm there is no nulls:

In [23]:
# create a dictionary with column names and their respective null counts
null_count_dict = {col_name: sum(col(col_name).isNull().cast("int")).alias(col_name) for col_name in full_transaction.columns}

# use agg() to calculate null counts for each column
null_counts_df = full_transaction.agg(*null_count_dict.values())
null_counts_df.show()



+------------+--------+---------+--------------------------+-------------------+------------+---------------+---------------+---------------+---------------+---------------+----------------+------------+-------------------+--------------+--------------+---------------+--------------+---------------+---------------+-----------------+---------------+----------------+----------------+-------------------+------------+----------------+------------------+-------------+-------------+---------------+---------------+-------------+--------------+--------------------+--------------+-------------------+-----------------+--------------+-----------------+--------------+-------------------+--------------+---------------+---------------+-------------+--------------+-------------------+--------------+--------------+----------------+-------------+-------------+---------------+---------------+----------------+----------------+---------------+------------------+----------------+-------------------+-------

                                                                                

In [24]:
# save as a parquet file
full_transaction.write.parquet('../data/curated/full_transaction_with_segments', mode='overwrite')

                                                                                