In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, regexp_replace, trim, split, min, max, udf
)
from pyspark.sql.types import StringType, DoubleType
from pyspark.ml.feature import (
    StringIndexer, VectorAssembler
)
from pyspark.ml.clustering import KMeans
from pyspark.ml.linalg import Vectors
from pyspark.ml.evaluation import ClusteringEvaluator
import pandas as pd
import geopandas as gpd
import nltk
from nltk.stem import WordNetLemmatizer
from nltk.corpus import wordnet


In [2]:
spark = (
    SparkSession.builder.appName("Preliminary_Analysis")
    .config("spark.sql.repl.eagerEval.enabled", True) 
    .config("spark.driver.memory","4G")
    .config("spark.executor.memory","4G")
    .config("spark.sql.parquet.cacheMetadata", "true")
    .config("spark.sql.session.timeZone", "Etc/UTC")
    .getOrCreate()
)

your 131072x1 screen size is bogus. expect trouble
24/09/17 13:47:03 WARN Utils: Your hostname, LAPTOP-AVQHG9I2 resolves to a loopback address: 127.0.1.1; using 172.21.158.129 instead (on interface eth0)
24/09/17 13:47:03 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/09/17 13:47:04 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
merchant_detail = spark.read.parquet('../data/curated/merchant_detail')

                                                                                

In [4]:
merchant_detail.show(5)

                                                                                

+--------------------+------------+-------------+---------+--------------------+
|                name|merchant_abn|revenue_level|take_rate|  processed_category|
+--------------------+------------+-------------+---------+--------------------+
|fusce aliquet lim...| 17189523131|            c|     1.59|writing paper sup...|
|    cum sociis corp.| 22528859307|            a|      6.4|shop jewelry repa...|
|morbi vehicula li...| 32413511882|            b|     3.35|vehicle new part ...|
|vel nisl incorpor...| 32897338221|            a|      6.2|souvenir card sho...|
|    egestas sed inc.| 34082818630|            a|     6.86|digital book musi...|
+--------------------+------------+-------------+---------+--------------------+
only showing top 5 rows



In [5]:
# convert string features to numerical features
indexer = StringIndexer(inputCol="processed_category", outputCol="category_index")
indexed_df = indexer.fit(merchant_detail).transform(merchant_detail)

indexed_df.show(5)


                                                                                

+--------------------+------------+-------------+---------+--------------------+--------------+
|                name|merchant_abn|revenue_level|take_rate|  processed_category|category_index|
+--------------------+------------+-------------+---------+--------------------+--------------+
|fusce aliquet lim...| 17189523131|            c|     1.59|writing paper sup...|          15.0|
|    cum sociis corp.| 22528859307|            a|      6.4|shop jewelry repa...|          11.0|
|morbi vehicula li...| 32413511882|            b|     3.35|vehicle new part ...|          18.0|
|vel nisl incorpor...| 32897338221|            a|      6.2|souvenir card sho...|           5.0|
|    egestas sed inc.| 34082818630|            a|     6.86|digital book musi...|           0.0|
+--------------------+------------+-------------+---------+--------------------+--------------+
only showing top 5 rows



In [6]:
# convert string features to numerical features
indexer = StringIndexer(inputCol="processed_category", outputCol="category_index")
indexed_df = indexer.fit(merchant_detail).transform(merchant_detail)

# create feature vectors
feature_columns = ['category_index']
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
feature_data = assembler.transform(indexed_df)

# perform KMeans clustering with 5 clusters
kmeans = KMeans(k=5, seed=42)
model = kmeans.fit(feature_data)
predictions = model.transform(feature_data)

predictions.show(5)


24/09/17 13:47:30 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS


+--------------------+------------+-------------+---------+--------------------+--------------+--------+----------+
|                name|merchant_abn|revenue_level|take_rate|  processed_category|category_index|features|prediction|
+--------------------+------------+-------------+---------+--------------------+--------------+--------+----------+
|fusce aliquet lim...| 17189523131|            c|     1.59|writing paper sup...|          15.0|  [15.0]|         4|
|    cum sociis corp.| 22528859307|            a|      6.4|shop jewelry repa...|          11.0|  [11.0]|         0|
|morbi vehicula li...| 32413511882|            b|     3.35|vehicle new part ...|          18.0|  [18.0]|         4|
|vel nisl incorpor...| 32897338221|            a|      6.2|souvenir card sho...|           5.0|   [5.0]|         2|
|    egestas sed inc.| 34082818630|            a|     6.86|digital book musi...|           0.0|   [0.0]|         1|
+--------------------+------------+-------------+---------+-------------

In [7]:
# view the number of clusters (k) used in the KMeans model
print(f"Number of clusters (k): {model.getK()}")

# view the centers of each cluster
centers = model.clusterCenters()
print("Cluster centers:")
for center in centers:
    print(center)

# calculate the size of each cluster
cluster_sizes = predictions.groupBy("prediction").count()
cluster_sizes.show()

# view the prediction results, including the cluster assigned to each data point
predictions.select("name", "processed_category", "category_index", "prediction").show(5)


Number of clusters (k): 5
Cluster centers:
[11.48435171]
[1.9640592]
[6.98102679]
[21.27830832]
[15.95384615]
+----------+-----+
|prediction|count|
+----------+-----+
|         1|  946|
|         3|  733|
|         4|  780|
|         2|  896|
|         0|  671|
+----------+-----+

+--------------------+--------------------+--------------+----------+
|                name|  processed_category|category_index|prediction|
+--------------------+--------------------+--------------+----------+
|fusce aliquet lim...|writing paper sup...|          15.0|         4|
|    cum sociis corp.|shop jewelry repa...|          11.0|         0|
|morbi vehicula li...|vehicle new part ...|          18.0|         4|
|vel nisl incorpor...|souvenir card sho...|           5.0|         2|
|    egestas sed inc.|digital book musi...|           0.0|         1|
+--------------------+--------------------+--------------+----------+
only showing top 5 rows



In [8]:
# obtain the sample category for each cluster
for cluster_id in range(model.getK()):
    print(f"Cluster {cluster_id} categories:")
    cluster_samples = predictions.filter(col("prediction") == cluster_id)
    category_counts = cluster_samples.groupBy("processed_category").count().orderBy("count", ascending=False)
    category_counts.show()


Cluster 0 categories:
+--------------------+-----+
|  processed_category|count|
+--------------------+-----+
|service shop sale...|  170|
|shop jewelry repa...|  170|
|instrument musica...|  167|
|and beauty health...|  164|
+--------------------+-----+

Cluster 1 categories:
+--------------------+-----+
|  processed_category|count|
+--------------------+-----+
|digital book musi...|  195|
|artist shop suppl...|  193|
|computer integrat...|  191|
|           shop shoe|  185|
|equipment furnish...|  182|
+--------------------+-----+

Cluster 2 categories:
+--------------------+-----+
|  processed_category|count|
+--------------------+-----+
|souvenir card sho...|  182|
|equipment periphe...|  181|
|nursery florist f...|  180|
|and shop awning tent|  178|
|radio other servi...|  175|
+--------------------+-----+

Cluster 3 categories:
+--------------------+-----+
|  processed_category|count|
+--------------------+-----+
|game shop hobby t...|  142|
|equipment rent fu...|  134|
|shop rest

In [9]:
# evaluate using contour coefficient
evaluator = ClusteringEvaluator(featuresCol='features', metricName='silhouette', distanceMeasure='squaredEuclidean')

silhouette_score = evaluator.evaluate(predictions)
print(f"Silhouette Score: {silhouette_score}")

Silhouette Score: 0.7273731394415569


In [10]:
# get cluster centers and broadcast them
centers = model.clusterCenters()
broadcast_centers = spark.sparkContext.broadcast(centers)

# define UDF to compute squared Euclidean distance
def squared_distance(features, center):
    return float(Vectors.squared_distance(features, center))

# define UDF to use the broadcasted centers
squared_distance_udf = udf(lambda features, cluster: squared_distance(features, Vectors.dense(broadcast_centers.value[int(cluster)])), DoubleType())

In [11]:
# add distance column to the DataFrame
distance = predictions.withColumn('distance', squared_distance_udf(col('features'), col('prediction')))

# compute WCSS
wcss = distance.groupBy('prediction').agg({'distance': 'sum'}).agg({'sum(distance)': 'sum'}).collect()[0][0]
print(f"Within-Cluster Sum of Squares (WCSS): {wcss}")

# show the predictions with cluster assignments
predictions.show(truncate=False)

print("Cluster Centers:")
for center in centers:
    print(center)

[Stage 54:>                                                         (0 + 1) / 1]

Within-Cluster Sum of Squares (WCSS): 8114.604724895357
+---------------------------------+------------+-------------+---------+--------------------------------------------------------------------------+--------------+--------+----------+
|name                             |merchant_abn|revenue_level|take_rate|processed_category                                                        |category_index|features|prediction|
+---------------------------------+------------+-------------+---------+--------------------------------------------------------------------------+--------------+--------+----------+
|fusce aliquet limited            |17189523131 |c            |1.59     |writing paper supply stationery office printing and                       |15.0          |[15.0]  |4         |
|cum sociis corp.                 |22528859307 |a            |6.4      |shop jewelry repair clock and watch                                       |11.0          |[11.0]  |0         |
|morbi vehicula limited      

                                                                                