# Imports

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    from_unixtime,
    to_timestamp,
    min,
    max,
    sum,
    avg,
    col,
    countDistinct,
    broadcast,
    date_trunc,
    count,
)
from pyspark.sql import Window
import pyspark.sql.functions as F
import plotly.express as px

# Read Files

In [2]:
filepaths = ["./iot_malware/CTU-IoT-Malware-Capture-1-1conn.log.labeled.csv", "./iot_malware/CTU-IoT-Malware-Capture-3-1conn.log.labeled.csv"]


spark = (
    SparkSession.builder.appName("iot")
    .getOrCreate()
)
spark.sparkContext.setLogLevel("ERROR")
spark.sparkContext.version

23/12/10 13:12:00 WARN Utils: Your hostname, Antonss-MacBook-Pro-6.local resolves to a loopback address: 127.0.0.1; using 192.168.1.143 instead (on interface en0)
23/12/10 13:12:00 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/12/10 13:12:00 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/12/10 13:12:00 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


'3.5.0'

In [3]:
df = spark.read.option("delimiter", "|").csv(filepaths, inferSchema = True, header = True)
df.show(5)

                                                                                

+-------------------+------------------+---------------+---------+---------------+---------+-----+-------+--------+----------+----------+----------+----------+----------+------------+-------+---------+-------------+---------+-------------+--------------+---------+--------------------+
|                 ts|               uid|      id.orig_h|id.orig_p|      id.resp_h|id.resp_p|proto|service|duration|orig_bytes|resp_bytes|conn_state|local_orig|local_resp|missed_bytes|history|orig_pkts|orig_ip_bytes|resp_pkts|resp_ip_bytes|tunnel_parents|    label|      detailed-label|
+-------------------+------------------+---------------+---------+---------------+---------+-----+-------+--------+----------+----------+----------+----------+----------+------------+-------+---------+-------------+---------+-------------+--------------+---------+--------------------+
|1.525879831015811E9|CUmrqr4svHuSXJy5z7|192.168.100.103|  51524.0| 65.127.233.163|     23.0|  tcp|      -|2.999051|         0|         0|     

In [4]:
df.printSchema()

root
 |-- ts: double (nullable = true)
 |-- uid: string (nullable = true)
 |-- id.orig_h: string (nullable = true)
 |-- id.orig_p: double (nullable = true)
 |-- id.resp_h: string (nullable = true)
 |-- id.resp_p: double (nullable = true)
 |-- proto: string (nullable = true)
 |-- service: string (nullable = true)
 |-- duration: string (nullable = true)
 |-- orig_bytes: string (nullable = true)
 |-- resp_bytes: string (nullable = true)
 |-- conn_state: string (nullable = true)
 |-- local_orig: string (nullable = true)
 |-- local_resp: string (nullable = true)
 |-- missed_bytes: double (nullable = true)
 |-- history: string (nullable = true)
 |-- orig_pkts: double (nullable = true)
 |-- orig_ip_bytes: double (nullable = true)
 |-- resp_pkts: double (nullable = true)
 |-- resp_ip_bytes: double (nullable = true)
 |-- tunnel_parents: string (nullable = true)
 |-- label: string (nullable = true)
 |-- detailed-label: string (nullable = true)



## Pre-processing

In [5]:
df = df.withColumn("dt", from_unixtime("ts")).withColumn("dt", to_timestamp("dt"))

In [6]:
df = df.withColumnsRenamed(
    {
        "id.orig_h": "source_ip",
        "id.orig_p": "source_port",
        "id.resp_h": "dest_ip",
        "id.resp_p": "dest_port",
    }
)

## Dataset Quality Checks

### Min, Max datetime

In [7]:
df.agg(
    min("dt").alias("min_date"), 
    max("dt").alias("max_date")
).show()

[Stage 3:=====>                                                    (1 + 9) / 10]

+-------------------+-------------------+
|           min_date|           max_date|
+-------------------+-------------------+
|2018-05-09 16:30:31|2018-05-21 08:04:46|
+-------------------+-------------------+



                                                                                

### Shape

In [8]:
df.count(), len(df.columns)

(1164851, 24)

### Static Columns

In [9]:
to_analyse = [
    "source_ip",
    "source_port",
    "dest_ip",
    "dest_port",
    "proto",
    "service",
    "duration",
    "orig_bytes",
    "resp_bytes",
    "conn_state",
    "local_orig",
    "local_resp",
    "missed_bytes",
    "history",
    "orig_pkts",
    "orig_ip_bytes",
    "resp_pkts",
    "resp_ip_bytes",
    "tunnel_parents",
    "label",
    "detailed-label",
]

unique_counts = df.agg(*(countDistinct(col(c)).alias(c) for c in to_analyse))
print(unique_counts.show())



+---------+-----------+-------+---------+-----+-------+--------+----------+----------+----------+----------+----------+------------+-------+---------+-------------+---------+-------------+--------------+-----+--------------+
|source_ip|source_port|dest_ip|dest_port|proto|service|duration|orig_bytes|resp_bytes|conn_state|local_orig|local_resp|missed_bytes|history|orig_pkts|orig_ip_bytes|resp_pkts|resp_ip_bytes|tunnel_parents|label|detailed-label|
+---------+-----------+-------+---------+-----+-------+--------+----------+----------+----------+----------+----------+------------+-------+---------+-------------+---------+-------------+--------------+-----+--------------+
|    15898|      28259| 661294|    65427|    3|      6|   34789|       200|       601|        12|         1|         1|           1|    167|       58|         1311|       74|         1480|             1|    2|             4|
+---------+-----------+-------+---------+-----+-------+--------+----------+----------+----------+---

                                                                                

In [11]:
unique_counts = unique_counts.first()
static_cols = [c for c in unique_counts.asDict() if unique_counts[c] == 1]
print("Dataset has", len(static_cols), "static columns: ", static_cols)
df = df.drop(*static_cols)



Dataset has 4 static columns:  ['local_orig', 'local_resp', 'missed_bytes', 'tunnel_parents']


                                                                                

### Count Distinct Values

In [12]:
source_ips = df.select(col("source_ip")).distinct()
dest_ips = df.select(col("dest_ip")).distinct()
common_ips = source_ips.join(broadcast(dest_ips), source_ips.source_ip == dest_ips.dest_ip, how='inner')


print("Source IPs count:", source_ips.count())
print("Dest IPs count:", dest_ips.count())
print("IPs as both:", common_ips.count())

Source IPs count: 15898
Dest IPs count: 661294


                                                                                

IPs as both: 9423


In [13]:
source_ports = df.select(col("source_port")).distinct()
dest_ports = df.select(col("dest_port")).distinct()
common_ports = source_ports.join(broadcast(dest_ports), source_ports.source_port == dest_ports.dest_port, how='inner')


print("Source Ports count:", source_ports.count())
print("Dest Ports count:", dest_ports.count())
print("Ports as both:", common_ports.count())

Source Ports count: 28259
Dest Ports count: 65427


[Stage 61:=====>                                                   (1 + 9) / 10]

Ports as both: 28212


                                                                                

### Count Nulls

In [14]:
df = df.replace("-", None)

In [15]:
remaining_cols = [f for f in to_analyse if f not in static_cols]
df.select(
    [count(F.when(F.isnan(c) | col(c).isNull(), c)).alias(c) for c in remaining_cols]
).show()

[Stage 69:=====>                                                   (1 + 9) / 10]

+---------+-----------+-------+---------+-----+-------+--------+----------+----------+----------+-------+---------+-------------+---------+-------------+-----+--------------+
|source_ip|source_port|dest_ip|dest_port|proto|service|duration|orig_bytes|resp_bytes|conn_state|history|orig_pkts|orig_ip_bytes|resp_pkts|resp_ip_bytes|label|detailed-label|
+---------+-----------+-------+---------+-----+-------+--------+----------+----------+----------+-------+---------+-------------+---------+-------------+-----+--------------+
|        0|          0|      0|        0|    0|1155702|  870244|    870244|    870244|         0|  18628|        0|            0|        0|            0|    0|        473811|
+---------+-----------+-------+---------+-----+-------+--------+----------+----------+----------+-------+---------+-------------+---------+-------------+-----+--------------+



                                                                                

## Time-Series Plots

In [16]:
df = df.withColumns(
    {
        "day": date_trunc("day", "dt"),
        "hour": date_trunc("hour", "dt"),
        "minute": date_trunc("minute", "dt"),
        "second": date_trunc("second", "dt"),
    }
)

In [17]:
for agg in ['day', 'hour', 'minute']:
    plotting_table = df.groupBy([agg, "label"]).agg(count("uid").alias("counts")).orderBy(agg).toPandas()
    fig = px.line(plotting_table, x=agg, y="counts", color="label", title=f'Event Counts per {agg}')
    fig.show()

## Univariate Data Analysis

In [18]:
def counts(df, var):
    var_counts = df.groupBy(var).count().orderBy("count", ascending=False)
    var_counts = var_counts.withColumn(
        "percent", F.round(col("count") / sum(col("count")).over(Window.partitionBy()), 4)
    )
    var_counts.show()
    fig = px.bar(var_counts.toPandas(), x=var, y="count")
    fig.show()


categorical_columns = ["proto", "service", "conn_state", "history", "label"]

for c in categorical_columns:
    counts(df, c)

+-----+------+-------+
|proto| count|percent|
+-----+------+-------+
|  tcp|737292| 0.6329|
|  udp|408931| 0.3511|
| icmp| 18628|  0.016|
+-----+------+-------+



+-------+-------+-------+
|service|  count|percent|
+-------+-------+-------+
|   NULL|1155702| 0.9921|
|    ssh|   5899| 0.0051|
|   http|   3238| 0.0028|
|    irc|      6|    0.0|
|   dhcp|      4|    0.0|
|    dns|      2|    0.0|
+-------+-------+-------+



+----------+-------+-------+
|conn_state|  count|percent|
+----------+-------+-------+
|        S0|1116839| 0.9588|
|        SF|  20569| 0.0177|
|       OTH|  18637|  0.016|
|       REJ|   6669| 0.0057|
|      RSTR|   1526| 0.0013|
|    RSTOS0|    197| 2.0E-4|
|      RSTO|    168| 1.0E-4|
|        S2|     72| 1.0E-4|
|        S1|     69| 1.0E-4|
|     RSTRH|     67| 1.0E-4|
|        SH|     37|    0.0|
|       SHR|      1|    0.0|
+----------+-------+-------+



+----------+------+-------+
|   history| count|percent|
+----------+------+-------+
|         S|715203|  0.614|
|         D|401618| 0.3448|
|      NULL| 18628|  0.016|
|        Dd|  7294| 0.0063|
|        Sr|  6668| 0.0057|
|  ShAdDaFf|  6643| 0.0057|
|  ShAdDafF|  2797| 0.0024|
|  ShADadfF|  1385| 0.0012|
|   ShADafF|   740| 6.0E-4|
|    ShADar|   397| 3.0E-4|
|  ShAdDaFr|   285| 2.0E-4|
|  ShAdDfFr|   228| 2.0E-4|
|         R|   181| 2.0E-4|
|     ShADr|   164| 1.0E-4|
|  ShADdfFa|   151| 1.0E-4|
| ShAdDaftF|   150| 1.0E-4|
|ShAdDaFRfR|   149| 1.0E-4|
|    ShAdDr|   123| 1.0E-4|
|   ShADafr|   120| 1.0E-4|
| ShAdDafrR|   120| 1.0E-4|
+----------+------+-------+
only showing top 20 rows



+---------+------+-------+
|    label| count|percent|
+---------+------+-------+
|Malicious|691040| 0.5932|
|   Benign|473811| 0.4068|
+---------+------+-------+



## Prepare for Modelling

In [19]:
numerical_cols = [
    "duration",
    "orig_bytes",
    "resp_bytes",
    "orig_pkts",
    "orig_ip_bytes",
    "resp_pkts",
    "resp_ip_bytes",
]
categorical_cols = ["proto", "service", "conn_state"]
label = "label"

all_cols = numerical_cols + categorical_cols

In [20]:
recast_cols = {}
fill_vals = {}
for c in numerical_cols:
    recast_cols[c] = col(c).cast("double")
    fill_vals[c] = -999999

for c in categorical_cols:
    fill_vals[c] = 'missing'
    
df = df.withColumns(recast_cols)
df = df.fillna(fill_vals)


## Full Pipeline

In [21]:
static_cols = ["local_orig", "local_resp", "missed_bytes", "tunnel_parents"]

recast_cols = {
    "duration": col("duration").cast("double"),
    "orig_bytes": col("orig_bytes").cast("double"),
    "resp_bytes": col("resp_bytes").cast("double"),
    "orig_ip_bytes": col("orig_ip_bytes").cast("double"),
    "orig_pkts": col("orig_pkts").cast("double"),
    "resp_pkts": col("resp_pkts").cast("double"),
    "resp_ip_bytes": col("resp_ip_bytes").cast("double"),
}

fill_vals = {
    "duration": -999999,
    "orig_bytes": -999999,
    "resp_bytes": -999999,
    "orig_pkts": -999999,
    "orig_ip_bytes": -999999,
    "resp_pkts": -999999,
    "resp_ip_bytes": -999999,
    "history": "missing",
    "proto": "missing",
    "service": "missing",
    "conn_state": "missing",
}

preprocessed_data = (
    spark.read.option("delimiter", "|")
    .csv(filepaths, inferSchema=True, header=True)
    .withColumn("dt", to_timestamp(from_unixtime("ts")))
    .withColumns(
        {
            "day": date_trunc("day", "dt"),
            "hour": date_trunc("hour", "dt"),
            "minute": date_trunc("minute", "dt"),
            "second": date_trunc("second", "dt"),
        }
    )
    .withColumnsRenamed(
        {
            "id.orig_h": "source_ip",
            "id.orig_p": "source_port",
            "id.resp_h": "dest_ip",
            "id.resp_p": "dest_port",
        }
    )
    .drop(*static_cols)
    .replace("-", None)
    .withColumns(recast_cols)
    .fillna(fill_vals)
)

preprocessed_data.show()

+-------------------+------------------+---------------+-----------+---------------+---------+-----+-------+---------+----------+----------+----------+-------+---------+-------------+---------+-------------+---------+--------------------+-------------------+-------------------+-------------------+-------------------+-------------------+
|                 ts|               uid|      source_ip|source_port|        dest_ip|dest_port|proto|service| duration|orig_bytes|resp_bytes|conn_state|history|orig_pkts|orig_ip_bytes|resp_pkts|resp_ip_bytes|    label|      detailed-label|                 dt|                day|               hour|             minute|             second|
+-------------------+------------------+---------------+-----------+---------------+---------+-----+-------+---------+----------+----------+----------+-------+---------+-------------+---------+-------------+---------+--------------------+-------------------+-------------------+-------------------+-------------------+----

## Write Out

In [23]:
preprocessed_data.writeparquet("processed.pq")

                                                                                

In [24]:
read_in = spark.read.parquet("processed.pq")
read_in.show()

+-------------------+------------------+---------------+-----------+---------------+---------+-----+-------+---------+----------+----------+----------+-------+---------+-------------+---------+-------------+---------+--------------------+-------------------+-------------------+-------------------+-------------------+-------------------+
|                 ts|               uid|      source_ip|source_port|        dest_ip|dest_port|proto|service| duration|orig_bytes|resp_bytes|conn_state|history|orig_pkts|orig_ip_bytes|resp_pkts|resp_ip_bytes|    label|      detailed-label|                 dt|                day|               hour|             minute|             second|
+-------------------+------------------+---------------+-----------+---------------+---------+-----+-------+---------+----------+----------+----------+-------+---------+-------------+---------+-------------+---------+--------------------+-------------------+-------------------+-------------------+-------------------+----