In [1]:
#setup based on this: https://t-redactyl.io/blog/2020/08/reading-s3-data-into-a-spark-dataframe-using-sagemaker.html
import boto3
import json 
import time
import ntpath
import pandas as pd
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql import Window
from pyspark.sql.types import *
from pyspark.sql.functions import *
import matplotlib.pyplot as plt
import sagemaker_pyspark
import botocore.session

## Set Spark Session Configuration

In [2]:
session = botocore.session.get_session()
credentials = session.get_credentials()

In [3]:
client = boto3.client('secretsmanager')
response = client.get_secret_value(
    SecretId='sapient-s3-access'
)
response = json.loads(response['SecretString'])
access_key = response["aws_access_key_id"]
secret_key = response["aws_secret_access_key"]

In [4]:
conf = (SparkConf()
        .set("spark.driver.extraClassPath", ":".join(sagemaker_pyspark.classpath_jars()))
       )

In [5]:
# https://spark.apache.org/docs/latest/configuration.html#memory-management
spark = (
    SparkSession
    .builder
    .config(conf=conf) \
    .config('fs.s3a.access.key', access_key)
    .config('fs.s3a.secret.key', secret_key)
    .config('spark.network.timeout', 300)
    .config('spark.local.dir', '/home/ec2-user/SageMaker/tmp')
    .config("spark.executor.memory", "25g")
    .config("spark.driver.memory", "20g")
    .config("spark.memory.offHeap.enabled", "true")
    .config("spark.memory.offHeap.size","20g")
    .appName("sapient")
    .getOrCreate()
)
spark.sparkContext.setLogLevel("ERROR")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/03/15 07:06:44 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/03/15 07:06:44 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).
23/03/15 07:06:45 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


## Functions to Load and Read Data

In [6]:
# read from raw bucket + write to refined bucket + aggregate final to the trusted bucket
s3_url_raw = "s3a://sapient-bucket-raw/"
s3_url_refined = "s3a://sapient-bucket-refined/"
s3_url_trusted = "s3a://sapient-bucket-trusted/"
ecar_cols = [
    'id','timestamp','objectID','actorID','object','action','hostname', 'user_name', 'privileges', 'image_path', 
    'parent_image_path', 'new_path', 'file_path', 'direction', 'logon_id', 'requesting_domain', 'requesting_user'
            ]
bro_cols_conn = ['ts', 'uid', 'id.orig_h', 'id.orig_p', 'id.resp_', 'id.resp_p', 'proto', 'service', 'duration', 'orig_bytes', 'resp_bytes', 'conn_state', 
                 'local_orig', 'local_resp', 'missed_bytes', 'history', 'orig_pkts', 'orig_ip_bytes', 'resp_pkts', 'resp_ip_bytes', 'tunnel_parents']
bro_cols_rep = ['ts', 'level', 'message', 'location']

In [7]:
# Create a new dataframe with distinct objectIDs to identify malcious ObjectIds
df_labels = spark.read.parquet(f"{s3_url_refined}/prod/labels").cache()
df_malcious_objectIDs = df_labels.select('id').distinct()
df_labels.unpersist()

                                                                                

DataFrame[hostname: string, id: string, objectID: string, actorID: string, timestamp: timestamp, object: string, action: string]

In [8]:
def readCheckpoint(type='ecar', env='prod', size='small'):
    """
    type: ecar, ecar-bro, bro
    """
    if type == 'labels':
        s3_parquet_loc = f"{s3_url_trusted}/{env}/{type}"
    else:
        s3_parquet_loc = f"{s3_url_trusted}/{env}/{type}/{size}"
    start_time = time.time()
    df = spark.read.parquet(s3_parquet_loc).cache()
    print(time.strftime('%l:%M%p %Z on %b %d, %Y') + " --- read and cache time: %s seconds ---" % (time.time() - start_time))
    return df

In [9]:
def getFirstEvents(df):
    """
    this creates then drops duplicates and gets the first appearance of each relationship entry
    input - dataframe with columns objectID and actorID
    output dataframe
    """
    window = Window.partitionBy("relationship").orderBy("timestamp")
    df_new = df.withColumn('relationship', concat(df.actorID, lit('->'),df.objectID) ) \
                .withColumn('rank', rank().over(window)) \
                .filter(col('rank') == 1) \
                .drop('rank')
    return df_new

In [10]:
def write_firsts(size='all'):
    """
    this creates then drops duplicates and writes them to file in S3
    """
    start_time = time.time()
    df = readCheckpoint(size=size).cache()
    df_first_events = get_firsts(df)
    df_first_events.write.option("maxRecordsPerFile", 300000).mode("overwrite").parquet(f"{s3_url_trusted}/prod/graph/first_events")
    print(time.strftime('%l:%M%p %Z on %b %d, %Y') + " --- read and write time: %s seconds ---" % (time.time() - start_time))

In [11]:
def writeFirstEvents(size='all'):
    """
    this creates then drops duplicates and writes them to file in S3
    """
    start_time = time.time()
    df = readCheckpoint(size=size).cache()
    df_first_events = getFirstEvents(df)
    df_first_events.write.option("maxRecordsPerFile", 300000).mode("overwrite").parquet(f"{s3_url_trusted}/prod/graph/first_events")
    print(time.strftime('%l:%M%p %Z on %b %d, %Y') + " --- read and write time: %s seconds ---" % (time.time() - start_time))

In [12]:
def getFile(str):
    """
    udf to get the file from a full file path
    similar udf (non-windows): https://stackoverflow.com/questions/40848681/udf-to-extract-only-the-file-name-from-path-in-spark-sql
    """
    if str == None:
        pass
    else:
        new_str = ntpath.basename(str)
        return new_str
getFileUDF = udf(lambda z: getFile(z),StringType())

In [13]:
def readFirstEvents():
    start_time = time.time()
    df = spark.read.parquet(f"{s3_url_trusted}/prod/graph/first_events")\
            .withColumn("image_path", getFileUDF(col("image_path"))) \
            .withColumn("parent_image_path", getFileUDF(col("parent_image_path"))) \
            .withColumn("new_path", getFileUDF(col("new_path"))) \
            .withColumn("file_path", getFileUDF(col("file_path")))
    # filter first event hosts to only those with malicious events
    mal_hosts = list(df.select('hostname').filter( col('malicious') == 1 ).distinct().toPandas()['hostname'])
    df = df.filter(col("hostname").isin(mal_hosts)).cache()
    print(time.strftime('%l:%M%p %Z on %b %d, %Y') + " --- read time: %s seconds ---" % (time.time() - start_time))
    return df

In [14]:
# writeFirstEvents()

In [15]:
# df = readFirstEvents().cache()

                                                                                

 7:07AM UTC on Mar 15, 2023 --- read time: 14.19411849975586 seconds ---
