In [1]:
import pandas as pd
import numpy as np
import os

import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql import types as T

os.environ['JAVA_HOME']= '/home/jovyan/conda/kalash/lib/jvm'

spark_conf = pyspark.SparkConf()
spark_conf.setMaster("local[*]").setAppName("JoinModality")
spark_conf.set("spark.driver.maxResultSize", "16g")
spark_conf.set("spark.executor.memory", "32g")
spark_conf.set("spark.executor.memoryOverhead", "16g")
spark_conf.set("spark.driver.memory", "32g")
spark_conf.set("spark.driver.memoryOverhead", "16g")
spark_conf.set("spark.cores.max", "24")
spark_conf.set("spark.sql.shuffle.partitions", "200")
spark_conf.set("spark.local.dir", "../../spark_local_dir")


spark = SparkSession.builder.config(conf=spark_conf).getOrCreate()
spark.sparkContext.getConf().getAll()

In [4]:
TRX_DATA_PATH = 'scenario_mbd/data/dataset-huggingface/ptls/trx/'
GEO_DATA_PATH = 'scenario_mbd/data/dataset-huggingface/ptls/geo/'
DIAL_DATA_PATH = 'scenario_mbd/data/dataset-huggingface/ptls/dialog/'

MM_DATA_PATH = 'scenario_mbd_data/mm_dataset/'

In [5]:
def rename_col(df, prefix, col_id='client_id'):
    new_column_names = [f"{prefix}_{col}" for col in df.columns if col != col_id]
    old_column_names = [col for col in df.columns if col != col_id]
    for old_col, new_col in zip(old_column_names, new_column_names):
        df = df.withColumnRenamed(old_col, new_col)
    return df

In [None]:
for fold in range(-1, 5):
    trx = spark.read.parquet(os.path.join(TRX_DATA_PATH, f'fold={fold}'))
    geo = spark.read.parquet(os.path.join(GEO_DATA_PATH, f'fold={fold}'))
    dial = spark.read.parquet(os.path.join(DIAL_DATA_PATH, f'fold={fold}'))
    
    trx = rename_col(trx, 'trx')
    geo = rename_col(geo, 'geo')
    dial = rename_col(dial, 'dial')
    
    mm_dataset = trx.join(geo, on='client_id', how='outer').join(dial, on='client_id', how='outer')
    mm_dataset.write.mode('overwrite').parquet(os.path.join(MM_DATA_PATH, f'fold={fold}'))
    

# Targets

In [7]:
from ptls.preprocessing import PysparkDataPreprocessor


libgomp: Invalid value for environment variable OMP_NUM_THREADS

libgomp: Invalid value for environment variable OMP_NUM_THREADS


In [8]:
TARGETS_DATA_PATH = 'scenario_mbd/data/dataset-huggingface/targets'

In [9]:
preprocessor_target = PysparkDataPreprocessor(
    col_id="client_id",
    col_event_time="mon",
    event_time_transformation="dt_to_timestamp",
    cols_identity=["target_1", "target_2", "target_3", "target_4"],
)

In [10]:
for fold in range(5):
    targets = spark.read.parquet(os.path.join(TARGETS_DATA_PATH , f'fold={fold}'))
    mm_dataset = spark.read.parquet(os.path.join(MM_DATA_PATH , f'fold={fold}'))
    
    targets = preprocessor_target.fit_transform(targets)
    mm_dataset = mm_dataset.join(targets, on='client_id', how='left').drop(*['event_time', 'trans_count', 'diff_trans_date'])
    mm_dataset.write.parquet(os.path.join(MMT_DATA_PATH, f'fold={fold}'))
    
    

                                                                                

In [11]:
spark.stop()