In [3]:
#%pip install graphframes

In [2]:
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import VectorAssembler
from pyspark.sql.functions import *
from pyspark.sql.types import *
from graphframes import GraphFrame

In [4]:
%run read_file.ipynb

In [7]:
def de_dupe():
    #read in first events and then drop bi-directional events
    window = Window.partitionBy("relationship").orderBy("timestamp")
    df = readFirstEvents().sort("timestamp").withColumn("relationship", sort_array(array("actorID", "objectID")) )  \
                .withColumn("row_num", row_number().over(window)) \
                .filter(col("row_num") == 1) \
                .drop("row_num") \
                .drop("relationship").cache()
    
    #df = df.where(col("event_day") == day).cache()
    
    #create file extnsion columns
    df = df.withColumn("file_path_ext", regexp_extract("file_path", "\.[0-9a-z]+$", 0))
    
    #filter file extensions
    file_ext = ['None', 'empty', '.dll', '.pyc', '.bat', '.pf', '.ini', '.exe', '.tmp']
    df = df.withColumn("file_path_ext", when(col("file_path_ext").isNull(),"None").otherwise(col("file_path_ext"))) \
    .withColumn("file_path_ext", when(col("file_path_ext") == '', 'empty').otherwise(col("file_path_ext"))) \
    .withColumn("file_path_ext", when(col("file_path_ext").isin(file_ext), col('file_path_ext')).otherwise('other'))
    
    #TODO, incorporate get_file_udf and then filter based on updated image path and parent path 
    list_img_short = ['None','empty','.exe','.com']
    list_img = ['svchost.exe', 'System', 'cmd.exe', 'None', 
                'conhost.exe', 'python.exe', 'csrss.exe', 'GoogleUpdate.exe', 
                'firefox.exe', 'backgroundTaskHost.exe', 'PING.EXE', 'python.EXE', 
                'geckodriver.exe', 'Explorer.EXE', 'CompatTelRunner.exe', 'taskhostw.exe']
    list_par = ['None', 'cmd.exe', 'csrss.exe', 'svchost.exe', 'GoogleUpdate.exe', 'conhost.exe']
    
    #create image path columns 
    df = df.withColumn("image_path", getFileUDF(col("image_path")))
    df = df.withColumn("parent_image_path", getFileUDF(col("parent_image_path")))
    
    #extract the right elements from the image and parent image path columns
    df = df.withColumn("image_path", when(col("image_path").isNull(),"None").otherwise(col("image_path"))) \
    .withColumn("image_path", when(col("image_path") == '', 'empty').otherwise(col("image_path"))) \
    .withColumn("image_path", when(col("image_path").isin(list_img), col('image_path')).otherwise('other'))
    
    df = df.withColumn("parent_image_path", when(col("parent_image_path").isNull(),"None").otherwise(col("parent_image_path"))) \
    .withColumn("parent_image_path", when(col("parent_image_path") == '', 'empty').otherwise(col("parent_image_path"))) \
    .withColumn("parent_image_path", when(col("parent_image_path").isin(list_par), col('parent_image_path')).otherwise('other'))
    
    """
    df = df.withColumn("image_path_ext", when(col("image_path_ext").isNull(),"None").otherwise(col("image_path_ext"))) \
    .withColumn("image_path_ext", when(col("image_path_ext") == '', 'empty').otherwise(col("image_path_ext"))) \
    .withColumn("image_path_ext", when(col("image_path_ext").isin(list_img), col('image_path_ext')).otherwise('other')).cache()
    
    
    df = df.withColumn("parent_path_ext", when(col("parent_path_ext").isNull(),"None").otherwise(col("parent_path_ext"))) \
    .withColumn("parent_path_ext", when(col("parent_path_ext") == '', 'empty').otherwise(col("parent_path_ext"))).cache()
    """
    
    '''#read in pagerank data <- TODO: return to this later
    s3_url_trusted = "s3a://sapient-bucket-trusted/"
    pr = spark.read.parquet(f"{s3_url_trusted}prod/graph/pagerank")
    
    #now we need to add page rank by "union"ing page rank based on ID
    #now let's join
    df = df.sort('objectID').join(pr.sort('objectID'), on = 'objectID')'''
    
    return df

In [8]:
#function to generate a graphframe
def create_graph2(df):
    #now create graph for event traces of 2 events, 3 events
    start_time = time.time()

    #create trace matrix from malicious events for speed. 
    # Create distinct vertices with source as actorid, destination as objectid for malicious
    src_vertices = df.selectExpr('objectID as id').distinct()
    dst_vertices = df.selectExpr('actorID as id').distinct()
    vertices = src_vertices.union(dst_vertices).distinct()

    # Create edges by using timestamp as an edge
    edges = df.selectExpr('objectID as src', 'actorID as dst', 'timestamp', 'object', 'action', 'hostname', 'user_name', 'privileges', 'image_path',
                              'parent_image_path', 'new_path', 'file_path', 'direction', 'logon_id', 'requesting_domain', 'requesting_user', 'malicious',
                         'file_path_ext')#, 'pagerank')#,'parent_path_ext','image_path_ext')

    # Create GraphFrame
    g = GraphFrame(vertices, edges)
    motifs6 = g.find("(a)-[e1]->(b); (b)-[e2]->(c)")
    print("found connections: "+ str(time.time() - start_time))
    #create paths and count
    # filter paths to only those where all edges are connected
    connected_paths = motifs6.filter('''e1.timestamp <= e2.timestamp''')
    print("filtered connections: "+ str(time.time() - start_time))
    
    return connected_paths

In [31]:
#function to generate a graphframe
def create_graph3(df):
    #now create graph for event traces of 2 events, 3 events
    start_time = time.time()

    #create trace matrix from malicious events for speed. 
    # Create distinct vertices with source as actorid, destination as objectid for malicious
    src_vertices = df.selectExpr('objectID as id').distinct()
    dst_vertices = df.selectExpr('actorID as id').distinct()
    vertices = src_vertices.union(dst_vertices).distinct()

    # Create edges by using timestamp as an edge
    edges = df.selectExpr('objectID as src', 'actorID as dst', 'timestamp', 'object', 'action', 'hostname', 'user_name', 'privileges', 'image_path',
                              'parent_image_path', 'new_path', 'file_path', 'direction', 'logon_id', 'requesting_domain', 'requesting_user', 'malicious',
                         'file_path_ext')#, 'pagerank')#,'parent_path_ext','image_path_ext')

    # Create GraphFrame
    g = GraphFrame(vertices, edges)
    motifs6 = g.find("(a)-[e1]->(b); (b)-[e2]->(c); (c)-[e3]->(d)") #; (d)-[e4]->(e)") #; (e)-[e5]->(f); (f)-[e6]->(g)")
    print("found connections: "+ str(time.time() - start_time))
    #create paths and count
    # filter paths to only those where all edges are connected
    connected_paths = motifs6.filter('''e1.timestamp <= e2.timestamp and e2.timestamp <= e3.timestamp''')# and 
                                        #e3.timestamp <= e4.timestamp''').cache()# and e4.timestamp <= e5.timestamp and 
                                        #e5.timestamp <= e6.timestamp''').cache()

    #tot = connected_paths.count()
    #print("event traces: "+str(tot))
    return connected_paths

In [32]:
#function to generate a graphframe
def create_graph4(df):
    start_time = time.time()

    #create trace matrix from malicious events for speed. 
    # Create distinct vertices with source as actorid, destination as objectid for malicious
    src_vertices = df.selectExpr('objectID as id').distinct()
    dst_vertices = df.selectExpr('actorID as id').distinct()
    vertices = src_vertices.union(dst_vertices).distinct()

    # Create edges by using timestamp as an edge
    edges = df.selectExpr('objectID as src', 'actorID as dst', 'timestamp', 'object', 'action', 'hostname', 'user_name', 'privileges', 'image_path',
                              'parent_image_path', 'new_path', 'file_path', 'direction', 'logon_id', 'requesting_domain', 'requesting_user', 'malicious',
                         'file_path_ext')#, 'pagerank')#,'parent_path_ext','image_path_ext')

    # Create GraphFrame
    g = GraphFrame(vertices, edges)
    motifs6 = g.find("(a)-[e1]->(b); (b)-[e2]->(c); (c)-[e3]->(d); (d)-[e4]->(e)") #; (e)-[e5]->(f); (f)-[e6]->(g)")
    print("found connections: "+ str(time.time() - start_time))
    #create paths and count
    # filter paths to only those where all edges are connected
    connected_paths = motifs6.filter('''e1.timestamp <= e2.timestamp and e2.timestamp <= e3.timestamp and 
                                        e3.timestamp <= e4.timestamp''')# and e4.timestamp <= e5.timestamp and 
                                        #e5.timestamp <= e6.timestamp''')
    
    #tot = connected_paths.count()
    #print("event traces: "+str(tot))
    return connected_paths

In [34]:
#function to generate a graphframe
def create_graph5(df):
    start_time = time.time()

    #create trace matrix from malicious events for speed. 
    # Create distinct vertices with source as actorid, destination as objectid for malicious
    src_vertices = df.selectExpr('objectID as id').distinct()
    dst_vertices = df.selectExpr('actorID as id').distinct()
    vertices = src_vertices.union(dst_vertices).distinct()

    # Create edges by using timestamp as an edge
    edges = df.selectExpr('objectID as src', 'actorID as dst', 'timestamp', 'object', 'action', 'hostname', 'user_name', 'privileges', 'image_path',
                              'parent_image_path', 'new_path', 'file_path', 'direction', 'logon_id', 'requesting_domain', 'requesting_user', 'malicious',
                         'file_path_ext')#, 'pagerank')#,'parent_path_ext','image_path_ext')

    # Create GraphFrame
    g = GraphFrame(vertices, edges)
    motifs5 = g.find("(a)-[e1]->(b); (b)-[e2]->(c); (c)-[e3]->(d); (d)-[e4]->(e); (e)-[e5]->(f)")
    print("found connections: "+ str(time.time() - start_time))
    #create paths and count
    # filter paths to only those where all edges are connected
    connected_paths = motifs5.filter('''e1.timestamp <= e2.timestamp and e2.timestamp <= e3.timestamp and 
                                        e3.timestamp <= e4.timestamp and e4.timestamp <= e5.timestamp''')

    #tot = connected_paths.count()
    #print("event traces: "+str(tot))
    return connected_paths

In [33]:
#function to generate a graphframe
def create_graph6(df):
    start_time = time.time()

    #create trace matrix from malicious events for speed. 
    # Create distinct vertices with source as actorid, destination as objectid for malicious
    src_vertices = df.selectExpr('objectID as id').distinct()
    dst_vertices = df.selectExpr('actorID as id').distinct()
    vertices = src_vertices.union(dst_vertices).distinct()

    # Create edges by using timestamp as an edge
    edges = df.selectExpr('objectID as src', 'actorID as dst', 'timestamp', 'object', 'action', 'hostname', 'user_name', 'privileges', 'image_path',
                              'parent_image_path', 'new_path', 'file_path', 'direction', 'logon_id', 'requesting_domain', 'requesting_user', 'malicious',
                         'file_path_ext')#, 'pagerank')#,'parent_path_ext','image_path_ext')

    # Create GraphFrame
    g = GraphFrame(vertices, edges)
    motifs6 = g.find("(a)-[e1]->(b); (b)-[e2]->(c); (c)-[e3]->(d); (d)-[e4]->(e); (e)-[e5]->(f); (f)-[e6]->(g)")
    print("found connections: "+ str(time.time() - start_time))
    #create paths and count
    # filter paths to only those where all edges are connected
    connected_paths = motifs6.filter('''e1.timestamp <= e2.timestamp and e2.timestamp <= e3.timestamp and 
                                        e3.timestamp <= e4.timestamp and e4.timestamp <= e5.timestamp and 
                                        e5.timestamp <= e6.timestamp''')

    #tot = connected_paths.count()
    #print("event traces: "+str(tot))
    return connected_paths

In [31]:
#function to bin the timestamp data in the data frame
from pyspark.sql import Window
from pyspark.ml.feature import Bucketizer

def ts_diff(df):
    # Define the window specification
    window_spec = Window.partitionBy("Trace").orderBy("timestamp")

    # Calculate the timestamp difference between the current event and the preceding event
    df = df.withColumn(
        "timestamp_difference",
        coalesce(
            (col("timestamp").cast("long") - lag(col("timestamp").cast("long")).over(window_spec)) * 1000
            + (col("timestamp").cast("double") % 1 - lag(col("timestamp").cast("double")).over(window_spec) % 1) * 1000,
            lit(None).cast("double")
        )
    )

    return df

    

In [36]:
#create bins
def bin_it(df):
    
    df = df.drop("timestamp_difference")
    
    df = ts_diff(df)
    
    # if timestamp_bins exists then drop the column
    df = df.drop("timestamp_bins")
    
    # create a Bucketizer instance with quantiles
    #bins = stacked_df.approxQuantile("timestamp_difference", [0.2, 0.4, .6, .8], 0)
    
    bins_2 = [float("-inf")] + [1730209.0001106262, 8254118.999958038, 
                                21626879.999876022, 44756893.000125885] + [float("inf")]
    
    bucketizer = Bucketizer(splits=bins_2, inputCol="timestamp_difference", outputCol="timestamp_bins")
    
    # transform the DataFrame
    df = bucketizer.transform(df)

    #drop tiimestamp diff
    df = df.drop('timestamp_difference')
    
    return df

In [35]:
def bin_it_pr(df):
    
    df = df.drop('pagerank_bins')
    
    #df = df.distinct().orderBy(rand()).limit(5000000).cache()
    
    #create a bucketizer for quintiles
    bins = df.approxQuantile("pagerank", [0.2, 0.4, .6, .8], 0)
    
    bins = [float("-inf")] + bins + [float("inf")]
    
    bucketizer = Bucketizer(splits=bins, inputCol="pagerank", outputCol="pagerank_bins")
    
    # transform the DataFrame
    df = bucketizer.transform(df)
    
    #drop pagerank diff
    df = df.drop('pagerank')
    
    return df, bins
    

In [6]:
#function to accept an event trace data frame and then encode it. Ideally we run this twice, once on the malicious
#traces and again on benign traces. This implies that we run graphframes on both sets of events independently. 

#TODO: extract only the extensions, or declare file path an attribute 

def oneHotCol(df, colm, dict_mapping, cols_sparse):
    
    #now action
    #turn into numeric index before encoding
    
    num = colm+'_numeric'
    sparse = colm+'_sparse'
    indexer = StringIndexer(inputCol=colm, outputCol=num, handleInvalid="keep")
    indexer_fitted = indexer.fit(df)
    df = indexer_fitted.transform(df)

    encoder = OneHotEncoder(inputCol=num, outputCol=sparse,dropLast=False)
    encoder_fit = encoder.fit(df)
    df = encoder_fit.transform(df)
    df = df.drop(colm, num)

    #set dict to mapping
    dict_mapping[colm] = indexer_fitted.labels
    
    #add column to cols_sparse list
    cols_sparse.append(sparse)
        
    return df.cache(), dict_mapping, cols_sparse


In [35]:
#udf functions

# define a user-defined function to convert binary int array to string array
def binary_to_string_array(binary_int_array):
    string_array = []
    for i in binary_int_array:
        string_array.append(str(int(i)))
    return ''.join(string_array)

# register the user-defined function as a UDF
binary_to_string_array_udf = udf(binary_to_string_array, StringType())

def int_cast(num):
    return int(num)

int_cast_udf = udf(int_cast, IntegerType())


In [46]:
def transp_expl5(df):
    #add "trace index" to keep track of traces. important for transposing back
    df = df.withColumn("Trace", (monotonically_increasing_id() + 1))
    df = df.withColumn("Trace",concat(col("Trace"), lit("-5")))
    df = df.select("Trace", *[col for col in df.columns if col != "Trace"])

    #drop all vertices
    df = df.drop('a','b','c','d','e','f')

    #transpose rows 
    df = df.selectExpr(
        "Trace", 
        "posexplode(array(e1, e2, e3, e4, e5)) as (pos, col)"
    ).select(
        "Trace", 
        expr('''CASE pos 
        WHEN 0 THEN 'e1' 
        WHEN 1 THEN 'e2'
        WHEN 2 THEN 'e3'
        WHEN 3 THEN 'e4'
        ELSE 'e5' END''').alias("event"),
        "col"
    ).orderBy("Trace","event")

    #explode columns
    df = df.select(*df.columns, "col.*").drop('col')
    
    return df

In [44]:
def transp_expl(df):
    #add "trace index" to keep track of traces. important for transposing back
    df = df.withColumn("Trace", (monotonically_increasing_id() + 1))
    df = df.withColumn("Trace",concat(col("Trace"), lit("-6")))
    df = df.select("Trace", *[col for col in df.columns if col != "Trace"])

    #drop all vertices
    df = df.drop('a','b','c','d','e','f','g')

    #transpose rows 
    df = df.selectExpr(
        "Trace", 
        "posexplode(array(e1, e2, e3, e4, e5, e6)) as (pos, col)"
    ).select(
        "Trace", 
        expr('''CASE pos 
        WHEN 0 THEN 'e1' 
        WHEN 1 THEN 'e2'
        WHEN 2 THEN 'e3'
        WHEN 3 THEN 'e4'
        WHEN 4 THEN 'e5'
        ELSE 'e6' END''').alias("event"),
        "col"
    ).orderBy("Trace","event")

    #explode columns
    df = df.select(*df.columns, "col.*").drop('col')
    
    return df

In [26]:
def transp_expl2(df_transp):
    #add "trace index" to keep track of traces. important for transposing back
    df_transp = df_transp.withColumn("Trace", (monotonically_increasing_id() + 1))
    #add a motif identifier
    df_transp = df_transp.withColumn("Trace",concat(col("Trace"), lit("-2")))
    df_transp = df_transp.select("Trace", 
                                 *[col for col in df_transp.columns if col != "Trace"])

    #drop all vertices
    df_transp = df_transp.drop('a','b','c')

    #transpose rows 
    df_transp = df_transp.selectExpr(
        "Trace", 
        "posexplode(array(e1, e2)) as (pos, col)"
    ).select(
        "Trace", 
        expr('''CASE pos 
        WHEN 0 THEN 'e1'
        ELSE 'e2' END''').alias("event"),
        "col"
    ).orderBy("Trace","event").cache()

    #explode columns
    df_transp = df_transp.select(*df_transp.columns, "col.*").drop('col')
    
    return df_transp

In [41]:
def transp_expl3(df_transp):
    #add "trace index" to keep track of traces. important for transposing back
    df_transp = df_transp.withColumn("Trace", (monotonically_increasing_id() + 1))
    df_transp = df_transp.withColumn("Trace",concat(col("Trace"), lit("-3")))
    df_transp = df_transp.select("Trace", 
                                 *[col for col in df_transp.columns if col != "Trace"])

    #drop all vertices
    df_transp = df_transp.drop('a','b','c','d')

    #transpose rows 
    df_transp = df_transp.selectExpr(
        "Trace", 
        "posexplode(array(e1, e2, e3)) as (pos, col)"
    ).select(
        "Trace", 
        expr('''CASE pos 
        WHEN 0 THEN 'e1' 
        WHEN 1 THEN 'e2'
        ELSE 'e3' END''').alias("event"),
        "col"
    ).orderBy("Trace","event")

    #explode columns
    df_transp = df_transp.select(*df_transp.columns, "col.*").drop('col')
    
    return df_transp

In [42]:
def transp_expl4(df_transp):
    #add "trace index" to keep track of traces. important for transposing back
    df_transp = df_transp.withColumn("Trace", (monotonically_increasing_id() + 1))
    df_transp = df_transp.withColumn("Trace",concat(col("Trace"), lit("-4")))
    df_transp = df_transp.select("Trace", 
                                 *[col for col in df_transp.columns if col != "Trace"])

    #drop all vertices
    df_transp = df_transp.drop('a','b','c','d','e')

    #transpose rows 
    df_transp = df_transp.selectExpr(
        "Trace", 
        "posexplode(array(e1, e2, e3, e4)) as (pos, col)"
    ).select(
        "Trace", 
        expr('''CASE pos 
        WHEN 0 THEN 'e1' 
        WHEN 1 THEN 'e2'
        WHEN 2 THEN 'e3'
        ELSE 'e4' END''').alias("event"),
        "col"
    ).orderBy("Trace","event")

    #explode columns
    df_transp = df_transp.select(*df_transp.columns, "col.*").drop('col')
    
    return df_transp

In [39]:
def trace_encode(df, traces, list_cols, output = 'vec'):
    
    start_time = time.time()
    
    if traces == 3:
        df_graph = create_graph3(df)
        print("create graph: "+ str(time.time() - start_time))
        df_onehot = transp_expl3(df_graph)
        print("transp-explode: "+ str(time.time() - start_time))
    elif traces == 4:
        df_graph = create_graph4(df)
        print("create graph: "+ str(time.time() - start_time))
        df_onehot = transp_expl4(df_graph)
        print("transp-explode: "+ str(time.time() - start_time))
    else: 
        df_graph = create_graph6(df)
        print("create graph: "+ str(time.time() - start_time))
        #step one accept the event trace transpose it, and explode it.
        df_onehot = transp_expl(df_graph)
        print("transp-explode: "+ str(time.time() - start_time))

    #calculate malicious trace then check schema and how many were made 
    w = Window.partitionBy("Trace")
    df_onehot = df_onehot.withColumn('mal_trace', when(sum('malicious').over(w) > 0, 1).otherwise(0)).cache()
    
    #pr_bins = []
    
    #now develop time_diff bins, include tot for chunking 
    df_onehot = bin_it(df_onehot).cache()
    
    #and now pagerank
    #df_onehot, pr_bins = bin_it_pr(df_onehot, tot)
    
    #instantiate dictionary and return df
    dict_mapping = {}
    #list of sparse cols
    list_sparse = []
    
    print("bin time: "+ str(time.time() - start_time))
    
    #think we insert binning here
    
    #for all columns to one hot, one hot, preserve mapping
    for colm in list_cols+["timestamp_bins"]:#,"pagerank_bins"]:
        #print(colm)
        df_onehot, dict_mapping, list_sparse = oneHotCol(df_onehot,colm, dict_mapping, list_sparse)
    
    #assemble vectors for all sparse columns - this might be enough for our ML algorithms
    assembler = VectorAssembler(inputCols=list_sparse, 
                            outputCol="final_vec")
    df_onehot = assembler.transform(df_onehot).cache()
    
    #turn into string
    df_onehot = df_onehot.withColumn("vec2string", binary_to_string_array_udf("final_vec")).cache()
    
    print("one-hot time: "+ str(time.time() - start_time))
    
    
    #now I need to arrange the output in a column-wise dataframe with event strings or indices and the malicious tag
    if output == 'vec':
        
        #Generate a list of columns to drop
        keep_cols = ['mal_trace','Trace','event','vec2string']
        drop_cols = [col for col in df_onehot.columns  
                     if col not in list_cols and col not in keep_cols]
    
        #i want to drop any columnn not in the column list or is the malicious column
        df_onehot = df_onehot.drop(*drop_cols).cache()
        
        if traces == 3: 
        
            #first pivot aka transpose and keep all events
            pivot_vec = df_onehot.groupBy('Trace').pivot('event')\
            .agg(first('mal_trace'),first('vec2string')).cache()
            #then consolidate the columns into a single event sequence
            df_onehot = pivot_vec.select('Trace',col('e1_first(mal_trace)').alias('mal_trace'),
                                      array('e1_first(vec2string)', 'e2_first(vec2string)',
                               'e3_first(vec2string)').alias('event_sequence')).cache()
        elif traces == 4:
            
            #first pivot aka transpose and keep all events
            pivot_vec = df_onehot.groupBy('Trace').pivot('event')\
            .agg(first('mal_trace'),first('vec2string')).cache()
            #then consolidate the columns into a single event sequence
            df_onehot = pivot_vec.select('Trace',col('e1_first(mal_trace)').alias('mal_trace'),
                                      array('e1_first(vec2string)', 'e2_first(vec2string)',
                               'e3_first(vec2string)','e4_first(vec2string)').alias('event_sequence')).cache()
        else: 
            
            #first pivot aka transpose and keep all events
            pivot_vec = df_onehot.groupBy('Trace').pivot('event')\
            .agg(first('mal_trace'),first('vec2string')).cache()
            #then consolidate the columns into a single event sequence
            df_onehot = pivot_vec.select('Trace',col('e1_first(mal_trace)').alias('mal_trace'),
                                      array('e1_first(vec2string)', 'e2_first(vec2string)',
                               'e3_first(vec2string)','e4_first(vec2string)', 
                               'e5_first(vec2string)', 'e6_first(vec2string)').alias('event_sequence')).cache()
    else: 
        
        #index
        indexer = StringIndexer(inputCol='vec2string', outputCol='event_ind')
        indexer_fitted = indexer.fit(df_onehot)
        df_onehot = indexer_fitted.transform(df_onehot).cache()

        #turn index into an integer
        df_onehot = df_onehot.withColumn("event_index", int_cast_udf("event_ind")).cache()

        print("indexing time: "+ str(time.time() - start_time))

        #Generate a list of columns to drop
        keep_cols = ['mal_trace','Trace','event','vec2string',"event_index"]
        drop_cols = [col for col in df_onehot.columns  
                     if col not in list_cols and col not in keep_cols]
    
        #i want to drop any columnn not in the column list or is the malicious column
        df_onehot = df_onehot.drop(*drop_cols).cache()
        
        #now do it for the indices
        pivot_ind = df_onehot.groupBy('Trace').pivot('event').agg(first('mal_trace'),
                                        first('event_index')).cache()
        df_onehot = pivot_ind.select('Trace',col('e1_first(mal_trace)').alias('mal_trace'),
                                  array('e1_first(event_index)', 'e2_first(event_index)',
                           'e3_first(event_index)','e4_first(event_index)', 
                           'e5_first(event_index)', 'e6_first(event_index)').alias('event_sequence')).cache()
        
    print("total elapsed time: "+ str(time.time() - start_time))
    
    return df_onehot,dict_mapping

In [46]:
def write_after(df,day,trace,div, des):
    
    start_time = time.time()
    
    #split the df in quarters
    parts = df.randomSplit(div*[1.0/div], seed=42)
    
    print("split time: " + str(time.time() - start_time))
    
    #now develop file name for the writes of each part to S3
    s3_url_trusted = "s3a://sapient-bucket-trusted/"
    day = str(day)
    trace = str(trace)
    
    count = 0
    
    for part in parts:
        count+=1
        cnt = str(count)
        print("write: part " + str(cnt)) 
        part.write.option("maxRecordsPerFile", 300000).mode("overwrite")\
        .parquet(f'''{s3_url_trusted}/prod/graph/encoded/real/
                    {des}/{day}Sep/{trace}/part_{cnt}''')
        
        print("part " + cnt + " write time: " + str(time.time() - start_time))
        
    print("total write time: " + str(time.time() - start_time))
    
    
    