In [1]:
import findspark
findspark.init()
import pyspark
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

In [2]:
df = spark.read \
    .csv('Documents/final-project-team-deep-coral/down_id5000_.csv')
df = df.toDF('user_id', 'book_id', 'is_read', 'rating', 'is_reviewed')
df.count()

2513835

In [3]:
user_list = df.select("user_id").distinct()
user_list.count()

4995

In [4]:
user_train, user_valid, user_test = user_list.randomSplit([0.6, 0.2, 0.2], seed = 42)

In [5]:
user_train.count(), user_valid.count(), user_test.count()

(2964, 1000, 1031)

In [6]:
df.createOrReplaceTempView('df')
user_train.createOrReplaceTempView('user_train')
X_train = spark.sql('SELECT * FROM df WHERE user_id IN (SELECT user_id FROM user_train)')
X_train.count()

1523935

In [7]:
user_valid.createOrReplaceTempView('user_valid')
X_valid = spark.sql('SELECT * FROM df WHERE user_id IN (SELECT user_id FROM user_valid)')
X_valid.groupby(['user_id'])\
    .agg({'book_id': 'count'})\
    .show(10)

+-------+--------------+
|user_id|count(book_id)|
+-------+--------------+
|   2294|           173|
|   1090|           565|
|    467|           129|
|   3650|          1415|
|   2464|           280|
|   3858|           370|
|   4975|           209|
|   2393|           236|
|   4127|           313|
|   3200|          1196|
+-------+--------------+
only showing top 10 rows



In [8]:
X_valid_sampled = X_valid.sampleBy("user_id", fractions={k['user_id']: 0.5 for k in user_valid.rdd.collect()}, seed=42)
X_valid_sampled.groupby(['user_id'])\
    .agg({'book_id': 'count'})\
    .show(10)

+-------+--------------+
|user_id|count(book_id)|
+-------+--------------+
|   2294|            93|
|   1090|           285|
|    467|            64|
|   3650|           703|
|   2464|           132|
|   3858|           197|
|   4975|           110|
|   2393|           136|
|   4127|           163|
|   3200|           631|
+-------+--------------+
only showing top 10 rows



In [11]:
# X_valid.createOrReplaceTempView('X_valid')
# X_valid_sampled.createOrReplaceTempView('X_valid_sampled')
# X_valid_to_train = spark.sql('SELECT * FROM X_valid WHERE book_id NOT IN (SELECT book_id FROM X_valid_sampled)')
X_valid_to_train = X_valid.subtract(X_valid_sampled)
X_valid_to_train.groupby(['user_id'])\
    .agg({'book_id': 'count'})\
    .show(10)

+-------+--------------+
|user_id|count(book_id)|
+-------+--------------+
|   1090|           280|
|    467|            65|
|   2294|            80|
|   2464|           148|
|   3650|           712|
|   3858|           173|
|   4975|            99|
|   2393|           100|
|   4127|           150|
|   2530|            68|
+-------+--------------+
only showing top 10 rows



In [12]:
user_test.createOrReplaceTempView('user_test')
X_test = spark.sql('SELECT * FROM df WHERE user_id IN (SELECT user_id FROM user_test)')
X_test.groupby(['user_id'])\
    .agg({'book_id': 'count'})\
    .show(10)

+-------+--------------+
|user_id|count(book_id)|
+-------+--------------+
|   1159|           376|
|   3414|           408|
|   1436|           148|
|   1512|           279|
|   4032|           237|
|   3441|           102|
|    944|           442|
|   1394|           325|
|   2275|           279|
|   4838|           616|
+-------+--------------+
only showing top 10 rows



In [13]:
X_test_sampled = X_test.sampleBy("user_id", fractions={k['user_id']: 0.5 for k in user_test.rdd.collect()}, seed=42)
X_test_sampled.groupby(['user_id'])\
    .agg({'book_id': 'count'})\
    .show(10)

+-------+--------------+
|user_id|count(book_id)|
+-------+--------------+
|   1159|           202|
|   3414|           214|
|   1436|            76|
|   1512|           137|
|   4032|           114|
|   3441|            48|
|    944|           219|
|   1394|           157|
|   2275|           135|
|   4838|           287|
+-------+--------------+
only showing top 10 rows



In [14]:
X_test_to_train = X_test.subtract(X_test_sampled)
X_test_to_train.groupby(['user_id'])\
    .agg({'book_id': 'count'})\
    .show(10)

+-------+--------------+
|user_id|count(book_id)|
+-------+--------------+
|   1159|           174|
|   1512|           142|
|   1436|            72|
|   3414|           194|
|   4032|           123|
|   4838|           329|
|   1394|           168|
|   3441|            54|
|   2275|           144|
|    944|           223|
+-------+--------------+
only showing top 10 rows



In [15]:
def data_split(spark, file_path):
    """
    This function splits a dataframe into the train/valid/test set.
    
    - train: randomly sample 60% of users and include all of their interactions
                + 50% of interactions from users in the valid set
    - valid: randomly sample 20% of users and include 50% of their interactions
    - test : randomly sample 20% of users and include 50% of their interactions
    
    Random sampling of users and interactions results in mutually exclusive splits.

    Parameters
    ----------
    spark : spark session object
    file_path : string; The path (in HDFS) to the CSV file, e.g., `hdfs:/user/bm106/pub/people_small.csv`

    """
    # Load the CSV file and Set the column name in case it's missing
    df = spark.read.csv(file_path)
    df = df.toDF('user_id', 'book_id', 'is_read', 'rating', 'is_reviewed')  
    
    # Create a single-column dataframe with distinct user_ids and Randomly split into train/valid/test user groups
    user_list = df.select("user_id").distinct()
    user_train, user_valid, user_test = user_list.randomSplit([0.6, 0.2, 0.2], seed = 42)
    
    # Create X_train
    df.createOrReplaceTempView('df')
    user_train.createOrReplaceTempView('user_train')
    X_train = spark.sql('SELECT * FROM df WHERE user_id IN (SELECT user_id FROM user_train)')
    
    # Create X_valid
    user_valid.createOrReplaceTempView('user_valid')
    X_valid = spark.sql('SELECT * FROM df WHERE user_id IN (SELECT user_id FROM user_valid)')
    X_valid_sampled = X_valid.sampleBy("user_id", fractions={k['user_id']: 0.5 for k in user_valid.rdd.collect()}, seed=42)
    X_valid_to_train = X_valid.subtract(X_valid_sampled)  # This dataframe will be concatenated with X_train
    
    # Create X_test
    user_test.createOrReplaceTempView('user_test')
    X_test = spark.sql('SELECT * FROM df WHERE user_id IN (SELECT user_id FROM user_test)')
    X_test_sampled = X_test.sampleBy("user_id", fractions={k['user_id']: 0.5 for k in user_test.rdd.collect()}, seed=42)
    X_test_to_train = X_test.subtract(X_test_sampled)
    
    # Concatenate remaining records of valid/test to X_train
    X_train = X_train.union(X_valid_to_train)
    X_train = X_train.union(X_test_to_train)
    
    return X_train, X_valid_sampled, X_test_sampled

In [16]:
a, b, c = data_split(spark, 'Documents/final-project-team-deep-coral/down_id5000_.csv')

In [18]:
a.count()

2018797

In [19]:
b.count()

244539

In [20]:
c.count()

250499

In [23]:
2018797+244539+250499 == 2513835

True