# NLP RECSYS dataset subselection

Use a smaller subset for more advanced predictions




In [2]:
import com.johnsnowlabs.nlp.SparkNLP
import com.johnsnowlabs.nlp.annotator._
import com.johnsnowlabs.nlp.base._
import org.apache.spark.ml.{Pipeline, PipelineModel, Transformer}
import org.apache.spark.sql.types._
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.functions.{udf,to_timestamp}
import org.apache.spark.storage._
import org.apache.spark.ml.feature._
import org.apache.spark.ml.classification._
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, explode, udf}
import org.apache.spark.sql.types.{DataTypes, StructType}


val dataDir = sys.env("HOME") + "/recsys2020"
val dsName = "training"

In [1]:
val df = spark.read.parquet(dataDir + s"/${dsName}.parquet")

In [3]:
df

[user_id: string, tweet_id: string ... 17 more fields]

In [4]:
val num_user_interactions = df.groupBy("user_id").count()
num_user_interactions.show

+--------------------+-----+
|             user_id|count|
+--------------------+-----+
|6DB36629F09EEF6B0...|    5|
|6DB3CB729521B094E...|    6|
|6DB935FBE5AFB72C0...|    5|
|6DB9B1833E6F6B16D...|    5|
|6DBA6B1E954B4D4AA...|    9|
|6DBC5460B25D693AC...|   15|
|6DBC89235E5F715CF...|    8|
|6DC28C6B6E1E45CE4...|   55|
|6DC66F416E5010527...|   13|
|6DC79920B693257A4...|    4|
|6DC8D079EA2CEDC84...|    3|
|6DCC26641F198AE2C...|   13|
|6DCD8B083F0AA9BF2...|   19|
|6DD212AFBAB897074...|    3|
|6DD3132C30767176B...|    4|
|6DD3CEC47DEA7CC2D...|    8|
|6DD5B319ED8EF914B...|    3|
|6DDC42DEDEA6E6DF2...|    6|
|6DDD8632D48677C3F...|   14|
|6DDFBF4ED82A25E05...|    7|
+--------------------+-----+
only showing top 20 rows



In [5]:
print(num_user_interactions.count())
print(num_user_interactions.select(avg($"count")).show)
print(num_user_interactions.select(sum($"count")).show)

24429989+-----------------+
|       avg(count)|
+-----------------+
|5.591499938866121|
+-----------------+

()+----------+
|sum(count)|
+----------+
| 136600282|
+----------+

()

In [6]:
val fraction = 0.01 
val sampled_users = num_user_interactions.sample(fraction)
print(sampled_users.count())
print(sampled_users.select(avg($"count")).show)
print(sampled_users.select(sum($"count")).show)

243927+-----------------+
|       avg(count)|
+-----------------+
|5.601483230638674|
+-----------------+

()+----------+
|sum(count)|
+----------+
|   1366353|
+----------+

()

In [7]:
val allowed_user_list = sampled_users.select("user_id").rdd.map(r => r(0)).collect().toList

In [8]:
val filtered_df = df.filter($"user_id".isin(allowed_user_list:_*))
filtered_df.show()

+--------------------+--------------------+----------+------------------------------+---------------------+----------------------+------------------+-------------------+--------------------+----------------+-------+---------------+--------------------+--------------+--------------------+-----------+------------------------+--------+---------+
|             user_id|            tweet_id|tweet_type|                    tweet_text|author_follower_count|author_following_count|author_is_verified|user_follower_count|user_following_count|user_is_verified|follows|tweet_timestamp|            hashtags| present_media|     present_domains|has_retweet|has_retweet_with_comment|has_like|has_reply|
+--------------------+--------------------+----------+------------------------------+---------------------+----------------------+------------------+-------------------+--------------------+----------------+-------+---------------+--------------------+--------------+--------------------+-----------+----------

In [9]:
val outDsName = "user_sampled"
filtered_df.write.mode(SaveMode.Overwrite).parquet(dataDir + s"/${outDsName}.parquet")
// filtered_df.write.mode(SaveMode.Overwrite).parquet(sys.env("HOME") + s"/${outDsName}.parquet")

In [10]:
filtered_df.count().show