# Malware Detection in Network Traffic Data

## Imports and spark setup

In [18]:
import pyspark
from tqdm import tqdm
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pyspark.sql import SQLContext, SparkSession

spark = SparkSession.builder.appName('Malware').config("spark.executor.memory", "14g").getOrCreate()

#sc = spark.sparkContext
#sqlsc = SQLContext(spark)

import os
# os.environ['HADOOP_HOME'] = 'C:/dummy/hadoop_home'

In [19]:
print(spark.sparkContext.getConf().getAll())

[('spark.app.submitTime', '1702480620972'), ('spark.sql.warehouse.dir', 'file:/C:/Users/Vincenzo/Projects/DDAM_Project_23-24/code/spark-warehouse'), ('spark.executor.memory', '14g'), ('spark.app.startTime', '1702480621048'), ('spark.app.id', 'local-1702480621676'), ('spark.driver.port', '57701'), ('spark.executor.id', 'driver'), ('spark.driver.extraJavaOptions', '-Djava.net.preferIPv6Addresses=false -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.act

## Dataset concatenation / data assesment / schema corrections [preprocessing]

In [20]:
from pyspark.sql.types import StructType, StructField, StringType, DoubleType

# Define your custom schema
custom_schema = StructType([
    StructField("ts", DoubleType(), True),
    StructField("uid", StringType(), True),
    StructField("id.orig_h", StringType(), True),
    StructField("id.orig_p", StringType(), True),
    StructField("id.resp_h", StringType(), True),
    StructField("id.resp_p", StringType(), True),
    StructField("proto", StringType(), True),
    StructField("service", StringType(), True),
    StructField("duration", DoubleType(), True),
    StructField("orig_bytes", DoubleType(), True),
    StructField("resp_bytes", DoubleType(), True),
    StructField("conn_state", StringType(), True),
    StructField("local_orig", StringType(), True),
    StructField("local_resp", StringType(), True),
    StructField("missed_bytes", DoubleType(), True),
    StructField("history", StringType(), True),
    StructField("orig_pkts", DoubleType(), True),
    StructField("orig_ip_bytes", DoubleType(), True),
    StructField("resp_pkts", DoubleType(), True),
    StructField("resp_ip_bytes", DoubleType(), True),
    StructField("tunnel_parents", StringType(), True),
    StructField("label", StringType(), True),
    StructField("detailed-label", StringType(), True),
])

# List of file paths
file_paths = [
    r"C:\Users\Vincenzo\Projects\DDAM_data\malware\CTU-IoT-Malware-Capture-1-1conn.log.labeled.csv",
    r"C:\Users\Vincenzo\Projects\DDAM_data\malware\CTU-IoT-Malware-Capture-3-1conn.log.labeled.csv",
    r"C:\Users\Vincenzo\Projects\DDAM_data\malware\CTU-IoT-Malware-Capture-9-1conn.log.labeled.csv",
    r"C:\Users\Vincenzo\Projects\DDAM_data\malware\CTU-IoT-Malware-Capture-20-1conn.log.labeled.csv",
    r"C:\Users\Vincenzo\Projects\DDAM_data\malware\CTU-IoT-Malware-Capture-21-1conn.log.labeled.csv",
    r"C:\Users\Vincenzo\Projects\DDAM_data\malware\CTU-IoT-Malware-Capture-34-1conn.log.labeled.csv",
    r"C:\Users\Vincenzo\Projects\DDAM_data\malware\CTU-IoT-Malware-Capture-35-1conn.log.labeled.csv",
    r"C:\Users\Vincenzo\Projects\DDAM_data\malware\CTU-IoT-Malware-Capture-42-1conn.log.labeled.csv",
    r"C:\Users\Vincenzo\Projects\DDAM_data\malware\CTU-IoT-Malware-Capture-44-1conn.log.labeled.csv",
    r"C:\Users\Vincenzo\Projects\DDAM_data\malware\CTU-IoT-Malware-Capture-48-1conn.log.labeled.csv",
    r"C:\Users\Vincenzo\Projects\DDAM_data\malware\CTU-IoT-Malware-Capture-60-1conn.log.labeled.csv",
]

# Initialize an empty DataFrame with the custom schema
df = spark.createDataFrame(spark.sparkContext.emptyRDD(), custom_schema)

# Use tqdm for progress bar
for file_path in tqdm(file_paths, desc="Reading files", unit="file"):
    df_temp = spark.read.option("escape", "\"").option("delimiter", "|").csv(file_path, header=True, schema=custom_schema)
    df = df.union(df_temp)

df = df.drop("uid")

df.show(5)
df.printSchema()


Reading files: 100%|██████████| 11/11 [00:00<00:00, 47.06file/s]


+-------------------+---------------+---------+---------------+---------+-----+-------+--------+----------+----------+----------+----------+----------+------------+-------+---------+-------------+---------+-------------+--------------+---------+--------------------+
|                 ts|      id.orig_h|id.orig_p|      id.resp_h|id.resp_p|proto|service|duration|orig_bytes|resp_bytes|conn_state|local_orig|local_resp|missed_bytes|history|orig_pkts|orig_ip_bytes|resp_pkts|resp_ip_bytes|tunnel_parents|    label|      detailed-label|
+-------------------+---------------+---------+---------------+---------+-----+-------+--------+----------+----------+----------+----------+----------+------------+-------+---------+-------------+---------+-------------+--------------+---------+--------------------+
|1.525879831015811E9|192.168.100.103|    51524| 65.127.233.163|       23|  tcp|      -|2.999051|       0.0|       0.0|        S0|         -|         -|         0.0|      S|      3.0|        180.0|   

In [21]:
#replacing dots with underscores in column names to avoid errors
df = df.toDF(*(c.replace('.', '_') for c in df.columns))
df = df.toDF(*(c.replace('-', '_') for c in df.columns))

In [22]:
# #looking for the coluns that have the '-' value to later work on them
# from pyspark.sql.functions import col

# columns_with_dash = [col_name for col_name in df.columns if df.filter(col(col_name) == "-").count() > 0]
# print("Columns with '-' values:", columns_with_dash)

In [23]:
from pyspark.sql.functions import col, when
from pyspark.sql import SparkSession

columns_with_dash = ['service', 'local_orig', 'local_resp', 'history', 'tunnel_parents', 'detailed_label']

def replace_dash_with_nan(df, columns):

    result_df = df
    for column in columns:
        result_df = result_df.withColumn(column, when(col(column) == '-', None).otherwise(col(column)))

    return result_df

df = replace_dash_with_nan(df, columns_with_dash)

In [24]:
# from pyspark.sql.functions import isnan, when, count, col, isnull
# missing = df.select([count(when(isnull(c), c)).alias(c) for c in df.columns]).show()

In [25]:
# from pyspark.sql.functions import from_unixtime, col

# df = df.withColumn("formatted_ts", from_unixtime("ts").cast("timestamp"))
# df.show(5)

In [26]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window

window_spec = Window.partitionBy('id_orig_h', 'id_resp_h')

# Extract features from IP addresses
df_features = df.withColumn('id_orig_h_features', F.split('id_orig_h', r'\.')[0]) \
                .withColumn('id_resp_h_features', F.split('id_resp_h', r'\.')[0]) \
                .withColumn('is_internal_orig', F.when(F.split('id_orig_h', r'\.')[0] == '192', 1).otherwise(0)) \
                .withColumn('is_internal_resp', F.when(F.split('id_resp_h', r'\.')[0] == '192', 1).otherwise(0)) \
                .withColumn('common_orig_dest_pairs', F.count('id_orig_h').over(window_spec)) \
                .withColumn('hour_of_day', F.hour(F.from_unixtime('ts'))) \
                .withColumn('day_of_week', F.dayofweek(F.from_unixtime('ts')))
                

# Select relevant columns for prediction
selected_columns_feature_extraction = ['id_orig_h_features', 'id_resp_h_features', 'is_internal_orig', 'is_internal_resp', 'common_orig_dest_pairs', 'hour_of_day', 'day_of_week']

df_selected_IPfeatures = df_features.select(selected_columns_feature_extraction).drop('id_orig_h', 'id_resp_h')

# # Convert columns to integer
df_selected_IPfeatures = df_selected_IPfeatures.withColumn("id_orig_h_features", col("id_orig_h_features").cast("integer"))
df_selected_IPfeatures = df_selected_IPfeatures.withColumn("id_resp_h_features", col("id_resp_h_features").cast("integer"))

df_features = df_features.withColumn("id_orig_h_features", col("id_orig_h_features").cast("integer"))
df_features = df_features.withColumn("id_resp_h_features", col("id_resp_h_features").cast("integer"))

#print schema
df_selected_IPfeatures.printSchema()

root
 |-- id_orig_h_features: integer (nullable = true)
 |-- id_resp_h_features: integer (nullable = true)
 |-- is_internal_orig: integer (nullable = false)
 |-- is_internal_resp: integer (nullable = false)
 |-- common_orig_dest_pairs: long (nullable = false)
 |-- hour_of_day: integer (nullable = true)
 |-- day_of_week: integer (nullable = true)



In [27]:
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

# Columns for StringIndexing

string_indexing_cols = [ 'id_orig_p', 'id_resp_p', 'proto', 'conn_state', 'label', 'detailed_label']
indexers = [StringIndexer(inputCol=col_name, outputCol=col_name+'_index',handleInvalid='keep') for col_name in string_indexing_cols]

# Create a pipeline with StringIndexer stages
pipeline = Pipeline(stages=indexers)
df_features = pipeline.fit(df_features).transform(df_features)

# Select relevant columns for prediction
selected_columns_string_index = ['id_orig_p', 'id_resp_p', 'proto', 'conn_state', 'label', 'detailed_label']
selected_columns_string_index += [f'{col_name}_index' for col_name in string_indexing_cols]

df_selected_string_index = df_features.select(selected_columns_string_index).drop(*string_indexing_cols)

#print schema
df_selected_string_index.printSchema()

root
 |-- id_orig_p_index: double (nullable = false)
 |-- id_resp_p_index: double (nullable = false)
 |-- proto_index: double (nullable = false)
 |-- conn_state_index: double (nullable = false)
 |-- label_index: double (nullable = false)
 |-- detailed_label_index: double (nullable = false)



In [28]:
ip = df_selected_IPfeatures.schema.names
string_index = df_selected_string_index.schema.names
num_cols = [item[0] for item in df_features.dtypes if item[1] != 'string']

selected_features = ip + string_index + num_cols + ['history']

df_selected_features = df_features.select(selected_features)
df_selected_features.printSchema()

root
 |-- id_orig_h_features: integer (nullable = true)
 |-- id_resp_h_features: integer (nullable = true)
 |-- is_internal_orig: integer (nullable = false)
 |-- is_internal_resp: integer (nullable = false)
 |-- common_orig_dest_pairs: long (nullable = false)
 |-- hour_of_day: integer (nullable = true)
 |-- day_of_week: integer (nullable = true)
 |-- id_orig_p_index: double (nullable = false)
 |-- id_resp_p_index: double (nullable = false)
 |-- proto_index: double (nullable = false)
 |-- conn_state_index: double (nullable = false)
 |-- label_index: double (nullable = false)
 |-- detailed_label_index: double (nullable = false)
 |-- ts: double (nullable = true)
 |-- duration: double (nullable = true)
 |-- orig_bytes: double (nullable = true)
 |-- resp_bytes: double (nullable = true)
 |-- missed_bytes: double (nullable = true)
 |-- orig_pkts: double (nullable = true)
 |-- orig_ip_bytes: double (nullable = true)
 |-- resp_pkts: double (nullable = true)
 |-- resp_ip_bytes: double (nullable 

In [29]:
# # dropping the useless columns to start performing the analysis

# to_drop_binary = ['ts', 'service', 'duration', 'orig_bytes', 'resp_bytes','local_orig', 'local_resp', 'tunnel_parents', 'detailed_label']
# to_drop_multiclass = ['ts', 'service', 'local_orig', 'local_resp', 'tunnel_parents']

# df_binary = df_features.drop(*to_drop_binary)
# df_multiclass = df_features.drop(*to_drop_multiclass)

In [30]:
from pyspark.sql.functions import isnan, when, count, col, isnull
missing = df_selected_features.select([count(when(isnull(c), c)).alias(c) for c in df_selected_features.columns]).show()

+------------------+------------------+----------------+----------------+----------------------+-----------+-----------+---------------+---------------+-----------+----------------+-----------+--------------------+---+--------+----------+----------+------------+---------+-------------+---------+-------------+------------------+------------------+----------------+----------------+----------------------+-----------+-----------+---------------+---------------+-----------+----------------+-----------+--------------------+-------+
|id_orig_h_features|id_resp_h_features|is_internal_orig|is_internal_resp|common_orig_dest_pairs|hour_of_day|day_of_week|id_orig_p_index|id_resp_p_index|proto_index|conn_state_index|label_index|detailed_label_index| ts|duration|orig_bytes|resp_bytes|missed_bytes|orig_pkts|orig_ip_bytes|resp_pkts|resp_ip_bytes|id_orig_h_features|id_resp_h_features|is_internal_orig|is_internal_resp|common_orig_dest_pairs|hour_of_day|day_of_week|id_orig_p_index|id_resp_p_index|proto_i

In [31]:
to_drop = ['duration', 'orig_bytes', 'resp_bytes']
df_selected_features = df_selected_features.drop(*to_drop)

In [33]:
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml import Pipeline
from pyspark.sql import functions as F

# Step 1: Drop rows with missing labels
df = df_selected_features.filter(F.col('history').isNotNull())

# Step 2: Convert "history" to numerical indices
history_indexer = StringIndexer(inputCol='history', outputCol='history_index')
df = history_indexer.fit(df).transform(df)

# Step 3: Prepare features and labels
features = df.schema.names
features.remove('history')
vector_assembler = VectorAssembler(inputCols=features, outputCol='features')
df = vector_assembler.transform(df)

# Step 4: Train RandomForest model
rf_classifier = RandomForestClassifier(featuresCol='features', labelCol='history_index')
pipeline = Pipeline(stages=[rf_classifier])
model = pipeline.fit(df)

# Step 5: Use the trained model to predict missing values
df_pred = model.transform(df_selected_features)


Py4JJavaError: An error occurred while calling o2283.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 3 in stage 64.0 failed 1 times, most recent failure: Lost task 3.0 in stage 64.0 (TID 1516) (192.168.1.143 executor driver): java.lang.OutOfMemoryError: Java heap space
	at org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader.<init>(UnsafeSorterSpillReader.java:50)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter.getReader(UnsafeSorterSpillWriter.java:159)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.getSortedIterator(UnsafeExternalSorter.java:555)
	at org.apache.spark.sql.execution.UnsafeExternalRowSorter.sort(UnsafeExternalRowSorter.java:172)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage13.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
	at org.apache.spark.sql.execution.window.WindowExec$$anon$1.fetchNextRow(WindowExec.scala:118)
	at org.apache.spark.sql.execution.window.WindowExec$$anon$1.<init>(WindowExec.scala:127)
	at org.apache.spark.sql.execution.window.WindowExec.$anonfun$doExecute$3(WindowExec.scala:107)
	at org.apache.spark.sql.execution.window.WindowExec$$Lambda$4731/20633237.apply(Unknown Source)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:855)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:855)
	at org.apache.spark.rdd.RDD$$Lambda$3275/24342096.apply(Unknown Source)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.sql.execution.SQLExecutionRDD.$anonfun$compute$1(SQLExecutionRDD.scala:52)
	at org.apache.spark.sql.execution.SQLExecutionRDD$$Lambda$4756/13573322.apply(Unknown Source)
	at org.apache.spark.sql.internal.SQLConf$.withExistingConf(SQLConf.scala:158)
	at org.apache.spark.sql.execution.SQLExecutionRDD.compute(SQLExecutionRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2844)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2780)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2779)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2779)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1242)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1242)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1242)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3048)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2982)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2971)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:984)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2398)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2493)
	at org.apache.spark.rdd.RDD.$anonfun$aggregate$1(RDD.scala:1225)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:407)
	at org.apache.spark.rdd.RDD.aggregate(RDD.scala:1218)
	at org.apache.spark.ml.tree.impl.DecisionTreeMetadata$.buildMetadata(DecisionTreeMetadata.scala:125)
	at org.apache.spark.ml.tree.impl.RandomForest$.run(RandomForest.scala:274)
	at org.apache.spark.ml.classification.RandomForestClassifier.$anonfun$train$1(RandomForestClassifier.scala:168)
	at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
	at org.apache.spark.ml.classification.RandomForestClassifier.train(RandomForestClassifier.scala:139)
	at org.apache.spark.ml.classification.RandomForestClassifier.train(RandomForestClassifier.scala:47)
	at org.apache.spark.ml.Predictor.fit(Predictor.scala:114)
	at org.apache.spark.ml.Predictor.fit(Predictor.scala:78)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(Unknown Source)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(Unknown Source)
	at java.lang.reflect.Method.invoke(Unknown Source)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.lang.Thread.run(Unknown Source)
Caused by: java.lang.OutOfMemoryError: Java heap space
	at org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader.<init>(UnsafeSorterSpillReader.java:50)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter.getReader(UnsafeSorterSpillWriter.java:159)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.getSortedIterator(UnsafeExternalSorter.java:555)
	at org.apache.spark.sql.execution.UnsafeExternalRowSorter.sort(UnsafeExternalRowSorter.java:172)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage13.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
	at org.apache.spark.sql.execution.window.WindowExec$$anon$1.fetchNextRow(WindowExec.scala:118)
	at org.apache.spark.sql.execution.window.WindowExec$$anon$1.<init>(WindowExec.scala:127)
	at org.apache.spark.sql.execution.window.WindowExec.$anonfun$doExecute$3(WindowExec.scala:107)
	at org.apache.spark.sql.execution.window.WindowExec$$Lambda$4731/20633237.apply(Unknown Source)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:855)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:855)
	at org.apache.spark.rdd.RDD$$Lambda$3275/24342096.apply(Unknown Source)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.sql.execution.SQLExecutionRDD.$anonfun$compute$1(SQLExecutionRDD.scala:52)
	at org.apache.spark.sql.execution.SQLExecutionRDD$$Lambda$4756/13573322.apply(Unknown Source)
	at org.apache.spark.sql.internal.SQLConf$.withExistingConf(SQLConf.scala:158)
	at org.apache.spark.sql.execution.SQLExecutionRDD.compute(SQLExecutionRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)


ERROR:root:Exception while sending command.
Traceback (most recent call last):
  File "c:\Users\Vincenzo\Projects\DDAM_Project_23-24\venv\Lib\site-packages\py4j\clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Vincenzo\AppData\Local\Programs\Python\Python312\Lib\socket.py", line 707, in readinto
    return self._sock.recv_into(b)
           ^^^^^^^^^^^^^^^^^^^^^^^
ConnectionResetError: [WinError 10054] An existing connection was forcibly closed by the remote host

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "c:\Users\Vincenzo\Projects\DDAM_Project_23-24\venv\Lib\site-packages\py4j\java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Vincenzo\Projects\DDAM_Project_23-24\venv\Lib\site-packages\py4j\client

In [None]:
# from pyspark.sql.functions import col

# # Assume 'prediction' is the column containing the predicted values
# # and 'history' is the original column with missing values

# # Select only the relevant columns for substitution
# df_substitute = df_pred.select('history_index', 'prediction')

# # Replace the missing values in the 'history' column with the predicted values
# df_result = df_binary.join(df_substitute, on='history_index', how='left_outer') \
#     .withColumn('history', F.when(col('history').isNotNull(), col('history')).otherwise(col('prediction'))) \
#     .drop('history_index', 'prediction')

# # Show the resulting DataFrame
# df_result.show()


## Data understanding

### dimension and content analysis

In [None]:
# rows, cols = df.count(), len(df.columns)
# print(f'Dimension of the whole Dataframe is: {(rows,cols)}')

# rows, cols = df_binary.count(), len(df_binary.columns)
# print(f'Dimension of the binary Dataframe is: {(rows,cols)}')

# rows, cols = df_multiclass.count(), len(df_multiclass.columns)
# print(f'Dimension of the multiclass Dataframe is: {(rows,cols)}')

In [None]:
# num_cols = [item[0] for item in df.dtypes if item[1] != 'string']
# print('Le colonne numeriche sono {}'.format(len(num_cols)))
# print(num_cols)

In [None]:
# non_num_cols = [item[0] for item in df.dtypes if item[1] == 'string']
# print('Le colonne non numeriche sono {}'.format(len(non_num_cols)))
# print(non_num_cols)

### Check for missing labels for each file to understandand labelling consistency

In [None]:
# from pyspark.sql.functions import isnan, when, count, col, isnull


# def check_nan(spark, column_to_check="", label_only=False):
#     # List of file numbers
#     file_numbers = [1, 3, 9, 20, 21, 34, 35, 42, 44, 48, 60]
    
#     for number in tqdm(file_numbers, desc="Processing files", unit="file"):
    
#         file_path = r"C:\Users\Vincenzo\Projects\DDAM_data\malware\CTU-IoT-Malware-Capture-{}-1conn.log.labeled.csv".format(number)
#         df = spark.read.option("escape", "\"").option("delimiter", "|").csv(file_path, header='true', inferSchema='true')
#         df = df.toDF(*(c.replace('.', '_') for c in df.columns))

#         print("", flush=True)
#         # Check for missing values
#         if label_only:
#             missing = df.select([count(when(isnull("detailed-label"), "detailed-label")).alias("missing_count")])
#             print("Missing values in file {}: ".format(number))
#             missing.select("missing_count").show()

#         elif column_to_check:
#             missing = df.select([count(when(isnull(column_to_check), column_to_check)).alias("missing_count")])
#             print("Missing values in file {}: ".format(number))
#             missing.select("missing_count").show()

#         elif column_to_check == "all":
#             missing = df.select([count(when(isnull(c), c)).alias(c) for c in df.columns])
#             print("Missing values in file {}: ".format(number))
#             missing.show()

#         else:
#             print("Please enter a valid column name or enter True to check the label column only")

# # Check for missing values in the label column only
# check_nan(spark, label_only=True)


### Check for distinct values and raw statistics

In [None]:
# from pyspark.sql.functions import col, countDistinct

# non_num_df = df.select(non_num_cols)

# def count_distinct_values(df):
#     result = {}
#     columns = df.columns

#     # Use tqdm to create a progress bar
#     for column in tqdm(columns, desc="Counting Distinct Values", unit="column"):
#         distinct_count = df.select(column).agg(countDistinct(column)).collect()[0][0]
#         result[column] = distinct_count

#     return result

# distinct_counts = count_distinct_values(non_num_df)

# for column, count in distinct_counts.items():
#     print(f"Column '{column}' has {count} distinct values.")

In [None]:
# # Use the describe function to get statistical summary

# summary = df.describe()
# summary.show()


In [None]:
# # num_summary = summary.select(*num_cols)
# num_summary = summary.select(*(["summary"] + num_cols))
# # non_num_summary = summary.select(*non_num_cols)
# non_num_summary = summary.select(*(["summary"] + non_num_cols))

# print("Summary of numeric columns:")
# num_summary.show()

# print("Summary of non-numeric columns:")
# non_num_summary.show()