# Multi-Class Classification with Machine Learning
In this notebook, we will explore various machine learning models to solve a multi-class classification problem. We will evaluate and compare the performance of different algorithms on the dataset.


In [1]:
import ast
import random
from collections import Counter
from pathlib import Path

import joblib
import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.multiclass import OneVsRestClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.tree import DecisionTreeClassifier
from skmultilearn.model_selection import iterative_train_test_split
from tqdm import tqdm

from preprocess_functions import build_tree, extract_keys, preprocess_texts
from utils import CalibratedLabelRankClassifier, ChainOfClassifiers, LabelPowersetClassifier, \
    assess_models, prune_and_subsample, ConditionalDependencyNetwork, MetaBinaryRelevance


In [2]:
OVERWRITE = False
RANDOM_STATE = 42

np.random.seed(RANDOM_STATE)
random.seed(RANDOM_STATE)


In [3]:
INIT_POINTS = 1
N_ITER = 5
TEST_SIZE = 2e-1

BASE_CLASSIFIERS = {
    'logistic_regression': LogisticRegression(solver='liblinear', random_state=RANDOM_STATE),
    'gaussian_nb': GaussianNB(),
    'decision_tree': DecisionTreeClassifier(random_state=RANDOM_STATE),
    'random_forest': RandomForestClassifier(random_state=RANDOM_STATE),
    'xgb': xgb.XGBClassifier(random_state=RANDOM_STATE),
}

COLAB_PATH = Path('/content/drive/MyDrive')
KAGGLE_PATH = Path('/kaggle/input')
LOCAL_PATH = Path('./')

# Step 1: Check if running in Google Colab
try:
    import google.colab

    DATA_PATH = COLAB_PATH / Path('data')
    MODELS_PATH = COLAB_PATH / Path('models')
except ImportError:
    # Step 2: Check if running in Kaggle
    try:
        import kaggle_secrets

        DATA_PATH = KAGGLE_PATH
        MODELS_PATH = KAGGLE_PATH
    except ImportError:
        # Step 3: Default to local Jupyter Notebook
        DATA_PATH = LOCAL_PATH / Path('data')
        MODELS_PATH = LOCAL_PATH / Path('models')

GLOVE_6B_PATH = MODELS_PATH / Path('glove-embeddings')
THREAT_TWEETS_PATH = DATA_PATH / Path('tweets-dataset-for-cyberattack-detection')

GLOVE_6B_300D_TXT = GLOVE_6B_PATH / Path('glove.6B.300d.txt')
THREAT_TWEETS_CSV = THREAT_TWEETS_PATH / Path('tweets_final.csv')


## 1. Introduction

In this notebook, we are going to solve a multi-class classification problem using different machine learning models. Our goal is to predict the class of each sample based on the input features.


## 2. Data Loading and Preprocessing
We will load the dataset, inspect its structure, and preprocess it for machine learning models.


In [4]:
# Read the CSV file and process columns in one step
threat_tweets = (
    pd.read_csv(filepath_or_buffer=THREAT_TWEETS_CSV)
    .assign(
        tweet=lambda df: df['tweet'].apply(func=ast.literal_eval),
        watson=lambda df: df['watson'].apply(func=ast.literal_eval)
        .apply(func=lambda x: x.get('categories', []))
        .apply(func=build_tree),
        watson_list=lambda df: df['watson'].apply(func=extract_keys),
    )
    .query(expr='relevant == True')
    .drop(labels=['relevant'], axis=1)
    .dropna(subset=['text'], ignore_index=True)
)

threat_tweets.head()


Unnamed: 0,_id,date,id,text,tweet,type,watson,annotation,urls,destination_url,valid_certificate,watson_list
0,b'5b8876f9bb325e65fa7e78e4',2018-08-30 23:00:08+00:00,1035301167952211969,Protect your customers access Prestashop Ant...,{'created_at': 'Thu Aug 30 23:00:08 +0000 2018...,ddos,{'technology and computing': {'internet techno...,threat,['http://addons.prestashop.com/en/23513-anti-d...,https://addons.prestashop.com/en/23513-anti-dd...,True,"[technology and computing, internet technology..."
1,b'5b8876f9bb325e65fa7e78e5',2018-08-30 23:00:09+00:00,1035301173178249217,Data leak from Huazhu Hotels may affect 130 mi...,{'created_at': 'Thu Aug 30 23:00:09 +0000 2018...,leak,"{'travel': {'hotels': {}}, 'home and garden': ...",threat,['http://www.hotelmanagement.net/tech/data-lea...,http://www.hotelmanagement.net/tech/data-leak-...,True,"[travel, hotels, home and garden, home improve..."
2,b'5b8876fabb325e65fa7e78e6',2018-08-30 23:00:09+00:00,1035301174583353344,Instagram App 41.1788.50991.0 #Denial Of #Serv...,{'created_at': 'Thu Aug 30 23:00:09 +0000 2018...,general,{'science': {'weather': {'meteorological disas...,threat,['https://packetstormsecurity.com/files/149120...,https://packetstormsecurity.com/files/149120/i...,True,"[science, weather, meteorological disaster, hu..."
3,b'5b88770abb325e65fa7e78e7',2018-08-30 23:00:25+00:00,1035301242271096832,(good slides): \n\nThe Advanced Exploitation o...,{'created_at': 'Thu Aug 30 23:00:25 +0000 2018...,vulnerability,{'business and industrial': {'business operati...,threat,['https://twitter.com/i/web/status/10353012422...,https://twitter.com/i/web/status/1035301242271...,True,"[business and industrial, business operations,..."
4,b'5b887713bb325e65fa7e78e8',2018-08-30 23:00:35+00:00,1035301282095853569,CVE-2018-1000532 (beep)\nhttps://t.co/CaKbo38U...,{'created_at': 'Thu Aug 30 23:00:35 +0000 2018...,vulnerability,{'technology and computing': {'computer securi...,threat,['https://web.nvd.nist.gov/view/vuln/detail?vu...,https://nvd.nist.gov/vuln/detail/CVE-2018-1000532,True,"[technology and computing, computer security, ..."


For the goal of the project, the categories of interest are:
1. computer security/network security
2. computer security/antivirus and malware
3. operating systems/mac os
4. operating systems/windows
5. operating systems/unix
6. operating systems/linux
7. software
8. programming languages, included in software
9. software/databases
10. hardware
11. electronic components, included in hardware
12. hardware/computer/servers
13. hardware/computer/portable computer
14. hardware/computer/desktop computer
15. hardware/computer components
16. hardware/computer networking/router
17. hardware/computer networking/wireless technology
18. networking
19. internet technology, included in networking


In [5]:
FIX_TARGETS = {
    'computer security': 'computer security',
    'operating systems': 'operating systems',
    'software': 'software',
    'programming languages': 'software',
    'hardware': 'hardware',
    'electronic components': 'hardware',
    'networking': 'networking',
    'internet technology': 'networking'
}

chosen_categories = [
    list(set(FIX_TARGETS.keys()) & set(s))
    for s in threat_tweets['watson_list']
]

for i, watson_list in enumerate(chosen_categories):
    temp = list(set([FIX_TARGETS[c] for c in watson_list]))
    if len(temp) < 1:
        temp = ['other']
    chosen_categories[i] = temp

threat_tweets['target'] = chosen_categories

threat_tweets.head()


Unnamed: 0,_id,date,id,text,tweet,type,watson,annotation,urls,destination_url,valid_certificate,watson_list,target
0,b'5b8876f9bb325e65fa7e78e4',2018-08-30 23:00:08+00:00,1035301167952211969,Protect your customers access Prestashop Ant...,{'created_at': 'Thu Aug 30 23:00:08 +0000 2018...,ddos,{'technology and computing': {'internet techno...,threat,['http://addons.prestashop.com/en/23513-anti-d...,https://addons.prestashop.com/en/23513-anti-dd...,True,"[technology and computing, internet technology...","[computer security, networking, software]"
1,b'5b8876f9bb325e65fa7e78e5',2018-08-30 23:00:09+00:00,1035301173178249217,Data leak from Huazhu Hotels may affect 130 mi...,{'created_at': 'Thu Aug 30 23:00:09 +0000 2018...,leak,"{'travel': {'hotels': {}}, 'home and garden': ...",threat,['http://www.hotelmanagement.net/tech/data-lea...,http://www.hotelmanagement.net/tech/data-leak-...,True,"[travel, hotels, home and garden, home improve...",[other]
2,b'5b8876fabb325e65fa7e78e6',2018-08-30 23:00:09+00:00,1035301174583353344,Instagram App 41.1788.50991.0 #Denial Of #Serv...,{'created_at': 'Thu Aug 30 23:00:09 +0000 2018...,general,{'science': {'weather': {'meteorological disas...,threat,['https://packetstormsecurity.com/files/149120...,https://packetstormsecurity.com/files/149120/i...,True,"[science, weather, meteorological disaster, hu...",[hardware]
3,b'5b88770abb325e65fa7e78e7',2018-08-30 23:00:25+00:00,1035301242271096832,(good slides): \n\nThe Advanced Exploitation o...,{'created_at': 'Thu Aug 30 23:00:25 +0000 2018...,vulnerability,{'business and industrial': {'business operati...,threat,['https://twitter.com/i/web/status/10353012422...,https://twitter.com/i/web/status/1035301242271...,True,"[business and industrial, business operations,...",[operating systems]
4,b'5b887713bb325e65fa7e78e8',2018-08-30 23:00:35+00:00,1035301282095853569,CVE-2018-1000532 (beep)\nhttps://t.co/CaKbo38U...,{'created_at': 'Thu Aug 30 23:00:35 +0000 2018...,vulnerability,{'technology and computing': {'computer securi...,threat,['https://web.nvd.nist.gov/view/vuln/detail?vu...,https://nvd.nist.gov/vuln/detail/CVE-2018-1000532,True,"[technology and computing, computer security, ...","[computer security, hardware, software]"


In [6]:
FIX_TARGETS_SOFTWARE = {
    'databases': 'databases'
}

threat_tweets_software = threat_tweets[
    threat_tweets["watson_list"].apply(lambda x: "software" in x)
].reset_index(drop=True)

chosen_categories = [
    list(set(FIX_TARGETS_SOFTWARE.keys()) & set(s))
    for s in threat_tweets_software['watson_list']
]

for i, watson_list in enumerate(chosen_categories):
    temp = list(set([FIX_TARGETS_SOFTWARE[c] for c in watson_list]))

    if len(temp) < 1:
        temp = ['other']

    chosen_categories[i] = temp

print(len(chosen_categories))
threat_tweets_software['target'] = chosen_categories

threat_tweets_software.head()


3689


Unnamed: 0,_id,date,id,text,tweet,type,watson,annotation,urls,destination_url,valid_certificate,watson_list,target
0,b'5b8876f9bb325e65fa7e78e4',2018-08-30 23:00:08+00:00,1035301167952211969,Protect your customers access Prestashop Ant...,{'created_at': 'Thu Aug 30 23:00:08 +0000 2018...,ddos,{'technology and computing': {'internet techno...,threat,['http://addons.prestashop.com/en/23513-anti-d...,https://addons.prestashop.com/en/23513-anti-dd...,True,"[technology and computing, internet technology...",[other]
1,b'5b887713bb325e65fa7e78e8',2018-08-30 23:00:35+00:00,1035301282095853569,CVE-2018-1000532 (beep)\nhttps://t.co/CaKbo38U...,{'created_at': 'Thu Aug 30 23:00:35 +0000 2018...,vulnerability,{'technology and computing': {'computer securi...,threat,['https://web.nvd.nist.gov/view/vuln/detail?vu...,https://nvd.nist.gov/vuln/detail/CVE-2018-1000532,True,"[technology and computing, computer security, ...",[databases]
2,b'5b887a29bb325e65fa7e78fa',2018-08-30 23:13:45+00:00,1035304594677460992,'Insight' into Home Automation Reveals Vulnera...,{'created_at': 'Thu Aug 30 23:13:45 +0000 2018...,vulnerability,{'technology and computing': {'computer securi...,threat,"['http://bit.ly/2wyOWAp', 'https://twitter.com...",https://www.mcafee.com:443/blogs/?utm_content=...,True,"[technology and computing, computer security, ...",[other]
3,b'5b887d7abb325e65fa7e790e',2018-08-30 23:27:53+00:00,1035308152747712512,@BitcoinGuruInfo The only way to prove a vulne...,{'created_at': 'Thu Aug 30 23:27:53 +0000 2018...,vulnerability,"{'technology and computing': {'software': {}, ...",threat,['https://twitter.com/i/web/status/10353081527...,https://twitter.com/i/web/status/1035308152747...,True,"[technology and computing, software, operating...",[other]
4,b'5b887ebfbb325e65fa7e791d',2018-08-30 23:33:18+00:00,1035309516756320257,"""Immunity Debugger 1.85 Denial Of Service"" htt...",{'created_at': 'Thu Aug 30 23:33:18 +0000 2018...,general,{'technology and computing': {'computer securi...,threat,['https://ift.tt/2wqE48u'],https://packetstormsecurity.com/files/149164/i...,True,"[technology and computing, computer security, ...",[other]


In [7]:
Counter([item[0] for item in chosen_categories])


Counter({'other': 2187, 'databases': 1502})

In [8]:
Counter(tuple(item) for item in chosen_categories)


Counter({('other',): 2187, ('databases',): 1502})

In [9]:
X = preprocess_texts(
    list_str=threat_tweets_software['text'],
    model_path=GLOVE_6B_300D_TXT,
    embedding_dim=300
)


## 4. Model Training

We will now train different models and evaluate their performance.


In [10]:
br = None
clr = None
cc = None
lp = None
pst = None
cdn = None
mbr = None
PATH_BR_SOFTWARE = Path('models/binary_problems/br_software.pkl')
PATH_CLR_SOFTWARE = Path('models/binary_problems/clr_software.pkl')
PATH_CC_SOFTWARE = Path('models/binary_problems/cc_software.pkl')
PATH_LP_SOFTWARE = Path('models/multiclass_problems/lp_software.pkl')
PATH_PST_SOFTWARE = Path('models/multiclass_problems/pst_software.pkl')
PATH_CDN_SOFTWARE = Path('models/ensembles/cdn_software.pkl')
PATH_MBR_SOFTWARE = Path('models/ensembles/mbr_software.pkl')

if PATH_BR_SOFTWARE.exists():
    br = joblib.load(PATH_BR_SOFTWARE)

if PATH_CLR_SOFTWARE.exists():
    clr = joblib.load(PATH_CLR_SOFTWARE)

if PATH_CC_SOFTWARE.exists():
    cc = joblib.load(PATH_CC_SOFTWARE)

if PATH_LP_SOFTWARE.exists():
    lp = joblib.load(PATH_LP_SOFTWARE)

if PATH_PST_SOFTWARE.exists():
    pst = joblib.load(PATH_PST_SOFTWARE)

if PATH_CDN_SOFTWARE.exists():
    cdn = joblib.load(PATH_CDN_SOFTWARE)

if PATH_MBR_SOFTWARE.exists():
    mbr = joblib.load(PATH_MBR_SOFTWARE)


In [11]:
mlb = MultiLabelBinarizer()
y = mlb.fit_transform(y=threat_tweets_software['target'])

unique_label_sets, threat_tweets_software['target_lp'] = np.unique(
    ar=y,
    axis=0,
    return_inverse=True
)

X_pst, y_pst, label_map_pst, _ = prune_and_subsample(
    x=X,
    y=y,
    pruning_threshold=10,
    max_sub_samples=3000
)

y_lp = threat_tweets_software['target_lp']

label_map_lp = {i: tuple(lbl_set) for i, lbl_set in enumerate(unique_label_sets)}


In [12]:
# BR, CDR, CC, CDN, MBR
X_train_val, y_train_val, X_test, y_test = iterative_train_test_split(
    X=X,
    y=y,
    test_size=TEST_SIZE
)

X_train, y_train, X_val, y_val = iterative_train_test_split(
    X=X_train_val,
    y=y_train_val,
    test_size=TEST_SIZE
)

# LP
X_train_val_lp, X_test_lp, y_train_val_lp, y_test_lp = train_test_split(
    X, y_lp,
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
    shuffle=True,
    stratify=y_lp
)

X_train_lp, X_val_lp, y_train_lp, y_val_lp = train_test_split(
    X_train_val_lp, y_train_val_lp,
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
    shuffle=True,
    stratify=y_train_val_lp
)

# PSt
X_train_val_pst, X_test_pst, y_train_val_pst, y_test_pst = train_test_split(
    X_pst, y_pst,
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
    shuffle=True,
    stratify=y_pst
)

X_train_pst, X_val_pst, y_train_pst, y_val_pst = train_test_split(
    X_train_val_pst, y_train_val_pst,
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
    shuffle=True,
    stratify=y_train_val_pst
)


### 4.1. Binary Problems


#### 4.1.1. BR (Binary Relevance)


In [13]:
if not br or OVERWRITE:
    br = {}

    for k in tqdm(BASE_CLASSIFIERS.keys()):
        br[k] = OneVsRestClassifier(estimator=BASE_CLASSIFIERS[k]).fit(
            X=X_train,
            y=y_train
        )

    joblib.dump(br, PATH_BR_SOFTWARE, compress=9)


100%|██████████| 5/5 [00:16<00:00,  3.30s/it]


#### 4.1.2. CLR (Calibrated Label Ranking)


In [14]:
if not clr or OVERWRITE:
    clr = {}

    for k, v in tqdm(BASE_CLASSIFIERS.items()):
        model = CalibratedLabelRankClassifier(
            classifier=v,
            classes=mlb.classes_,
            random_state=RANDOM_STATE
        )

        clr[k] = model.fit(
            x=X_train,
            y=[list(mlb.classes_[np.where(row == 1)[0]]) for row in y_train]
        )

    joblib.dump(clr, PATH_CLR_SOFTWARE, compress=9)


100%|██████████| 5/5 [00:24<00:00,  4.95s/it]


#### 4.1.3. CC (Classifier Chains)


In [15]:
if not cc or OVERWRITE:
    cc = {}

    for k, v in tqdm(BASE_CLASSIFIERS.items()):
        model = ChainOfClassifiers(
            classifier=v,
            classes=mlb.classes_,
            random_state=RANDOM_STATE
        )

        cc[k] = model.fit(
            x=X_train,
            y=y_train
        )

    joblib.dump(cc, PATH_CC_SOFTWARE, compress=9)


100%|██████████| 5/5 [00:12<00:00,  2.50s/it]


### 4.2. Multi-class Problems



#### 4.2.1. LP (Label Powerset)


In [16]:
if not lp or OVERWRITE:
    lp = {}

    for k, v in tqdm(BASE_CLASSIFIERS.items()):
        model = LabelPowersetClassifier(
            classifier=v,
            label_map=label_map_lp,
            random_state=RANDOM_STATE
        )

        lp[k] = model.fit(
            x=X_train_lp,
            y=y_train_lp
        )

    joblib.dump(lp, PATH_LP_SOFTWARE, compress=9)


100%|██████████| 5/5 [00:08<00:00,  1.79s/it]


#### 4.2.2. PSt (Pruned Sets)


In [17]:
if not pst or OVERWRITE:
    pst = {}

    for k, v in tqdm(BASE_CLASSIFIERS.items()):
        model = LabelPowersetClassifier(
            classifier=v,
            label_map=label_map_pst,
            random_state=RANDOM_STATE
        )

        pst[k] = model.fit(
            x=X_train_pst,
            y=y_train_pst
        )

    joblib.dump(pst, PATH_PST_SOFTWARE, compress=9)


100%|██████████| 5/5 [00:08<00:00,  1.76s/it]


### 4.3. Ensembles


#### 4.3.1. CDN (Conditional Dependency Network)


In [18]:
if not cdn or OVERWRITE:
    cdn = {}

    for k, v in tqdm(BASE_CLASSIFIERS.items()):
        model = ConditionalDependencyNetwork(
            classifier=v,
            num_iterations=100,
            burn_in=10
        )

        cdn[k] = model.fit(
            x=X_train,
            y=y_train
        )

    joblib.dump(cdn, PATH_CDN_SOFTWARE, compress=9)


100%|██████████| 5/5 [00:07<00:00,  1.42s/it]


#### 4.3.2. MBR (Meta-Binary Relevance)


In [19]:
if not mbr or OVERWRITE:
    mbr = {}

    for k, v in tqdm(BASE_CLASSIFIERS.items()):
        model = MetaBinaryRelevance(
            classifier=v,
            use_cross_val=True,
            n_splits=5
        )

        mbr[k] = model.fit(
            x=X_train,
            y=y_train
        )

    joblib.dump(mbr, PATH_MBR_SOFTWARE, compress=9)


100%|██████████| 5/5 [01:42<00:00, 20.59s/it]


## 5. Model Evaluation

Now that we've trained the models, let's evaluate them in more detail.


In [20]:
evaluation = {
    'BR': assess_models(
        x=X_val,
        y=y_val,
        technique=br
    ),
    'CLR': assess_models(
        x=X_val,
        y=y_val,
        technique=clr
    ),
    'CC': assess_models(
        x=X_val,
        y=y_val,
        technique=cc
    ),
    'LP': assess_models(
        x=X_val_lp,
        y=np.array([list(label_map_lp[yp]) for yp in y_val_lp]),
        technique=lp
    ),
    'PST': assess_models(
        x=X_val_pst,
        y=np.array([list(label_map_pst[yp]) for yp in y_val_pst]),
        technique=pst
    ),
    'CDN': assess_models(
        x=X_val,
        y=y_val,
        technique=cdn
    ),
    'MBR': assess_models(
        x=X_val,
        y=y_val,
        technique=mbr
    )
}


In [21]:
performances = pd.DataFrame(evaluation).T
performances


Unnamed: 0,Accuracy,Classifier,Model,Precision example-based,Recall example-based,F1 example-based,Hamming loss,Micro precision,Micro recall,Micro F1,Macro precision,Macro recall,Macro F1,Coverage
BR,0.886441,random_forest,OneVsRestClassifier(estimator=RandomForestClas...,0.886441,0.886441,0.886441,0.111017,0.890971,0.886441,0.8887,0.899267,0.870893,0.881337,1.113559
CLR,0.898305,random_forest,CalibratedLabelRankClassifier(classes=array(['...,0.898305,0.898305,0.898305,0.097458,0.905983,0.898305,0.902128,0.920446,0.880893,0.894921,1.101695
CC,0.901695,random_forest,"ChainOfClassifiers(classes=array(['databases',...",0.901695,0.901695,0.901695,0.098305,0.901695,0.901695,0.901695,0.917422,0.88375,0.894508,1.098305
LP,0.901861,logistic_regression,LabelPowersetClassifier(classifier=LogisticReg...,0.901861,0.901861,0.901861,0.098139,0.901861,0.901861,0.901861,0.897185,0.90099,0.898909,1.098139
PST,0.901861,logistic_regression,LabelPowersetClassifier(classifier=LogisticReg...,0.901861,0.901861,0.901861,0.098139,0.901861,0.901861,0.901861,0.897185,0.90099,0.898909,1.098139
CDN,0.832203,logistic_regression,ConditionalDependencyNetwork(classifier=Logist...,0.833051,0.833898,0.833333,0.164407,0.836735,0.833898,0.835314,0.829861,0.831845,0.830436,1.167797
MBR,0.901695,xgb,MetaBinaryRelevance(classifier=XGBClassifier(b...,0.901695,0.901695,0.901695,0.098305,0.901695,0.901695,0.901695,0.904115,0.891607,0.896679,1.098305


In [23]:
z = evaluation['LP']['Model'].predict(X_val_lp)

print(classification_report(
    y_true=np.array([list(label_map_lp[yp]) for yp in y_val_lp]),
    y_pred=z,
    target_names=mlb.classes_,
    zero_division=0
))


              precision    recall  f1-score   support

   databases       0.87      0.90      0.88       241
       other       0.93      0.91      0.92       350

   micro avg       0.90      0.90      0.90       591
   macro avg       0.90      0.90      0.90       591
weighted avg       0.90      0.90      0.90       591
 samples avg       0.90      0.90      0.90       591

