In [1]:
!pip install pyspark

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import pandas as pd
from pyspark.sql.functions import col, explode


In [3]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [4]:
from pyspark.sql import SparkSession
#spark.stop()
spark = SparkSession \
.builder \
.config("spark.executor.instances", "1")\
.config('spark.driver.memory','1g')\
.config('spark.executor.memory', '1g') \
.getOrCreate()

In [5]:
df = spark.read.csv('/content/gdrive/MyDrive/musicrecom/hetrec2011-lastfm-2k/user_artists.dat', sep='\t', inferSchema=True, header=True, nullValue='NA', nanValue='NA',emptyValue='NA').dropna()


In [6]:
from pyspark.sql.functions import countDistinct
noUsers = df.select(countDistinct("userID"))
noUsers.show()

+----------------------+
|count(DISTINCT userID)|
+----------------------+
|                  1892|
+----------------------+



In [7]:
df = df.withColumn('userID', col('userID').cast('integer')).\
    withColumn('artistID', col('artistID').cast('integer')).\
    withColumn('weight', col('weight').cast('integer'))


In [8]:
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

In [9]:
(train, test) = df.randomSplit([0.8, 0.2], seed = 1234)
als = ALS(userCol="userID", itemCol="artistID", ratingCol="weight", nonnegative = True, implicitPrefs = False, coldStartStrategy="drop")

In [10]:
param_grid = ParamGridBuilder() \
            .addGrid(als.rank, [50, 100,]) \
            .addGrid(als.regParam, [.1, .15]) \
            .build()
            # .addGrid(als.maxIter, [5, 50, 100, 200]) \

evaluator = RegressionEvaluator(metricName="rmse", labelCol="weight", predictionCol="prediction") 
print ("Num models to be tested: ", len(param_grid))

Num models to be tested:  4


In [11]:
cv = CrossValidator(estimator=als, estimatorParamMaps=param_grid, evaluator=evaluator, numFolds=5)


In [12]:
model = cv.fit(train)
best_model = model.bestModel
best_model

ALSModel: uid=ALS_221447f39b23, rank=100

In [15]:
best_model.write().overwrite().save('/content/gdrive/MyDrive/musicrecom/als')

In [16]:
test_predictions = best_model.transform(test)
?

5597.569265835045


In [17]:
test_predictions.show()


+------+--------+------+----------+
|userID|artistID|weight|prediction|
+------+--------+------+----------+
|   148|     619|   227|  283.2136|
|   148|     632|   195| 350.75766|
|   148|    1118|   214| 417.04828|
|   148|    1409|   282|   138.543|
|   148|    1986|   202| 186.01276|
|   148|    3341|   280|  68.48562|
|   148|    3352|   212|  79.40968|
|   148|    3354|   203| 164.13243|
|   148|    3357|   192| 157.21304|
|   463|      46|    10| 3.2048187|
|   463|     389|    16| 62.228577|
|   463|     726|    14| 36.672028|
|   463|    7159|    26|  4.746778|
|   463|    7170|    13| 16.858574|
|   471|     212|   352|  108.3551|
|   471|     227|  1901| 1814.5715|
|   471|     267|    86| 24.974503|
|   471|     718|   117| 56.381977|
|   471|    1254|    65| 40.655174|
|   471|    1369|   133| 91.567085|
+------+--------+------+----------+
only showing top 20 rows



In [18]:
nrecommendations = best_model.recommendForAllUsers(10)
nrecommendations.limit(10).show()

+------+--------------------+
|userID|     recommendations|
+------+--------------------+
|     3|[{67, 36321.96}, ...|
|     5|[{687, 4289.433},...|
|     6|[{511, 1133.7229}...|
|     9|[{1672, 17271.17}...|
|    12|[{1672, 387618.12...|
|    13|[{2309, 28248.559...|
|    15|[{2044, 4506.3647...|
|    16|[{4271, 6463.8247...|
|    17|[{294, 46319.355}...|
|    20|[{203, 23886.137}...|
+------+--------------------+



In [20]:
nrecommendations = nrecommendations\
    .withColumn("rec_exp", explode("recommendations"))\
    .select('userId', col("rec_exp.artistID"), col("rec_exp.rating"))

nrecommendations.limit(10).show()

+------+--------+---------+
|userId|artistID|   rating|
+------+--------+---------+
|     3|      67| 36321.96|
|     3|    3478| 29221.01|
|     3|     701|29174.752|
|     3|   14986|25873.701|
|     3|     744| 25219.01|
|     3|    2044|24737.242|
|     3|     289|23932.045|
|     3|   14987|23744.113|
|     3|     207|23654.018|
|     3|     154|23618.371|
+------+--------+---------+



In [23]:
artists = spark.read.csv('/content/gdrive/MyDrive/musicrecom/hetrec2011-lastfm-2k/artists.dat', sep='\t', inferSchema=True, header=True, nullValue='NA', nanValue='NA',emptyValue='NA').dropna()


In [30]:
nrecommendations.join(artists, nrecommendations.artistID == artists.id).filter('userId = 3').orderBy("rating").show()


+------+--------+---------+-----+--------------------+--------------------+--------------------+
|userId|artistID|   rating|   id|                name|                 url|          pictureURL|
+------+--------+---------+-----+--------------------+--------------------+--------------------+
|     3|     154|23618.371|  154|           Radiohead|http://www.last.f...|http://userserve-...|
|     3|     207|23654.018|  207|      Arctic Monkeys|http://www.last.f...|http://userserve-...|
|     3|   14987|23744.113|14987|RICHARD DIXON-COM...|http://www.last.f...|http://userserve-...|
|     3|     289|23932.045|  289|      Britney Spears|http://www.last.f...|http://userserve-...|
|     3|    2044|24737.242| 2044|     Sarah Brightman|http://www.last.f...|http://userserve-...|
|     3|     744| 25219.01|  744|            Autechre|http://www.last.f...|http://userserve-...|
|     3|   14986|25873.701|14986|         Dicky Dixon|http://www.last.f...|http://userserve-...|
|     3|     701|29174.752|  7