Skip to content
This repository has been archived by the owner. It is now read-only.
Permalink
Browse files
Update main_tfrecord_generator.py
Replace pandas with spark
  • Loading branch information
radibnia77 committed Feb 2, 2022
1 parent 40da3b1 commit 4ee29bcd32ed9780818b14b0bed3de7c96cd468e
Showing 1 changed file with 48 additions and 65 deletions.
@@ -16,31 +16,17 @@
# under the License.

"""
spark-submit --master yarn --executor-memory 16G --driver-memory 24G --num-executors 10 --executor-cores 5 --jars spark-tensorflow-connector_2.11-1.15.0.jar --conf spark.hadoop.hive.exec.dynamic.partition=true --conf spark.hadoop.hive.exec.dynamic.partition.mode=nonstrict pipeline/_main_trainready_india.py config.yml
spark-submit --master yarn --executor-memory 16G --driver-memory 24G --num-executors 10 --executor-cores 5 --jars spark-tensorflow-connector_2.11-1.15.0.jar --conf spark.hadoop.hive.exec.dynamic.partition=true --conf spark.hadoop.hive.exec.dynamic.partition.mode=nonstrict pipeline/main_trainready_generator.py config.yml
input: trainready table
output: dataset readable by trainer in tfrecord format
"""

import yaml
import argparse
import os
import timeit
from pyspark import SparkContext
from pyspark.sql import functions as fn
from pyspark.sql.functions import lit, col, udf, collect_list, concat_ws, first, create_map, monotonically_increasing_id
from pyspark.sql.functions import count, lit, col, udf, expr, collect_list, explode
from pyspark.sql.window import Window
from pyspark.sql.types import IntegerType, ArrayType, StringType,BooleanType
from pyspark.sql import HiveContext
from pyspark.sql.session import SparkSession
from datetime import datetime, timedelta
from lookalike_model.pipeline.util import write_to_table, write_to_table_with_partition, print_batching_info, resolve_placeholder, load_config, load_batch_config, load_df
from itertools import chain
from pyspark.sql.types import IntegerType, ArrayType, StringType, BooleanType, FloatType, DoubleType
from util import write_to_table, write_to_table_with_partition, save_pickle_file


def generate_tfrecord(sc, hive_context, tf_statis_path, keyword_table, cutting_date, length, trainready_table, tfrecords_hdfs_path_train, tfrecords_hdfs_path_test):
from pyspark.sql.functions import lit, udf, explode
from pyspark.sql.types import IntegerType, ArrayType, StructType, StructField
from util import save_pickle_file, resolve_placeholder, load_config


def generate_tfrecord(hive_context, tf_stat_path, keyword_table, cutting_date, length, trainready_table, tfrecords_hdfs_path_train, tfrecords_hdfs_path_test):

def str_to_intlist(table):
ji = []
@@ -59,55 +45,51 @@ def list_of_list_toint(table):
ji.append(s)
return ji

def flatten(lst):
f = [y for x in lst for y in x]
return f

def padding(kwlist,length):
diff = length-len(kwlist)
print(len(kwlist))
print(length)
print(diff)
temp_list = [0 for i in range(diff)]
padded_keyword = kwlist + temp_list
return padded_keyword

def create_dataset(df_panda ,click, keyword):
t_set = []
for i in range(len(df_panda.aid_index)):
click_counts = click[i]
keyword_int = keyword[i]
aid_index = df_panda.aid_index[i]
for m in range(len(click_counts)):
for n in range(len(click_counts[m])):
if (click_counts[m][n] != 0):
pos = (aid_index, flatten(keyword_int[m + 1:m + 1 + length]), keyword_int[m][n], 1)
if len(pos[1]) >= 1:
t_set.append(pos)
elif (m % 5 == 0 and n % 2 == 0):
neg = (aid_index, flatten(keyword_int[m + 1:m + 1 + length]), keyword_int[m][n], 0)
if len(neg[1]) >= 1:
t_set.append(neg)
return t_set

def generating_dataframe(dataset, spark ):
data_set = [(int(tup[0]), tup[1], int(tup[2]), int(tup[3])) for tup in dataset]
df = spark.createDataFrame(data=data_set, schema=deptColumns)
def generating_dataframe(df ):
df = df.withColumn("sl", udf(lambda x: len(x), IntegerType())(df.keyword_list))
df = df.where(df.sl > 5)
df = df.withColumn('max_length', lit(df.agg({'sl': 'max'}).collect()[0][0]))
df = df.withColumn('keyword_list_padded',
udf(padding, ArrayType(IntegerType()))(df.keyword_list, df.max_length))
return df

def generate_tf_statistics(testsetDF, trainDF, keyword_df, tf_statis_path):
def generate_tf_statistics(testsetDF, trainDF, keyword_df, tf_stat_path):
tfrecords_statistics = {}
tfrecords_statistics['test_dataset_count'] = testsetDF.count()
tfrecords_statistics['train_dataset_count'] = trainDF.count()
tfrecords_statistics['user_count'] = trainDF.select('aid').distinct().count()
tfrecords_statistics['item_count'] = keyword_df.distinct().count() + 1
save_pickle_file(tfrecords_statistics, tf_statis_path)
save_pickle_file(tfrecords_statistics, tf_stat_path)

def create_trainset(aid_index, click_counts, keyword_int):
def flatten(lst):
f = [y for x in lst for y in x]
return f
t_set = []
for m in range(len(click_counts)):
for n in range(len(click_counts[m])):
if (click_counts[m][n] != 0):
pos = (aid_index, flatten(keyword_int[m + 1:m + 1 + length]), keyword_int[m][n], 1)
if len(pos[1]) >= 1:
t_set.append(pos)
elif (m % 5 == 0 and n % 2 == 0):
neg = (aid_index, flatten(keyword_int[m + 1:m + 1 + length]), keyword_int[m][n], 0)
if len(neg[1]) >= 1:
t_set.append(neg)
return t_set

schema = StructType([
StructField("aid_index", IntegerType(), True),
StructField("keyword_list", ArrayType(IntegerType()), True),
StructField("keyword", IntegerType(), True),
StructField("label", IntegerType(), True)
])

command = """SELECT * FROM {}"""
df = hive_context.sql(command.format(trainready_table))
@@ -122,27 +104,28 @@ def generate_tf_statistics(testsetDF, trainDF, keyword_df, tf_statis_path):
df = df.withColumn('keyword_int_test', udf(lambda x, y: x[:y],ArrayType(ArrayType(IntegerType())))(df._kwi, df.indicing))
df = df.withColumn('click_counts_train', udf(lambda x, y: x[y:],ArrayType(ArrayType(IntegerType())))(df.click_counts, df.indicing))
df = df.withColumn('click_counts_test', udf(lambda x, y: x[:y],ArrayType(ArrayType(IntegerType())))(df.click_counts, df.indicing))
df = df.withColumn('train_set', udf(create_trainset, ArrayType(schema))(df.aid_index, df.click_counts_train,df.keyword_int_train))
df = df.withColumn('test_set', udf(create_trainset, ArrayType(schema))(df.aid_index, df.click_counts_test, df.keyword_int_test))
trainDF = df.select(df.aid_index, explode(df.train_set).alias('dataset'))
testDF = df.select(df.aid_index, explode(df.test_set).alias('dataset'))

spark = SparkSession(sc)
deptColumns = ["aid", "keyword_list", "keyword", "label"]
train_set = trainDF.select('aid_index', trainDF.dataset['aid_index'].alias('aid'), trainDF.dataset['keyword_list'].alias('keyword_list'), trainDF.dataset['keyword'].alias('keyword'), trainDF.dataset['label'].alias('label'))
test_set = testDF.select('aid_index', testDF.dataset['aid_index'].alias('aid'), testDF.dataset['keyword_list'].alias('keyword_list'), testDF.dataset['keyword'].alias('keyword'), testDF.dataset['label'].alias('label'))

df_panda = df.select('click_counts_train', 'keyword_int_train', 'aid_index').toPandas()
train_set = create_dataset(df_panda,df_panda.click_counts_train, df_panda.keyword_int_train)
trainDF = generating_dataframe(train_set, spark = spark)
trainDF.write.format("tfrecords").option("recordType", "Example").mode("overwrite").save(tfrecords_hdfs_path_train)
train_set = generating_dataframe(train_set)
train_set.write.option("header", "true").option("encoding", "UTF-8").mode("overwrite").format('hive').saveAsTable(tfrecords_hdfs_path_train)

train_set.write.format("tfrecords").option("recordType", "Example").mode("overwrite").save(tfrecords_hdfs_path_train)

df_panda = df.select('click_counts_test', 'keyword_int_test', 'aid_index').toPandas()
test_set = create_dataset(df_panda, df_panda.click_counts_test, df_panda.keyword_int_test)
testsetDF = generating_dataframe(test_set, spark = spark)
testsetDF = generating_dataframe(test_set)
testsetDF.write.format("tfrecords").option("recordType", "Example").mode("overwrite").save(tfrecords_hdfs_path_test)


command = "SELECT * from {}"
keyword_df = hive_context.sql(command.format(keyword_table))
generate_tf_statistics(testsetDF, trainDF, keyword_df, tf_statis_path)
generate_tf_statistics(testsetDF, trainDF, keyword_df, tf_stat_path)

def run(sc, hive_context, cfg):
def run(hive_context, cfg):
cfgp = cfg['pipeline']
cfg_train = cfg['pipeline']['main_trainready']
trainready_table = cfg_train['trainready_output_table']
@@ -151,11 +134,11 @@ def run(sc, hive_context, cfg):
tfrecords_hdfs_path_test = cfg_tfrecord['tfrecords_hdfs_path_test']
cutting_date = cfg['pipeline']['cutting_date']
length = cfg['pipeline']['length']
tf_statis_path = cfgp['tfrecords']['tfrecords_statistics_path']
tf_stat_path = cfgp['tfrecords']['tfrecords_statistics_path']
keyword_table = cfgp['main_keywords']['keyword_output_table']


generate_tfrecord(sc, hive_context, tf_statis_path, keyword_table, cutting_date, length, trainready_table, tfrecords_hdfs_path_train, tfrecords_hdfs_path_test)
generate_tfrecord(hive_context, tf_stat_path, keyword_table, cutting_date, length, trainready_table, tfrecords_hdfs_path_train, tfrecords_hdfs_path_test)


if __name__ == "__main__":
@@ -166,5 +149,5 @@ def run(sc, hive_context, cfg):
"""
sc, hive_context, cfg = load_config(description="pre-processing train ready data")
resolve_placeholder(cfg)
run(sc=sc, hive_context=hive_context, cfg=cfg)
run(hive_context=hive_context, cfg=cfg)
sc.stop()

0 comments on commit 4ee29bc

Please sign in to comment.