#### Load dataset

In [1]:
import pyspark
from pyspark.sql import SparkSession
from pyspark import SparkContext, SQLContext
from pyspark.ml import Pipeline,Transformer
from pyspark.ml.feature import Imputer,StandardScaler,StringIndexer,OneHotEncoder, VectorAssembler
from pyspark.ml.classification import LogisticRegression

from pyspark.sql.functions import *
from pyspark.sql.types import *
import numpy as np

appName = "SparkML-test"
master = "yarn"

# # Create Configuration object for Spark.
# conf = pyspark.SparkConf()\
# .set('spark.driver.host','127.0.0.1')\
# .setAppName(appName)\
# .setMaster(master)

# # Create Spark Context 
# sc = SparkContext.getOrCreate(conf=conf)

# # Create SQL Context to conduct some database operations
# sqlContext = SQLContext(sc)

# # If you have SQL context, you create the session from the Spark Context
# spark = sqlContext.sparkSession.builder\
# .config("spark.jars.packages", "org.postgresql:postgresql:42.2.29") \
# .config("spark.sql.execution.arrow.enabled","true")\
# .getOrCreate() 

spark = SparkSession.builder \
        .master(master) \
        .appName(appName) \
        .getOrCreate()
        

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/10/14 22:20:41 INFO SparkEnv: Registering MapOutputTracker
24/10/14 22:20:42 INFO SparkEnv: Registering BlockManagerMaster
24/10/14 22:20:42 INFO SparkEnv: Registering BlockManagerMasterHeartbeat
24/10/14 22:20:42 INFO SparkEnv: Registering OutputCommitCoordinator


In [2]:
col_names = ["duration","protocol_type","service","flag","src_bytes",
"dst_bytes","land","wrong_fragment","urgent","hot","num_failed_logins",
"logged_in","num_compromised","root_shell","su_attempted","num_root",
"num_file_creations","num_shells","num_access_files","num_outbound_cmds",
"is_host_login","is_guest_login","count","srv_count","serror_rate",
"srv_serror_rate","rerror_rate","srv_rerror_rate","same_srv_rate",
"diff_srv_rate","srv_diff_host_rate","dst_host_count","dst_host_srv_count",
"dst_host_same_srv_rate","dst_host_diff_srv_rate","dst_host_same_src_port_rate",
"dst_host_srv_diff_host_rate","dst_host_serror_rate","dst_host_srv_serror_rate",
"dst_host_rerror_rate","dst_host_srv_rerror_rate","class","difficulty"]

# Split train and test dataset
data_path = 'gs://kdd-dataset/'

nslkdd_raw = spark.read.csv(data_path + 'KDDTrain+.txt',header=False).toDF(*col_names)
nslkdd_test_raw = spark.read.csv(data_path + 'KDDTest+.txt',header=False).toDF(*col_names)

                                                                                

#### Data preprocessing

In [3]:
from pyspark.sql.types import IntegerType

nominal_cols = ['protocol_type','service','flag']
binary_cols = ['land', 'logged_in', 'root_shell', 'su_attempted', 'is_host_login',
'is_guest_login']
continuous_cols = ['duration' ,'src_bytes', 'dst_bytes', 'wrong_fragment' ,'urgent', 'hot',
'num_failed_logins', 'num_compromised', 'num_root' ,'num_file_creations',
'num_shells', 'num_access_files', 'num_outbound_cmds', 'count' ,'srv_count',
'serror_rate', 'srv_serror_rate' ,'rerror_rate' ,'srv_rerror_rate',
'same_srv_rate', 'diff_srv_rate', 'srv_diff_host_rate' ,'dst_host_count',
'dst_host_srv_count' ,'dst_host_same_srv_rate' ,'dst_host_diff_srv_rate',
'dst_host_same_src_port_rate' ,'dst_host_srv_diff_host_rate',
'dst_host_serror_rate' ,'dst_host_srv_serror_rate', 'dst_host_rerror_rate',
'dst_host_srv_rerror_rate']

class OutcomeCreater(Transformer): # this defines a transformer that creates the outcome column
    
    def __init__(self):
        super().__init__()

    def _transform(self, dataset):
        def attack_category(attack_type):
            if attack_type == 'normal':
                return 0
            elif attack_type in ['port-Sweep', 'ip-Sweep', 'nmap', 'satan', 'saint', 'mscan']:
                return 1  # Probing
            elif attack_type in ['neptune', 'smurf', 'pod', 'teardrop', 'land', 'back', 'apache2',
                                'udpstorm', 'processtable', 'mail-Bomb']:
                return 2  # Dos
            elif attack_type in ['buffer-Overflow', 'load-Module', 'perl', 'rootkit', 'xterm',
                                'ps', 'sqlattack']:
                return 3  # U2R
            else:
                return 4  # R2L
          
        # Convert the function to a UDF, specifying IntegerType for output
        label_to_multiclasses = udf(attack_category, IntegerType())
        output_df = dataset.withColumn('outcome', label_to_multiclasses(col('class'))).drop("class")  
        output_df = output_df.withColumn('outcome', col('outcome').cast(DoubleType()))
        output_df = output_df.drop('difficulty')
        return output_df

class FeatureTypeCaster(Transformer): # this transformer will cast the columns as appropriate types  
    def __init__(self):
        super().__init__()

    def _transform(self, dataset):
        output_df = dataset
        for col_name in binary_cols + continuous_cols:
            output_df = output_df.withColumn(col_name,col(col_name).cast(DoubleType()))

        return output_df
    
class ColumnDropper(Transformer): # this transformer drops unnecessary columns
    def __init__(self, columns_to_drop = None):
        super().__init__()
        self.columns_to_drop=columns_to_drop
    def _transform(self, dataset):
        output_df = dataset
        for col_name in self.columns_to_drop:
            output_df = output_df.drop(col_name)
        return output_df
    
def get_preprocess_muticlass_pipeline():
    # Stage where columns are casted as appropriate types
    stage_typecaster = FeatureTypeCaster()

    # Stage where nominal columns are transformed to index columns using StringIndexer
    nominal_id_cols = [x+"_index" for x in nominal_cols]
    nominal_onehot_cols = [x+"_encoded" for x in nominal_cols]
    stage_nominal_indexer = StringIndexer(inputCols = nominal_cols, outputCols = nominal_id_cols )

    # Stage where the index columns are further transformed using OneHotEncoder
    stage_nominal_onehot_encoder = OneHotEncoder(inputCols=nominal_id_cols, outputCols=nominal_onehot_cols)

    # Stage where all relevant features are assembled into a vector (and dropping a few)
    feature_cols = continuous_cols+binary_cols+nominal_onehot_cols
    corelated_cols_to_remove = ["dst_host_serror_rate","srv_serror_rate","dst_host_srv_serror_rate",
                     "srv_rerror_rate","dst_host_rerror_rate","dst_host_srv_rerror_rate"]
    for col_name in corelated_cols_to_remove:
        feature_cols.remove(col_name)
    stage_vector_assembler = VectorAssembler(inputCols=feature_cols, outputCol="vectorized_features")

    # Stage where we scale the columns
    stage_scaler = StandardScaler(inputCol= 'vectorized_features', outputCol= 'features')
    

    # Stage for creating the outcome column representing whether there is normal, DOS, R2L, U2R, probing.
    stage_outcome = OutcomeCreater()

    # Removing all unnecessary columbs, only keeping the 'features' and 'outcome' columns
    stage_column_dropper = ColumnDropper(columns_to_drop = nominal_cols+nominal_id_cols+
        nominal_onehot_cols+ binary_cols + continuous_cols + ['vectorized_features'])
    
    # Connect the columns into a pipeline
    pipeline = Pipeline(stages=[stage_typecaster,stage_nominal_indexer,stage_nominal_onehot_encoder,
        stage_vector_assembler,stage_scaler,stage_outcome,stage_column_dropper])
    return pipeline 

In [4]:
preprocess_multi_class_pipeline = get_preprocess_muticlass_pipeline()
preprocess_multi_class = preprocess_multi_class_pipeline.fit(nslkdd_raw)

# Trandform train dataset
nslkdd_multi = preprocess_multi_class.transform(nslkdd_raw)
nslkdd_multi.show(3, vertical=True)

24/10/14 22:21:24 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
[Stage 8:>                                                          (0 + 1) / 1]

-RECORD 0------------------------
 features | (113,[1,13,14,17,... 
 outcome  | 0.0                  
-RECORD 1------------------------
 features | (113,[1,13,14,17,... 
 outcome  | 0.0                  
-RECORD 2------------------------
 features | (113,[13,14,15,17... 
 outcome  | 2.0                  
only showing top 3 rows



                                                                                

In [5]:
# Trandform test datset 
nslkdd_test_multi = preprocess_multi_class.transform(nslkdd_test_raw)
nslkdd_test_multi.show(3, vertical=True)

[Stage 9:>                                                          (0 + 1) / 1]

-RECORD 0------------------------
 features | (113,[13,14,16,17... 
 outcome  | 2.0                  
-RECORD 1------------------------
 features | (113,[13,14,16,17... 
 outcome  | 2.0                  
-RECORD 2------------------------
 features | (113,[0,1,13,14,1... 
 outcome  | 0.0                  
only showing top 3 rows



                                                                                

#### Machine Learning process

In [7]:
# Create a logistic regression model
lr = LogisticRegression(featuresCol = 'features', 
                             labelCol = 'outcome', 
                             maxIter=10)

# Fit the model
lrModel = lr.fit(nslkdd_multi)

# Calculate a train accuracy
predictions_train = lrModel.transform(nslkdd_multi)
accuracy_train = (predictions_train.filter(predictions_train.outcome == predictions_train.prediction)
    .count() / float(predictions_train.count()))
# Calculate a test accuracy
lr_predictions_test = lrModel.transform(nslkdd_test_multi)
accuracy_test = (lr_predictions_test.filter(lr_predictions_test.outcome == lr_predictions_test.prediction)
    .count() / float(lr_predictions_test.count()))

print(f"Train Accuracy : {np.round(accuracy_train*100,2)}%")
print(f"Test Accuracy : {np.round(accuracy_test*100,2)}%")

                                                                                

Train Accuracy : 96.98%
Test Accuracy : 71.9%


#### Print the number of partitions of Dataframe

In [6]:
nslkdd_multi.cache()
nslkdd_test_multi.cache()

DataFrame[features: vector, outcome: double]

In [9]:
num_partitions = nslkdd_multi.rdd.getNumPartitions()
print(f"Number of partitions: {num_partitions}")

Number of partitions: 2


In [10]:
spark.stop()