In [1]:
import findspark
findspark.init()
import pyspark
from pyspark.sql import SparkSession
from pyspark.conf import SparkConf
from pyspark import SparkContext

from pyspark.sql import functions as f

conf = SparkConf()\
    .setMaster("local[*]")\
    .setAppName("Working with DF")\
    .setExecutorEnv("spark.driver.memory","2g")\
    .setExecutorEnv("spark.executor.memory","4g")

spark = SparkSession\
    .builder\
    .config(conf=conf)\
    .getOrCreate()

spark

In [1]:
##### data source: http://kdd.ics.uci.edu/databases/kddcup99/kddcup99.html

In [2]:
# we can read zip file with read.text() method
df = spark.read.text('data/kddcup.data.gz')

In [3]:
# let's explore schema
df.printSchema()

root
 |-- value: string (nullable = true)



In [4]:
df.show()

+--------------------+
|               value|
+--------------------+
|0,tcp,http,SF,215...|
|0,tcp,http,SF,162...|
|0,tcp,http,SF,236...|
|0,tcp,http,SF,233...|
|0,tcp,http,SF,239...|
|0,tcp,http,SF,238...|
|0,tcp,http,SF,235...|
|0,tcp,http,SF,234...|
|0,tcp,http,SF,239...|
|0,tcp,http,SF,181...|
|0,tcp,http,SF,184...|
|0,tcp,http,SF,185...|
|0,tcp,http,SF,239...|
|0,tcp,http,SF,181...|
|0,tcp,http,SF,236...|
|0,tcp,http,SF,233...|
|0,tcp,http,SF,238...|
|0,tcp,http,SF,235...|
|0,tcp,http,SF,234...|
|0,tcp,http,SF,239...|
+--------------------+
only showing top 20 rows



In [11]:


split_col = f.split(df['value'], ',')

df = df.withColumn('protocol', split_col.getItem(1))\
       .withColumn('service', split_col.getItem(2))\
       .withColumn('flag', split_col.getItem(3))\
       .withColumn('src_bytes', split_col.getItem(4))\
       .withColumn('dst_bytes', split_col.getItem(5))\
       .withColumn('urgent', split_col.getItem(8))\
       .withColumn('num_failed_logins', split_col.getItem(10))\
       .withColumn('root_shell', split_col.getItem(13))\
       .withColumn('guest_login', split_col.getItem(21))\
       .withColumn('label', split_col.getItem(41))\
       .drop('value') 

df.show()
       

+--------+-------+----+---------+---------+------+-----------------+----------+-----------+-------+
|protocol|service|flag|src_bytes|dst_bytes|urgent|num_failed_logins|root_shell|guest_login|  label|
+--------+-------+----+---------+---------+------+-----------------+----------+-----------+-------+
|     tcp|   http|  SF|      215|    45076|     0|                0|         0|          0|normal.|
|     tcp|   http|  SF|      162|     4528|     0|                0|         0|          0|normal.|
|     tcp|   http|  SF|      236|     1228|     0|                0|         0|          0|normal.|
|     tcp|   http|  SF|      233|     2032|     0|                0|         0|          0|normal.|
|     tcp|   http|  SF|      239|      486|     0|                0|         0|          0|normal.|
|     tcp|   http|  SF|      238|     1282|     0|                0|         0|          0|normal.|
|     tcp|   http|  SF|      235|     1337|     0|                0|         0|          0|normal.|


In [12]:
df.rdd.getNumPartitions()

1

In [13]:
df = df.repartition(10)

In [14]:
df.rdd.getNumPartitions()

10

In [15]:
df.createOrReplaceTempView('d_kdd_cup')

### 1. Count number of connections for each label

In [16]:
df.groupBy('label').count().orderBy('count', ascending=False).show()

+----------------+-------+
|           label|  count|
+----------------+-------+
|          smurf.|2807886|
|        neptune.|1072017|
|         normal.| 972781|
|          satan.|  15892|
|        ipsweep.|  12481|
|      portsweep.|  10413|
|           nmap.|   2316|
|           back.|   2203|
|    warezclient.|   1020|
|       teardrop.|    979|
|            pod.|    264|
|   guess_passwd.|     53|
|buffer_overflow.|     30|
|           land.|     21|
|    warezmaster.|     20|
|           imap.|     12|
|        rootkit.|     10|
|     loadmodule.|      9|
|      ftp_write.|      8|
|       multihop.|      7|
+----------------+-------+
only showing top 20 rows



2. Get the list of `Protocols` that are `normal` and `vulnerable to attacks`, where there is NOT `guest login` to the destination addresses

In [21]:
query = '''

    select protocol,
           case label
               when 'normal.' then 'no_attack'
               else 'attack'
            end as state,
            count(*) as freq
    from    d_kdd_cup
    where   guest_login != '1'
    group by protocol, 
             case label
               when 'normal.' then 'no_attack'
               else 'attack'
             end 
    order by protocol desc

'''

spark.sql(query).show()



+--------+---------+-------+
|protocol|    state|   freq|
+--------+---------+-------+
|     udp|no_attack| 191348|
|     udp|   attack|   2940|
|     tcp|no_attack| 764894|
|     tcp|   attack|1101613|
|    icmp|no_attack|  12763|
|    icmp|   attack|2820782|
+--------+---------+-------+



### 3. Apply Some Descriptive Statistics on Numerical Data

In [23]:
summary = df.select(
    f.mean(df['src_bytes']).alias('mean_src_bytes'),
    f.min(df['src_bytes']).alias('min_src_bytes'),
    f.max(df['src_bytes']).alias('max_src_bytes'),
    f.count(df['src_bytes']).alias('count_src_bytes'),
    f.stddev(df['src_bytes']).alias('std_src_bytes'),
    f.skewness(df['src_bytes']).alias('skew_src_bytes')  
    )

summary.show()

+------------------+-------------+-------------+---------------+-----------------+------------------+
|    mean_src_bytes|min_src_bytes|max_src_bytes|count_src_bytes|    std_src_bytes|    skew_src_bytes|
+------------------+-------------+-------------+---------------+-----------------+------------------+
|1834.6211752293746|            0|          999|        4898431|941431.0744911298|1188.9519100465736|
+------------------+-------------+-------------+---------------+-----------------+------------------+



In [25]:
df.groupBy('protocol')\
    .agg({
        'src_bytes': 'mean',
        'dst_bytes': 'stddev'
    }).show()

+--------+-----------------+------------------+
|protocol|   avg(src_bytes)| stddev(dst_bytes)|
+--------+-----------------+------------------+
|     tcp|3388.569965326596|1043771.3100418103|
|     udp|97.22772893848308| 55.43318653434132|
|    icmp|927.8916893855577|               0.0|
+--------+-----------------+------------------+



### 4. A Descriptive Stats based on `Protocols` and `Labels`

In [28]:
query2 = '''

    select protocol,
        case label
            when 'normal.' then 'no_attack'
            else 'attack'
            end as state,
        count(*) as total_freq,
        round(avg(src_bytes),2) as mean_src_bytes,
        round(avg(dst_bytes),2) as mean_dst_bytes,
        sum(urgent) as sum_urgent,
        sum(num_failed_logins) as sum_num_failed_logins,
        sum(root_shell) as sum_root_shell,
        sum(guest_login) as sum_guest_login
    from d_kdd_cup
    group by protocol, state

'''

spark.sql(query2).show()

+--------+---------+----------+--------------+--------------+----------+---------------------+--------------+---------------+
|protocol|    state|total_freq|mean_src_bytes|mean_dst_bytes|sum_urgent|sum_num_failed_logins|sum_root_shell|sum_guest_login|
+--------+---------+----------+--------------+--------------+----------+---------------------+--------------+---------------+
|     tcp|no_attack|    768670|       1844.29|       4071.32|      35.0|                 96.0|         302.0|         3776.0|
|     udp|no_attack|    191348|         98.32|         89.41|       0.0|                  0.0|           0.0|            0.0|
|     udp|   attack|      2940|          26.4|          0.82|       0.0|                  0.0|           0.0|            0.0|
|    icmp|no_attack|     12763|         90.68|           0.0|       0.0|                  0.0|           0.0|            0.0|
|     tcp|   attack|   1101928|       4465.81|       2005.96|       4.0|                 61.0|          32.0|         

### 5. Get the frquency of `services` for the original `UDP and ICMP` based `attacks`
(original attacks: `[dos, u2r, r2l, probe]`)

(returns `services` and `protocols` center justified)

In [33]:
from pyspark.sql.types import StringType

def attack_category(conn):
    conn = conn.replace('.','')
    if conn in ['back', 'land', 'neptune', 'pod', 'smurf', 'teardrop']:
        return 'DoS'
    elif conn in ['buffer_overflow', 'load_module', 'perl', 'rootkit']:
        return 'U2R'
    elif conn in ['ftp_write', 'guess_password', 'multihop', 'phf', 'spy', 'warezclient', 'warezmaster']:
        return 'R2L'
    else:
        return 'probs'

def center_justify(item):
    return item.center(10)

spark.udf.register('get_original_attack', attack_category, StringType())
spark.udf.register('center_justify', center_justify, StringType())

query3 = '''

    select center_justify(service) as service,
           center_justify(protocol) as protocol,
           get_original_attack(label) as original_label,
           count(*) as freq
           from d_kdd_cup
           where protocol in ('udp', 'icmp')
           and label != 'normal.'
           group by service, original_label, protocol
           order by freq desc

'''

spark.sql(query3).show()

+----------+----------+--------------+-------+
|   service|  protocol|original_label|   freq|
+----------+----------+--------------+-------+
|  ecr_i   |   icmp   |           DoS|2808145|
|  eco_i   |   icmp   |         probs|  12570|
| private  |   udp    |         probs|   1688|
| private  |   udp    |           DoS|    979|
|  other   |   udp    |         probs|    261|
|  ecr_i   |   icmp   |         probs|     59|
| domain_u |   udp    |         probs|      9|
|  tim_i   |   icmp   |           DoS|      5|
|  other   |   udp    |           U2R|      3|
|  urp_i   |   icmp   |         probs|      3|
+----------+----------+--------------+-------+

