# Multi-Class Classification with Machine Learning
In this notebook, we will explore various machine learning models to solve a multi-class classification problem. We will evaluate and compare the performance of different algorithms on the dataset.


In [1]:
import ast
import json
import random
from pathlib import Path

import joblib
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from sklearn.multiclass import OneVsRestClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.tree import DecisionTreeClassifier
from skmultilearn.model_selection import iterative_train_test_split
from tqdm import tqdm
from xgboost import XGBClassifier

from preprocess_functions import build_tree, extract_keys, merge_all_trees_with_counts, preprocess_texts
from utils import CalibratedLabelRankClassifier, ChainOfClassifiers, LabelPowersetClassifier, \
    assess_models, prune_and_subsample, ConditionalDependencyNetwork, MetaBinaryRelevance


In [2]:
OVERWRITE = False
RANDOM_STATE = 42

np.random.seed(RANDOM_STATE)
random.seed(RANDOM_STATE)


In [3]:
INIT_POINTS = 1
N_ITER = 5
TEST_SIZE = 2e-1

BASE_CLASSIFIERS = {
    'logistic_regression': LogisticRegression(solver='liblinear', random_state=RANDOM_STATE),
    'gaussian_nb': GaussianNB(),
    'decision_tree': DecisionTreeClassifier(random_state=RANDOM_STATE),
    'random_forest': RandomForestClassifier(random_state=RANDOM_STATE),
    'xgb': XGBClassifier(random_state=RANDOM_STATE)
}

COLAB_PATH = Path('/content/drive/MyDrive')
KAGGLE_PATH = Path('/kaggle/input')
LOCAL_PATH = Path('./')

# Step 1: Check if running in Google Colab
try:
    import google.colab

    DATA_PATH = COLAB_PATH / Path('data')
    MODELS_PATH = COLAB_PATH / Path('models')
except ImportError:
    # Step 2: Check if running in Kaggle
    try:
        import kaggle_secrets

        DATA_PATH = KAGGLE_PATH
        MODELS_PATH = KAGGLE_PATH
    except ImportError:
        # Step 3: Default to local Jupyter Notebook
        DATA_PATH = LOCAL_PATH / Path('data')
        MODELS_PATH = LOCAL_PATH / Path('models')

GLOVE_6B_PATH = MODELS_PATH / Path('glove-embeddings')
THREAT_TWEETS_PATH = DATA_PATH / Path('tweets-dataset-for-cyberattack-detection')

GLOVE_6B_300D_TXT = GLOVE_6B_PATH / Path('glove.6B.300d.txt')
THREAT_TWEETS_CSV = THREAT_TWEETS_PATH / Path('tweets_final.csv')


## 1. Introduction

In this notebook, we are going to solve a multi-class classification problem using different machine learning models. Our goal is to predict the class of each sample based on the input features.


## 2. Data Loading and Preprocessing
We will load the dataset, inspect its structure, and preprocess it for machine learning models.


In [4]:
# Read the CSV file and process columns in one step
threat_tweets = (
    pd.read_csv(filepath_or_buffer=THREAT_TWEETS_CSV)
    .assign(
        tweet=lambda df: df['tweet'].apply(func=ast.literal_eval),
        watson=lambda df: df['watson'].apply(func=ast.literal_eval)
        .apply(func=lambda x: x.get('categories', []))
        .apply(func=build_tree),
        watson_list=lambda df: df['watson'].apply(func=extract_keys),
    )
    .query(expr='relevant == True')
    .drop(labels=['relevant'], axis=1)
    .dropna(subset=['text'], ignore_index=True)
)

threat_tweets.head()


Unnamed: 0,_id,date,id,text,tweet,type,watson,annotation,urls,destination_url,valid_certificate,watson_list
0,b'5b8876f9bb325e65fa7e78e4',2018-08-30 23:00:08+00:00,1035301167952211969,Protect your customers access Prestashop Ant...,{'created_at': 'Thu Aug 30 23:00:08 +0000 2018...,ddos,{'technology and computing': {'internet techno...,threat,['http://addons.prestashop.com/en/23513-anti-d...,https://addons.prestashop.com/en/23513-anti-dd...,True,"[technology and computing, internet technology..."
1,b'5b8876f9bb325e65fa7e78e5',2018-08-30 23:00:09+00:00,1035301173178249217,Data leak from Huazhu Hotels may affect 130 mi...,{'created_at': 'Thu Aug 30 23:00:09 +0000 2018...,leak,"{'travel': {'hotels': {}}, 'home and garden': ...",threat,['http://www.hotelmanagement.net/tech/data-lea...,http://www.hotelmanagement.net/tech/data-leak-...,True,"[travel, hotels, home and garden, home improve..."
2,b'5b8876fabb325e65fa7e78e6',2018-08-30 23:00:09+00:00,1035301174583353344,Instagram App 41.1788.50991.0 #Denial Of #Serv...,{'created_at': 'Thu Aug 30 23:00:09 +0000 2018...,general,{'science': {'weather': {'meteorological disas...,threat,['https://packetstormsecurity.com/files/149120...,https://packetstormsecurity.com/files/149120/i...,True,"[science, weather, meteorological disaster, hu..."
3,b'5b88770abb325e65fa7e78e7',2018-08-30 23:00:25+00:00,1035301242271096832,(good slides): \n\nThe Advanced Exploitation o...,{'created_at': 'Thu Aug 30 23:00:25 +0000 2018...,vulnerability,{'business and industrial': {'business operati...,threat,['https://twitter.com/i/web/status/10353012422...,https://twitter.com/i/web/status/1035301242271...,True,"[business and industrial, business operations,..."
4,b'5b887713bb325e65fa7e78e8',2018-08-30 23:00:35+00:00,1035301282095853569,CVE-2018-1000532 (beep)\nhttps://t.co/CaKbo38U...,{'created_at': 'Thu Aug 30 23:00:35 +0000 2018...,vulnerability,{'technology and computing': {'computer securi...,threat,['https://web.nvd.nist.gov/view/vuln/detail?vu...,https://nvd.nist.gov/vuln/detail/CVE-2018-1000532,True,"[technology and computing, computer security, ..."


In [5]:
print(f"Number of CS related tweets:\t{len(threat_tweets)}")


Number of CS related tweets:	11112


In [6]:
general_tree, visit_count = merge_all_trees_with_counts(trees=threat_tweets['watson'])


In [7]:
print("The subcategories in 'technology and computing' are:")
for category in list(general_tree['technology and computing'].keys()):
    print(f'· {category}')


The subcategories in 'technology and computing' are:
· computer security
· internet technology
· software
· hardware
· operating systems
· data centers
· mp3 and midi
· computer reviews
· programming languages
· consumer electronics
· tech news
· networking
· electronic components
· computer crime
· enterprise technology
· computer certification
· technological innovation
· technical support


In [8]:
sorted_visit_count = dict(sorted(visit_count.items(), key=lambda item: item[1], reverse=True))

with open('general_tree.json', 'w') as file:
    file.write(json.dumps(general_tree, indent=4))

with open('general_tree_visit_counts.json', 'w') as file:
    file.write(json.dumps(sorted_visit_count, indent=4))


## 3. Exploratory Data Analysis (EDA)
Let's analyze the dataset and gain insights into its distribution.


In [9]:
print('At macro categories are:')
for category in list(general_tree.keys()):
    print(f'· {category}')


At macro categories are:
· technology and computing
· health and fitness
· home and garden
· travel
· art and entertainment
· science
· business and industrial
· sports
· finance
· law, govt and politics
· society
· real estate
· pets
· style and fashion
· news
· hobbies and interests
· food and drink
· education
· shopping
· family and parenting
· religion and spirituality
· automotive and vehicles
· careers


For the goal of the project, the categories of interest are:
1. computer security/network security
2. computer security/antivirus and malware
3. operating systems/mac os
4. operating systems/windows
5. operating systems/unix
6. operating systems/linux
7. software
8. programming languages, included in software
9. software/databases
10. hardware
11. electronic components, included in hardware
12. hardware/computer/servers
13. hardware/computer/portable computer
14. hardware/computer/desktop computer
15. hardware/computer components
16. hardware/computer networking/router
17. hardware/computer networking/wireless technology
18. networking
19. internet technology, included in networking


In [10]:
FIX_TARGETS = {
    'computer security': 'computer security',
    'operating systems': 'operating systems',
    'software': 'software',
    'programming languages': 'software',
    'hardware': 'hardware',
    'electronic components': 'hardware',
    'networking': 'networking',
    'internet technology': 'networking'
}

chosen_categories = [
    list(set(FIX_TARGETS.keys()) & set(s))
    for s in threat_tweets['watson_list']
]

for i, watson_list in enumerate(chosen_categories):
    temp = list(set([FIX_TARGETS[c] for c in watson_list]))
    if len(temp) < 1:
        temp = ['other']
    chosen_categories[i] = temp

threat_tweets['target'] = chosen_categories

threat_tweets.head()


Unnamed: 0,_id,date,id,text,tweet,type,watson,annotation,urls,destination_url,valid_certificate,watson_list,target
0,b'5b8876f9bb325e65fa7e78e4',2018-08-30 23:00:08+00:00,1035301167952211969,Protect your customers access Prestashop Ant...,{'created_at': 'Thu Aug 30 23:00:08 +0000 2018...,ddos,{'technology and computing': {'internet techno...,threat,['http://addons.prestashop.com/en/23513-anti-d...,https://addons.prestashop.com/en/23513-anti-dd...,True,"[technology and computing, internet technology...","[computer security, networking, software]"
1,b'5b8876f9bb325e65fa7e78e5',2018-08-30 23:00:09+00:00,1035301173178249217,Data leak from Huazhu Hotels may affect 130 mi...,{'created_at': 'Thu Aug 30 23:00:09 +0000 2018...,leak,"{'travel': {'hotels': {}}, 'home and garden': ...",threat,['http://www.hotelmanagement.net/tech/data-lea...,http://www.hotelmanagement.net/tech/data-leak-...,True,"[travel, hotels, home and garden, home improve...",[other]
2,b'5b8876fabb325e65fa7e78e6',2018-08-30 23:00:09+00:00,1035301174583353344,Instagram App 41.1788.50991.0 #Denial Of #Serv...,{'created_at': 'Thu Aug 30 23:00:09 +0000 2018...,general,{'science': {'weather': {'meteorological disas...,threat,['https://packetstormsecurity.com/files/149120...,https://packetstormsecurity.com/files/149120/i...,True,"[science, weather, meteorological disaster, hu...",[hardware]
3,b'5b88770abb325e65fa7e78e7',2018-08-30 23:00:25+00:00,1035301242271096832,(good slides): \n\nThe Advanced Exploitation o...,{'created_at': 'Thu Aug 30 23:00:25 +0000 2018...,vulnerability,{'business and industrial': {'business operati...,threat,['https://twitter.com/i/web/status/10353012422...,https://twitter.com/i/web/status/1035301242271...,True,"[business and industrial, business operations,...",[operating systems]
4,b'5b887713bb325e65fa7e78e8',2018-08-30 23:00:35+00:00,1035301282095853569,CVE-2018-1000532 (beep)\nhttps://t.co/CaKbo38U...,{'created_at': 'Thu Aug 30 23:00:35 +0000 2018...,vulnerability,{'technology and computing': {'computer securi...,threat,['https://web.nvd.nist.gov/view/vuln/detail?vu...,https://nvd.nist.gov/vuln/detail/CVE-2018-1000532,True,"[technology and computing, computer security, ...","[computer security, hardware, software]"


In [11]:
X = preprocess_texts(
    list_str=threat_tweets['text'],
    model_path=GLOVE_6B_300D_TXT,
    embedding_dim=300
)


## 4. Model Training

We will now train different models and evaluate their performance.


In [12]:
br = None
clr = None
cc = None
lp = None
pst = None
cdn = None
mbr = None

if Path('models/br.pkl').exists():
    br = joblib.load('models/br.pkl')

if Path('models/clr.pkl').exists():
    clr = joblib.load('models/clr.pkl')

if Path('models/cc.pkl').exists():
    cc = joblib.load('models/cc.pkl')

if Path('models/lp.pkl').exists():
    lp = joblib.load('models/lp.pkl')

if Path('models/pst.pkl').exists():
    pst = joblib.load('models/pst.pkl')

if Path('models/cdn.pkl').exists():
    cdn = joblib.load('models/cdn.pkl')

if Path('models/ensembles/meta_binary_relevance.pkl').exists():
    mbr = joblib.load('models/ensembles/meta_binary_relevance.pkl')


In [13]:
mlb = MultiLabelBinarizer()
y = mlb.fit_transform(y=threat_tweets['target'])

unique_label_sets, threat_tweets['target mcc'] = np.unique(ar=y, axis=0, return_inverse=True)
threat_tweets['target mcc pruned'], label_map_pst = prune_and_subsample(y, pruning_threshold=5, max_sub_samples=3)

y_lp = threat_tweets['target mcc']
y_pst = threat_tweets['target mcc pruned']

label_map_lp = {i: tuple(lbl_set) for i, lbl_set in enumerate(unique_label_sets)}


In [14]:
X_train_val, y_train_val, X_test, y_test = iterative_train_test_split(
    X=X,
    y=y,
    test_size=TEST_SIZE
)

X_train, y_train, X_val, y_val = iterative_train_test_split(
    X=X_train_val,
    y=y_train_val,
    test_size=TEST_SIZE
)


In [15]:
X_train_val_mcc, X_test_mcc, y_train_val_mcc, y_test_mcc = train_test_split(
    X, y_lp,
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
    shuffle=True,
    stratify=y_lp
)

X_train_mcc, X_val_mcc, y_train_mcc, y_val_mcc = train_test_split(
    X_train_val_mcc, y_train_val_mcc,
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
    shuffle=True,
    stratify=y_train_val_mcc
)


In [16]:
X_train_val_mcc_pruned, X_test_mcc_pruned, y_train_val_mcc_pruned, y_test_mcc_pruned = train_test_split(
    X, y_pst,
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
    shuffle=True,
    stratify=y_pst
)

X_train_mcc_pruned, X_val_mcc_pruned, y_train_mcc_pruned, y_val_mcc_pruned = train_test_split(
    X_train_val_mcc_pruned, y_train_val_mcc_pruned,
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
    shuffle=True,
    stratify=y_train_val_mcc_pruned
)


### 4.1. Binary Problems


#### 4.1.1. BR (Binary Relevance)


In [17]:
if not br or OVERWRITE:
    br = {}

    for k in tqdm(BASE_CLASSIFIERS.keys()):
        br[k] = OneVsRestClassifier(estimator=BASE_CLASSIFIERS[k]).fit(
            X=X_train,
            y=y_train
        )

    joblib.dump(br, 'models/br.pkl', compress=9)


#### 4.1.2. CLR (Calibrated Label Ranking)


In [18]:
if not clr or OVERWRITE:
    clr = {}

    for k, v in tqdm(BASE_CLASSIFIERS.items()):
        model = CalibratedLabelRankClassifier(
            classifier=v,
            classes=mlb.classes_,
            random_state=RANDOM_STATE
        )

        clr[k] = model.fit(
            x=X_train,
            y=[list(mlb.classes_[np.where(row == 1)[0]]) for row in y_train]
        )

    joblib.dump(clr, 'models/clr.pkl', compress=9)


#### 4.1.3. CC (Classifier Chains)


In [19]:
if not cc or OVERWRITE:
    cc = {}

    for k, v in tqdm(BASE_CLASSIFIERS.items()):
        model = ChainOfClassifiers(
            classifier=v,
            classes=mlb.classes_,
            random_state=RANDOM_STATE
        )

        cc[k] = model.fit(
            x=X_train,
            y=y_train
        )

    joblib.dump(cc, 'models/cc.pkl', compress=9)


### 4.2. Multi-class Problems



#### 4.2.1. LP (Label Powerset)


In [20]:
if not lp or OVERWRITE:
    lp = {}

    for k, v in tqdm(BASE_CLASSIFIERS.items()):
        model = LabelPowersetClassifier(
            classifier=v,
            label_map=label_map_lp,
            random_state=RANDOM_STATE
        )

        lp[k] = model.fit(
            x=X_train_mcc,
            y=y_train_mcc
        )

    joblib.dump(lp, 'models/lp.pkl', compress=9)


#### 4.2.2. PSt (Pruned Sets)


In [21]:
if not pst or OVERWRITE:
    pst = {}

    for k, v in tqdm(BASE_CLASSIFIERS.items()):
        model = LabelPowersetClassifier(
            classifier=v,
            label_map=label_map_pst,
            random_state=RANDOM_STATE
        )

        pst[k] = model.fit(
            x=X_train_mcc_pruned,
            y=y_train_mcc_pruned
        )

    joblib.dump(pst, 'models/pst.pkl', compress=9)


### 4.3. Ensembles


#### 4.3.1. CDN (Conditional Dependency Network)


In [22]:
if not cdn or OVERWRITE:
    cdn = {}

    for k, v in tqdm(BASE_CLASSIFIERS.items()):
        model = ConditionalDependencyNetwork(
            classifier=v,
            num_iterations=100,
            burn_in=10
        )

        cdn[k] = model.fit(
            x=X_train,
            y=y_train
        )

    joblib.dump(cdn, 'models/cdn.pkl', compress=9)


#### 4.3.2. MBR (Meta-Binary Relevance)


In [23]:
if not mbr or OVERWRITE:
    mbr = {}

    for k, v in tqdm(BASE_CLASSIFIERS.items()):
        model = MetaBinaryRelevance(
            classifier=v,
            use_cross_val=True,
            n_splits=5
        )

        mbr[k] = model.fit(
            x=X_train,
            y=y_train
        )

    joblib.dump(mbr, 'models/ensembles/meta_binary_relevance.pkl', compress=9)


100%|██████████| 5/5 [15:45<00:00, 189.09s/it]


## 5. Model Evaluation

Now that we've trained the models, let's evaluate them in more detail.


In [24]:
evaluation = {
    'BR': assess_models(
        x=X_val,
        y=y_val,
        technique=br
    ),
    'CLR': assess_models(
        x=X_val,
        y=y_val,
        technique=clr
    ),
    'CC': assess_models(
        x=X_val,
        y=y_val,
        technique=cc
    ),
    'LP': assess_models(
        x=X_val_mcc,
        y=np.array([list(label_map_lp[yp]) for yp in y_val_mcc]),
        technique=lp
    ),
    'PST': assess_models(
        x=X_val_mcc_pruned,
        y=np.array([list(label_map_pst[yp]) for yp in y_val_mcc_pruned]),
        technique=pst
    ),
    'CDN' : assess_models(
        x=X_val,
        y=y_val,
        technique=cdn
    ),
    'MBR': assess_models(
        x=X_val,
        y=y_val,
        technique=mbr
    )
}


In [28]:
if OVERWRITE:
    joblib.dump(mbr, 'models/evaluation.pkl', compress=9)


In [25]:
performances = pd.DataFrame(evaluation).T
performances


Unnamed: 0,Accuracy,Classifier,Model,Precision example-based,Recall example-based,F1 example-based,Hamming loss,Micro precision,Micro recall,Micro F1,Macro precision,Macro recall,Macro F1,Coverage
BR,0.59322,xgb,OneVsRestClassifier(estimator=XGBClassifier(ba...,0.742844,0.703766,0.710414,0.099718,0.860736,0.708589,0.777287,0.88861,0.61961,0.715834,3.129379
CLR,0.615254,xgb,CalibratedLabelRankClassifier(classes=array(['...,0.767363,0.749341,0.744641,0.095292,0.843078,0.751917,0.794893,0.848015,0.674025,0.74303,2.934463
CC,0.668362,xgb,ChainOfClassifiers(classes=array(['computer se...,0.816902,0.790678,0.789699,0.094539,0.826282,0.778758,0.801816,0.825839,0.704932,0.750153,2.716384
LP,0.693476,xgb,LabelPowersetClassifier(classifier=XGBClassifi...,0.810086,0.782902,0.785883,0.094488,0.828466,0.773129,0.799841,0.851387,0.71253,0.764076,2.676603
PST,0.677165,xgb,LabelPowersetClassifier(classifier=XGBClassifi...,0.799775,0.770247,0.773641,0.099644,0.819123,0.759693,0.788289,0.840448,0.68634,0.741929,2.745782
CDN,0.559322,xgb,ConditionalDependencyNetwork(classifier=XGBCla...,0.7129,0.684746,0.685386,0.13371,0.743443,0.695552,0.7187,0.749439,0.621048,0.669159,3.240678
MBR,0.623164,xgb,MetaBinaryRelevance(classifier=XGBClassifier(b...,0.767608,0.740772,0.7412,0.092938,0.858787,0.743865,0.797206,0.86605,0.675409,0.750002,2.951412


In [26]:
z = evaluation['LP']['Model'].predict(X_val_mcc)
acc = accuracy_score(np.array([list(label_map_lp[yp]) for yp in y_val_mcc]), z)

print(classification_report(y_true=np.array([list(label_map_lp[yp]) for yp in y_val_mcc]), y_pred=z, target_names=mlb.classes_, zero_division=0))
print(acc)


                   precision    recall  f1-score   support

computer security       0.84      0.90      0.87       930
         hardware       0.90      0.54      0.68       355
       networking       0.88      0.58      0.69       146
operating systems       0.95      0.68      0.79       170
            other       0.69      0.82      0.75       361
         software       0.85      0.76      0.81       643

        micro avg       0.83      0.77      0.80      2605
        macro avg       0.85      0.71      0.76      2605
     weighted avg       0.84      0.77      0.80      2605
      samples avg       0.81      0.78      0.79      2605

0.6934758155230596
