To use spark, set up your environment for it

In [None]:
import os
import sys

SPARK_HOME = '...'

os.environ['SPARK_HOME'] = SPARK_HOME
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable
os.environ['PYSPARK_PYTHON'] = sys.executable

In [None]:
SparkAppName = '...'

from pyspark import SparkContext, SparkConf, HiveContext

from pyspark.sql import SparkSession
import pyspark.sql.functions as F
import pyspark.sql.types as T
import numpy as np
import pandas as pd
from pyspark.ml.functions import array_to_vector, vector_to_array
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.feature import PCA, PCAModel

conf = [
    ('spark.driver.memory', '10g'),
    ('spark.driver.maxResultSize', '8g'),
    ('spark.executor.cores', '3'), 
    ('spark.executor.memory', '10g'), 
    ('spark.executor.memoryOverhead', '4g'),
    ('spark.sql.broadcastTimeout', 3000),
    ('spark.sql.autoBroadcastJoinThreshold', -1),
    ('spark.sql.adaptive.autoBroadcastJoinThreshold', -1),
    ('spark.shuffle.service.enabled', 'true'),
    ('spark.dynamicAllocation.enabled', 'true'),
    ('spark.dynamicAllocation.executorIdleTimeout', '120s'),
    ('spark.dynamicAllocation.cachedExecutorIdleTimeout', '600s'),
    ('spark.dynamicAllocation.initialExecutors', '1'),
    ('spark.dynamicAllocation.maxExecutors', '150'),
    ('spark.port.maxRetries', '150'),
]

spark = SparkSession.builder
for k, v in conf:
    spark = spark.config(k, v)
spark = spark.appName(SparkAppName).master("yarn").getOrCreate()
sqlc = HiveContext(spark.sparkContext)

# Matching

In [6]:
dialogs = 'dialog_embeddings'
trx_general_path = 'trx_embeddings'
geo_general_path = 'geo_embeddings'

In [7]:
sources = {
    'dialog': [None, 'mean_agg', 'last_agg'],
    'trx': [None, 'baseline_agg', 'baseline_gpt', 'baseline_tabformer', 'baseline'],
    'geo': [None, 'baseline_agg', 'baseline_gpt', 'baseline_tabformer', 'baseline']
}

def get_paths(a,b,c):
    paths = list()
    if sources['dialog'][a] is not None:
        if sources['dialog'][a] == 'mean_agg':
            paths.append(dialogs + 'mean_agg')
        elif sources['dialog'][a] == 'last_agg':
            paths.append(dialogs + 'last_agg')
    if sources['trx'][b] is not None:
        paths.append('_'.join([trx_general_path, sources['trx'][b]]))
    if sources['geo'][c] is not None:
        paths.append('_'.join([geo_general_path, sources['geo'][c]]))
    return paths

In [8]:
from pyspark.sql import Window, Row
from random import randint




def concat_emb_cols(df):
    n = max([int(x.split('_')[1]) for x in df.schema.fieldNames() if x.startswith('emb')]) + 1
    return df.select('client_id', F.array(*[F.col('emb_' + ('0000' + str(i))[-4:]) for i in range(n)]).alias('embedding'))

def get_name(a, b, c):
    dialog = sources['dialog'][a]
    trx = sources['trx'][b]
    geo = sources['geo'][c]
    name = '+'.join(['_'.join([s, x]) for s, x in zip(['dialog', 'trx', 'geo'], [dialog, trx, geo]) if x is not None])
    return name

def udf_randint():
    return randint(0, 199)
udf_randint = F.udf(udf_randint, T.IntegerType())

def goida_filter(emb):
    for x in emb:
        if (abs(x) == float('inf')) or (x is None):
            return False
    return True
udf_goida_filter = F.udf(goida_filter, T.BooleanType())

def get_cat_embs_f(defaultl1, defaultl2):
    def cat_embeddings(x, y):
        if x is not None and y is not None:
            return x + y
        elif x is not None:
            return x + [0. for _ in range(defaultl2)]
        elif y is not None:
            return [0. for _ in range(defaultl1)] + y
        else:
            return [0. for _ in range(defaultl1+defaultl2)]
    udf_cat_embeddings = F.udf(cat_embeddings, T.ArrayType(T.FloatType()))
    return udf_cat_embeddings


def load_dfs(path):
    train_path = '_'.join([path, 'train.parquet'])
    test_path = '_'.join([path, 'test.parquet'])
    train_df = concat_emb_cols(sqlc.read.parquet(train_path))
    test_df = concat_emb_cols(sqlc.read.parquet(test_path))

    train_df = train_df.withColumn('marker', udf_goida_filter(F.col('embedding')))\
                       .filter(F.col('marker') == True)\
                       .select('client_id', 'embedding')
    test_df = test_df.withColumn('marker', udf_goida_filter(F.col('embedding')))\
                     .filter(F.col('marker') == True)\
                     .select('client_id', 'embedding')
    return train_df, test_df


def late_merge_modalities(df1, df2, udf_cat_embeddings):
    merged_data = df1.select('client_id', F.col('embedding').alias('embedding_1'))\
    .join(df2.select('client_id', F.col('embedding').alias('embedding_2')), on='client_id', how='inner')\
    .withColumn('embedding', udf_cat_embeddings(F.col('embedding_1'), F.col('embedding_2')))\
    .select('client_id', 'embedding', 'embedding_1', 'embedding_2')
    return merged_data


def split_paired_rows(itr):
    for x in itr:
        x = x.asDict()
        x1 = {'client_id_1': x['client_id_a'],
              'client_id_2': x['client_id_b'],
              'embedding': x['embedding_1_a']+x['embedding_2_b'],
              'target': 0}
        x2 = {'client_id_1': x['client_id_b'],
              'client_id_2': x['client_id_a'],
              'embedding': x['embedding_1_b'] + x['embedding_2_a'],
              'target': 0}
        for xi in [x1, x2]:
            yield Row(**xi)

def pair_rows(itr):
    k = -1
    for x in itr:
        k += 1
        if k % 2 == 0:
            x1 = x
        else:
            yield Row(**{k + '_a' : v for k, v in x1.asDict().items()}, **{k + '_b' : v for k, v in x.asDict().items()})


def get_all_combs(na, nb, nc, done={(0,0,0)}):
    for i in range(na):
        for j in range(nb):
            for k in range(nc):
                num_zeros = sum([x==0 for x in [i,j,k]])
                if ((i,j,k) not in done) and ((j==k) or (j==0) or (k==0)) and (num_zeros==1):
                    yield i, j, k


def load_data_to_tmp(paths):
    assert len(paths) == 2
    train_1, test_1 = load_dfs(paths[0])
    train_2, test_2 = load_dfs(paths[1])
    
    defaultl1 = len(train_1.select('embedding').limit(1).collect()[0]['embedding'])
    defaultl2 = len(train_2.select('embedding').limit(1).collect()[0]['embedding'])
    udf_cat_embeddings = get_cat_embs_f(defaultl1, defaultl2)
    
    train_df = late_merge_modalities(train_1, train_2, udf_cat_embeddings)
    train_df.withColumn('target', F.lit(1).cast(T.IntegerType()))\
            .select(F.col('client_id').alias('client_id_1'), F.col('client_id').alias('client_id_2'), 'embedding', 'target')\
            .write.mode('overwrite')\
            .parquet('temp_train_true.parquet')
    train_df = train_df.withColumn('group', udf_randint()).repartition(200, 'group')\
                       .rdd.mapPartitions(pair_rows).mapPartitions(split_paired_rows).toDF()\
                       .write.mode('overwrite')\
                       .parquet('temp_train_false.parquet')
    
    test_df = late_merge_modalities(test_1, test_2, udf_cat_embeddings)
    test_df.withColumn('target', F.lit(1).cast(T.IntegerType()))\
           .select(F.col('client_id').alias('client_id_1'), F.col('client_id').alias('client_id_2'), 'embedding', 'target')\
           .write.mode('overwrite')\
           .parquet('temp_test_true.parquet')
    test_df = test_df.withColumn('group', udf_randint()).repartition(200, 'group')\
                     .rdd.mapPartitions(pair_rows).mapPartitions(split_paired_rows).toDF()\
                     .write.mode('overwrite')\
                     .parquet('temp_test_false.parquet')


def solve_clf(name):
    train_true = sqlc.read.parquet('temp_train_true.parquet')
    train_false = sqlc.read.parquet('temp_train_false.parquet')
    train = train_true.union(train_false).withColumn('embedding', array_to_vector(F.col("embedding")))
    train.repartition(300, 'client_id_2').write.mode('overwrite')\
         .parquet('temp_train.parquet')
    train = sqlc.read.parquet('temp_train.parquet')

    clf = GBTClassifier(labelCol='target', featuresCol='embedding',
                        stepSize=0.02, maxDepth=6, minInstancesPerNode=50)
    evaluator = BinaryClassificationEvaluator(labelCol='target', metricName='areaUnderROC')

    gbModel = clf.fit(train)
    
    
    test_true = sqlc.read.parquet('temp_test_true.parquet')
    test_false = sqlc.read.parquet('temp_test_false.parquet')
    test = test_true.union(test_false).withColumn('embedding', array_to_vector(F.col("embedding")))
    
    gbPredictions = gbModel.transform(test)
    gbPredictions.select('client_id_1', 'client_id_2', 'target', 'rawPrediction', 'probability')\
                 .write.mode('overwrite').parquet('matching_'+name+'.parquet')
    
    score = evaluator.evaluate(gbPredictions)
    
    gbModel.save('matching_'+name+'.gbtc_model')
    
    return score


def write_results(name, score, file='results_matching.csv'):
    with open(file, 'a') as f:
        f.write(','.join([name, str(score)])+'\n')

In [None]:
g = get_all_combs(3,5,5, done={(0,0,0), (2, 4, 0)})

for k in g:
    name = get_name(*k)
    paths = get_paths(*k)
    load_data_to_tmp(paths)
    mean_score = solve_clf(name)
    write_results(name, mean_score)