In [27]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [28]:
!apt-get -y install openjdk-8-jre-headless
!pip install pyspark

Reading package lists... Done
Building dependency tree       
Reading state information... Done
openjdk-8-jre-headless is already the newest version (8u352-ga-1~18.04).
The following package was automatically installed and is no longer required:
  libnvidia-common-460
Use 'apt autoremove' to remove it.
0 upgraded, 0 newly installed, 0 to remove and 20 not upgraded.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [29]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("recommandation").getOrCreate()

In [30]:
df_train = spark.read.csv('/content/drive/MyDrive/2022-CCBDA/train_student.csv',header=True,inferSchema=True)

In [31]:
df_train.printSchema()

root
 |-- item: string (nullable = true)
 |-- user: string (nullable = true)
 |-- rating: integer (nullable = true)



In [33]:
from pyspark.sql.functions import format_number, lit
df_train = df_train.withColumn('rating', df_train.rating.cast('float'))

In [32]:
df_train.printSchema()

root
 |-- item: string (nullable = true)
 |-- user: string (nullable = true)
 |-- rating: integer (nullable = true)



In [34]:
df_train.describe().show()

+-------+--------------------+--------------------+------------------+
|summary|                item|                user|            rating|
+-------+--------------------+--------------------+------------------+
|  count|               83798|               83798|             83798|
|   mean|2.0769135928283918E9|                null| 4.359018114990811|
| stddev|2.0710778913333938E9|                null|0.9940809586163019|
|    min|          0439893577|A012468118FTQAINE...|               1.0|
|    max|          B00LBI9BKA|       AZZYW4YOE1B6E|               5.0|
+-------+--------------------+--------------------+------------------+



In [35]:
df_train.head(10)

[Row(item='B008H54GVE', user='A2OIMJEGOCTQ87', rating=4.0),
 Row(item='B001NLISDG', user='A38KQZS5M1A8T8', rating=4.0),
 Row(item='B007S3S8HO', user='A1F5O1USOUOOXI', rating=3.0),
 Row(item='B008G6OOHA', user='A1FTZ5LLEX7NCM', rating=4.0),
 Row(item='B00A9JNR8E', user='A1R2JUOGIYH6HO', rating=4.0),
 Row(item='B00000JH3R', user='A2X7C89I7YRX1O', rating=5.0),
 Row(item='B000RXPU0U', user='A24BSKCWXC4M6D', rating=5.0),
 Row(item='B00DWXUYN0', user='A3JX1D26WFEXOS', rating=5.0),
 Row(item='B000E9DPVI', user='A2IWHA1XEYSPD6', rating=5.0),
 Row(item='B001EB9F3C', user='A2340917M1HHZ3', rating=4.0)]

Useful function: indexing ID(string) to numerical index

In [36]:
from pyspark.ml.feature import StringIndexer, IndexToString
from pyspark.ml import Pipeline


In [37]:
indexers = [StringIndexer(inputCol=column, outputCol=column+"_index").setHandleInvalid("keep").fit(df_train) for column in list(set(df_train.columns)-set(['rating'])) ]
pipeline = Pipeline(stages=indexers)
df_ID = pipeline.fit(df_train).transform(df_train)

In [38]:
df_ID.head(10)

[Row(item='B008H54GVE', user='A2OIMJEGOCTQ87', rating=4.0, user_index=1292.0, item_index=5709.0),
 Row(item='B001NLISDG', user='A38KQZS5M1A8T8', rating=4.0, user_index=4940.0, item_index=1849.0),
 Row(item='B007S3S8HO', user='A1F5O1USOUOOXI', rating=3.0, user_index=669.0, item_index=937.0),
 Row(item='B008G6OOHA', user='A1FTZ5LLEX7NCM', rating=4.0, user_index=107.0, item_index=10888.0),
 Row(item='B00A9JNR8E', user='A1R2JUOGIYH6HO', rating=4.0, user_index=1210.0, item_index=1979.0),
 Row(item='B00000JH3R', user='A2X7C89I7YRX1O', rating=5.0, user_index=4764.0, item_index=979.0),
 Row(item='B000RXPU0U', user='A24BSKCWXC4M6D', rating=5.0, user_index=2866.0, item_index=6385.0),
 Row(item='B00DWXUYN0', user='A3JX1D26WFEXOS', rating=5.0, user_index=8104.0, item_index=9620.0),
 Row(item='B000E9DPVI', user='A2IWHA1XEYSPD6', rating=5.0, user_index=11230.0, item_index=420.0),
 Row(item='B001EB9F3C', user='A2340917M1HHZ3', rating=4.0, user_index=10687.0, item_index=51.0)]

In [144]:
(train, test) = df_ID.randomSplit([0.85, 0.15], seed = 2023)

Useful function: Construct Binary data \\
adding 1 for watched and 0 for not watched

In [145]:
def get_binary_data(ratings):
    ratings = df_ID.withColumn('binary', lit(1))
    userIds = df_ID.select("user_index").distinct()
    itemIds = df_ID.select("item_index").distinct()

    user_item = userIds.crossJoin(itemIds).join(ratings, ['user_index', 'item_index'], "left")
    user_item = user_item.select(['user_index', 'item_index', 'binary']).fillna(0)
    return user_item

user_item = get_binary_data(df_ID)

In [146]:
user_item.show()

+----------+----------+------+
|user_index|item_index|binary|
+----------+----------+------+
|    5776.0|     305.0|     0|
|     305.0|     305.0|     0|
|   12737.0|     305.0|     0|
|    2734.0|     305.0|     0|
|     934.0|     305.0|     0|
|     692.0|     305.0|     0|
|    3980.0|     305.0|     0|
|   13533.0|     305.0|     0|
|     496.0|     305.0|     0|
|    6653.0|     305.0|     0|
|   12172.0|     305.0|     0|
|   16981.0|     305.0|     0|
|   12467.0|     305.0|     0|
|    6433.0|     305.0|     0|
|   13918.0|     305.0|     0|
|    6067.0|     305.0|     0|
|    9753.0|     305.0|     0|
|     299.0|     305.0|     0|
|    4800.0|     305.0|     0|
|   14452.0|     305.0|     0|
+----------+----------+------+
only showing top 20 rows



Useful Function: pyspark recommandation system ML model \\
Remember to set coldStartStrategy="nan" to make sure model not ignore null data.

Spark allows users to set the coldStartStrategy parameter to “drop” in order to drop any rows in the DataFrame of predictions that contain NaN values. The evaluation metric will then be computed over the non-NaN data and will be valid. Usage of this parameter is illustrated in the example below.

In [147]:
# Import the required functions
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS

# Create ALS model
als = ALS(
         userCol="user_index", 
         itemCol="item_index",
         ratingCol="rating", 
         nonnegative = True, 
         implicitPrefs = False,
         coldStartStrategy="nan"
)

In [136]:
# Import the requisite packages
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

# [Custom]
param_grid = ParamGridBuilder() \
            .addGrid(als.rank, [8, 9]) \
            .addGrid(als.regParam, [.05, .1,]) \
            .build()
            # .addGrid(als.regParam, [.005, .01,]) \
            # .addGrid(als.maxIter, [5, 10]) \
           
print ("Num models to be tested: ", len(param_grid))

evaluator=RegressionEvaluator(metricName="rmse",labelCol="rating",predictionCol="prediction")


Num models to be tested:  4


In [137]:
cv = CrossValidator(estimator=als, estimatorParamMaps=param_grid, evaluator=evaluator, numFolds=5)
print(cv)

CrossValidator_c1fb5be011bb


In [148]:
#Fit cross validator to the 'train' dataset
model = als.fit(train)

#[Custom]
#model = cv.fit(train) 
#model = model.bestModel


Useful function: fill value to NaN part \\
example : \\
a = a.na.fill(value=999)

In [149]:
# View the predictions
test_predictions = model.transform(test)
print(test_predictions.na)
# rememver to fix nan in prediction
test_predictions=test_predictions.na.fill(value=4.35)

RMSE = evaluator.evaluate(test_predictions)
print(RMSE)

<pyspark.sql.dataframe.DataFrameNaFunctions object at 0x7f3795f8bd00>
1.278774629300528


In [150]:
df_public = spark.read.csv('/content/drive/MyDrive/2022-CCBDA/test_public.csv',header=True,inferSchema=True)
df_public_ID = pipeline.fit(df_public).transform(df_public)

In [151]:
evaluator=RegressionEvaluator(metricName="rmse",labelCol="rating",predictionCol="prediction")
predictions=model.transform(df_public_ID)
# rememver to fix nan in prediction
predictions=predictions.na.fill(value=4.35)

rmse=evaluator.evaluate(predictions)
print("RMSE="+str(rmse))
predictions.show()

RMSE=1.260688436786773
+----------+--------------+------+----------+----------+----------+
|      item|          user|rating|user_index|item_index|prediction|
+----------+--------------+------+----------+----------+----------+
|B0006IRTU0|A2CYXQOAR1EJRQ|     5|   18105.0|     471.0|  4.291501|
|B000TK8440|A35B72PSA30R67|     5|    1699.0|     183.0|  4.396342|
|B0002BSTY6|A194UXXKM11698|     4|    1160.0|    7919.0| 3.1965234|
|B006HCVT5A|A3ISFBZ5UFK81I|     5|   16376.0|    3045.0|  4.237437|
|B004R1ZUNA| A3Y0IB3VYLD6A|     5|    5296.0|     297.0| 3.4934626|
|B00428LJ06|A1MLBMJSFK6BIJ|     5|    6338.0|     142.0| 4.8054934|
|B004GXIDYM|A34BONVNM07TRG|     5|   18436.0|    2537.0| 3.9557204|
|B0040GK7NK| AR3EVUQF0AC7R|     5|    1433.0|    1901.0| 3.8232243|
|B00A88EPCI|A3L249C56OJI7D|     5|    8140.0|     838.0| 4.0189342|
|B000IBPD76| ARSNAGZWXP7GN|     5|   13708.0|      39.0| 3.9411035|
|B0063NC3N0|A3AZPAZXGOD4VL|     5|   16190.0|    1495.0|  4.326409|
|B001Q1A2P0|A168O2YKPE9BE

In [153]:
df_testall = spark.read.csv('/content/drive/MyDrive/2022-CCBDA/test_all.csv',header=True,inferSchema=True)


In [154]:
df_test_ID = pipeline.fit(df_testall).transform(df_testall)

In [155]:
df_test_ID.printSchema()

root
 |-- item: string (nullable = true)
 |-- user: string (nullable = true)
 |-- user_index: double (nullable = false)
 |-- item_index: double (nullable = false)



In [156]:
df_test_ID.describe().show()

+-------+--------------------+--------------------+-----------------+------------------+
|summary|                item|                user|       user_index|        item_index|
+-------+--------------------+--------------------+-----------------+------------------+
|  count|               83799|               83799|            83799|             83799|
|   mean|2.1813556526264625E9|                null|8276.219012160049|3844.3237150801324|
| stddev|2.1882518602245975E9|                null|6463.700444398778| 3647.800641218535|
|    min|          0439893577|A012468118FTQAINE...|              0.0|               0.0|
|    max|          B00LBI9BKA|       AZZYW4YOE1B6E|          19127.0|           11824.0|
+-------+--------------------+--------------------+-----------------+------------------+



In [157]:
df_test_ID.show()

+----------+--------------+----------+----------+
|      item|          user|user_index|item_index|
+----------+--------------+----------+----------+
|B0015FRC32|A28QKOPBDPSHE5|   10890.0|   10076.0|
|B000EUKRY0|A1PTTEYFE49BQM|    6406.0|     104.0|
|B007XPLI56|A2C5VTBNC6I5MY|   15278.0|     404.0|
|B003A5RTHO|A12IOCD2A7OC7K|    2525.0|    6845.0|
|B006HCVT5A|A3ISFBZ5UFK81I|   16376.0|    3045.0|
|B004GXIDYM|A34BONVNM07TRG|   18436.0|    2537.0|
|B004R1ZUNA| A3Y0IB3VYLD6A|    5296.0|     297.0|
|B0006IRTU0|A2CYXQOAR1EJRQ|   18105.0|     471.0|
|B00A88EPCI|A3L249C56OJI7D|    8140.0|     838.0|
|B00428LJ06|A1MLBMJSFK6BIJ|    6338.0|     142.0|
|B0002BSTY6|A194UXXKM11698|    1160.0|    7919.0|
|B000IBPD76| ARSNAGZWXP7GN|   13708.0|      39.0|
|B000TK8440|A35B72PSA30R67|    1699.0|     183.0|
|B000EQGT00|A304ILYRZ145SI|     710.0|     135.0|
|B00005YVRN|A3118YKNMNAS33|    2209.0|     134.0|
|B0027FFMBS|A3LHE5MHDF7X2R|   12545.0|   11442.0|
|B00AAPHZVW| ANOST6C92T7HB|   18970.0|     557.0|


In [158]:
testall_predictions=model.transform(df_test_ID)

In [159]:
testall_predictions.printSchema()

root
 |-- item: string (nullable = true)
 |-- user: string (nullable = true)
 |-- user_index: double (nullable = false)
 |-- item_index: double (nullable = false)
 |-- prediction: float (nullable = false)



In [160]:
testall_predictions.describe().show()

+-------+--------------------+--------------------+------------------+------------------+----------+
|summary|                item|                user|        user_index|        item_index|prediction|
+-------+--------------------+--------------------+------------------+------------------+----------+
|  count|               83799|               83799|             83799|             83799|     83799|
|   mean|2.1813556526264625E9|                null| 8276.219012160049|3844.3237150801324|       NaN|
| stddev|2.1882518602245984E9|                null|6463.7004443987325|3647.8006412185114|       NaN|
|    min|          0439893577|A012468118FTQAINE...|               0.0|               0.0|0.04960001|
|    max|          B00LBI9BKA|       AZZYW4YOE1B6E|           19127.0|           11824.0|       NaN|
+-------+--------------------+--------------------+------------------+------------------+----------+



In [161]:
testall_predictions.show()

+----------+--------------+----------+----------+----------+
|      item|          user|user_index|item_index|prediction|
+----------+--------------+----------+----------+----------+
|B0006IRTU0|A2CYXQOAR1EJRQ|   18105.0|     471.0|  4.291501|
|B000TK8440|A35B72PSA30R67|    1699.0|     183.0|  4.396342|
|B0002BSTY6|A194UXXKM11698|    1160.0|    7919.0| 3.1965234|
|B006HCVT5A|A3ISFBZ5UFK81I|   16376.0|    3045.0|  4.237437|
|B004R1ZUNA| A3Y0IB3VYLD6A|    5296.0|     297.0| 3.4934626|
|B00428LJ06|A1MLBMJSFK6BIJ|    6338.0|     142.0| 4.8054934|
|B004GXIDYM|A34BONVNM07TRG|   18436.0|    2537.0| 3.9557204|
|B0040GK7NK| AR3EVUQF0AC7R|    1433.0|    1901.0| 3.8232243|
|B00A88EPCI|A3L249C56OJI7D|    8140.0|     838.0| 4.0189342|
|B000IBPD76| ARSNAGZWXP7GN|   13708.0|      39.0| 3.9411035|
|B0063NC3N0|A3AZPAZXGOD4VL|   16190.0|    1495.0|  4.326409|
|B001Q1A2P0|A168O2YKPE9BE8|      25.0|   10170.0|   5.03905|
|B007XPLI56|A2C5VTBNC6I5MY|   15278.0|     404.0| 3.0067468|
|B003A5RTHO|A12IOCD2A7OC

Useful function: concate User Item to generate U_I column

In [162]:
from pyspark.sql.functions import concat, col, lit

out = testall_predictions.select(concat(col("user"), lit("_"), col("item")) , testall_predictions.prediction)

In [163]:
out.printSchema()

root
 |-- concat(user, _, item): string (nullable = true)
 |-- prediction: float (nullable = false)



In [164]:
out.show()

+---------------------+----------+
|concat(user, _, item)|prediction|
+---------------------+----------+
| A2CYXQOAR1EJRQ_B0...|  4.291501|
| A35B72PSA30R67_B0...|  4.396342|
| A194UXXKM11698_B0...| 3.1965234|
| A3ISFBZ5UFK81I_B0...|  4.237437|
| A3Y0IB3VYLD6A_B00...| 3.4934626|
| A1MLBMJSFK6BIJ_B0...| 4.8054934|
| A34BONVNM07TRG_B0...| 3.9557204|
| AR3EVUQF0AC7R_B00...| 3.8232243|
| A3L249C56OJI7D_B0...| 4.0189342|
| ARSNAGZWXP7GN_B00...| 3.9411035|
| A3AZPAZXGOD4VL_B0...|  4.326409|
| A168O2YKPE9BE8_B0...|   5.03905|
| A2C5VTBNC6I5MY_B0...| 3.0067468|
| A12IOCD2A7OC7K_B0...| 3.8784966|
| A304ILYRZ145SI_B0...| 3.2616067|
| ANOST6C92T7HB_B00...|       NaN|
| A3J1CEZ30ZOJ7S_B0...| 3.1128843|
| A28QKOPBDPSHE5_B0...| 4.3424253|
| A1PTTEYFE49BQM_B0...| 3.6763585|
| A3118YKNMNAS33_B0...| 4.1702085|
+---------------------+----------+
only showing top 20 rows



Remember to fix NaN value

In [165]:
out=out.na.fill(value=4.35)

In [166]:
out.show()

+---------------------+----------+
|concat(user, _, item)|prediction|
+---------------------+----------+
| A2CYXQOAR1EJRQ_B0...|  4.291501|
| A35B72PSA30R67_B0...|  4.396342|
| A194UXXKM11698_B0...| 3.1965234|
| A3ISFBZ5UFK81I_B0...|  4.237437|
| A3Y0IB3VYLD6A_B00...| 3.4934626|
| A1MLBMJSFK6BIJ_B0...| 4.8054934|
| A34BONVNM07TRG_B0...| 3.9557204|
| AR3EVUQF0AC7R_B00...| 3.8232243|
| A3L249C56OJI7D_B0...| 4.0189342|
| ARSNAGZWXP7GN_B00...| 3.9411035|
| A3AZPAZXGOD4VL_B0...|  4.326409|
| A168O2YKPE9BE8_B0...|   5.03905|
| A2C5VTBNC6I5MY_B0...| 3.0067468|
| A12IOCD2A7OC7K_B0...| 3.8784966|
| A304ILYRZ145SI_B0...| 3.2616067|
| ANOST6C92T7HB_B00...|      4.35|
| A3J1CEZ30ZOJ7S_B0...| 3.1128843|
| A28QKOPBDPSHE5_B0...| 4.3424253|
| A1PTTEYFE49BQM_B0...| 3.6763585|
| A3118YKNMNAS33_B0...| 4.1702085|
+---------------------+----------+
only showing top 20 rows



Remember to sort by U_I before output

In [167]:
out.orderBy("concat(user, _, item)").show()

+---------------------+----------+
|concat(user, _, item)|prediction|
+---------------------+----------+
| A012468118FTQAINE...| 4.6847363|
| A012468118FTQAINE...|  4.544454|
| A012468118FTQAINE...| 4.2808523|
| A0182108CPDLPRCXQ...| 4.1718836|
| A0182108CPDLPRCXQ...|  4.174871|
| A0182108CPDLPRCXQ...|  3.510594|
| A026961431MGW0616...| 1.9600806|
| A026961431MGW0616...|  4.127269|
| A026961431MGW0616...| 3.0147252|
| A034597326Z83X79S...| 3.5298998|
| A034597326Z83X79S...| 3.5314264|
| A034597326Z83X79S...| 3.8800292|
| A034597326Z83X79S...| 3.8184423|
| A04295422T2ZG087R...|      4.35|
| A04295422T2ZG087R...| 2.4695177|
| A060131923OZAPX4N...|  3.419979|
| A060131923OZAPX4N...|  3.471266|
| A060131923OZAPX4N...| 3.4758053|
| A060131923OZAPX4N...| 3.5014637|
| A060131923OZAPX4N...|  2.802917|
+---------------------+----------+
only showing top 20 rows



In [168]:
out.orderBy("concat(user, _, item)").coalesce(1).write.format("csv").mode('overwrite').save("/content/drive/MyDrive/2022-CCBDA/sample_submission.csv")