# Machine Learning Jupyter Notebook with SHAP and LIME

# 1. Import Libraries

In [None]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import OneHotEncoder, LabelEncoder, StandardScaler, MinMaxScaler
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset
import os
import time
from sklearn.utils.class_weight import compute_class_weight
from tqdm import tqdm
import random
import shap
from sklearn.metrics import accuracy_score, f1_score

In [None]:
def plot_metrics(train, val, epochs, title):
    """
    Plot the training, validation, and test metrics.

    Parameters:
        - train_loss (list): Training metric
        - val_loss (list): Validation metric
        - test_loss (list): Test tmetric
        - epochs (int): The number of epochs.
    """
    sns.set_theme(style="whitegrid", palette="pastel")
    plt.figure(figsize=(8, 8))
    plt.plot(range(1, epochs + 1), train, label='Training', color=sns.color_palette("pastel")[0])
    plt.plot(range(1, epochs + 1), val, label='Validation', color=sns.color_palette("pastel")[1])
    #plt.plot(range(1, epochs + 1), test, label='Test', color=sns.color_palette("pastel")[2])
    plt.xlabel('Epochs')
    plt.ylabel(f"{title}")
    plt.title(f"{title}")
    plt.legend()
    plt.show()

# Sample Dataset.
This section loads the data and creates a train, validation, and test set. Because of the large amount of Normal values, a downsampling technique was used

In [None]:
class DataSampler:
    """
    DataSampler: Manages sampling and splitting of the USNW-NB15 dataset. Returns a training, validation, and test set.

    Initialisation:
        - train (None): The attributes that stores the train set.
        - val (None): The attributes that stores the validation set.
        - test (None): The attributes that stores the test set.
    """

    def __init__(self):
        self.train = None
        self.val = None
        self.test = None

    def sample_data(self, data, size, rs):
        """
        sample_data: Combines the USNW-NB15 1, 2, 3, 4 sets, does some basic cleaning, Use a random number generator to take slices for the validation
        and test set from the main data then splits into train, validation, and test set. Downsamples the data so that Normal labels match Generic in each set.
        Uses a mask to take only Normal values that do not precede a Non-Normal value.

        Parameters:
            - data (string): The dataset to use (currently only USNW-NB15 is supported).
            - type_of (string):  The type of sampling.
            - rs (int): The random seed for slicing operation.

        NOTE: Includes print outs to verify the process.
        """

        if data == 'full_data':
            feature_names = pd.read_csv('features2.csv' )

            #print('Set1')
            feature_names_list = feature_names['Name'].tolist()
            s1 = pd.read_csv('UNSW-NB15_1.csv', header=None)
            s1.columns = feature_names_list
            s1.loc[s1['attack_cat'].isnull(), 'attack_cat'] = 'Normal'
            #print(s1['attack_cat'].value_counts())
            #print('Set2')
            feature_names_list = feature_names['Name'].tolist()
            s2 = pd.read_csv('UNSW-NB15_2.csv', header=None)
            s2.columns = feature_names_list
            s2.loc[s2['attack_cat'].isnull(), 'attack_cat'] = 'Normal'
            #print(s2['attack_cat'].value_counts())
            #print('Set3')
            feature_names_list = feature_names['Name'].tolist()
            s3 = pd.read_csv('UNSW-NB15_3.csv', header=None)
            s3.columns = feature_names_list
            s3.loc[s3['attack_cat'].isnull(), 'attack_cat'] = 'Normal'
            #print(s3['attack_cat'].value_counts())
            #print('Set4')
            feature_names_list = feature_names['Name'].tolist()
            s4 = pd.read_csv('UNSW-NB15_4.csv', header=None)
            s4.columns = feature_names_list
            s4.loc[s4['attack_cat'].isnull(), 'attack_cat'] = 'Normal'
            #print(s4['attack_cat'].value_counts())
            data = [s1, s2, s3, s4]
            i = 0

            while i < len(data):
                df = data[i]
                #print(f"Set{i + 1}")
                normal = df[df['attack_cat'] == 'Normal'].shape[0]
                generic = df[df['attack_cat'] == 'Generic'].shape[0]
                difference = normal - generic
                mask = df['Label'].shift(-1) != 1
                rows = df[(df['attack_cat'] == 'Normal') & mask]
                downsampled = rows.sample(n=difference, random_state=rs)
                df2 = df.drop(downsampled.index)
                #print(df2['attack_cat'].value_counts())
                data[i] = df2
                i += 1
            # Clean Labels.
            full_data = pd.concat([data[0], data[1], data[2], data[3]]).reset_index(drop=True)
            full_data['attack_cat'] = full_data['attack_cat'].str.replace(r'\s+', '', regex=True)
            full_data['attack_cat'] = full_data['attack_cat'].str.replace('Backdoors', 'Backdoor')
            # Drop Sparse Data.
            # NOTE: These values could also be transformed by adding 1 and setting nulls to 0.
            full_data = full_data.drop(columns=['ct_ftp_cmd', 'ct_flw_http_mthd', 'is_ftp_login'])
            # Remove error values created by nulls.
            full_data['sport'] = full_data['sport'].apply(pd.to_numeric)
            full_data = full_data[~full_data['dsport'].astype(str).str.startswith('0x')]
            full_data['dsport'] = full_data['dsport'].apply(pd.to_numeric)

            slice_size = int(size * len(full_data))
            val_start = random.randrange(0, len(full_data) - 2 * slice_size)
            val_end = val_start + slice_size
            val_data = full_data.iloc[val_start:val_end]
            df = full_data.drop(val_data.index)
            test_start = random.randrange(0, len(df) - slice_size)
            test_end = test_start + slice_size
            test_data = df.iloc[test_start:test_end]
            train_data = df.drop(test_data.index)
            #print('Train')
            #print(len(train_data))
            #print('Val')
            #print(len(val_data))
            #print('Test')
            #print(len(test_data))
            
            return train_data, val_data, test_data

# 3. Explore Data

# 4. Preprocess Data
# Handle missing values, encode categorical variables, etc.
# Example: Fill missing values with the mean or else defined methods.

# Split data into features and target

# Standardize features

# 5. Train Model

Model 1 For Exmaple LSTM

Model 2 For Exmaple Ramdom Forest

# 6. Evaluate Model

# 7. Make Predictions
# Example: Predict on new data
# Replace 'data.csv' with the path to your new data

# 8. Visualize Results
# Example: Plot feature importances

# 9. Model Interpretability with SHAP

# Summary plot

# 10. Model Interpretability with LIME