In [3]:
# Imports & Spark setup
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.clustering import KMeans
from pyspark.sql.types import FloatType, StringType
from pyspark.sql import Row
from pyspark.sql.functions import broadcast, udf
import numpy as np
from tools import setup_spark_config, read_parquet_files

sc, spark = setup_spark_config("Clustering Million Song Dataset")

In [4]:
# read songs data from parquet files
basedir = 'parsed-MillionSongSubset'
songs_df = read_parquet_files(basedir, spark)

Reading songs from parquet files to DataFrame


In [5]:
songs_df.show()

+--------+------------+----+------------------+--------------+---+-------+
|loudness|song_hotness|year|artist_familiarity|artist_hotness|key|  tempo|
+--------+------------+----+------------------+--------------+---+-------+
|  -9.636|  0.54795295|2008|        0.55746025|    0.38615164|  0|124.059|
| -11.061|  0.47563848|2004|         0.6269577|     0.4348596|  1| 80.084|
|  -4.264|   0.7883882|1982|        0.73703754|     0.5392454| 10| 92.897|
|  -4.707|    0.681092|2004|         0.8218443|     0.5924395|  0|157.715|
|  -4.523|  0.40148672|2005|        0.49579692|    0.38949883|  0|146.331|
|  -4.076|   0.6878737|2004|        0.73343325|     0.4555588|  0| 84.992|
|  -3.312|  0.35528553|2001|        0.48433375|     0.3359355|  1| 99.959|
| -25.651|  0.21508032|1982|         0.5772761|    0.37693998|  1|104.989|
|  -6.052|  0.87222904|2000|         0.8873861|      0.791143|  4|105.095|
| -15.433|   0.5968407|1981|         0.6559214|     0.5783016|  5|100.042|
|  -4.325|   0.6248335|20

In [6]:
# transform grouped (by year) data to vector to use for clustering
input_cols = ["loudness", \
              "song_hotness", \
              "artist_familiarity", \
              "artist_hotness", \
              "key", \
              "tempo"]
vecAssembler = VectorAssembler(inputCols=input_cols, \
                               outputCol="features")
vec_df = vecAssembler.transform(songs_df)

In [7]:
vec_df.show()

+--------+------------+----+------------------+--------------+---+-------+--------------------+
|loudness|song_hotness|year|artist_familiarity|artist_hotness|key|  tempo|            features|
+--------+------------+----+------------------+--------------+---+-------+--------------------+
|  -9.636|  0.54795295|2008|        0.55746025|    0.38615164|  0|124.059|[-9.6359996795654...|
| -11.061|  0.47563848|2004|         0.6269577|     0.4348596|  1| 80.084|[-11.060999870300...|
|  -4.264|   0.7883882|1982|        0.73703754|     0.5392454| 10| 92.897|[-4.2639999389648...|
|  -4.707|    0.681092|2004|         0.8218443|     0.5924395|  0|157.715|[-4.7069997787475...|
|  -4.523|  0.40148672|2005|        0.49579692|    0.38949883|  0|146.331|[-4.5229997634887...|
|  -4.076|   0.6878737|2004|        0.73343325|     0.4555588|  0| 84.992|[-4.0760002136230...|
|  -3.312|  0.35528553|2001|        0.48433375|     0.3359355|  1| 99.959|[-3.3120000362396...|
| -25.651|  0.21508032|1982|         0.5

In [8]:
# fit a KMeans model to the vector transform of the grouped (by year) data
kmeans = KMeans(k=len(input_cols), seed=1)
model = kmeans.fit(vec_df.select('features'))

In [9]:
# cluster the vector transform of the grouped (by year) data
transformed_df = model.transform(vec_df)

In [10]:
transformed_df.show()

+--------+------------+----+------------------+--------------+---+-------+--------------------+----------+
|loudness|song_hotness|year|artist_familiarity|artist_hotness|key|  tempo|            features|prediction|
+--------+------------+----+------------------+--------------+---+-------+--------------------+----------+
|  -9.636|  0.54795295|2008|        0.55746025|    0.38615164|  0|124.059|[-9.6359996795654...|         1|
| -11.061|  0.47563848|2004|         0.6269577|     0.4348596|  1| 80.084|[-11.060999870300...|         5|
|  -4.264|   0.7883882|1982|        0.73703754|     0.5392454| 10| 92.897|[-4.2639999389648...|         0|
|  -4.707|    0.681092|2004|         0.8218443|     0.5924395|  0|157.715|[-4.7069997787475...|         4|
|  -4.523|  0.40148672|2005|        0.49579692|    0.38949883|  0|146.331|[-4.5229997634887...|         4|
|  -4.076|   0.6878737|2004|        0.73343325|     0.4555588|  0| 84.992|[-4.0760002136230...|         0|
|  -3.312|  0.35528553|2001|        0

In [11]:
# create dataframe for each centroid
centroids = model.clusterCenters()
centroids = np.array(centroids).T.tolist()
centroids.append([i for i in range(len(input_cols))])

R = Row("loudness", \
        "song_hotness", \
        "artist_familiarity", \
        "artist_hotness", \
        "key", \
        "tempo", \
        "centroid")
centroids_df = sc.parallelize([R(*r) for r in zip(*centroids)]).toDF()

In [12]:
centroids_df.show()

+-------------------+-------------------+------------------+------------------+------------------+------------------+--------+
|           loudness|       song_hotness|artist_familiarity|    artist_hotness|               key|             tempo|centroid|
+-------------------+-------------------+------------------+------------------+------------------+------------------+--------+
| -9.601248617988924|0.43143110260449724|0.6319199618055017|0.4415737438761727|5.3535911602209945| 95.64976029369713|       0|
| -9.210976199227936| 0.4506732383383915| 0.640198458784393|0.4465091046866646| 5.380952380952381|122.39838436239161|       1|
|  -8.62233099131517| 0.4758835130594146|0.6536595206445371|0.4462047878285529| 4.711267605633803|208.00740814208984|       2|
| -8.520272992123132| 0.4563011870901475| 0.651254778278285|0.4543562534896807| 4.847701149425287|173.10547703710097|       3|
| -8.700728515549413| 0.4588742870806039|0.6371752564712077|0.4425835703290735|   5.3929173693086|147.182127902

In [17]:
# add fictional genre to each centroid
genres = ["hot&loud", "plain", "mellow&soft", "mainstream", "temp1", "temp2"]

def add_genre(centroid):
    print(centroid)
    return genres[int(centroid)]

udf_add_genre = udf(add_genre, StringType())
genres_df = centroids_df.withColumn("genre", udf_add_genre("centroid")).select("centroid", "genre")

In [18]:
genres_df.show()

+--------+-----------+
|centroid|      genre|
+--------+-----------+
|       0|   hot&loud|
|       1|      plain|
|       2|mellow&soft|
|       3| mainstream|
|       4|      temp1|
|       5|      temp2|
+--------+-----------+



In [19]:
# add centroid genre to songs
song_genres_df = transformed_df.join(broadcast(genres_df), transformed_df.prediction == genres_df.centroid) \
    .select("loudness", \
            "song_hotness", \
            "artist_familiarity", \
            "artist_hotness", \
            "key", \
            "tempo", \
            "genre")

In [20]:
song_genres_df.show()

+--------+------------+------------------+--------------+---+-------+----------+
|loudness|song_hotness|artist_familiarity|artist_hotness|key|  tempo|     genre|
+--------+------------+------------------+--------------+---+-------+----------+
|  -9.636|  0.54795295|        0.55746025|    0.38615164|  0|124.059|     plain|
| -11.061|  0.47563848|         0.6269577|     0.4348596|  1| 80.084|     temp2|
|  -4.264|   0.7883882|        0.73703754|     0.5392454| 10| 92.897|  hot&loud|
|  -4.707|    0.681092|         0.8218443|     0.5924395|  0|157.715|     temp1|
|  -4.523|  0.40148672|        0.49579692|    0.38949883|  0|146.331|     temp1|
|  -4.076|   0.6878737|        0.73343325|     0.4555588|  0| 84.992|  hot&loud|
|  -3.312|  0.35528553|        0.48433375|     0.3359355|  1| 99.959|  hot&loud|
| -25.651|  0.21508032|         0.5772761|    0.37693998|  1|104.989|  hot&loud|
|  -6.052|  0.87222904|         0.8873861|      0.791143|  4|105.095|  hot&loud|
| -15.433|   0.5968407|     

In [21]:
song_genres_df.createOrReplaceTempView("songs_with_genres")

In [22]:
n_mainstream_songs = spark.sql("SELECT COUNT(*) FROM songs_with_genres WHERE genre = \"mainstream\"").collect()[0][0]
n_hotnloud_songs = spark.sql("SELECT COUNT(*) FROM songs_with_genres WHERE genre = \"hot&loud\"").collect()[0][0]
n_mellownsoft_songs = spark.sql("SELECT COUNT(*) FROM songs_with_genres WHERE genre = \"mellow&soft\"").collect()[0][0]
n_plain_songs = spark.sql("SELECT COUNT(*) FROM songs_with_genres WHERE genre = \"plain\"").collect()[0][0]

print("There are %d mainstream songs in the dataset" % (n_mainstream_songs))
print("There are %d hot&loud songs in the dataset" % (n_hotnloud_songs))
print("There are %d mellow&soft songs in the dataset" % (n_mellownsoft_songs))
print("There are %d plain songs in the dataset" % (n_plain_songs))

There are 354 mainstream songs in the dataset
There are 907 hot&loud songs in the dataset
There are 142 mellow&soft songs in the dataset
There are 876 plain songs in the dataset
