# Multi-Class Classification with Machine Learning
In this notebook, we will explore various machine learning models to solve a multi-class classification problem. We will evaluate and compare the performance of different algorithms on the dataset.


In [41]:
import ast
import json
import random
from collections import Counter
from pathlib import Path

import joblib
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.multiclass import OneVsRestClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.tree import DecisionTreeClassifier
from skmultilearn.model_selection import iterative_train_test_split
from tqdm import tqdm
from xgboost import XGBClassifier

from preprocess_functions import build_tree, extract_keys, merge_all_trees_with_counts, preprocess_texts
from utils import CalibratedLabelRankClassifier, ChainOfClassifiers, LabelPowersetClassifier, \
    assess_models, prune_and_subsample, ConditionalDependencyNetwork, MetaBinaryRelevance


In [2]:
OVERWRITE = True
RANDOM_STATE = 42

np.random.seed(RANDOM_STATE)
random.seed(RANDOM_STATE)


In [3]:
TEST_SIZE = 2e-1

BASE_CLASSIFIERS = {
    'logistic_regression': LogisticRegression(solver='liblinear', random_state=RANDOM_STATE),
    'gaussian_nb': GaussianNB(),
    'decision_tree': DecisionTreeClassifier(random_state=RANDOM_STATE),
    'random_forest': RandomForestClassifier(random_state=RANDOM_STATE),
    'xgb': XGBClassifier(random_state=RANDOM_STATE)
}

COLAB_PATH = Path('/content/drive/MyDrive')
KAGGLE_PATH = Path('/kaggle/input')
LOCAL_PATH = Path('./')

# Step 1: Check if running in Google Colab
try:
    import google.colab

    DATA_PATH = COLAB_PATH / Path('data')
    MODELS_PATH = COLAB_PATH / Path('models')
except ImportError:
    # Step 2: Check if running in Kaggle
    try:
        import kaggle_secrets

        DATA_PATH = KAGGLE_PATH
        MODELS_PATH = KAGGLE_PATH
    except ImportError:
        # Step 3: Default to local Jupyter Notebook
        DATA_PATH = LOCAL_PATH / Path('data')
        MODELS_PATH = LOCAL_PATH / Path('models')

GLOVE_6B_PATH = MODELS_PATH / Path('glove-embeddings')
THREAT_TWEETS_PATH = DATA_PATH / Path('tweets-dataset-for-cyberattack-detection')

GLOVE_6B_300D_TXT = GLOVE_6B_PATH / Path('glove.6B.300d.txt')
THREAT_TWEETS_CSV = THREAT_TWEETS_PATH / Path('tweets_final.csv')


## 1. Data Loading and Preprocessing

Load the dataset, inspect its structure, and preprocess it for machine learning models.


In [4]:
# Read the CSV file and process columns in one step
threat_tweets = (
    pd.read_csv(filepath_or_buffer=THREAT_TWEETS_CSV)
    .assign(
        tweet=lambda df: df['tweet'].apply(func=ast.literal_eval),
        watson=lambda df: df['watson'].apply(func=ast.literal_eval)
        .apply(func=lambda x: x.get('categories', []))
        .apply(func=build_tree),
        watson_list=lambda df: df['watson'].apply(func=extract_keys),
    )
    .query(expr='relevant == True')
    .drop(labels=[
        'relevant', '_id', 'date',
        'id', 'tweet', 'type',
        'annotation', 'urls', 'destination_url',
        'valid_certificate'
    ], axis=1)
    .dropna(subset=['text'], ignore_index=True)
)

threat_tweets.head()


Unnamed: 0,text,watson,watson_list
0,Protect your customers access Prestashop Ant...,{'technology and computing': {'internet techno...,"[technology and computing, internet technology..."
1,Data leak from Huazhu Hotels may affect 130 mi...,"{'travel': {'hotels': {}}, 'home and garden': ...","[travel, hotels, home and garden, home improve..."
2,Instagram App 41.1788.50991.0 #Denial Of #Serv...,{'science': {'weather': {'meteorological disas...,"[science, weather, meteorological disaster, hu..."
3,(good slides): \n\nThe Advanced Exploitation o...,{'business and industrial': {'business operati...,"[business and industrial, business operations,..."
4,CVE-2018-1000532 (beep)\nhttps://t.co/CaKbo38U...,{'technology and computing': {'computer securi...,"[technology and computing, computer security, ..."


In [5]:
print(f"Number of CS related tweets:\t{len(threat_tweets)}")


Number of CS related tweets:	11112


In [6]:
general_tree, visit_count = merge_all_trees_with_counts(trees=threat_tweets['watson'])


In [7]:
print("The subcategories in 'technology and computing' are:")
for category in list(general_tree['technology and computing'].keys()):
    print(f'· {category}')


The subcategories in 'technology and computing' are:
· computer security
· internet technology
· software
· hardware
· operating systems
· data centers
· mp3 and midi
· computer reviews
· programming languages
· consumer electronics
· tech news
· networking
· electronic components
· computer crime
· enterprise technology
· computer certification
· technological innovation
· technical support


In [8]:
sorted_visit_count = dict(sorted(visit_count.items(), key=lambda item: item[1], reverse=True))

with open('general_tree.json', 'w') as file:
    file.write(json.dumps(general_tree, indent=4))

with open('general_tree_visit_counts.json', 'w') as file:
    file.write(json.dumps(sorted_visit_count, indent=4))


## 2. Exploratory Data Analysis (EDA)

Analyze the dataset and gain insights into its distribution.


In [9]:
print('At macro categories are:')
for category in list(general_tree.keys()):
    print(f'· {category}')


At macro categories are:
· technology and computing
· health and fitness
· home and garden
· travel
· art and entertainment
· science
· business and industrial
· sports
· finance
· law, govt and politics
· society
· real estate
· pets
· style and fashion
· news
· hobbies and interests
· food and drink
· education
· shopping
· family and parenting
· religion and spirituality
· automotive and vehicles
· careers


For the goal of the project, the categories of interest are:
1. computer security/network security
2. computer security/antivirus and malware
3. operating systems/mac os
4. operating systems/windows
5. operating systems/unix
6. operating systems/linux
7. software
8. programming languages, included in software
9. software/databases
10. hardware
11. electronic components, included in hardware
12. hardware/computer/servers
13. hardware/computer/portable computer
14. hardware/computer/desktop computer
15. hardware/computer components
16. hardware/computer networking/router
17. hardware/computer networking/wireless technology
18. networking
19. internet technology, included in networking


In [10]:
FIX_TARGETS = {
    'computer security': 'computer security',
    'operating systems': 'operating systems',
    'software': 'software',
    'programming languages': 'software',
    'hardware': 'hardware',
    'electronic components': 'hardware',
    'networking': 'networking',
    'internet technology': 'networking'
}

chosen_categories = [
    list(set(FIX_TARGETS.keys()) & set(s))
    for s in threat_tweets['watson_list']
]

for i, watson_list in enumerate(chosen_categories):
    temp = list(set([FIX_TARGETS[c] for c in watson_list]))
    if len(temp) < 1:
        temp = ['other']
    chosen_categories[i] = temp

threat_tweets['target'] = chosen_categories

threat_tweets.head()


Unnamed: 0,text,watson,watson_list,target
0,Protect your customers access Prestashop Ant...,{'technology and computing': {'internet techno...,"[technology and computing, internet technology...","[networking, software, computer security]"
1,Data leak from Huazhu Hotels may affect 130 mi...,"{'travel': {'hotels': {}}, 'home and garden': ...","[travel, hotels, home and garden, home improve...",[other]
2,Instagram App 41.1788.50991.0 #Denial Of #Serv...,{'science': {'weather': {'meteorological disas...,"[science, weather, meteorological disaster, hu...",[hardware]
3,(good slides): \n\nThe Advanced Exploitation o...,{'business and industrial': {'business operati...,"[business and industrial, business operations,...",[operating systems]
4,CVE-2018-1000532 (beep)\nhttps://t.co/CaKbo38U...,{'technology and computing': {'computer securi...,"[technology and computing, computer security, ...","[hardware, software, computer security]"


In [11]:
counts_classes = Counter([item[0] for item in chosen_categories])
counts_classes


Counter({'computer security': 2930,
         'software': 2830,
         'other': 2255,
         'hardware': 2216,
         'networking': 738,
         'operating systems': 143})

In [12]:
counts_targets = Counter(tuple(item) for item in chosen_categories)
counts_targets


Counter({('computer security',): 2798,
         ('other',): 2255,
         ('software', 'computer security'): 1439,
         ('software',): 803,
         ('hardware',): 611,
         ('hardware', 'computer security'): 540,
         ('hardware', 'software'): 441,
         ('software', 'operating systems'): 392,
         ('hardware', 'software', 'computer security'): 269,
         ('networking',): 208,
         ('software', 'operating systems', 'computer security'): 196,
         ('networking', 'software', 'computer security'): 179,
         ('networking', 'software'): 179,
         ('networking', 'computer security'): 148,
         ('operating systems',): 143,
         ('computer security', 'operating systems'): 132,
         ('hardware', 'software', 'operating systems'): 77,
         ('hardware', 'computer security', 'networking'): 70,
         ('hardware', 'networking'): 68,
         ('hardware', 'operating systems'): 66,
         ('hardware', 'software', 'networking'): 38,
         (

In [13]:
X = preprocess_texts(
    list_str=threat_tweets['text'],
    model_path=GLOVE_6B_300D_TXT,
    embedding_dim=300
)


## 4. Model Training

We will now train different models and evaluate their performance.


In [14]:
br = None
clr = None
cc = None
lp = None
pst = None
cdn = None
mbr = None

PATH_BR_GENERAL = Path('models/binary_problems/br_general.pkl')
PATH_CLR_GENERAL = Path('models/binary_problems/clr_general.pkl')
PATH_CC_GENERAL = Path('models/binary_problems/cc_general.pkl')

PATH_LP_GENERAL = Path('models/multiclass_problems/lp_general.pkl')
PATH_PST_GENERAL = Path('models/multiclass_problems/pst_general.pkl')

PATH_CDN_GENERAL = Path('models/ensembles/cdn_general.pkl')
PATH_MBR_GENERAL = Path('models/ensembles/mbr_general.pkl')

if PATH_BR_GENERAL.exists():
    br = joblib.load(PATH_BR_GENERAL)

if PATH_CLR_GENERAL.exists():
    clr = joblib.load(PATH_CLR_GENERAL)

if PATH_CC_GENERAL.exists():
    cc = joblib.load(PATH_CC_GENERAL)

if PATH_LP_GENERAL.exists():
    lp = joblib.load(PATH_LP_GENERAL)

if PATH_PST_GENERAL.exists():
    pst = joblib.load(PATH_PST_GENERAL)

if PATH_CDN_GENERAL.exists():
    cdn = joblib.load(PATH_CDN_GENERAL)

if PATH_MBR_GENERAL.exists():
    mbr = joblib.load(PATH_MBR_GENERAL)


In [15]:
mlb = MultiLabelBinarizer()
y = mlb.fit_transform(y=threat_tweets['target'])

unique_label_sets, y_lp = np.unique(ar=y, axis=0, return_inverse=True)
label_map_lp = {i: tuple(lbl_set) for i, lbl_set in enumerate(unique_label_sets)}

X_pst, y_pst, label_map_pst, _ = prune_and_subsample(
    x=X,
    y=y,
    pruning_threshold=np.median(np.array(list(counts_targets.values()))) * .25,
    max_sub_samples=round(np.median(np.array(list(counts_targets.values()))) * .25)
)


In [16]:
X_train_val, y_train_val, X_test, y_test = iterative_train_test_split(
    X=X,
    y=y,
    test_size=TEST_SIZE
)

X_train, y_train, X_val, y_val = iterative_train_test_split(
    X=X_train_val,
    y=y_train_val,
    test_size=TEST_SIZE
)

X_train_val_lp, X_test_lp, y_train_val_lp, y_test_lp = train_test_split(
    X, y_lp,
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
    shuffle=True,
    stratify=y_lp
)

X_train_lp, X_val_lp, y_train_lp, y_val_lp = train_test_split(
    X_train_val_lp, y_train_val_lp,
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
    shuffle=True,
    stratify=y_train_val_lp
)

X_train_val_pst, X_test_pst, y_train_val_pst, y_test_pst = train_test_split(
    X_pst, y_pst,
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
    shuffle=True,
    stratify=y_pst
)

X_train_pst, X_val_pst, y_train_pst, y_val_pst = train_test_split(
    X_train_val_pst, y_train_val_pst,
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
    shuffle=True,
    stratify=y_train_val_pst
)


### 4.1. Binary Problems


#### 4.1.1. BR (Binary Relevance)


In [17]:
if not br or OVERWRITE:
    br = {}

    for k in tqdm(BASE_CLASSIFIERS.keys()):
        br[k] = OneVsRestClassifier(estimator=BASE_CLASSIFIERS[k]).fit(
            X=X_train,
            y=y_train
        )

    joblib.dump(br, PATH_BR_GENERAL, compress=9)


100%|██████████| 5/5 [02:39<00:00, 31.97s/it]


#### 4.1.2. CLR (Calibrated Label Ranking)


In [18]:
if not clr or OVERWRITE:
    clr = {}

    for k, v in tqdm(BASE_CLASSIFIERS.items()):
        model = CalibratedLabelRankClassifier(
            classifier=v,
            classes=mlb.classes_,
            random_state=RANDOM_STATE
        )

        clr[k] = model.fit(
            x=X_train,
            y=[list(mlb.classes_[np.where(row == 1)[0]]) for row in y_train]
        )

    joblib.dump(clr, PATH_CLR_GENERAL, compress=9)


100%|██████████| 5/5 [04:47<00:00, 57.59s/it]


#### 4.1.3. CC (Classifier Chains)


In [19]:
if not cc or OVERWRITE:
    cc = {}

    for k, v in tqdm(BASE_CLASSIFIERS.items()):
        model = ChainOfClassifiers(
            classifier=v,
            classes=mlb.classes_,
            random_state=RANDOM_STATE
        )

        cc[k] = model.fit(
            x=X_train,
            y=y_train
        )

    joblib.dump(cc, PATH_CC_GENERAL, compress=9)


100%|██████████| 5/5 [02:05<00:00, 25.01s/it]


### 4.2. Multi-class Problems



#### 4.2.1. LP (Label Powerset)


In [20]:
if not lp or OVERWRITE:
    lp = {}

    for k, v in tqdm(BASE_CLASSIFIERS.items()):
        model = LabelPowersetClassifier(
            classifier=v,
            label_map=label_map_lp,
            random_state=RANDOM_STATE
        )

        lp[k] = model.fit(
            x=X_train_lp,
            y=y_train_lp
        )

    joblib.dump(lp, PATH_LP_GENERAL, compress=9)


100%|██████████| 5/5 [03:21<00:00, 40.31s/it]


#### 4.2.2. PSt (Pruned Sets)


In [21]:
if not pst or OVERWRITE:
    pst = {}

    for k, v in tqdm(BASE_CLASSIFIERS.items()):
        model = LabelPowersetClassifier(
            classifier=v,
            label_map=label_map_pst,
            random_state=RANDOM_STATE
        )

        pst[k] = model.fit(
            x=X_train_pst,
            y=y_train_pst
        )

    joblib.dump(pst, PATH_PST_GENERAL, compress=9)


100%|██████████| 5/5 [04:20<00:00, 52.14s/it] 


### 4.3. Ensembles


#### 4.3.1. CDN (Conditional Dependency Network)


In [22]:
if not cdn or OVERWRITE:
    cdn = {}

    for k, v in tqdm(BASE_CLASSIFIERS.items()):
        model = ConditionalDependencyNetwork(
            classifier=v,
            num_iterations=100,
            burn_in=10
        )

        cdn[k] = model.fit(
            x=X_train,
            y=y_train
        )

    joblib.dump(cdn, PATH_CDN_GENERAL, compress=9)


100%|██████████| 5/5 [02:25<00:00, 29.17s/it]


#### 4.3.2. MBR (Meta-Binary Relevance)


In [23]:
if not mbr or OVERWRITE:
    mbr = {}

    for k, v in tqdm(BASE_CLASSIFIERS.items()):
        model = MetaBinaryRelevance(
            classifier=v,
            use_cross_val=True,
            n_splits=5
        )

        mbr[k] = model.fit(
            x=X_train,
            y=y_train
        )

    joblib.dump(mbr, PATH_MBR_GENERAL, compress=9)


100%|██████████| 5/5 [16:41<00:00, 200.21s/it]


## 5. Model Evaluation

Now that we've trained the models, let's evaluate them in more detail.


In [26]:
[str(yp) for yp in label_map_pst.values()]

['(1, 0, 1, 0, 0, 1)',
 '(0, 0, 0, 0, 1, 0)',
 '(0, 1, 0, 0, 0, 0)',
 '(0, 0, 0, 1, 0, 0)',
 '(1, 1, 0, 0, 0, 1)',
 '(1, 1, 0, 0, 0, 0)',
 '(1, 0, 0, 1, 0, 0)',
 '(1, 0, 0, 0, 0, 0)',
 '(1, 0, 0, 0, 0, 1)',
 '(0, 0, 0, 1, 0, 1)',
 '(0, 0, 0, 0, 0, 1)',
 '(1, 0, 1, 0, 0, 0)',
 '(0, 0, 1, 0, 0, 0)',
 '(0, 1, 0, 0, 0, 1)',
 '(0, 1, 1, 0, 0, 0)',
 '(0, 0, 1, 0, 0, 1)',
 '(1, 1, 1, 0, 0, 0)',
 '(1, 0, 0, 1, 0, 1)',
 '(0, 1, 0, 1, 0, 0)',
 '(0, 1, 0, 1, 0, 1)']

In [29]:
evaluation = {
    'BR': assess_models(
        x=X_val,
        y=y_val,
        technique=br,
        classes=mlb.classes_
    ),
    'CLR': assess_models(
        x=X_val,
        y=y_val,
        technique=clr,
        classes=mlb.classes_
    ),
    'CC': assess_models(
        x=X_val,
        y=y_val,
        technique=cc,
        classes=mlb.classes_
    ),
    'LP': assess_models(
        x=X_val_lp,
        y=np.array([list(label_map_lp[yp]) for yp in y_val_lp]),
        technique=lp,
        classes=mlb.classes_
    ),
    'PST': assess_models(
        x=X_val_pst,
        y=np.array([list(label_map_pst[yp]) for yp in y_val_pst]),
        technique=pst,
        classes=mlb.classes_
    ),
    'CDN': assess_models(
        x=X_val,
        y=y_val,
        technique=cdn,
        classes=mlb.classes_
    ),
    'MBR': assess_models(
        x=X_val,
        y=y_val,
        technique=mbr,
        classes=mlb.classes_
    )
}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize

In [30]:
if OVERWRITE:
    joblib.dump(mbr, 'models/evaluation_general.pkl', compress=9)


In [31]:
performances = pd.DataFrame(evaluation).T
performances


Unnamed: 0,Accuracy,Classifier,Model,Precision example-based,Recall example-based,F1 example-based,Hamming loss,Micro precision,Micro recall,Micro F1,Macro precision,Macro recall,Macro F1,Coverage,Classification
BR,0.59322,xgb,"OneVsRestClassifier(estimator=XGBClassifier(base_score=None, booster=None,\n callbacks=None,\n colsample_bylevel=None,\n colsample_bynode=None,\n colsample_bytree=None, device=None,\n early_stopping_rounds=None,\n enable_categorical=False,\n eval_metric=None,\n feature_types=None, gamma=None,\n grow_policy=None,\n importance_type=None,\n interaction_constraints=None,\n learning_rate=None, max_bin=None,\n max_cat_threshold=None,\n max_cat_to_onehot=None,\n max_delta_step=None, max_depth=None,\n max_leaves=None,\n min_child_weight=None, missing=nan,\n monotone_constraints=None,\n multi_strategy=None,\n n_estimators=None, n_jobs=None,\n num_parallel_tree=None,\n random_state=42, ...))",0.742844,0.703766,0.710414,0.099718,0.860736,0.708589,0.777287,0.88861,0.61961,0.715834,3.129379,precision recall f1-score support\n\ncomputer security 0.85 0.87 0.86 931\n hardware 0.87 0.50 0.63 355\n networking 0.95 0.42 0.58 146\noperating systems 0.92 0.60 0.73 171\n other 0.90 0.56 0.69 361\n software 0.84 0.78 0.81 644\n\n micro avg 0.86 0.71 0.78 2608\n macro avg 0.89 0.62 0.72 2608\n weighted avg 0.87 0.71 0.77 2608\n samples avg 0.74 0.70 0.71 2608\n
CLR,0.615254,xgb,"CalibratedLabelRankClassifier(classes=array(['computer security', 'hardware', 'networking', 'operating systems',\n 'other', 'software'], dtype=object),\n classifier=XGBClassifier(base_score=None,\n booster=None,\n callbacks=None,\n colsample_bylevel=None,\n colsample_bynode=None,\n colsample_bytree=None,\n device=None,\n early_stopping_rounds=None,\n enable_categorical=False,\n eval_metr...\n grow_policy=None,\n importance_type=None,\n interaction_constraints=None,\n learning_rate=None,\n max_bin=None,\n max_cat_threshold=None,\n max_cat_to_onehot=None,\n max_delta_step=None,\n max_depth=None,\n max_leaves=None,\n min_child_weight=None,\n missing=nan,\n monotone_constraints=None,\n multi_strategy=None,\n n_estimators=None,\n n_jobs=None,\n num_parallel_tree=None,\n random_state=42, ...),\n random_state=42)",0.767363,0.749341,0.744641,0.095292,0.843078,0.751917,0.794893,0.848015,0.674025,0.74303,2.934463,precision recall f1-score support\n\ncomputer security 0.86 0.88 0.87 931\n hardware 0.82 0.57 0.67 355\n networking 0.85 0.47 0.61 146\noperating systems 0.87 0.65 0.74 171\n other 0.87 0.66 0.75 361\n software 0.81 0.81 0.81 644\n\n micro avg 0.84 0.75 0.79 2608\n macro avg 0.85 0.67 0.74 2608\n weighted avg 0.84 0.75 0.79 2608\n samples avg 0.77 0.75 0.74 2608\n
CC,0.668362,xgb,"ChainOfClassifiers(classes=array(['computer security', 'hardware', 'networking', 'operating systems',\n 'other', 'software'], dtype=object),\n classifier=XGBClassifier(base_score=None, booster=None,\n callbacks=None,\n colsample_bylevel=None,\n colsample_bynode=None,\n colsample_bytree=None, device=None,\n early_stopping_rounds=None,\n enable_categorical=False,\n eval_metric=None,\n fea...\n grow_policy=None,\n importance_type=None,\n interaction_constraints=None,\n learning_rate=None, max_bin=None,\n max_cat_threshold=None,\n max_cat_to_onehot=None,\n max_delta_step=None, max_depth=None,\n max_leaves=None,\n min_child_weight=None, missing=nan,\n monotone_constraints=None,\n multi_strategy=None,\n n_estimators=None, n_jobs=None,\n num_parallel_tree=None,\n random_state=42, ...),\n random_state=42)",0.816902,0.790678,0.789699,0.094539,0.826282,0.778758,0.801816,0.825839,0.704932,0.750153,2.716384,precision recall f1-score support\n\ncomputer security 0.86 0.89 0.87 931\n hardware 0.81 0.56 0.66 355\n networking 0.85 0.49 0.62 146\noperating systems 0.87 0.65 0.74 171\n other 0.74 0.83 0.78 361\n software 0.83 0.81 0.82 644\n\n micro avg 0.83 0.78 0.80 2608\n macro avg 0.83 0.70 0.75 2608\n weighted avg 0.83 0.78 0.80 2608\n samples avg 0.82 0.79 0.79 2608\n
LP,0.693476,xgb,"LabelPowersetClassifier(classifier=XGBClassifier(base_score=None, booster=None,\n callbacks=None,\n colsample_bylevel=None,\n colsample_bynode=None,\n colsample_bytree=None,\n device=None,\n early_stopping_rounds=None,\n enable_categorical=False,\n eval_metric=None,\n feature_types=None, gamma=None,\n grow_policy=None,\n importance_type=None,\n interaction_constraints=None,\n learning_rat...\n 4: (0, 0, 1, 0, 0, 0), 5: (0, 0, 1, 0, 0, 1),\n 6: (0, 0, 1, 1, 0, 0), 7: (0, 0, 1, 1, 0, 1),\n 8: (0, 1, 0, 0, 0, 0), 9: (0, 1, 0, 0, 0, 1),\n 10: (0, 1, 0, 1, 0, 0),\n 11: (0, 1, 0, 1, 0, 1),\n 12: (0, 1, 1, 0, 0, 0),\n 13: (0, 1, 1, 0, 0, 1),\n 14: (1, 0, 0, 0, 0, 0),\n 15: (1, 0, 0, 0, 0, 1),\n 16: (1, 0, 0, 1, 0, 0),\n 17: (1, 0, 0, 1, 0, 1),\n 18: (1, 0, 1, 0, 0, 0),\n 19: (1, 0, 1, 0, 0, 1),\n 20: (1, 0, 1, 1, 0, 0),\n 21: (1, 1, 0, 0, 0, 0),\n 22: (1, 1, 0, 0, 0, 1),\n 23: (1, 1, 0, 1, 0, 0),\n 24: (1, 1, 1, 0, 0, 0)},\n random_state=42)",0.810086,0.782902,0.785883,0.094488,0.828466,0.773129,0.799841,0.851387,0.71253,0.764076,2.676603,precision recall f1-score support\n\ncomputer security 0.84 0.90 0.87 930\n hardware 0.90 0.54 0.68 355\n networking 0.88 0.58 0.69 146\noperating systems 0.95 0.68 0.79 170\n other 0.69 0.82 0.75 361\n software 0.85 0.76 0.81 643\n\n micro avg 0.83 0.77 0.80 2605\n macro avg 0.85 0.71 0.76 2605\n weighted avg 0.84 0.77 0.80 2605\n samples avg 0.81 0.78 0.79 2605\n
PST,0.532133,xgb,"LabelPowersetClassifier(classifier=XGBClassifier(base_score=None, booster=None,\n callbacks=None,\n colsample_bylevel=None,\n colsample_bynode=None,\n colsample_bytree=None,\n device=None,\n early_stopping_rounds=None,\n enable_categorical=False,\n eval_metric=None,\n feature_types=None, gamma=None,\n grow_policy=None,\n importance_type=None,\n interaction_constraints=None,\n learning_rat...\n label_map={0: (1, 0, 1, 0, 0, 1), 1: (0, 0, 0, 0, 1, 0),\n 2: (0, 1, 0, 0, 0, 0), 3: (0, 0, 0, 1, 0, 0),\n 4: (1, 1, 0, 0, 0, 1), 5: (1, 1, 0, 0, 0, 0),\n 6: (1, 0, 0, 1, 0, 0), 7: (1, 0, 0, 0, 0, 0),\n 8: (1, 0, 0, 0, 0, 1), 9: (0, 0, 0, 1, 0, 1),\n 10: (0, 0, 0, 0, 0, 1),\n 11: (1, 0, 1, 0, 0, 0),\n 12: (0, 0, 1, 0, 0, 0),\n 13: (0, 1, 0, 0, 0, 1),\n 14: (0, 1, 1, 0, 0, 0),\n 15: (0, 0, 1, 0, 0, 1),\n 16: (1, 1, 1, 0, 0, 0),\n 17: (1, 0, 0, 1, 0, 1),\n 18: (0, 1, 0, 1, 0, 0),\n 19: (0, 1, 0, 1, 0, 1)},\n random_state=42)",0.715582,0.680993,0.681771,0.138559,0.730488,0.676263,0.70233,0.707463,0.614256,0.645762,3.292264,precision recall f1-score support\n\ncomputer security 0.78 0.86 0.82 1094\n hardware 0.62 0.55 0.58 615\n networking 0.56 0.46 0.51 333\noperating systems 0.76 0.34 0.47 342\n other 0.74 0.76 0.75 361\n software 0.78 0.72 0.75 798\n\n micro avg 0.73 0.68 0.70 3543\n macro avg 0.71 0.61 0.65 3543\n weighted avg 0.73 0.68 0.69 3543\n samples avg 0.72 0.68 0.68 3543\n
CDN,0.560452,xgb,"ConditionalDependencyNetwork(classifier=XGBClassifier(base_score=None,\n booster=None,\n callbacks=None,\n colsample_bylevel=None,\n colsample_bynode=None,\n colsample_bytree=None,\n device=None,\n early_stopping_rounds=None,\n enable_categorical=False,\n eval_metric=None,\n feature_types=None,\n gamma=None,\n grow_policy=None,\n importance_type=None,\n interaction_constraints=None,\n learning_rate=None,\n max_bin=None,\n max_cat_threshold=None,\n max_cat_to_onehot=None,\n max_delta_step=None,\n max_depth=None,\n max_leaves=None,\n min_child_weight=None,\n missing=nan,\n monotone_constraints=None,\n multi_strategy=None,\n n_estimators=None,\n n_jobs=None,\n num_parallel_tree=None,\n random_state=42, ...))",0.707957,0.681827,0.681601,0.136347,0.73751,0.690567,0.713267,0.729757,0.609992,0.653994,3.246328,precision recall f1-score support\n\ncomputer security 0.79 0.80 0.79 931\n hardware 0.71 0.50 0.59 355\n networking 0.73 0.40 0.52 146\noperating systems 0.82 0.57 0.67 171\n other 0.53 0.62 0.57 361\n software 0.80 0.78 0.79 644\n\n micro avg 0.74 0.69 0.71 2608\n macro avg 0.73 0.61 0.65 2608\n weighted avg 0.74 0.69 0.71 2608\n samples avg 0.71 0.68 0.68 2608\n
MBR,0.623164,xgb,"MetaBinaryRelevance(classifier=XGBClassifier(base_score=None, booster=None,\n callbacks=None,\n colsample_bylevel=None,\n colsample_bynode=None,\n colsample_bytree=None, device=None,\n early_stopping_rounds=None,\n enable_categorical=False,\n eval_metric=None,\n feature_types=None, gamma=None,\n grow_policy=None,\n importance_type=None,\n interaction_constraints=None,\n learning_rate=None, max_bin=None,\n max_cat_threshold=None,\n max_cat_to_onehot=None,\n max_delta_step=None,\n max_depth=None, max_leaves=None,\n min_child_weight=None, missing=nan,\n monotone_constraints=None,\n multi_strategy=None,\n n_estimators=None, n_jobs=None,\n num_parallel_tree=None,\n random_state=42, ...),\n use_cross_val=True)",0.767608,0.740772,0.7412,0.092938,0.858787,0.743865,0.797206,0.86605,0.675409,0.750002,2.951412,precision recall f1-score support\n\ncomputer security 0.86 0.85 0.86 931\n hardware 0.83 0.55 0.66 355\n networking 0.90 0.45 0.60 146\noperating systems 0.88 0.71 0.79 171\n other 0.86 0.69 0.77 361\n software 0.85 0.80 0.82 644\n\n micro avg 0.86 0.74 0.80 2608\n macro avg 0.87 0.68 0.75 2608\n weighted avg 0.86 0.74 0.79 2608\n samples avg 0.77 0.74 0.74 2608\n


In [45]:
best_general_classifier_name = max(evaluation, key=lambda t: evaluation[t]['Accuracy'])
best_general_classifier = evaluation[best_general_classifier_name]
y_val_hat = best_general_classifier['Model'].predict(X_val_lp)
if best_general_classifier_name == 'LP':
    y_true = np.array([list(label_map_lp[yp]) for yp in y_val_lp])
elif best_general_classifier_name == 'PST':
    y_true = np.array([list(label_map_pst[yp]) for yp in y_val_pst])
else:
    y_true = y_val

print(f"\nBest Method:\t{best_general_classifier_name}\n")
print(classification_report(
    y_true=np.array([list(label_map_lp[yp]) for yp in y_val_lp]),
    y_pred=y_val_hat,
    target_names=mlb.classes_,
    zero_division=0
))

print(f"Accuracy:\t{accuracy_score(y_true=y_true, y_pred=y_val_hat):.4f}")
print(f"AUC:\t{roc_auc_score(y_true=y_true, y_score=y_val_hat):.4f}")



Best Method:	LP

                   precision    recall  f1-score   support

computer security       0.84      0.90      0.87       930
         hardware       0.90      0.54      0.68       355
       networking       0.88      0.58      0.69       146
operating systems       0.95      0.68      0.79       170
            other       0.69      0.82      0.75       361
         software       0.85      0.76      0.81       643

        micro avg       0.83      0.77      0.80      2605
        macro avg       0.85      0.71      0.76      2605
     weighted avg       0.84      0.77      0.80      2605
      samples avg       0.81      0.78      0.79      2605

Accuracy:	0.6935
AUC:	0.8243
