# Mean encoding in Spark SQL on GPU

In [1]:
import json
import os
import sys

from argparse import ArgumentParser
from collections import OrderedDict
from contextlib import contextmanager
from operator import itemgetter
from time import time
import os

from pyspark import broadcast, SparkConf
from pyspark.sql import Row, SparkSession, Window
from pyspark.sql.functions import *
from pyspark.sql.types import *


In [2]:
def assign_id_with_window(df):
    windowed = Window.partitionBy('column_id').orderBy(desc('count'))
    return (df
            .withColumn('id', row_number().over(windowed))
            .withColumnRenamed('count', 'model_count'))


def assign_low_mem_partial_ids(df):
    # To avoid some scaling issues with a simple window operation, we use a more complex method
    # to compute the same thing, but in a more distributed spark specific way
    df = df.orderBy(asc('column_id'), desc('count'))
    # The monotonically_increasing_id is the partition id in the top 31 bits and the rest
    # is an increasing count of the rows within that partition.  So we split it into two parts,
    # the partion id part_id and the count mono_id
    df = df.withColumn('part_id', spark_partition_id())
    return df.withColumn('mono_id', monotonically_increasing_id() - shiftLeft(col('part_id'), 33))


def assign_low_mem_final_ids(df):
    # Now we can find the minimum and maximum mono_ids within a given column/partition pair
    sub_model = df.groupBy('column_id', 'part_id').agg(max('mono_id').alias('top'), min('mono_id').alias('bottom'))
    sub_model = sub_model.withColumn('diff', col('top') - col('bottom') + 1)
    sub_model = sub_model.drop('top')
    # This window function is over aggregated column/partition pair table. It will do a running sum of the rows
    # within that column
    windowed = Window.partitionBy('column_id').orderBy('part_id').rowsBetween(Window.unboundedPreceding, -1)
    sub_model = sub_model.withColumn('running_sum', sum('diff').over(windowed)).na.fill(0, ["running_sum"])

    joined = df.withColumnRenamed('column_id', 'i_column_id')
    joined = joined.withColumnRenamed('part_id', 'i_part_id')
    joined = joined.withColumnRenamed('count', 'model_count')

    # Then we can join the original input with the pair it is a part of
    joined = joined.join(sub_model, (col('i_column_id') == col('column_id')) & (col('part_id') == col('i_part_id')))

    # So with all that we can subtract bottom from mono_id makeing it start at 0 for each partition
    # and then add in the running_sum so the id is contiguous and unique for the entire column. + 1 to make it match the 1 based indexing
    # for row_number
    ret = joined.select(col('column_id'),
                        col('data'),
                        (col('mono_id') - col('bottom') + col('running_sum') + 1).cast(IntegerType()).alias('id'),
                        col('model_count'))
    return ret

In [3]:
def get_column_models(combined_model):
    for i in CAT_COLS:
        model = (combined_model
            .filter('column_id == %d' % (i - CAT_COLS[0]))
            .drop('column_id'))
        yield i, model


def col_of_rand_long():
    return (rand() * (1 << 52)).cast(LongType())

def skewed_join(df, model, col_name, cutoff):
    # Most versions of spark don't have a good way
    # to deal with a skewed join out of the box.
    # Some do and if you want to replace this with
    # one of those that would be great.
    
    # Because we have statistics about the skewedness
    # that we can used we divide the model up into two parts
    # one part is the highly skewed part and we do a
    # broadcast join for that part, but keep the result in
    # a separate column
    b_model = broadcast(model.filter(col('model_count') >= cutoff)
            .withColumnRenamed('data', col_name)
            .drop('model_count'))
    
    df = (df
            .join(b_model, col_name, how='left')
            .withColumnRenamed('id', 'id_tmp'))
    
    # We also need to spread the skewed data that matched
    # evenly.  We will use a source of randomness for this
    # but use a -1 for anything that still needs to be matched
    if 'ordinal' in df.columns:
        rand_column = col('ordinal')
    else:
        rand_column = col_of_rand_long()

    df = df.withColumn('join_rand',
            # null values are not in the model, they are filtered out
            # but can be a source of skewedness so include them in
            # the even distribution
            when(col('id_tmp').isNotNull() | col(col_name).isNull(), rand_column)
            .otherwise(lit(-1)))
    
    # Null out the string data that already matched to save memory
    df = df.withColumn(col_name,
            when(col('id_tmp').isNotNull(), None)
            .otherwise(col(col_name)))
    
    # Now do the second join, which will be a non broadcast join.
    # Sadly spark is too smart for its own good and will optimize out
    # joining on a column it knows will always be a constant value.
    # So we have to make a convoluted version of assigning a -1 to the
    # randomness column for the model itself to work around that.
    nb_model = (model
            .withColumn('join_rand', when(col('model_count') < cutoff, lit(-1)).otherwise(lit(-2)))
            .filter(col('model_count') < cutoff)
            .withColumnRenamed('data', col_name)
            .drop('model_count'))
    
    df = (df
            .join(nb_model, ['join_rand', col_name], how='left')
            .drop(col_name, 'join_rand')
            # Pick either join result as an answer
            .withColumn(col_name, coalesce(col('id'), col('id_tmp')))
            .drop('id', 'id_tmp'))

    return df


def apply_models(df, models, broadcast_model = False, skew_broadcast_pct = 1.0):
    # sort the models so broadcast joins come first. This is
    # so we reduce the amount of shuffle data sooner than later
    # If we parsed the string hex values to ints early on this would
    # not make a difference.
    models = sorted(models, key=itemgetter(3), reverse=True)
    for i, model, original_rows, would_broadcast in models:
        col_name = '_c%d' % i
        if not (would_broadcast or broadcast_model):
            # The data is highly skewed so we need to offset that
            cutoff = int(original_rows * skew_broadcast_pct/100.0)
            df = skewed_join(df, model, col_name, cutoff)
        else:
            # broadcast joins can handle skewed data so no need to
            # do anything special
            model = (model.drop('model_count')
                          .withColumnRenamed('data', col_name))
            model = broadcast(model) if broadcast_model else model
            df = (df
                .join(model, col_name, how='left')
                .drop(col_name)
                .withColumnRenamed('id', col_name))
    return df.fillna(0, ['_c%d' % i for i in CAT_COLS])


def transform_log(df, transform_log = False):
    cols = ['_c%d' % i for i in INT_COLS]
    if transform_log:
        for col_name in cols:
            df = df.withColumn(col_name, log(df[col_name] + 3))
    return df.fillna(0, cols)

def would_broadcast(spark, str_path):
    sc = spark.sparkContext
    config = sc._jsc.hadoopConfiguration()
    path = sc._jvm.org.apache.hadoop.fs.Path(str_path)
    fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(config)
    stat = fs.listFiles(path, True)
    sum = 0
    while stat.hasNext():
       sum = sum + stat.next().getLen()
    sql_conf = sc._jvm.org.apache.spark.sql.internal.SQLConf()
    cutoff = sql_conf.autoBroadcastJoinThreshold() * sql_conf.fileCompressionFactor()
    return sum <= cutoff

def delete_data_source(spark, path):
    sc = spark.sparkContext
    config = sc._jsc.hadoopConfiguration()
    path = sc._jvm.org.apache.hadoop.fs.Path(path)
    sc._jvm.org.apache.hadoop.fs.FileSystem.get(config).delete(path, True)
    
def load_raw(spark, folder, day_range):
    label_fields = [StructField('_c%d' % LABEL_COL, IntegerType())]
    int_fields = [StructField('_c%d' % i, IntegerType()) for i in INT_COLS]
    str_fields = [StructField('_c%d' % i, StringType()) for i in CAT_COLS]

    schema = StructType(label_fields + int_fields + str_fields)
    paths = [os.path.join(folder, 'day_%d' % i) for i in day_range]
    return (spark
        .read
        .schema(schema)
        .option('sep', '\t')
        .csv(paths))

def rand_ordinal(df):
    # create a random long from the double precision float.  
    # The fraction part of a double is 52 bits, so we try to capture as much
    # of that as possible
    return df.withColumn('ordinal', col_of_rand_long())

def day_from_ordinal(df, num_days):
    return df.withColumn('day', (col('ordinal') % num_days).cast(IntegerType()))

def day_from_input_file(df):
    return df.withColumn('day', substring_index(input_file_name(), '_', -1).cast(IntegerType()))

def psudo_sort_by_day_plus(spark, df, num_days):
    # Sort is very expensive because it needs to calculate the partitions
    # which in our case may involve rereading all of the data.  In some cases
    # we can avoid this by repartitioning the data and sorting within a single partition
    shuffle_parts = int(spark.conf.get('spark.sql.shuffle.partitions'))
    extra_parts = int(shuffle_parts/num_days)
    if extra_parts <= 0:
        df = df.repartition('day')
    else:
        #We want to spread out the computation to about the same amount as shuffle_parts
        divided = (col('ordinal') / num_days).cast(LongType())
        extra_ident = divided % extra_parts
        df = df.repartition(col('day'), extra_ident)
    return df.sortWithinPartitions('day', 'ordinal')


def load_combined_model(spark, model_folder):
    path = os.path.join(model_folder, 'combined.parquet')
    return spark.read.parquet(path)


def save_combined_model(df, model_folder, mode=None):
    path = os.path.join(model_folder, 'combined.parquet')
    df.write.parquet(path, mode=mode)


def delete_combined_model(spark, model_folder):
    path = os.path.join(model_folder, 'combined.parquet')
    delete_data_source(spark, path)


def load_low_mem_partial_ids(spark, model_folder):
    path = os.path.join(model_folder, 'partial_ids.parquet')
    return spark.read.parquet(path)


def save_low_mem_partial_ids(df, model_folder, mode=None):
    path = os.path.join(model_folder, 'partial_ids.parquet')
    df.write.parquet(path, mode=mode)


def delete_low_mem_partial_ids(spark, model_folder):
    path = os.path.join(model_folder, 'partial_ids.parquet')
    delete_data_source(spark, path)


def load_column_models(spark, model_folder, count_required):
    for i in CAT_COLS:
        path = os.path.join(model_folder, '%d.parquet' % i)
        df = spark.read.parquet(path)
        if count_required:
            values = df.agg(sum('model_count').alias('sum'), count('*').alias('size')).collect()
        else:
            values = df.agg(sum('model_count').alias('sum')).collect()
        yield i, df, values[0], would_broadcast(spark, path)

def save_column_models(column_models, model_folder, mode=None):
    for i, model in column_models:
        path = os.path.join(model_folder, '%d.parquet' % i)
        model.write.parquet(path, mode=mode)


def save_model_size(model_size, path, write_mode):
    if os.path.exists(path) and write_mode == 'errorifexists':
        print('Error: model size file %s exists' % path)
        sys.exit(1)

    os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
    with open(path, 'w') as fp:
        json.dump(model_size, fp, indent=4)

In [4]:

_benchmark = {}


@contextmanager
def _timed(step):
    start = time()
    yield
    end = time()
    _benchmark[step] = end - start

### use ratio instead of count

In [5]:
LABEL_COL = 0
INT_COLS = list(range(1, 14))
CAT_COLS = list(range(14, 40))
LABEL_COL_NAME="_c0"
def get_column_counts_with_frequency_limit(df, frequency_limit):
    cols = ['_c%d' % i for i in CAT_COLS]
    
    df = (df.select(posexplode(array(*cols)), LABEL_COL_NAME)
          .withColumnRenamed('pos', 'column_id')
          .withColumnRenamed('col', 'data')
          .filter('data is not null')
          .groupBy('column_id', 'data')
          .agg(sum("_c0").alias("num_pos"),
               count('*').alias("count"))
          .withColumn('ratio',(col('num_pos') / col('count')).cast('float'))
          .drop('num_pos'))
    
    default_limit = int(frequency_limit)
    df = df.filter((col('count') > default_limit) | (col('count') == default_limit))
    df = df.drop('count').withColumnRenamed('ratio', 'count')
    return df

In [6]:
def _main(spark, mode):
    write_mode = "overwrite"
    # You need to update them to your real paths!
    dataRoot = os.getenv("DATA_ROOT", "/data")
    input_folder = dataRoot + "/criteo/model_test/input"
    model_folder = dataRoot + "/criteo/model_test/model"
    output_folder = dataRoot + "/criteo/model_test/output"
    
    # this range is the input data range, like day_0, day_1, .... , start with 1.
    # we use 5 days datasets as an example, user can change from 1 to 23 days datasets instead.
    day_range = 5
    dict_build_shuffle_parallel_per_day = 2
    apply_shuffle_parallel_per_day = 25
    
    origin_df = load_raw(spark, input_folder, range(day_range))
    
    if mode == 'generate models':
        spark.conf.set('spark.sql.shuffle.partitions', day_range*dict_build_shuffle_parallel_per_day)
        with _timed('generate models'):
            print('generate models')
            col_counts = get_column_counts_with_frequency_limit(origin_df, 100)
            # in low memory mode we have to save an intermediate result
            # because if we try to do it in one query spark ends up assigning the
            # partial ids in two different locations that are not guaranteed to line up
            # this prevents that from happening by assigning the partial ids
            # and then writeing them out.
            save_low_mem_partial_ids(
                    assign_low_mem_partial_ids(col_counts),
                    model_folder,
                    write_mode)
            save_combined_model(
                    assign_low_mem_final_ids(load_low_mem_partial_ids(spark, model_folder)),
                    model_folder,
                    write_mode)
            save_column_models(
                get_column_models(load_combined_model(spark, model_folder)),
                model_folder,
                write_mode)

        
    if mode == 'transform':
        spark.conf.set('spark.sql.shuffle.partitions', day_range*apply_shuffle_parallel_per_day)
        with _timed('transform'):
            print("transform")
            origin_df = origin_df.withColumn('ordinal', monotonically_increasing_id())

            models = list(load_column_models(spark, model_folder, False))
            models = [(i, df, agg.sum, flag) for i, df, agg, flag in models]
          
            df = apply_models(origin_df,models,False,1.0)
            df = transform_log(df, False)
            df = df.orderBy('ordinal')
            df = df.drop('ordinal')
            df = df.drop('day')

            print("writing final output")
            df.write.parquet(
                output_folder,
                mode=write_mode)

    print('=' * 100)
    print(_benchmark)


if __name__ == '__main__':
    conf = SparkConf().setAppName('ETL')\
            .set('spark.executor.memory', '20g')\
            .set('spark.executor.resource.gpu.amount', 1)\
            .set('spark.executor.cores', 2)\
            .set('spark.cores.max', 2)\
            .set('spark.task.cpus', 1)\
            .set('spark.task.resource.gpu.amount', 0.5)\
            .set('spark.rapids.sql.concurrentGpuTasks', 1)

    spark = SparkSession.builder.config(conf=conf).getOrCreate()
    # generate model for cat columns
    _main(spark, "generate models")
    # transform data with model and save to new parquet
    _main(spark, "transform")
    spark.stop()

generate models
{'generate models': 11.397159337997437}
transform
writing final output
{'generate models': 11.397159337997437, 'transform': 4.923393726348877}
