In [None]:
import findspark
findspark.init('/opt/spark')

In [None]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('SCARP_Kafka_Test')\
                    .config('spark.jars.packages','org.apache.spark:spark-sql-kafka-0-10_2.12:3.0.1')\
                    .config('spark.jars.packages','org.apache.kafka:kafka-clients:2.4.1')\
                    .config("spark.driver.memory", "16g")\
                    .config("spark.executor.memory", "8g")\
                    .getOrCreate()

In [None]:
df = spark.readStream.format("kafka")\
    .option("kafka.bootstrap.servers", "localhost:9092")\
    .option("subscribe", "SCARP_kafka_notebook")\
    .option('includeTimestamp', 'true')\
    .load()

In [None]:
df.printSchema()

In [None]:
from pyspark.sql.functions import *

data_df = df.select(['value','timestamp'])

In [None]:
# Windowed dataframes

df_window = ( 
    data_df.groupBy(window(data_df['timestamp'], "10 seconds", "10 seconds"), data_df['value'])\
    .count()\
    .orderBy('window')\
    #.withColumnRenamed('count(follower)','TotalAmount')
)

df_window2 = ( 
    data_df.groupBy(window(data_df['timestamp'], "10 seconds", "10 seconds"), data_df['value'])\
    .count()\
    .orderBy(desc('window'))\
    #.withColumnRenamed('count(follower)','TotalAmount')
)

In [None]:
# Windowed query. Doesn't work for our purpose.

query = df_window2.writeStream\
        .outputMode('complete')\
        .format('memory')\
        .queryName('window_test')\
        .start()



In [None]:
# Viewing windowed query

import time
for x in range(30):
    _df = spark.sql('SELECT * FROM window_test')
    _df.show(10, truncate=50)
    time.sleep(4)

In [None]:
query.stop()

In [None]:
# Original atttempt at inference

import time
from IPython.display import clear_output

#GBTmodel = GBTClassificationModel.load("./models/gbt_best_model")
    
prev_timestamp = []
totalMalware = 0
totalBenign = 0

while True:
    _df = spark.sql('SELECT * FROM projectML WHERE ... ORDER BY timestamp DESC')

    if _df.first() is not None:

        cur_timestamp = _df.select('timestamp').collect()

        if prev_timestamp != cur_timestamp:
            #display.clear_output(wait=True)
            prev_timestamp = cur_timestamp
            print("New Data Found...")
            extractor = PEFeatureExtractor(2)
            features = np.array(extractor.feature_vector(bytes(_df.first().value)), dtype=np.float32)

            prediction = GBTmodel.predict(DenseVector(features))
            print("Prediction: ", prediction)

            if prediction==1:
                totalMalware += 1
            elif prediction==0:
                totalBenign += 1

            print("Total Malware:\t",+str(totalMalware))
            print("Total Benign:\t"+str(totalBenign))


In [None]:
# Working query

query2 = data_df.writeStream\
        .format('memory')\
        .queryName('stream_inference')\
        .start()


In [None]:
spark.sparkContext.setLogLevel("ERROR")

In [None]:
from pyspark.ml.classification import GBTClassificationModel
from pyspark.ml.classification import GBTClassifier
import ember_modified
from ember_modified.features import PEFeatureExtractor
import time
from IPython.display import clear_output

In [None]:
# Load in model
GBTmodel = GBTClassificationModel.load("./gbt500k")

In [None]:
totalMalware = 0
totalBenign = 0

extractor = PEFeatureExtractor(2)

while (_df.count() == 0):
    _df = spark.sql('SELECT * FROM stream_inference ORDER BY timestamp DESC')

prev_timestamp = _df.select('timestamp').collect()[0].asDict()['timestamp']
        
print("Entering loop")
while True:
    
    _df = spark.sql('SELECT * FROM stream_inference WHERE timestamp > \''+str(prev_timestamp)+'\' ORDER BY timestamp DESC')
    
    if _df.count() != 0:
        
        #_df.show(truncate=40)
        df_data = _df.collect()
        
        prev_timestamp = df_data[0].asDict()['timestamp']
        
        for row in df_data:
            #if (row.asDict()['timestamp'] > prev_timestamp):
            #    prev_timestamp = row.asDict()['timestamp']
            #start_time = time.time_ns()
    
            raw_data = row.asDict()['value']
            features = extractor.feature_vector_spark(bytes(raw_data))
            prediction = GBTmodel.predict(features)
            
            if prediction==1:
                totalMalware+=1
            else:
                totalBenign+=1
            
            #end_time = time.time_ns()
            #print( int((end_time - start_time)/1000000) )
        
        clear_output(wait=True)
        
        #print("New prev timestamp: ", prev_timestamp)
        #_df.show(truncate=40)
        print("Total Malware: "+str(totalMalware)+"\tTotal Benign: "+str(totalBenign))
        #time.sleep(3)

In [None]:
query2.stop()

In [None]:
# Performance Improvements
# Partition raw data into different features
# Feature importance study for potentially removing non-essential features