In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/10/05 02:15:41 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/10/05 02:15:41 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
24/10/05 02:15:41 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


### Reading user-artist interaction data

In [2]:
user_artist_data_path = "data/audioscrobbler/user_artist_data.txt"
raw_user_artist_data = spark.read.text(user_artist_data_path)

In [3]:
raw_user_artist_data.show(3) # user ID, artist ID, and play count

+------------------+
|             value|
+------------------+
|      1000002 1 55|
|1000002 1000006 33|
| 1000002 1000007 8|
+------------------+
only showing top 3 rows



`raw_user_artist_data` is like a list of strings. Not very useful. We parse it into a DataFrame with 3 integer-type columns.

In [4]:
raw_user_artist_data.createOrReplaceTempView("raw_user_artist_data")

query = """
    select 
        cast(split(value, ' ')[0] as int) as user,
        cast(split(value, ' ')[1] as int) as artist,
        cast(split(value, ' ')[2] as int) as count
    from raw_user_artist_data
"""
user_artist_data = spark.sql(query)

In [5]:
user_artist_data.show(3)

+-------+-------+-----+
|   user| artist|count|
+-------+-------+-----+
|1000002|      1|   55|
|1000002|1000006|   33|
|1000002|1000007|    8|
+-------+-------+-----+
only showing top 3 rows



### Reading artist aliases

In [6]:
artist_alias_path = "data/audioscrobbler/artist_alias.txt"
raw_artist_alias = spark.read.text(artist_alias_path)

raw_artist_alias.createOrReplaceTempView("raw_artist_alias")

query = """
    select 
        cast(split(value, '\t')[0] as int) as artist,
        cast(split(value, '\t')[1] as int) as alias
    from raw_artist_alias
"""
artist_alias = spark.sql(query)

In [7]:
artist_alias.show(3)

+-------+-------+
| artist|  alias|
+-------+-------+
|1092764|1000311|
|1095122|1000557|
|6708070|1007267|
+-------+-------+
only showing top 3 rows



### Prepare the training and testing data

Using `artist_alias` to replace artist IDs by their appropriate alias.

In [8]:
from pyspark.sql.functions import when, col

user_artist_data.createOrReplaceTempView("user_artist_data")
artist_alias.createOrReplaceTempView("artist_alias")

df = spark.sql("""
    select * 
    from user_artist_data 
        natural left join artist_alias;
""") 

df = df.withColumn(
    "artist", # column name
    when(col("alias").isNull(), col("artist")).otherwise(col("alias")) # when(condition, value).otherwise(value)
).drop("alias")

In [9]:
df = df.sample(0.02) # only taking 2% of the data because I'm poor
train, test = df.randomSplit([8.0, 2.0])

### Training the model

In [10]:
train.cache()

DataFrame[artist: int, user: int, count: int]

In [11]:
from pyspark.ml.recommendation import ALS

als = ALS(userCol='user', itemCol='artist', ratingCol='count')
model = als.fit(train)

24/10/05 02:15:52 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors
24/10/05 02:15:59 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
24/10/05 02:15:59 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.VectorBLAS
24/10/05 02:16:01 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK
                                                                                

In [12]:
predictions = model.transform(test)
predictions.na.drop().show(5)

                                                                                

+------+-------+-----+-----------+
|artist|   user|count| prediction|
+------+-------+-----+-----------+
|   721|1045486|    1|-0.30049655|
|   786|1042553|    1| -1.5385172|
|  3379|1004666|    2| -22.581377|
|  4149|1001129|    2| -12.545612|
|  4468|1049740|   42| -20.852589|
+------+-------+-----+-----------+
only showing top 5 rows



In [13]:
from pyspark.ml.evaluation import RegressionEvaluator

evaluator = RegressionEvaluator(labelCol="count")
evaluator.evaluate(predictions.na.drop())

                                                                                

122.6337927898658

### Hyperparameter tuning

In [14]:
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

param_grid = ParamGridBuilder() \
    .addGrid(als.rank, [1, 5, 10]) \
    .addGrid(als.maxIter, [5, 10]) \
    .addGrid(als.regParam, [0.05, 0.1]) \
    .build()

cv = CrossValidator(estimator=als, estimatorParamMaps=param_grid, evaluator=evaluator)

cv_model = cv.fit(train)
predictions = cv_model.transform(test)
evaluator.evaluate(predictions.na.drop())

                                                                                

387.8318339296248