# PySpark Project Step-by-Step: Part 2

This notebook will walk you through 2 more steps in the ML lifecycle - **Feature Engineering** and **Model Fitting & Evaluation**.<br>
* In the feature engineering part you'll see how to perform common aggregates using analytical functions.
* In the modelling part you'll see how to prepare your data for modelling in PySpark, and how to fit a model using MLLib.
* Finally, we'll see how we can evaluate the model we've built.

In [166]:
from pyspark.sql import Window
import pyspark.sql.functions as fn
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    from_unixtime,
    to_timestamp,
    min,
    max,
    sum,
    avg,
    col,
    countDistinct,
    broadcast,
    date_trunc,
    count,
)
from pyspark.sql import Window
import pyspark.sql.functions as F
import plotly.express as px

from pyspark.ml.feature import StringIndexer, VectorAssembler, VectorIndexer
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier

In [2]:
spark = (
    SparkSession.builder.appName("iot")
    .getOrCreate()
)
spark.sparkContext.setLogLevel("ERROR")

23/12/13 18:23:22 WARN Utils: Your hostname, Antonss-MacBook-Pro-6.local resolves to a loopback address: 127.0.0.1; using 10.160.95.31 instead (on interface en0)
23/12/13 18:23:22 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/12/13 18:23:23 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## Read Data

In [3]:
df = spark.read.parquet("processed.pq").withColumn(
    "is_bad", F.when(F.col("label") != "Benign", 1).otherwise(0)
)
df.show(5)

+-------------------+------------------+---------------+-----------+---------------+---------+-----+-------+---------+----------+----------+----------+-------+---------+-------------+---------+-------------+---------+--------------------+-------------------+-------------------+-------------------+-------------------+-------------------+------+
|                 ts|               uid|      source_ip|source_port|        dest_ip|dest_port|proto|service| duration|orig_bytes|resp_bytes|conn_state|history|orig_pkts|orig_ip_bytes|resp_pkts|resp_ip_bytes|    label|      detailed-label|                 dt|                day|               hour|             minute|             second|is_bad|
+-------------------+------------------+---------------+-----------+---------------+---------+-----+-------+---------+----------+----------+----------+-------+---------+-------------+---------+-------------+---------+--------------------+-------------------+-------------------+-------------------+----------

## Feature Engineering

Since we have a time-component to this data, we can engineer all sorts of rolling features. The ones that I'll cover here are:
* Number of times we've seen this source IP in the last minute
* Number of times we've seen this destination IP in the last minute
* Number of times we've seen this source PORT in the last minute
* Number of times we've seen this destination PORT in the last minute

To calculate these features, we'll need to use analytical functions. 

In [4]:
def mins_to_secs(mins):
    return mins * 60


def generate_window(window_in_minutes: int, partition_by: str, timestamp_col: str):
    window = (
        Window()
        .partitionBy(F.col(partition_by))
        .orderBy(col(timestamp_col).cast("long"))
        .rangeBetween(-mins_to_secs(window_in_minutes), -1)
    )

    return window


def generate_rolling_aggregate(
    col: str,
    partition_by: str | None = None,
    operation: str = "count",
    timestamp_col: str = "dt",
    window_in_minutes: int = 1,
):
    if partition_by is None:
        partition_by = col

    match operation:
        case "count":
            return F.count(col).over(
                generate_window(
                    window_in_minutes=window_in_minutes,
                    partition_by=col,
                    timestamp_col=timestamp_col,
                )
            )
        case "sum":
            return F.sum(col).over(
                generate_window(
                    window_in_minutes=window_in_minutes,
                    partition_by=col,
                    timestamp_col=timestamp_col,
                )
            )
        case "avg":
            return F.avg(col).over(
                generate_window(
                    window_in_minutes=window_in_minutes,
                    partition_by=col,
                    timestamp_col=timestamp_col,
                )
            )
        case _:
            raise ValueError(f"Operation {operation} is not defined")

### Generate Rolling Count Features

Due to the nicely defined functions above, generating rolling averages and counts is a piece of cake!

In [5]:
df = df.withColumns({
    "source_ip_count_last_min": generate_rolling_aggregate(col="source_ip", operation="count", timestamp_col="dt", window_in_minutes=1),
    "source_ip_count_last_30_mins": generate_rolling_aggregate(col="source_ip", operation="count", timestamp_col="dt", window_in_minutes=30),
    "source_port_count_last_min": generate_rolling_aggregate(col="source_port", operation="count", timestamp_col="dt", window_in_minutes=1),
    "source_port_count_last_30_mins": generate_rolling_aggregate(col="source_port", operation="count", timestamp_col="dt", window_in_minutes=30),
    "dest_ip_count_last_min": generate_rolling_aggregate(col="dest_ip", operation="count", timestamp_col="dt", window_in_minutes=1),
    "dest_ip_count_last_30_mins": generate_rolling_aggregate(col="dest_ip", operation="count", timestamp_col="dt", window_in_minutes=30),
    "dest_port_count_last_min": generate_rolling_aggregate(col="dest_port", operation="count", timestamp_col="dt", window_in_minutes=1),
    "dest_port_count_last_30_mins": generate_rolling_aggregate(col="dest_port", operation="count", timestamp_col="dt", window_in_minutes=30),
    "source_ip_avg_pkts_last_min": generate_rolling_aggregate(col="orig_pkts", partition_by="source_ip", operation="avg", timestamp_col="dt", window_in_minutes=1),
    "source_ip_avg_pkts_last_30_mins": generate_rolling_aggregate(col="orig_pkts", partition_by="source_ip", operation="avg", timestamp_col="dt", window_in_minutes=30),
    "source_ip_avg_bytes_last_min": generate_rolling_aggregate(col="orig_ip_bytes", partition_by="source_ip", operation="avg", timestamp_col="dt", window_in_minutes=1),
    "source_ip_avg_bytes_last_30_mins": generate_rolling_aggregate(col="orig_ip_bytes", partition_by="source_ip", operation="avg", timestamp_col="dt", window_in_minutes=30),
})

Now,execute and save the resulting table into a new parquet file

In [6]:
df.write.mode("overwrite").parquet("feature_engineered.pq")

                                                                                

In [119]:
df_fe = spark.read.parquet("feature_engineered.pq")

Let's compare the speed of calling the old `df` vs the new `df_fe`...

In [120]:
df_fe.show(10)

+-------------------+------------------+---------------+-----------+---------------+---------+-----+-------+---------+----------+----------+----------+-------+---------+-------------+---------+-------------+---------+--------------------+-------------------+-------------------+-------------------+-------------------+-------------------+------+------------------------+----------------------------+--------------------------+------------------------------+----------------------+--------------------------+------------------------+----------------------------+---------------------------+-------------------------------+----------------------------+--------------------------------+
|                 ts|               uid|      source_ip|source_port|        dest_ip|dest_port|proto|service| duration|orig_bytes|resp_bytes|conn_state|history|orig_pkts|orig_ip_bytes|resp_pkts|resp_ip_bytes|    label|      detailed-label|                 dt|                day|               hour|             minut

Such a drastic difference is because when you call `df.show()` it's going to execute all of the very expensive operations we did. Instead, it's better to construct a new dataframe for the analysis.

## Preprocessing

In [176]:
numerical_features = [
    "duration",
    "orig_bytes",
    "resp_bytes",
    "orig_pkts",
    "orig_ip_bytes",
    "resp_pkts",
    "resp_ip_bytes",
    "source_ip_count_last_min",
    "source_ip_count_last_30_mins",
    "source_port_count_last_min",
    "source_port_count_last_30_mins",
    "dest_ip_count_last_min",
    "dest_ip_count_last_30_mins",
    "dest_port_count_last_min",
    "dest_port_count_last_30_mins",
    "source_ip_avg_pkts_last_min",
    "source_ip_avg_pkts_last_30_mins",
    "source_ip_avg_bytes_last_min",
    "source_ip_avg_bytes_last_30_mins",
]
categorical_features = ["proto", "service", "conn_state", "history"]
categorical_features_indexed = [c + "_index" for c in categorical_features]

input_features = numerical_features + categorical_features_indexed

### Remove rare categories

In [122]:
df_fe.select([F.count_distinct(c) for c in categorical_features]).show()

+---------------------+-----------------------+--------------------------+-----------------------+
|count(DISTINCT proto)|count(DISTINCT service)|count(DISTINCT conn_state)|count(DISTINCT history)|
+---------------------+-----------------------+--------------------------+-----------------------+
|                    3|                      6|                        12|                    167|
+---------------------+-----------------------+--------------------------+-----------------------+



In [123]:
categorical_valid_values = {}

for c in categorical_features:
    # Find frequent values
    categorical_valid_values[c] = (
        df_fe.groupby(c)
        .count()
        .filter(F.col("count") > 100)
        .select(c)
        .toPandas()
        .values.ravel()
    )

    df_fe = df_fe.withColumn(
        c,
        F.when(F.col(c).isin(list(categorical_valid_values[c])), F.col(c)).otherwise(
            F.lit("Other").alias(c)
        ),
    )

In [124]:
df_fe.select([F.count_distinct(c) for c in categorical_features]).show()

+---------------------+-----------------------+--------------------------+-----------------------+
|count(DISTINCT proto)|count(DISTINCT service)|count(DISTINCT conn_state)|count(DISTINCT history)|
+---------------------+-----------------------+--------------------------+-----------------------+
|                    3|                      4|                         8|                     23|
+---------------------+-----------------------+--------------------------+-----------------------+



## Train/Test Split
Train test split will need to be done using the source IP address, otherwise we risk leaking data. The best way to do this is by splitting the IP addresses at random, and then filtering the data frame according to the IP address.

In [125]:
df_fe.groupby("source_ip").agg(F.sum(F.col("is_bad")).alias("bad_sum")).orderBy("bad_sum", ascending=False).show(5)

+---------------+-------+
|      source_ip|bad_sum|
+---------------+-------+
|192.168.100.103| 539473|
|    192.168.2.5| 151566|
|    192.168.2.1|      1|
|173.216.214.224|      0|
|   123.25.17.53|      0|
+---------------+-------+
only showing top 5 rows



In [126]:

# Training non-malicious IPs (80%)
train_ips = (
    df_fe.where(
        ~F.col("source_ip").isin(["192.168.100.103", "192.168.2.5", "192.168.2.1"])
    )
    .select(F.col("source_ip"), F.lit(1).alias("is_train"))
    .dropDuplicates()
    .sample(0.8)
)


df_fe = df_fe.join(train_ips, "source_ip", "left")

# Add 1 malicious IP to training and testing data
df_train = df_fe.where((F.col("is_train") == 1) | (F.col("source_ip") == "192.168.100.103"))
df_test = df_fe.where((F.col("is_train") != 1) | (F.col("source_ip") == "192.168.2.5"))

## Pipeline

In [177]:
ind = StringIndexer(inputCols=categorical_features, outputCols=categorical_features_indexed, handleInvalid='skip')
va = VectorAssembler(inputCols=input_features, outputCol="features", handleInvalid='skip' )
rf = RandomForestClassifier(labelCol="is_bad", numTrees=100)

pipeline = Pipeline(stages=[ind, va, rf])

## Fit and Predict

In [179]:
pipeline = pipeline.fit(df_train)
test_preds = pipeline.transform(df_test)

## Evaluate

In [180]:
from pyspark.mllib.evaluation import BinaryClassificationMetrics
from pyspark.ml.evaluation import BinaryClassificationEvaluator

eval = BinaryClassificationEvaluator(labelCol="is_bad")
eval.evaluate(test_preds)

                                                                                

0.9399533426642217

[2.060563333294077e-05,
 0.00040486459066375977,
 4.090228866343483e-06,
 0.0014795778180817921,
 0.008006220041877213,
 2.3897256938537146e-05,
 0.00011641531516134424,
 5.758253972207269e-06,
 1.3629910605598109e-05,
 0.0832323076310752,
 0.07822060819152368,
 0.003464341118247299,
 0.0070681405666078645,
 0.2700798142876739,
 0.27261585008038103,
 0.00010499559632692276,
 0.0001007746193154143,
 0.024407224026794027,
 0.0407510593620146,
 0.0537559121589615,
 0.0,
 0.00035495985268745777,
 0.15576895345889166]

In [196]:
import pandas as pd

pd.DataFrame(
    {
        "importance": list(pipeline.stages[-1].featureImportances),
        "feature": pipeline.stages[-2].getInputCols(),
    }
).sort_values("importance", ascending=False)

Unnamed: 0,importance,feature
14,0.272616,dest_port_count_last_30_mins
13,0.27008,dest_port_count_last_min
22,0.155769,history_index
9,0.083232,source_port_count_last_min
10,0.078221,source_port_count_last_30_mins
19,0.053756,proto_index
18,0.040751,source_ip_avg_bytes_last_30_mins
17,0.024407,source_ip_avg_bytes_last_min
4,0.008006,orig_ip_bytes
12,0.007068,dest_ip_count_last_30_mins


23/12/13 22:05:19 ERROR Inbox: Ignoring error
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:56)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:310)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRefByURI(RpcEnv.scala:102)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRef(RpcEnv.scala:110)
	at org.apache.spark.util.RpcUtils$.makeDriverRef(RpcUtils.scala:36)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.driverEndpoint$lzycompute(BlockManagerMasterEndpoint.scala:124)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.org$apache$spark$storage$BlockManagerMasterEndpoint$$driverEndpoint(BlockManagerMasterEndpoint.scala:123)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.isExecutorAlive$lzycompute$1(BlockManagerMasterEndpoint.scala:688)
	at org.apache.spark.storage.BlockManagerMasterE