In [1]:
from pyspark.sql.types import *
from pyspark.ml import Pipeline
from pyspark.sql import functions as fn
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.mllib.evaluation import BinaryClassificationMetrics, MulticlassMetrics
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
# from spark_stratifier import StratifiedCrossValidator

In [2]:
rf_TFG_schema = StructType([
 StructField("totalDestinationBytes",IntegerType(),False),
 StructField("totalDestinationPackets",IntegerType(),False),
 StructField("totalSourceBytes",IntegerType(),False),
 StructField("totalSourcePackets",IntegerType(),False),
 StructField("timeLength",IntegerType(),False),
 StructField("sourceByteRate",DoubleType(),False),
 StructField("destinationByteRate",DoubleType(),False),
 StructField("sourcePacketRate",DoubleType(),False),
 StructField("destinationPacketRate",DoubleType(),False),
 StructField("avgSourcePacketSize",DoubleType(),False),
 StructField("avgDestinationPacketSize",DoubleType(),False),
 StructField("totalDestinationBytesDiffMedianScal",DoubleType(),False),
 StructField("totalDestinationPacketsDiffMedianScal",DoubleType(),False),
 StructField("totalSourceBytesDiffMedianScal",DoubleType(),False),
 StructField("totalSourcePacketsDiffMedianScal",DoubleType(),False),
 StructField("timeLengthDiffMedianScal",DoubleType(),False),
 StructField("avgDestinationPacketSizeDiffMedianScal",DoubleType(),False),
 StructField("avgSourcePacketSizeDiffMedianScal",DoubleType(),False),
 StructField("destinationByteRateDiffMedianScal",DoubleType(),False),
 StructField("destinationPacketRateDiffMedianScal",DoubleType(),False),
 StructField("sourceByteRateDiffMedianScal",DoubleType(),False),
 StructField("sourcePacketRateDiffMedianScal",DoubleType(),False),
 StructField("protocolName_tcp_ip",IntegerType(),False),
 StructField("protocolName_udp_ip",IntegerType(),False),
 StructField("sourceTCPFlag_N/A",IntegerType(),False),
 StructField("sourceTCPFlag_S",IntegerType(),False),
 StructField("destinationResume_external",IntegerType(),False),
 StructField("destinationResume_mainServer",IntegerType(),False),
 StructField("Tag",StringType(),False)
])

In [3]:
# dataset = (spark.read.schema(rf_TFG_schema).option("header", "true").csv('/FileStore/tables/test'))

In [4]:
dataset = (spark.read.schema(rf_TFG_schema).option("header", "true").csv('/FileStore/tables/TFG'))

In [5]:
display(dataset.groupBy(dataset.Tag).count())

Tag,count
Attack,66795
Normal,66795


In [6]:
all_columns = dataset.columns
features_columns = all_columns.copy()
features_columns.remove('Tag')
stages = []
labelIndexer = StringIndexer(inputCol="Tag", outputCol="label").setHandleInvalid("skip")
assembler = VectorAssembler(inputCols=features_columns, outputCol="features")
stages += [labelIndexer, assembler]

In [7]:
rf = RandomForestClassifier(labelCol="label", featuresCol="features", numTrees=10, featureSubsetStrategy="12")
stages += [rf]
pipeline = Pipeline().setStages(stages)
evaluator = BinaryClassificationEvaluator()
paramGrid = ParamGridBuilder().addGrid(rf.numTrees, [700]).build()
crossval = CrossValidator(estimator=pipeline, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=10)

In [8]:
cvModel = crossval.fit(dataset)

In [9]:
cvModel.bestModel.write().overwrite().save('/TFG/result')

In [10]:
predictions = cvModel.transform(dataset)

In [11]:
selected = predictions.select("label", "prediction", "probability")
display(selected)

label,prediction,probability
1.0,1.0,"List(1, 2, List(), List(0.0062006345664021995, 0.9937993654335979))"
1.0,1.0,"List(1, 2, List(), List(0.006529694464140986, 0.9934703055358589))"
1.0,1.0,"List(1, 2, List(), List(0.008056439753508653, 0.9919435602464913))"
1.0,1.0,"List(1, 2, List(), List(0.01009896874965673, 0.9899010312503432))"
1.0,1.0,"List(1, 2, List(), List(0.00647338536003643, 0.9935266146399635))"
1.0,1.0,"List(1, 2, List(), List(0.006785411638027746, 0.9932145883619723))"
1.0,1.0,"List(1, 2, List(), List(0.006315927883009262, 0.9936840721169907))"
1.0,1.0,"List(1, 2, List(), List(0.006327667550694511, 0.9936723324493055))"
1.0,1.0,"List(1, 2, List(), List(0.006372848979940616, 0.9936271510200594))"
1.0,1.0,"List(1, 2, List(), List(0.0062006345664021995, 0.9937993654335979))"


In [12]:
evaluator.evaluate(predictions)