In [26]:
#setup based on this: https://t-redactyl.io/blog/2020/08/reading-s3-data-into-a-spark-dataframe-using-sagemaker.html
import os
import boto3
import json 
import time
import pandas as pd
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
import matplotlib.pyplot as plt
import sagemaker_pyspark
import botocore.session

## Set Spark Session Configuration

In [18]:
session = botocore.session.get_session()
credentials = session.get_credentials()

In [19]:
client = boto3.client('secretsmanager')
response = client.get_secret_value(
    SecretId='sapient-s3-access'
)
response = json.loads(response['SecretString'])
access_key = response["aws_access_key_id"]
secret_key = response["aws_secret_access_key"]

In [20]:
os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages graphframes:graphframes:0.8.2-spark3.2-s_2.12 pyspark-shell'
conf = (SparkConf()
        .set("spark.driver.extraClassPath", ":".join(sagemaker_pyspark.classpath_jars())))

In [21]:
# https://spark.apache.org/docs/latest/configuration.html#memory-management
spark = (
    SparkSession
    .builder
    .config(conf=conf) \
    .config('fs.s3a.access.key', access_key)
    .config('fs.s3a.secret.key', secret_key)
    .config('spark.network.timeout', 300)
    .config('spark.local.dir', '/home/ec2-user/SageMaker/tmp')
    .config("spark.executor.memory", "16g")
    .config("spark.driver.memory", "8g")
    .config("spark.memory.offHeap.enabled", "true")
    .config("spark.memory.offHeap.size","20g")
    .appName("sapient")
    .getOrCreate()
)
spark.sparkContext.setLogLevel("ERROR")

## Functions to Load and Read Data

In [22]:
# read from raw bucket + write to refined bucket + aggregate final to the trusted bucket
s3_url_raw = "s3a://sapient-bucket-raw/"
s3_url_refined = "s3a://sapient-bucket-refined/"
s3_url_trusted = "s3a://sapient-bucket-trusted/"
bro_cols_conn = ['ts', 'uid', 'id.orig_h', 'id.orig_p', 'id.resp_', 'id.resp_p', 'proto', 'service', 'duration', 'orig_bytes', 'resp_bytes', 'conn_state', 
                 'local_orig', 'local_resp', 'missed_bytes', 'history', 'orig_pkts', 'orig_ip_bytes', 'resp_pkts', 'resp_ip_bytes', 'tunnel_parents']
bro_cols_rep = ['ts', 'level', 'message', 'location']

In [25]:
def readCheckpoint(type='ecar-bro', env='dev', size='small'):
    """
    type: ecar, ecar-bro, bro
    """
    if type == 'labels':
        s3_parquet_loc = f"{s3_url_refined}/{env}/{type}"
    else:
        s3_parquet_loc = f"{s3_url_refined}/{env}/{type}/{size}"
    start_time = time.time()
    df = spark.read.parquet(s3_parquet_loc).select('id','timestamp','objectID','actorID','object','action','hostname').cache()
    # rdd = spark.sparkContext.parallelize(df.take(1000))
    # print(f"Your dataframe has {rdd.count():,} rows.")
    print(time.strftime('%l:%M%p %Z on %b %d, %Y') + " --- read and cache time: %s seconds ---" % (time.time() - start_time))
    return df

In [27]:
def readCheckpoint_bcm(type='ecar-bro', env='dev', size='small'):
    """
    type: ecar, ecar-bro, bro
    """
    if type == 'labels' or type == 'ecar':
        s3_parquet_loc = f"{s3_url_refined}/{env}/{type}"
        #df = spark.read.parquet(s3_parquet_loc).cache()
    else:
        s3_parquet_loc = f"{s3_url_refined}/{env}/{type}/{size}"
    start_time = time.time()
    #df = spark.select('id').read.parquet(s3_parquet_loc).cache()
    # rdd = spark.sparkContext.parallelize(df.take(1000))
    # print(f"Your dataframe has {rdd.count():,} rows.")
    print(time.strftime('%l:%M%p %Z on %b %d, %Y') + " --- read and cache time: %s seconds ---" % (time.time() - start_time))
    return df