In [78]:
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.recommendation.ALS
import org.apache.spark.sql.Dataset

import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.recommendation.ALS
import org.apache.spark.sql.Dataset


In [79]:
val rawPlayCounts = spark.sparkContext.
    textFile("user_artist_data.txt")

rawPlayCounts: org.apache.spark.rdd.RDD[String] = user_artist_data.txt MapPartitionsRDD[1259] at textFile at <console>:58


In [80]:
val pcSchemaString = "user_id artist_id playcount"
val pcFields = 
    pcSchemaString.split(" ").
    map(fieldName => StructField(fieldName, IntegerType, nullable = true))
val pcSchema = StructType(pcFields)

pcSchemaString: String = user_id artist_id playcount
pcFields: Array[org.apache.spark.sql.types.StructField] = Array(StructField(user_id,IntegerType,true), StructField(artist_id,IntegerType,true), StructField(playcount,IntegerType,true))
pcSchema: org.apache.spark.sql.types.StructType = StructType(StructField(user_id,IntegerType,true), StructField(artist_id,IntegerType,true), StructField(playcount,IntegerType,true))


In [81]:
val rowRDD = rawPlayCounts.
             map{_.split(' ')}.
             map{x => Row(x(0).toInt, x(1).toInt, x(2).toInt)}

val playcounts = spark.createDataFrame(rowRDD, pcSchema).cache()
playcounts.createOrReplaceTempView("playcounts")

rowRDD: org.apache.spark.rdd.RDD[org.apache.spark.sql.Row] = MapPartitionsRDD[1261] at map at <console>:66
playcounts: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [user_id: int, artist_id: int ... 1 more field]


In [82]:
playcounts.show()

+-------+---------+---------+
|user_id|artist_id|playcount|
+-------+---------+---------+
|1000002|        1|       55|
|1000002|  1000006|       33|
|1000002|  1000007|        8|
|1000002|  1000009|      144|
|1000002|  1000010|      314|
|1000002|  1000013|        8|
|1000002|  1000014|       42|
|1000002|  1000017|       69|
|1000002|  1000024|      329|
|1000002|  1000025|        1|
|1000002|  1000028|       17|
|1000002|  1000031|       47|
|1000002|  1000033|       15|
|1000002|  1000042|        1|
|1000002|  1000045|        1|
|1000002|  1000054|        2|
|1000002|  1000055|       25|
|1000002|  1000056|        4|
|1000002|  1000059|        2|
|1000002|  1000062|       71|
+-------+---------+---------+
only showing top 20 rows



In [83]:
val userartistDF = playcounts.toDF("user","artist","count")

userartistDF: org.apache.spark.sql.DataFrame = [user: int, artist: int ... 1 more field]


In [84]:
val uaDF = userartistDF.sample(false, 0.01)

uaDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [user: int, artist: int ... 1 more field]


In [85]:
uaDF.head()

res35: org.apache.spark.sql.Row = [1000002,1000726,5]


In [86]:
uaDF.printSchema()

root
 |-- user: integer (nullable = true)
 |-- artist: integer (nullable = true)
 |-- count: integer (nullable = true)



In [87]:
val Array(training, test) = uaDF.randomSplit(Array(0.8, 0.2))

training: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [user: int, artist: int ... 1 more field]
test: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [user: int, artist: int ... 1 more field]


In [88]:
training.count()

res37: Long = 193799


In [89]:
test.count()

res38: Long = 48034


In [90]:
training.printSchema()

root
 |-- user: integer (nullable = true)
 |-- artist: integer (nullable = true)
 |-- count: integer (nullable = true)



In [91]:
test.printSchema()

root
 |-- user: integer (nullable = true)
 |-- artist: integer (nullable = true)
 |-- count: integer (nullable = true)



In [97]:
val als = new ALS().
            setMaxIter(10).
            setRegParam(0.01).
            setUserCol("user").
            setItemCol("artist").
            setRatingCol("count")

als: org.apache.spark.ml.recommendation.ALS = als_fce13fde40ab


In [98]:
val model = als.fit(training)

model: org.apache.spark.ml.recommendation.ALSModel = als_fce13fde40ab


In [99]:
val predictions = model.transform(test)

predictions: org.apache.spark.sql.DataFrame = [user: int, artist: int ... 2 more fields]


In [100]:
predictions.sort("user").show(1000)

+-------+--------+-----+-------------+
|   user|  artist|count|   prediction|
+-------+--------+-----+-------------+
|    350| 6873847|    1|          NaN|
|    521| 1001821|    4|    24.287743|
|    521| 1013240|    2|    2.8621044|
|    536|    1811|    2|     14.62809|
|    581| 1035142|   10|  -0.24406591|
|    581| 7013164|    8|          NaN|
|    659| 1003557|    3|   -20.479895|
|    713|10238643|    1|          NaN|
|    803| 1016526|   12|          NaN|
|    825| 1006633|    2|          NaN|
|    863| 1001805|    1|    1.3264787|
|   1197| 1001459|    1|    -7.037362|
|   1197| 1008430|    1|     9.593417|
|   1298| 1003572|    1|   0.84852237|
|   1298|10217925|    2|          NaN|
|   1502| 2162933|    2|          NaN|
|   1502| 9899586|    1|          NaN|
|   1502| 1297503|    2|          NaN|
|   1502| 1010112|   24|    1.3905177|
|   1502|    2888|    1|  -0.74651164|
|   2527| 1018713|    2|  -0.33474118|
|   3048| 1001230|    9|    6.2425833|
|   3290| 1006672|   78| 

|1000846| 1004616|   79|    4.4945893|
|1000846| 2091261|   22|          NaN|
|1000846| 1013296|    8|    24.598007|
|1000856| 1333231|   52|          NaN|
|1000858| 1015553|    1|          NaN|
|1000858| 1011818|    7|          NaN|
|1000860| 1001525|    2|   -6.1990905|
|1000860|     622|    1|          NaN|
|1000873| 1000699|    3|   -0.6479933|
|1000879|    3402|   13|   -2.7507753|
|1000879| 1255028|    4|    -5.713773|
|1000879| 1176914|   50|          NaN|
|1000895|     386|    3|          NaN|
|1000895|     463|   11|          NaN|
|1000902| 1002470|    7|   -8.7216835|
|1000909| 6858560|    1| -0.026468053|
|1000921| 1023963|    1|          NaN|
|1000924| 1229617|    1|     0.528763|
|1000926| 1014532|    2|          NaN|
|1000928| 1249332|    2|          NaN|
|1000928| 1007204|    1|  -0.13994604|
|1000928|    1003|   16|   -13.703352|
|1000932| 1043228|   15|   -3.1613386|
|1000932| 1078236|   24|    1.4304411|
|1000935| 1006039|   15| -0.040727675|
|1000938| 1000445|    3| 

|1001523| 1053680|    2|          NaN|
|1001523| 1014863|    6|  -0.46177542|
|1001526| 1035626|    3|          NaN|
|1001529| 1250079|   12|      82.0258|
|1001533| 1254936|    1| -0.075037524|
|1001534|    4179|    1|   0.58743805|
|1001534| 1016514|    3|   -0.2574722|
|1001535|    3674|    1|     37.37452|
|1001535| 1037110|   14|          NaN|
|1001535|     251|   12|    14.522642|
|1001535|    4137|   18|     5.897642|
|1001539|    4195|    2|          NaN|
|1001540| 1239951|   13|     -5.40349|
|1001547| 2081924|    4|          NaN|
|1001549| 1031182|    1|    25.296633|
|1001555| 1006302|    1|          NaN|
|1001562| 1007286|   29|    -111.7603|
|1001562|      61|  117|    27.433655|
|1001562| 1002772|  357|    -69.76492|
|1001563| 1084759|   52|          NaN|
|1001568| 1018729|    2|          NaN|
|1001568| 1047430|    1|          NaN|
|1001568|    1177|   10|          NaN|
|1001569| 1000677|    1|          NaN|
|1001575| 1073862|    2|          NaN|
|1001577| 1007006|    6| 

|1002260| 1018931|   56|    20.605175|
|1002267| 1002530|    3|          NaN|
|1002271| 1000313|    6|    3.5732522|
|1002284| 1004775|    1|          NaN|
|1002284| 1020686|    1|    3.1682334|
|1002284| 1006024|    2|    17.058002|
|1002291| 1004147|    5|          NaN|
|1002296| 1014626|    7|  0.109076664|
|1002308|     733|   50|    0.5847292|
|1002310| 1000487|    6|   -39.712376|
|1002312| 1004129|    1|   -27.013977|
|1002313| 1000113|   15|   0.44461071|
|1002316|      54|   25|   -3.6066797|
|1002316|    1786|    1|   -13.528242|
|1002317| 1000295|    1|     7.221228|
|1002319| 1002291|  242|    51.814087|
|1002319|    1212|   15|    23.854301|
|1002322|    5837|  100|          NaN|
|1002325|    3329|    1|    -3.010603|
|1002325| 1001580|    3|     4.951286|
|1002330| 1006871|    1|   -1.4198209|
|1002340|     979|    1|          NaN|
|1002341|     224|   28|   -52.996254|
|1002345| 1142863|    2|   -23.731514|
|1002345| 6685157|    1|          NaN|
|1002345| 1160066|    3| 