# Multi-Class Classification with Machine Learning
In this notebook, we will explore various machine learning models to solve a multi-class classification problem. We will evaluate and compare the performance of different algorithms on the dataset.


In [1]:
import ast
import json
import random
import re
import string
from collections import defaultdict
from copy import copy
from itertools import combinations
from pathlib import Path
from typing import Union, Any

import numpy as np
import pandas
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score
from sklearn.multiclass import OneVsRestClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.tree import DecisionTreeClassifier
from skmultilearn.model_selection import iterative_train_test_split
from xgboost import XGBClassifier


In [2]:
RANDOM_STATE = 42

np.random.seed(RANDOM_STATE)
random.seed(RANDOM_STATE)


In [3]:
INIT_POINTS = 1
N_ITER = 5
TEST_SIZE = 2e-1

BASE_CLASSIFIERS = {
    'logistic_regression': LogisticRegression(solver='liblinear', random_state=RANDOM_STATE),
    'gaussian_naive_bayes': GaussianNB(),
    'decision_tree': DecisionTreeClassifier(random_state=RANDOM_STATE),
    'random_forest': RandomForestClassifier(random_state=RANDOM_STATE),
    'xgb': XGBClassifier(random_state=RANDOM_STATE)
}

COLAB_PATH = Path('/content/drive/MyDrive')
KAGGLE_PATH = Path('/kaggle/input')
LOCAL_PATH = Path('./')

# Step 1: Check if running in Google Colab
try:
    import google.colab

    DATA_PATH = COLAB_PATH / Path('data')
    MODELS_PATH = COLAB_PATH / Path('models')
except ImportError:
    # Step 2: Check if running in Kaggle
    try:
        import kaggle_secrets

        DATA_PATH = KAGGLE_PATH
        MODELS_PATH = KAGGLE_PATH
    except ImportError:
        # Step 3: Default to local Jupyter Notebook
        DATA_PATH = LOCAL_PATH / Path('data')
        MODELS_PATH = LOCAL_PATH / Path('models')

GLOVE_6B_PATH = MODELS_PATH / Path('glove-embeddings')
THREAT_TWEETS_PATH = DATA_PATH / Path('tweets-dataset-for-cyberattack-detection')

GLOVE_6B_300D_TXT = GLOVE_6B_PATH / Path('glove.6B.300d.txt')
THREAT_TWEETS_CSV = THREAT_TWEETS_PATH / Path('tweets_final.csv')


## Functions


### Preprocessing


In [4]:
def extract_keys(d, path=None):
    """
    Recursively extract keys from a dictionary, building paths as a list.

    Parameters
    ----------
    d : dict or any
        The dictionary to extract keys from.
    path : list, optional
        A list to accumulate the path, default is None.

    Returns
    -------
    list
        A list of paths representing keys in the dictionary.
    """
    if path is None:
        path = []

    if isinstance(d, dict):
        for key, value in d.items():
            path = extract_keys(value, path + [key])
    else:
        path = [d]

    return path


def build_tree(categories):
    """
    Build a tree-like structure (nested dictionary) from category labels.

    Parameters
    ----------
    categories : list of dict
        A list of categories, where each category has a 'label' key that contains a path-like string.

    Returns
    -------
    dict
        A nested dictionary representing the tree structure.
    """
    tree = {}
    for category in categories:
        current = tree
        for part in category['label'].strip('/').split('/'):
            current = current.setdefault(part, {})
    return tree


def merge_trees_with_counts(tree1, tree2, visit_count):
    """
    Merge two trees recursively and count the visits to each node.

    Parameters
    ----------
    tree1 : dict
        The first tree to be merged.
    tree2 : dict
        The second tree to be merged.
    visit_count : defaultdict(int)
        A dictionary that tracks the visit count for each node.

    Returns
    -------
    dict
        The merged tree after processing both input trees.
    """
    for key, value in tree2.items():
        if key not in tree1:
            tree1[key] = value
        elif isinstance(value, dict) and isinstance(tree1[key], dict):
            merge_trees_with_counts(tree1[key], value, visit_count)

        # Count visits for the node
        visit_count[key] += 1
    return tree1


def merge_all_trees_with_counts(trees):
    """
    Merge all trees into one general tree and count the visits to each node.

    Parameters
    ----------
    trees : list of dict
        A list of trees (dictionaries) to be merged.

    Returns
    -------
    (dict, defaultdict)
        The merged tree with all nodes, and a dictionary mapping each node to its visit count.

    """
    visit_count = defaultdict(int)
    unique_trees = [json.loads(json.dumps(tree, sort_keys=True)) for tree in trees]
    general_tree = {}

    for tree in unique_trees:
        general_tree = merge_trees_with_counts(tree1=general_tree, tree2=tree, visit_count=visit_count)

    return general_tree, visit_count


In [5]:
def load_word2vec_dict(model_path: Path, embedding_dim: int) -> dict[
    Union[str, list[str]], np.ndarray[Any, np.dtype]]:
    embeddings_dict = {}

    f = open(model_path, 'r', encoding='utf-8')
    for line in f:
        values = line.split()
        word = values[:-embedding_dim]

        if type(word) is list:
            word = ' '.join(word)

        vector = np.asarray([float(val) for val in values[-embedding_dim:]])
        embeddings_dict[word] = vector
    f.close()

    return embeddings_dict


def preprocess_texts(list_str, model_path, embedding_dim):
    word2vec_dict = load_word2vec_dict(
        model_path=model_path,
        embedding_dim=embedding_dim
    )
    list_embedded_str = np.zeros((len(list_str), embedding_dim))
    for i, text in enumerate(list_str):
        try:
            tokens = re.findall(r'\w+|[{}]'.format(re.escape(string.punctuation)), text)
            for token in tokens:
                try:
                    list_embedded_str[i] += word2vec_dict[token.lower()]
                except KeyError:
                    continue
        except:
            print(text)
            return
    return list_embedded_str


### Training


In [6]:
def train_clr_model(X, y, base_learner):
    """
    Train a CLR model using a base learner on a multi-label dataset.

    Parameters:
    - X: Feature matrix (n_samples x n_features)
    - y: List of lists (each sublist contains the labels for a sample)
    - base_learner: Base learner model (e.g., RandomForestClassifier). If None, use Random Forest.

    Returns:
    - pairwise_classifiers: Dictionary of trained pairwise classifiers {(l_i, l_j): model, ...}
    - artificial_classifiers: Dictionary of artificial label classifiers {l: model, ...}
    """
    # Get all unique labels
    all_labels = sorted(set(label for labels in y for label in labels))
    num_labels = len(all_labels)

    pairwise_classifiers = {}
    artificial_classifiers = {}

    # Pairwise classifiers
    for (l_i, l_j) in combinations(range(num_labels), 2):
        X_pair = []
        y_pair = []

        for idx, labels in enumerate(y):
            # Convert labels to binary indicators
            binary_labels = [1 if label in labels else 0 for label in all_labels]
            if binary_labels[l_i] != binary_labels[l_j]:  # One positive, one negative
                X_pair.append(X[idx])
                y_pair.append(1 if binary_labels[l_i] > binary_labels[l_j] else 0)

        if X_pair:  # Train only if dataset is non-empty
            X_pair = np.array(X_pair)
            y_pair = np.array(y_pair)
            model = copy(BASE_CLASSIFIERS[base_learner])
            pairwise_classifiers[(l_i, l_j)] = model.fit(X=X_pair, y=y_pair)

    # Artificial label classifiers
    for l_idx in range(num_labels):
        X_artificial = []
        y_artificial = []
        for idx, labels in enumerate(y):
            binary_labels = [1 if label in labels else 0 for label in all_labels]
            X_artificial.append(X[idx])
            y_artificial.append(binary_labels[l_idx])

        X_artificial = np.array(X_artificial)
        y_artificial = np.array(y_artificial)
        model = copy(BASE_CLASSIFIERS[base_learner])
        artificial_classifiers[l_idx] = model.fit(X=X_artificial, y=y_artificial)

    return pairwise_classifiers, artificial_classifiers


In [7]:
def train_cc_model(X, y, base_learner, random_order=True):
    """
    Train a Classifier Chain model for multi-label classification.

    Parameters:
    - X: Feature matrix (n_samples x n_features)
    - Y: Label matrix (n_samples x n_labels)
    - base_classifier: Base classifier to use for each chain step
    - random_order: Whether to randomize the order of labels in the chain

    Returns:
    - chain: List of trained classifiers for the chain
    - label_order: Order of labels used in the chain
    """
    num_labels = y.shape[1]
    chain = []
    label_order = np.arange(num_labels)
    X_augmented = X.copy()

    # Randomize label order if specified
    if random_order:
        np.random.shuffle(label_order)

    # Sequentially train classifiers in the chain
    for i, label_idx in enumerate(label_order):
        # Train the classifier for the current label
        clf = copy(BASE_CLASSIFIERS[base_learner])
        clf.fit(X_augmented, y[:, label_idx])
        chain.append(clf)

        # Predict current label for all instances to augment features
        Y_pred = clf.predict(X_augmented).reshape(-1, 1)

        # Augment feature space with predictions from previous classifiers
        X_augmented = np.hstack((X_augmented, Y_pred))

    return chain, label_order


### Evaluation


In [8]:
def predict_clr_model(X, pairwise_classifiers, artificial_classifiers, num_labels):
    """
    Predict rankings and bipartitions using the trained CLR model.

    Parameters:
    - X: Feature matrix for new instances (n_samples x n_features)
    - pairwise_classifiers: Trained pairwise classifiers {(l_i, l_j): model, ...}
    - artificial_classifiers: Trained artificial classifiers {l: model, ...}
    - num_labels: Total number of unique labels

    Returns:
    - rankings: List of rankings for each instance
    - bipartitions: List of relevant and irrelevant labels for each instance
    """
    rankings = []
    bipartitions = []

    for x in X:  # Iterate over each instance
        # 1. Predict pairwise preferences
        scores = np.zeros(num_labels)  # Initialize scores for each label

        for (l_i, l_j), clf in pairwise_classifiers.items():
            pred = clf.predict(x.reshape(1, -1))[0]  # Predict for this pair
            if pred == 1:  # l_i is more relevant
                scores[l_i] += 1
            else:  # l_j is more relevant
                scores[l_j] += 1

        # 2. Rank labels based on scores
        label_ranking = np.argsort(-scores)  # Sort labels by descending scores
        rankings.append(label_ranking)

        # 3. Predict relevance using artificial classifiers
        relevance = []
        for l_idx, clf in artificial_classifiers.items():
            pred = clf.predict(x.reshape(1, -1))[0]
            relevance.append(pred)

        # Split ranking into relevant and irrelevant labels
        relevant = [label for label in label_ranking if relevance[label] == 1]
        irrelevant = [label for label in label_ranking if relevance[label] == 0]

        bipartitions.append((relevant, irrelevant))

    return {'rankings': rankings, 'bipartitions': bipartitions}


def reconstruct_predictions(rankings, bipartitions, num_labels):
    """
    Reconstruct the predicted y_pred array from rankings and bipartitions.

    Parameters:
    - rankings: List of rankings for each instance (list of lists of label indices).
    - bipartitions: List of bipartitions for each instance (tuple of relevant, irrelevant labels).
    - num_labels: Total number of labels.

    Returns:
    - y_pred: Binary matrix (n_samples x num_labels) with predicted relevance.
    """
    num_instances = len(rankings)
    y_pred = np.zeros((num_instances, num_labels), dtype=int)

    for i, (ranking, bipartition) in enumerate(zip(rankings, bipartitions)):
        relevant_labels, _ = bipartition  # Extract relevant labels
        for label in relevant_labels:
            y_pred[i, label] = 1  # Mark relevant labels as 1

    return y_pred


In [9]:
def predict_cc_model(X, chain, label_order):
    """
    Predict labels using the trained Classifier Chain model.

    Parameters:
    - X: Feature matrix (n_samples x n_features)
    - chain: List of trained classifiers
    - label_order: Order of labels in the chain

    Returns:
    - Y_pred: Predicted label matrix (n_samples x n_labels)
    """
    n_samples = X.shape[0]
    num_labels = len(chain)
    y_pred = np.zeros((n_samples, num_labels), dtype=int)

    X_augmented = X.copy()  # Start with original features

    for i, clf in enumerate(chain):
        # Predict current label
        y_pred_i = clf.predict(X_augmented).reshape(-1, 1)

        # Store prediction in the correct position in Y_pred
        y_pred[:, label_order[i]] = y_pred_i.ravel()

        # Augment features for the next classifier
        X_augmented = np.hstack((X_augmented, y_pred_i))

    return y_pred


## 1. Introduction

In this notebook, we are going to solve a multi-class classification problem using different machine learning models. Our goal is to predict the class of each sample based on the input features.


## 2. Data Loading and Preprocessing
We will load the dataset, inspect its structure, and preprocess it for machine learning models.


In [10]:
# Read the CSV file and process columns in one step
threat_tweets = (
    pandas.read_csv(filepath_or_buffer=THREAT_TWEETS_CSV)
    .assign(
        tweet=lambda df: df['tweet'].apply(func=ast.literal_eval),
        watson=lambda df: df['watson'].apply(func=ast.literal_eval)
        .apply(func=lambda x: x.get('categories', []))
        .apply(func=build_tree),
        watson_list=lambda df: df['watson'].apply(func=extract_keys),
    )
    .query(expr='relevant == True')
    .drop(labels=['relevant'], axis=1)
    .dropna(subset=['text'], ignore_index=True)
)

threat_tweets.head()


Unnamed: 0,_id,date,id,text,tweet,type,watson,annotation,urls,destination_url,valid_certificate,watson_list
0,b'5b8876f9bb325e65fa7e78e4',2018-08-30 23:00:08+00:00,1035301167952211969,Protect your customers access Prestashop Ant...,{'created_at': 'Thu Aug 30 23:00:08 +0000 2018...,ddos,{'technology and computing': {'internet techno...,threat,['http://addons.prestashop.com/en/23513-anti-d...,https://addons.prestashop.com/en/23513-anti-dd...,True,"[technology and computing, internet technology..."
1,b'5b8876f9bb325e65fa7e78e5',2018-08-30 23:00:09+00:00,1035301173178249217,Data leak from Huazhu Hotels may affect 130 mi...,{'created_at': 'Thu Aug 30 23:00:09 +0000 2018...,leak,"{'travel': {'hotels': {}}, 'home and garden': ...",threat,['http://www.hotelmanagement.net/tech/data-lea...,http://www.hotelmanagement.net/tech/data-leak-...,True,"[travel, hotels, home and garden, home improve..."
2,b'5b8876fabb325e65fa7e78e6',2018-08-30 23:00:09+00:00,1035301174583353344,Instagram App 41.1788.50991.0 #Denial Of #Serv...,{'created_at': 'Thu Aug 30 23:00:09 +0000 2018...,general,{'science': {'weather': {'meteorological disas...,threat,['https://packetstormsecurity.com/files/149120...,https://packetstormsecurity.com/files/149120/i...,True,"[science, weather, meteorological disaster, hu..."
3,b'5b88770abb325e65fa7e78e7',2018-08-30 23:00:25+00:00,1035301242271096832,(good slides): \n\nThe Advanced Exploitation o...,{'created_at': 'Thu Aug 30 23:00:25 +0000 2018...,vulnerability,{'business and industrial': {'business operati...,threat,['https://twitter.com/i/web/status/10353012422...,https://twitter.com/i/web/status/1035301242271...,True,"[business and industrial, business operations,..."
4,b'5b887713bb325e65fa7e78e8',2018-08-30 23:00:35+00:00,1035301282095853569,CVE-2018-1000532 (beep)\nhttps://t.co/CaKbo38U...,{'created_at': 'Thu Aug 30 23:00:35 +0000 2018...,vulnerability,{'technology and computing': {'computer securi...,threat,['https://web.nvd.nist.gov/view/vuln/detail?vu...,https://nvd.nist.gov/vuln/detail/CVE-2018-1000532,True,"[technology and computing, computer security, ..."


In [11]:
print(f"Number of CS related tweets:\t{len(threat_tweets)}")


Number of CS related tweets:	11112


In [12]:
general_tree, visit_count = merge_all_trees_with_counts(threat_tweets['watson'])


In [13]:
print("The subcategories in 'technology and computing' are:")
for category in list(general_tree['technology and computing'].keys()):
    print(f'· {category}')


The subcategories in 'technology and computing' are:
· computer security
· internet technology
· software
· hardware
· operating systems
· data centers
· mp3 and midi
· computer reviews
· programming languages
· consumer electronics
· tech news
· networking
· electronic components
· computer crime
· enterprise technology
· computer certification
· technological innovation
· technical support


In [14]:
sorted_visit_count = dict(sorted(visit_count.items(), key=lambda item: item[1], reverse=True))

with open('general_tree.json', 'w') as file:
    file.write(json.dumps(general_tree, indent=4))

with open('general_tree_visit_counts.json', 'w') as file:
    file.write(json.dumps(sorted_visit_count, indent=4))


## 3. Exploratory Data Analysis (EDA)
Let's analyze the dataset and gain insights into its distribution.


In [15]:
print('At macro categories are:')
for category in list(general_tree.keys()):
    print(f'· {category}')


At macro categories are:
· technology and computing
· health and fitness
· home and garden
· travel
· art and entertainment
· science
· business and industrial
· sports
· finance
· law, govt and politics
· society
· real estate
· pets
· style and fashion
· news
· hobbies and interests
· food and drink
· education
· shopping
· family and parenting
· religion and spirituality
· automotive and vehicles
· careers


For the goal of the project, the categories of interest are:
1. computer security/network security
2. computer security/antivirus and malware
3. operating systems/mac os
4. operating systems/windows
5. operating systems/unix
6. operating systems/linux
7. software
8. programming languages, included in software
9. software/databases
10. hardware
11. electronic components, included in hardware
12. hardware/computer/servers
13. hardware/computer/portable computer
14. hardware/computer/desktop computer
15. hardware/computer components
16. hardware/computer networking/router
17. hardware/computer networking/wireless technology
18. networking
19. internet technology, included in networking


In [16]:
FIX_TARGETS = {
    'computer security': 'computer security',
    'operating systems': 'operating systems',
    'software': 'software',
    'programming languages': 'software',
    'hardware': 'hardware',
    'electronic components': 'hardware',
    'networking': 'networking',
    'internet technology': 'networking'
}

chosen_categories = [
    list(set(FIX_TARGETS.keys()) & set(s))
    for s in threat_tweets['watson_list']
]

for i, watson_list in enumerate(chosen_categories):
    temp = list(set([FIX_TARGETS[c] for c in watson_list]))
    if len(temp) < 1:
        temp = ['other']
    chosen_categories[i] = temp

threat_tweets['target'] = chosen_categories
threat_tweets.head()


Unnamed: 0,_id,date,id,text,tweet,type,watson,annotation,urls,destination_url,valid_certificate,watson_list,target
0,b'5b8876f9bb325e65fa7e78e4',2018-08-30 23:00:08+00:00,1035301167952211969,Protect your customers access Prestashop Ant...,{'created_at': 'Thu Aug 30 23:00:08 +0000 2018...,ddos,{'technology and computing': {'internet techno...,threat,['http://addons.prestashop.com/en/23513-anti-d...,https://addons.prestashop.com/en/23513-anti-dd...,True,"[technology and computing, internet technology...","[computer security, networking, software]"
1,b'5b8876f9bb325e65fa7e78e5',2018-08-30 23:00:09+00:00,1035301173178249217,Data leak from Huazhu Hotels may affect 130 mi...,{'created_at': 'Thu Aug 30 23:00:09 +0000 2018...,leak,"{'travel': {'hotels': {}}, 'home and garden': ...",threat,['http://www.hotelmanagement.net/tech/data-lea...,http://www.hotelmanagement.net/tech/data-leak-...,True,"[travel, hotels, home and garden, home improve...",[other]
2,b'5b8876fabb325e65fa7e78e6',2018-08-30 23:00:09+00:00,1035301174583353344,Instagram App 41.1788.50991.0 #Denial Of #Serv...,{'created_at': 'Thu Aug 30 23:00:09 +0000 2018...,general,{'science': {'weather': {'meteorological disas...,threat,['https://packetstormsecurity.com/files/149120...,https://packetstormsecurity.com/files/149120/i...,True,"[science, weather, meteorological disaster, hu...",[hardware]
3,b'5b88770abb325e65fa7e78e7',2018-08-30 23:00:25+00:00,1035301242271096832,(good slides): \n\nThe Advanced Exploitation o...,{'created_at': 'Thu Aug 30 23:00:25 +0000 2018...,vulnerability,{'business and industrial': {'business operati...,threat,['https://twitter.com/i/web/status/10353012422...,https://twitter.com/i/web/status/1035301242271...,True,"[business and industrial, business operations,...",[operating systems]
4,b'5b887713bb325e65fa7e78e8',2018-08-30 23:00:35+00:00,1035301282095853569,CVE-2018-1000532 (beep)\nhttps://t.co/CaKbo38U...,{'created_at': 'Thu Aug 30 23:00:35 +0000 2018...,vulnerability,{'technology and computing': {'computer securi...,threat,['https://web.nvd.nist.gov/view/vuln/detail?vu...,https://nvd.nist.gov/vuln/detail/CVE-2018-1000532,True,"[technology and computing, computer security, ...","[computer security, hardware, software]"


In [17]:
X = preprocess_texts(
    list_str=threat_tweets['text'],
    model_path=GLOVE_6B_300D_TXT,
    embedding_dim=300
)


## 4. Model Training

We will now train different models and evaluate their performance.


In [18]:
mlb = MultiLabelBinarizer()
y = mlb.fit_transform(threat_tweets['target'])


In [19]:
X_train_val, y_train_val, X_test, y_test = iterative_train_test_split(
    X, y,
    test_size=TEST_SIZE
)

X_train, y_train, X_val, y_val = iterative_train_test_split(
    X_train_val, y_train_val,
    test_size=TEST_SIZE
)


### 4.1. Logistic Regression


In [20]:
# Binary Relevance (BR)
classifier_lr = OneVsRestClassifier(estimator=BASE_CLASSIFIERS['logistic_regression'])
classifier_lr.fit(X=X_train, y=y_train)


In [21]:
# Calibrated Label Ranking (CLR)
pairwise_classifiers_lr, artificial_classifiers_lr = train_clr_model(
    X=X_train,
    y=[list(mlb.classes_[np.where(row == 1)[0]]) for row in y_train],
    base_learner='logistic_regression'
)

print(f"Trained {len(pairwise_classifiers_lr)} pairwise classifiers.")
print(f"Trained {len(artificial_classifiers_lr)} artificial classifiers.")


Trained 15 pairwise classifiers.
Trained 6 artificial classifiers.


In [22]:
# Classifier Chains (CC)
chain_lr, label_order_lr = train_cc_model(
    X=X_train,
    y=y_train,
    base_learner='logistic_regression'
)

print("Trained chain classifiers:", len(chain_lr))
print("Label order in the chain:", label_order_lr)


Trained chain classifiers: 6
Label order in the chain: [3 1 0 5 2 4]


### 4.2. Gaussian Naïve Bayes


In [23]:
# Binary Relevance (BR)
classifier_gnb = OneVsRestClassifier(estimator=BASE_CLASSIFIERS['gaussian_naive_bayes'])
classifier_gnb.fit(X=X_train, y=y_train)


In [24]:
# Calibrated Label Ranking (CLR)
pairwise_classifiers_gnb, artificial_classifiers_gnb = train_clr_model(
    X=X_train,
    y=[list(mlb.classes_[np.where(row == 1)[0]]) for row in y_train],
    base_learner='gaussian_naive_bayes'
)

print(f"Trained {len(pairwise_classifiers_gnb)} pairwise classifiers.")
print(f"Trained {len(artificial_classifiers_gnb)} artificial classifiers.")


Trained 15 pairwise classifiers.
Trained 6 artificial classifiers.


In [25]:
# Classifier Chains (CC)
chain_gnb, label_order_gnb = train_cc_model(
    X=X_train,
    y=y_train,
    base_learner='gaussian_naive_bayes'
)

print("Trained chain classifiers:", len(chain_gnb))
print("Label order in the chain:", label_order_gnb)


Trained chain classifiers: 6
Label order in the chain: [5 1 2 0 4 3]


### 4.3. Decision Tree Classifier


In [26]:
# Binary Relevance (BR)
classifier_dt = OneVsRestClassifier(estimator=BASE_CLASSIFIERS['decision_tree'])
classifier_dt.fit(X=X_train, y=y_train)


In [27]:
# Calibrated Label Ranking (CLR)
pairwise_classifiers_dt, artificial_classifiers_dt = train_clr_model(
    X=X_train,
    y=[list(mlb.classes_[np.where(row == 1)[0]]) for row in y_train],
    base_learner='decision_tree'
)

print(f"Trained {len(pairwise_classifiers_dt)} pairwise classifiers.")
print(f"Trained {len(artificial_classifiers_dt)} artificial classifiers.")


Trained 15 pairwise classifiers.
Trained 6 artificial classifiers.


In [28]:
# Classifier Chains (CC)
chain_dt, label_order_dt = train_cc_model(
    X=X_train,
    y=y_train,
    base_learner='decision_tree'
)

print("Trained chain classifiers:", len(chain_dt))
print("Label order in the chain:", label_order_dt)


Trained chain classifiers: 6
Label order in the chain: [5 4 2 1 3 0]


### 4.4. Random Forest Classifier


In [29]:
# Binary Relevance (BR)
classifier_rf = OneVsRestClassifier(estimator=BASE_CLASSIFIERS['random_forest'])
classifier_rf.fit(X=X_train, y=y_train)


In [30]:
# Calibrated Label Ranking (CLR)
pairwise_classifiers_rf, artificial_classifiers_rf = train_clr_model(
    X=X_train,
    y=[list(mlb.classes_[np.where(row == 1)[0]]) for row in y_train],
    base_learner='random_forest'
)

print(f"Trained {len(pairwise_classifiers_rf)} pairwise classifiers.")
print(f"Trained {len(artificial_classifiers_rf)} artificial classifiers.")


Trained 15 pairwise classifiers.
Trained 6 artificial classifiers.


In [31]:
# Classifier Chains (CC)
chain_rf, label_order_rf = train_cc_model(
    X=X_train,
    y=y_train,
    base_learner='random_forest'
)

print("Trained chain classifiers:", len(chain_rf))
print("Label order in the chain:", label_order_rf)


Trained chain classifiers: 6
Label order in the chain: [0 5 4 1 2 3]


### 4.5. eXtreme Gradient Boosting Classifier


In [32]:
# Binary Relevance (BR)
classifier_xgb = OneVsRestClassifier(estimator=BASE_CLASSIFIERS['xgb'])
classifier_xgb.fit(X=X_train, y=y_train)


In [33]:
# Calibrated Label Ranking (CLR)
pairwise_classifiers_xgb, artificial_classifiers_xgb = train_clr_model(
    X=X_train,
    y=[list(mlb.classes_[np.where(row == 1)[0]]) for row in y_train],
    base_learner='xgb'
)

print(f"Trained {len(pairwise_classifiers_xgb)} pairwise classifiers.")
print(f"Trained {len(artificial_classifiers_xgb)} artificial classifiers.")


Trained 15 pairwise classifiers.
Trained 6 artificial classifiers.


In [34]:
# Classifier Chains (CC)
chain_xgb, label_order_xgb = train_cc_model(
    X=X_train,
    y=y_train,
    base_learner='xgb'
)

print("Trained chain classifiers:", len(chain_xgb))
print("Label order in the chain:", label_order_xgb)


Trained chain classifiers: 6
Label order in the chain: [3 1 0 2 5 4]


## 5. Model Evaluation

Now that we've trained the models, let's evaluate them in more detail.


### 5.1. Logistic Regression


In [35]:
# Binary Relevance (BR)
y_pred = classifier_lr.predict(X=X_val)
accuracy = accuracy_score(y_true=y_val, y_pred=y_pred)
rocauc = roc_auc_score(
    y_true=y_val,
    y_score=y_pred,
    average='weighted'
)
classification = classification_report(
    y_true=y_val,
    y_pred=y_pred,
    target_names=mlb.classes_,
    zero_division=0
)

print(f"Accuracy:\t{accuracy:.2f}")
print(f"AUC:\t{rocauc:.2f}")
print(classification)


Accuracy:	0.50
AUC:	0.78
                   precision    recall  f1-score   support

computer security       0.83      0.83      0.83       931
         hardware       0.65      0.42      0.51       355
       networking       0.58      0.36      0.44       146
operating systems       0.76      0.65      0.70       171
            other       0.77      0.63      0.69       361
         software       0.77      0.71      0.74       644

        micro avg       0.77      0.68      0.72      2608
        macro avg       0.73      0.60      0.65      2608
     weighted avg       0.76      0.68      0.72      2608
      samples avg       0.72      0.69      0.68      2608



In [36]:
# Calibrated Label Ranking (CLR)
y_pred = reconstruct_predictions(
    **predict_clr_model(
        X=X_val,
        pairwise_classifiers=pairwise_classifiers_lr,
        artificial_classifiers=artificial_classifiers_lr,
        num_labels=6
    ),
    num_labels=6
)
accuracy = accuracy_score(y_true=y_val, y_pred=y_pred)
rocauc = roc_auc_score(
    y_true=y_val,
    y_score=y_pred,
    average='weighted'
)
classification = classification_report(
    y_true=y_val,
    y_pred=y_pred,
    target_names=mlb.classes_,
    zero_division=0
)

print(f"Accuracy:\t{accuracy:.2f}")
print(f"AUC:\t{rocauc:.2f}")
print(classification)


Accuracy:	0.50
AUC:	0.78
                   precision    recall  f1-score   support

computer security       0.83      0.83      0.83       931
         hardware       0.65      0.42      0.51       355
       networking       0.58      0.36      0.44       146
operating systems       0.76      0.65      0.70       171
            other       0.77      0.63      0.69       361
         software       0.77      0.71      0.74       644

        micro avg       0.77      0.68      0.72      2608
        macro avg       0.73      0.60      0.65      2608
     weighted avg       0.76      0.68      0.72      2608
      samples avg       0.72      0.69      0.68      2608



In [37]:
# Classifier Chains (CC)
y_pred = predict_cc_model(X_val, chain_lr, label_order_lr)
accuracy = accuracy_score(y_true=y_val, y_pred=y_pred)
rocauc = roc_auc_score(
    y_true=y_val,
    y_score=y_pred,
    average='weighted'
)
classification = classification_report(
    y_true=y_val,
    y_pred=y_pred,
    target_names=mlb.classes_,
    zero_division=0
)

print(f"Accuracy:\t{accuracy:.2f}")
print(f"AUC:\t{rocauc:.2f}")
print(classification)


Accuracy:	0.51
AUC:	0.78
                   precision    recall  f1-score   support

computer security       0.83      0.84      0.83       931
         hardware       0.65      0.40      0.50       355
       networking       0.60      0.38      0.46       146
operating systems       0.76      0.65      0.70       171
            other       0.77      0.65      0.70       361
         software       0.77      0.70      0.74       644

        micro avg       0.77      0.68      0.72      2608
        macro avg       0.73      0.60      0.65      2608
     weighted avg       0.76      0.68      0.72      2608
      samples avg       0.73      0.69      0.69      2608



### 5.2. Gaussian Naïve Bayes


In [38]:
# Binary Relevance (BR)
y_pred = classifier_gnb.predict(X=X_val)
accuracy = accuracy_score(y_true=y_val, y_pred=y_pred)
rocauc = roc_auc_score(
    y_true=y_val,
    y_score=y_pred,
    average='weighted'
)
classification = classification_report(
    y_true=y_val,
    y_pred=y_pred,
    target_names=mlb.classes_,
    zero_division=0
)

print(f"Accuracy:\t{accuracy:.2f}")
print(f"AUC:\t{rocauc:.2f}")
print(classification)


Accuracy:	0.12
AUC:	0.62
                   precision    recall  f1-score   support

computer security       0.60      0.79      0.68       931
         hardware       0.24      0.72      0.36       355
       networking       0.14      0.59      0.23       146
operating systems       0.22      0.73      0.33       171
            other       0.46      0.27      0.34       361
         software       0.48      0.81      0.60       644

        micro avg       0.38      0.70      0.49      2608
        macro avg       0.36      0.65      0.42      2608
     weighted avg       0.45      0.70      0.52      2608
      samples avg       0.41      0.67      0.48      2608



In [39]:
# Calibrated Label Ranking (CLR)
y_pred = reconstruct_predictions(
    **predict_clr_model(
        X=X_val,
        pairwise_classifiers=pairwise_classifiers_gnb,
        artificial_classifiers=artificial_classifiers_gnb,
        num_labels=6
    ),
    num_labels=6
)
accuracy = accuracy_score(y_true=y_val, y_pred=y_pred)
rocauc = roc_auc_score(
    y_true=y_val,
    y_score=y_pred,
    average='weighted'
)
classification = classification_report(
    y_true=y_val,
    y_pred=y_pred,
    target_names=mlb.classes_,
    zero_division=0
)

print(f"Accuracy:\t{accuracy:.2f}")
print(f"AUC:\t{rocauc:.2f}")
print(classification)


Accuracy:	0.12
AUC:	0.62
                   precision    recall  f1-score   support

computer security       0.60      0.79      0.68       931
         hardware       0.24      0.72      0.36       355
       networking       0.14      0.59      0.23       146
operating systems       0.22      0.73      0.33       171
            other       0.46      0.27      0.34       361
         software       0.48      0.81      0.60       644

        micro avg       0.38      0.70      0.49      2608
        macro avg       0.36      0.65      0.42      2608
     weighted avg       0.45      0.70      0.52      2608
      samples avg       0.41      0.67      0.48      2608



In [40]:
# Classifier Chains (CC)
y_pred = predict_cc_model(X_val, chain_gnb, label_order_gnb)
accuracy = accuracy_score(y_true=y_val, y_pred=y_pred)
rocauc = roc_auc_score(
    y_true=y_val,
    y_score=y_pred,
    average='weighted'
)
classification = classification_report(
    y_true=y_val,
    y_pred=y_pred,
    target_names=mlb.classes_,
    zero_division=0
)

print(f"Accuracy:\t{accuracy:.2f}")
print(f"AUC:\t{rocauc:.2f}")
print(classification)


Accuracy:	0.12
AUC:	0.62
                   precision    recall  f1-score   support

computer security       0.60      0.79      0.68       931
         hardware       0.24      0.72      0.36       355
       networking       0.14      0.58      0.23       146
operating systems       0.21      0.73      0.33       171
            other       0.46      0.28      0.35       361
         software       0.48      0.81      0.60       644

        micro avg       0.38      0.70      0.49      2608
        macro avg       0.36      0.65      0.42      2608
     weighted avg       0.45      0.70      0.52      2608
      samples avg       0.41      0.67      0.48      2608



### 5.3. Decision Tree Classifier


In [41]:
# Binary Relevance (BR)
y_pred = classifier_dt.predict(X=X_val)
accuracy = accuracy_score(y_true=y_val, y_pred=y_pred)
rocauc = roc_auc_score(
    y_true=y_val,
    y_score=y_pred,
    average='weighted'
)
classification = classification_report(
    y_true=y_val,
    y_pred=y_pred,
    target_names=mlb.classes_,
    zero_division=0
)

print(f"Accuracy:\t{accuracy:.2f}")
print(f"AUC:\t{rocauc:.2f}")
print(classification)


Accuracy:	0.45
AUC:	0.75
                   precision    recall  f1-score   support

computer security       0.77      0.76      0.76       931
         hardware       0.53      0.54      0.54       355
       networking       0.51      0.50      0.50       146
operating systems       0.49      0.58      0.53       171
            other       0.59      0.60      0.60       361
         software       0.70      0.72      0.71       644

        micro avg       0.66      0.67      0.66      2608
        macro avg       0.60      0.62      0.61      2608
     weighted avg       0.66      0.67      0.67      2608
      samples avg       0.64      0.67      0.63      2608



In [42]:
# Calibrated Label Ranking (CLR)
y_pred = reconstruct_predictions(
    **predict_clr_model(
        X=X_val,
        pairwise_classifiers=pairwise_classifiers_dt,
        artificial_classifiers=artificial_classifiers_dt,
        num_labels=6
    ),
    num_labels=6
)
accuracy = accuracy_score(y_true=y_val, y_pred=y_pred)
rocauc = roc_auc_score(
    y_true=y_val,
    y_score=y_pred,
    average='weighted'
)
classification = classification_report(
    y_true=y_val,
    y_pred=y_pred,
    target_names=mlb.classes_,
    zero_division=0
)

print(f"Accuracy:\t{accuracy:.2f}")
print(f"AUC:\t{rocauc:.2f}")
print(classification)


Accuracy:	0.45
AUC:	0.75
                   precision    recall  f1-score   support

computer security       0.77      0.76      0.76       931
         hardware       0.53      0.54      0.54       355
       networking       0.51      0.50      0.50       146
operating systems       0.49      0.58      0.53       171
            other       0.59      0.60      0.60       361
         software       0.70      0.72      0.71       644

        micro avg       0.66      0.67      0.66      2608
        macro avg       0.60      0.62      0.61      2608
     weighted avg       0.66      0.67      0.67      2608
      samples avg       0.64      0.67      0.63      2608



In [43]:
# Classifier Chains (CC)
y_pred = predict_cc_model(X_val, chain_dt, label_order_dt)
accuracy = accuracy_score(y_true=y_val, y_pred=y_pred)
rocauc = roc_auc_score(
    y_true=y_val,
    y_score=y_pred,
    average='weighted'
)
classification = classification_report(
    y_true=y_val,
    y_pred=y_pred,
    target_names=mlb.classes_,
    zero_division=0
)

print(f"Accuracy:\t{accuracy:.2f}")
print(f"AUC:\t{rocauc:.2f}")
print(classification)


Accuracy:	0.52
AUC:	0.75
                   precision    recall  f1-score   support

computer security       0.74      0.78      0.76       931
         hardware       0.56      0.54      0.55       355
       networking       0.40      0.50      0.45       146
operating systems       0.51      0.58      0.54       171
            other       0.67      0.58      0.62       361
         software       0.70      0.72      0.71       644

        micro avg       0.66      0.68      0.67      2608
        macro avg       0.60      0.62      0.60      2608
     weighted avg       0.66      0.68      0.67      2608
      samples avg       0.67      0.67      0.66      2608



### 5.4. Random Forest Classifier


In [44]:
# Binary Relevance (BR)
y_pred = classifier_rf.predict(X=X_val)
accuracy = accuracy_score(y_true=y_val, y_pred=y_pred)
rocauc = roc_auc_score(
    y_true=y_val,
    y_score=y_pred,
    average='weighted'
)
classification = classification_report(
    y_true=y_val,
    y_pred=y_pred,
    target_names=mlb.classes_,
    zero_division=0
)

print(f"Accuracy:\t{accuracy:.2f}")
print(f"AUC:\t{rocauc:.2f}")
print(classification)


Accuracy:	0.55
AUC:	0.78
                   precision    recall  f1-score   support

computer security       0.82      0.87      0.84       931
         hardware       0.96      0.40      0.56       355
       networking       0.95      0.37      0.53       146
operating systems       0.96      0.40      0.57       171
            other       0.94      0.43      0.59       361
         software       0.89      0.73      0.80       644

        micro avg       0.87      0.65      0.74      2608
        macro avg       0.92      0.53      0.65      2608
     weighted avg       0.89      0.65      0.72      2608
      samples avg       0.71      0.65      0.67      2608



In [45]:
# Calibrated Label Ranking (CLR)
y_pred = reconstruct_predictions(
    **predict_clr_model(
        X=X_val,
        pairwise_classifiers=pairwise_classifiers_rf,
        artificial_classifiers=artificial_classifiers_rf,
        num_labels=6
    ),
    num_labels=6
)
accuracy = accuracy_score(y_true=y_val, y_pred=y_pred)
rocauc = roc_auc_score(
    y_true=y_val,
    y_score=y_pred,
    average='weighted'
)
classification = classification_report(
    y_true=y_val,
    y_pred=y_pred,
    target_names=mlb.classes_,
    zero_division=0
)

print(f"Accuracy:\t{accuracy:.2f}")
print(f"AUC:\t{rocauc:.2f}")
print(classification)


Accuracy:	0.55
AUC:	0.78
                   precision    recall  f1-score   support

computer security       0.82      0.87      0.84       931
         hardware       0.96      0.40      0.56       355
       networking       0.95      0.37      0.53       146
operating systems       0.96      0.40      0.57       171
            other       0.94      0.43      0.59       361
         software       0.89      0.73      0.80       644

        micro avg       0.87      0.65      0.74      2608
        macro avg       0.92      0.53      0.65      2608
     weighted avg       0.89      0.65      0.72      2608
      samples avg       0.71      0.65      0.67      2608



In [46]:
# Classifier Chains (CC)
y_pred = predict_cc_model(X_val, chain_rf, label_order_rf)
accuracy = accuracy_score(y_true=y_val, y_pred=y_pred)
rocauc = roc_auc_score(
    y_true=y_val,
    y_score=y_pred,
    average='weighted'
)
classification = classification_report(
    y_true=y_val,
    y_pred=y_pred,
    target_names=mlb.classes_,
    zero_division=0
)

print(f"Accuracy:\t{accuracy:.2f}")
print(f"AUC:\t{rocauc:.2f}")
print(classification)


Accuracy:	0.59
AUC:	0.79
                   precision    recall  f1-score   support

computer security       0.82      0.87      0.84       931
         hardware       0.96      0.39      0.56       355
       networking       0.93      0.35      0.51       146
operating systems       0.90      0.42      0.57       171
            other       0.87      0.60      0.71       361
         software       0.89      0.74      0.81       644

        micro avg       0.86      0.68      0.76      2608
        macro avg       0.89      0.56      0.67      2608
     weighted avg       0.87      0.68      0.74      2608
      samples avg       0.74      0.68      0.70      2608



### 5.5. eXtreme Gradient Boosting Classifier


In [47]:
# Binary Relevance (BR)
y_pred = classifier_xgb.predict(X=X_val)
accuracy = accuracy_score(y_true=y_val, y_pred=y_pred)
rocauc = roc_auc_score(
    y_true=y_val,
    y_score=y_pred,
    average='weighted'
)
classification = classification_report(
    y_true=y_val,
    y_pred=y_pred,
    target_names=mlb.classes_,
    zero_division=0
)

print(f"Accuracy:\t{accuracy:.2f}")
print(f"AUC:\t{rocauc:.2f}")
print(classification)


Accuracy:	0.59
AUC:	0.81
                   precision    recall  f1-score   support

computer security       0.85      0.87      0.86       931
         hardware       0.87      0.50      0.63       355
       networking       0.95      0.42      0.58       146
operating systems       0.92      0.60      0.73       171
            other       0.90      0.56      0.69       361
         software       0.84      0.78      0.81       644

        micro avg       0.86      0.71      0.78      2608
        macro avg       0.89      0.62      0.72      2608
     weighted avg       0.87      0.71      0.77      2608
      samples avg       0.74      0.70      0.71      2608



In [48]:
# Calibrated Label Ranking (CLR)
y_pred = reconstruct_predictions(
    **predict_clr_model(
        X=X_val,
        pairwise_classifiers=pairwise_classifiers_xgb,
        artificial_classifiers=artificial_classifiers_xgb,
        num_labels=6
    ),
    num_labels=6
)
accuracy = accuracy_score(y_true=y_val, y_pred=y_pred)
rocauc = roc_auc_score(
    y_true=y_val,
    y_score=y_pred,
    average='weighted'
)
classification = classification_report(
    y_true=y_val,
    y_pred=y_pred,
    target_names=mlb.classes_,
    zero_division=0
)

print(f"Accuracy:\t{accuracy:.2f}")
print(f"AUC:\t{rocauc:.2f}")
print(classification)


Accuracy:	0.59
AUC:	0.81
                   precision    recall  f1-score   support

computer security       0.85      0.87      0.86       931
         hardware       0.87      0.50      0.63       355
       networking       0.95      0.42      0.58       146
operating systems       0.92      0.60      0.73       171
            other       0.90      0.56      0.69       361
         software       0.84      0.78      0.81       644

        micro avg       0.86      0.71      0.78      2608
        macro avg       0.89      0.62      0.72      2608
     weighted avg       0.87      0.71      0.77      2608
      samples avg       0.74      0.70      0.71      2608



In [49]:
# Classifier Chains (CC)
y_pred = predict_cc_model(X_val, chain_xgb, label_order_xgb)
accuracy = accuracy_score(y_true=y_val, y_pred=y_pred)
rocauc = roc_auc_score(
    y_true=y_val,
    y_score=y_pred,
    average='weighted'
)
classification = classification_report(
    y_true=y_val,
    y_pred=y_pred,
    target_names=mlb.classes_,
    zero_division=0
)

print(f"Accuracy:\t{accuracy:.2f}")
print(f"AUC:\t{rocauc:.2f}")
print(classification)


Accuracy:	0.66
AUC:	0.83
                   precision    recall  f1-score   support

computer security       0.85      0.88      0.87       931
         hardware       0.86      0.50      0.63       355
       networking       0.94      0.46      0.62       146
operating systems       0.92      0.60      0.73       171
            other       0.70      0.83      0.76       361
         software       0.82      0.79      0.80       644

        micro avg       0.82      0.76      0.79      2608
        macro avg       0.85      0.68      0.73      2608
     weighted avg       0.83      0.76      0.78      2608
      samples avg       0.80      0.77      0.77      2608

