In [None]:
import pandas as pd
import numpy as np
import os

import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql import types as T
from ptls.preprocessing import PysparkDataPreprocessor

os.environ['JAVA_HOME']= '/home/jovyan/conda/kalash/lib/jvm'

spark_conf = pyspark.SparkConf()
spark_conf.setMaster("local[*]").setAppName("JoinModality")
spark_conf.set("spark.driver.maxResultSize", "16g")
spark_conf.set("spark.executor.memory", "32g")
spark_conf.set("spark.executor.memoryOverhead", "16g")
spark_conf.set("spark.driver.memory", "32g")
spark_conf.set("spark.driver.memoryOverhead", "16g")
spark_conf.set("spark.cores.max", "24")
spark_conf.set("spark.sql.shuffle.partitions", "200")
spark_conf.set("spark.local.dir", "../../spark_local_dir")


spark = SparkSession.builder.config(conf=spark_conf).getOrCreate()
spark.sparkContext.getConf().getAll()

In [6]:
transactions = spark.read.csv('../data/raw_data/transactions.csv', header=True)
clickstream = spark.read.csv('../data/raw_data/clickstream.csv', header=True)
train_matching = spark.read.csv('../data/raw_data/train_matching.csv', header=True)
train_edu = spark.read.csv('../data/raw_data/train_edu.csv', header=True)

click_categories = spark.read.csv('../data/raw_data/click_categories.csv', header=True)
clickstream = clickstream.join(click_categories, on='cat_id')

# Preprocessing

In [5]:
transactions.select('user_id').distinct().count(), clickstream.select('user_id').distinct().count(), train_edu.select('bank').distinct().count(), train_matching.select('bank').distinct().count(), train_matching.select('rtk').distinct().count()

24/09/28 15:03:37 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors
                                                                                

(22533, 19623, 8509, 17581, 14672)

# Transform to PTLS format

In [6]:
preprocessor_trx = PysparkDataPreprocessor(
        col_id='user_id',
        col_event_time='transaction_dttm',
        event_time_transformation='dt_to_timestamp',
        cols_category=["mcc_code", "currency_rk"],
    )


preprocessor_click = PysparkDataPreprocessor(
    col_id='user_id',
    col_event_time='timestamp',
    event_time_transformation='dt_to_timestamp',
    cols_category=['cat_id', 'level_0', 'level_1', 'level_2'],
)

In [None]:
transactions_prepared = preprocessor_trx.fit_transform(transactions)
clickstream_prepared = preprocessor_click.fit_transform(clickstream)


# Get user_id for matching

In [9]:
train_matching = train_matching.withColumnRenamed('rtk', 'user_id')
clickstream_prepared = clickstream_prepared.join(train_matching, on='user_id', how='outer').drop('user_id')
clickstream_prepared  = clickstream_prepared.withColumnRenamed('bank', 'user_id')
clickstream_prepared.show(2)

[Stage 84:>                                                         (0 + 1) / 1]

+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|          event_time|              cat_id|             new_uid|             level_0|             level_1|             level_2|             user_id|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|[1616456760, 1616...|[1, 1, 1, 1, 1, 1...|[411399, 411399, ...|[1, 1, 1, 1, 1, 1...|[1, 1, 1, 1, 1, 1...|[1, 1, 1, 1, 1, 1...|95f2446d41fc4536b...|
|[1612148340, 1612...|[3, 12, 5, 12, 40...|[1840824, 1840824...|[3, 5, 5, 5, 38, ...|[1, 2, 1, 2, 1, 1...|[1, 1, 1, 1, 1, 1...|89d5b991d5dc4c5d8...|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
only showing top 2 rows



                                                                                

In [10]:
clickstream_prepared = clickstream_prepared.withColumnRenamed('event_time', 'click_event_time')
transactions_prepared = transactions_prepared.withColumnRenamed('event_time', 'trx_event_time')

# Join

In [11]:
mm_dataset = transactions_prepared.join(clickstream_prepared, on='user_id', how='outer')
mm_dataset.show(5)

[Stage 96:>                                                         (0 + 1) / 1]

+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|             user_id|      trx_event_time|            mcc_code|         currency_rk|     transaction_amt|    click_event_time|              cat_id|             new_uid|             level_0|             level_1|             level_2|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|0012e60b16f14da4b...|[1596243049, 1596...|[1, 23, 7, 1, 1, ...|[1, 1, 1, 1, 1, 1...|[-398.97632, -195...|[1611561647, 1611...|[29, 1, 27, 9, 9,...|[1439071, 1079747...|[26, 1, 25, 9, 9,...|[1, 1, 1, 1, 1, 1...|[1, 1, 1, 1, 1, 1...|
|003d93fb918846ada...|[1596247679, 1596...|[5, 5, 8, 2, 29, ...|[1, 

                                                                                

In [12]:
train_edu = train_edu.withColumnRenamed('bank', 'user_id')

In [13]:
mm_dataset = mm_dataset.join(train_edu, on='user_id', how='outer')

In [14]:
mm_dataset.show(1)



+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+----------------+
|             user_id|      trx_event_time|            mcc_code|         currency_rk|     transaction_amt|    click_event_time|              cat_id|             new_uid|             level_0|             level_1|             level_2|higher_education|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+----------------+
|0012e60b16f14da4b...|[1596243049, 1596...|[1, 23, 7, 1, 1, ...|[1, 1, 1, 1, 1, 1...|[-398.97632, -195...|[1611561647, 1611...|[29, 1, 27, 9, 9,...|[1439071, 1079747...|[26, 1, 25, 9, 9,...|[1, 1, 1, 1, 1, 1...|[1, 1, 1, 1, 1, 1...|            null|


                                                                                

In [15]:
mm_dataset.write.mode('overwrite').parquet('../data/mm_dataset.parquet')

                                                                                

# Split by fold 

In [2]:
mm_dataset = spark.read.parquet('../data/mm_dataset.parquet')

In [3]:
mm_dataset_fold0, mm_dataset_fold1, mm_dataset_fold2,  mm_dataset_fold3, mm_dataset_fold4 = mm_dataset.randomSplit([0.2, 0.2, 0.2, 0.2, 0.2], seed=42)
mm_dataset_fold0.write.mode('overwrite').parquet('../data/mm_dataset_fold/fold=0')
mm_dataset_fold1.write.mode('overwrite').parquet('../data/mm_dataset_fold/fold=1')
mm_dataset_fold2.write.mode('overwrite').parquet('../data/mm_dataset_fold/fold=2')
mm_dataset_fold3.write.mode('overwrite').parquet('../data/mm_dataset_fold/fold=3')
mm_dataset_fold4.write.mode('overwrite').parquet('../data/mm_dataset_fold/fold=4')

24/09/28 15:21:22 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors
                                                                                

In [7]:
spark.stop()