# Imports

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    from_unixtime,
    to_timestamp,
    min,
    max,
    sum,
    avg,
    col,
    countDistinct,
    broadcast,
    date_trunc,
    count,
)
from pyspark.sql import Window
import pyspark.sql.functions as F
import plotly.express as px

# Read Files

In [None]:
filepaths = ["./iot_malware/CTU-IoT-Malware-Capture-1-1conn.log.labeled.csv", "./iot_malware/CTU-IoT-Malware-Capture-3-1conn.log.labeled.csv"]


spark = (
    SparkSession.builder.appName("iot")
    .getOrCreate()
)
spark.sparkContext.setLogLevel("ERROR")
spark.sparkContext.version

In [None]:
df = spark.read.option("delimiter", "|").csv(filepaths, inferSchema = True, header = True)
df.show(5)

In [None]:
df.printSchema()

## Pre-processing

In [None]:
df = df.withColumn("dt", from_unixtime("ts")).withColumn("dt", to_timestamp("dt"))

In [None]:
df = df.withColumnsRenamed(
    {
        "id.orig_h": "source_ip",
        "id.orig_p": "source_port",
        "id.resp_h": "dest_ip",
        "id.resp_p": "dest_port",
    }
)

## Dataset Quality Checks

### Min, Max datetime

In [None]:
df.agg(
    min("dt").alias("min_date"), 
    max("dt").alias("max_date")
).show()

### Shape

In [None]:
df.count(), len(df.columns)

### Static Columns

In [None]:
to_analyse = [
    "source_ip",
    "source_port",
    "dest_ip",
    "dest_port",
    "proto",
    "service",
    "duration",
    "orig_bytes",
    "resp_bytes",
    "conn_state",
    "local_orig",
    "local_resp",
    "missed_bytes",
    "history",
    "orig_pkts",
    "orig_ip_bytes",
    "resp_pkts",
    "resp_ip_bytes",
    "tunnel_parents",
    "label",
    "detailed-label",
]

unique_counts = df.agg(*(countDistinct(col(c)).alias(c) for c in to_analyse))
print(unique_counts.show())

In [None]:
unique_counts = unique_counts.first()
static_cols = [c for c in unique_counts.asDict() if unique_counts[c] == 1]
print("Dataset has", len(static_cols), "static columns: ", static_cols)
df = df.drop(*static_cols)

### Count Distinct Values

In [None]:
source_ips = df.select(col("source_ip")).distinct()
dest_ips = df.select(col("dest_ip")).distinct()
common_ips = source_ips.join(broadcast(dest_ips), source_ips.source_ip == dest_ips.dest_ip, how='inner')


print("Source IPs count:", source_ips.count())
print("Dest IPs count:", dest_ips.count())
print("IPs as both:", common_ips.count())

In [None]:
source_ports = df.select(col("source_port")).distinct()
dest_ports = df.select(col("dest_port")).distinct()
common_ports = source_ports.join(broadcast(dest_ports), source_ports.source_port == dest_ports.dest_port, how='inner')


print("Source Ports count:", source_ports.count())
print("Dest Ports count:", dest_ports.count())
print("Ports as both:", common_ports.count())

### Count Nulls

In [None]:
df = df.replace("-", None)

In [None]:
remaining_cols = [f for f in to_analyse if f not in static_cols]
df.select(
    [count(F.when(F.isnan(c) | col(c).isNull(), c)).alias(c) for c in remaining_cols]
).show()

## Time-Series Plots

In [None]:
df = df.withColumns(
    {
        "day": date_trunc("day", "dt"),
        "hour": date_trunc("hour", "dt"),
        "minute": date_trunc("minute", "dt"),
        "second": date_trunc("second", "dt"),
    }
)

In [None]:
for agg in ['day', 'hour', 'minute']:
    plotting_table = df.groupBy([agg, "label"]).agg(count("uid").alias("counts")).orderBy(agg).toPandas()
    fig = px.line(plotting_table, x=agg, y="counts", color="label", title=f'Event Counts per {agg}')
    fig.show()

## Univariate Data Analysis

In [None]:
def counts(df, var):
    var_counts = df.groupBy(var).count().orderBy("count", ascending=False)
    var_counts = var_counts.withColumn(
        "percent", F.round(col("count") / sum(col("count")).over(Window.partitionBy()), 4)
    )
    var_counts.show()
    fig = px.bar(var_counts.toPandas(), x=var, y="count")
    fig.show()


categorical_columns = ["proto", "service", "conn_state", "history", "label"]

for c in categorical_columns:
    counts(df, c)

## Prepare for Modelling

In [None]:
numerical_cols = [
    "duration",
    "orig_bytes",
    "resp_bytes",
    "orig_pkts",
    "orig_ip_bytes",
    "resp_pkts",
    "resp_ip_bytes",
]
categorical_cols = ["proto", "service", "conn_state"]
label = "label"

all_cols = numerical_cols + categorical_cols

In [None]:
recast_cols = {}
fill_vals = {}
for c in numerical_cols:
    recast_cols[c] = col(c).cast("double")
    fill_vals[c] = -999999

for c in categorical_cols:
    fill_vals[c] = 'missing'
    
df = df.withColumns(recast_cols)
df = df.fillna(fill_vals)


## Full Pipeline

In [None]:
static_cols = ["local_orig", "local_resp", "missed_bytes", "tunnel_parents"]

recast_cols = {
    "duration": col("duration").cast("double"),
    "orig_bytes": col("orig_bytes").cast("double"),
    "resp_bytes": col("resp_bytes").cast("double"),
    "orig_ip_bytes": col("orig_ip_bytes").cast("double"),
    "orig_pkts": col("orig_pkts").cast("double"),
    "resp_pkts": col("resp_pkts").cast("double"),
    "resp_ip_bytes": col("resp_ip_bytes").cast("double"),
}

fill_vals = {
    "duration": -999999,
    "orig_bytes": -999999,
    "resp_bytes": -999999,
    "orig_pkts": -999999,
    "orig_ip_bytes": -999999,
    "resp_pkts": -999999,
    "resp_ip_bytes": -999999,
    "history": "missing",
    "proto": "missing",
    "service": "missing",
    "conn_state": "missing",
}

preprocessed_data = (
    spark.read.option("delimiter", "|")
    .csv(filepaths, inferSchema=True, header=True)
    .withColumn("dt", to_timestamp(from_unixtime("ts")))
    .withColumns(
        {
            "day": date_trunc("day", "dt"),
            "hour": date_trunc("hour", "dt"),
            "minute": date_trunc("minute", "dt"),
            "second": date_trunc("second", "dt"),
        }
    )
    .withColumnsRenamed(
        {
            "id.orig_h": "source_ip",
            "id.orig_p": "source_port",
            "id.resp_h": "dest_ip",
            "id.resp_p": "dest_port",
        }
    )
    .drop(*static_cols)
    .replace("-", None)
    .withColumns(recast_cols)
    .fillna(fill_vals)
)

preprocessed_data.show()

## Write Out

In [None]:
preprocessed_data.writeparquet("processed.pq")

In [None]:
read_in = spark.read.parquet("processed.pq")
read_in.show()