# AI Chatbot for MITRE ATT&CK Threat Classification and Organizational Impact Analysis

In [1]:
from tqdm.auto import tqdm
import torch, torchtext
from torch import nn
import torch.nn.functional as F
import random, math, time
from datasets import load_dataset
import pandas as pd
import numpy as np
import re

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print(device)

#make our work comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

cuda:1


# 1. Load Dataset

In [2]:
torch.__version__

'2.3.0+cu121'

In [3]:
from datasets import load_dataset

# Load the dataset (replace 'your_dataset_name' with the actual name)
dataset = load_dataset('tumeteor/Security-TTP-Mapping')

# # Optionally, select a specific split or a subset
# dataset = dataset['train']  # or 'test', depending on the split

# # Optionally select a specific range
# dataset = dataset.select(range(10000))

# Display the dataset
print(dataset)


DatasetDict({
    train: Dataset({
        features: ['text1', 'labels'],
        num_rows: 14936
    })
    validation: Dataset({
        features: ['text1', 'labels'],
        num_rows: 2630
    })
    test: Dataset({
        features: ['text1', 'labels'],
        num_rows: 3170
    })
})


## EDA

In [6]:
print(dataset['train'][1000]['text1'])
print(dataset['train'][1000]['labels'])

PLEAD also dabbled with a short-lived, fileless version of their malware when it obtained an exploit for a Flash vulnerability (CVE-2015-5119) that was leaked during the Hacking Team breach
['T1203']


In [7]:
print(dataset['train'][5]['text1'])
print(dataset['train'][5]['labels'])

In older versions, Valak downloads the second stage JS and uses only one obfuscation technique: Base64. The newer versions use XOR in addition to Base64
['T1027']


In [8]:
print(dataset['train'][:5]['text1'])
print(dataset['train'][:5]['labels'])

['The command processing function starts by substituting the main module name and path in the hosting process PEB, with the one of the default internet browser. The path of the main browser of the workstation is obtained by reading the registry value', 'Along the way, HermeticWiper’s more mundane operations provide us with further IOCs to monitor for. These include the momentary creation of the abused driver as well as a system service. It also modifies several registry keys, including setting the SYSTEM\\CurrentControlSet\\Control\\CrashControl CrashDumpEnabled key to 0, effectively disabling crash dumps before the abused driver’s execution starts', 'These Microsoft Office templates are hosted on a command and control server and the downloaded link is embedded in the first stage malicious document', 'Additionally, the IP 211[.]72 [.]242[.]120 is one of the hosts for the domain microsoftmse[.]com, which has been used by several KIVARS variants', 'When communicating with its C2 server, 

In [9]:
train = [(row['text1'], row['labels']) for row in dataset['train']]

In [10]:
#let's take a look at one example of train
sample = next(iter(train))
sample

('The command processing function starts by substituting the main module name and path in the hosting process PEB, with the one of the default internet browser. The path of the main browser of the workstation is obtained by reading the registry value',
 "['T1057']")

# 2. Preprocess Data

In [11]:
from datasets import load_dataset
from transformers import BertTokenizer

# Load the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Function to tokenize the text data
def tokenize_function(examples):
    return tokenizer(examples['text1'], padding="max_length", truncation=True, max_length=64)

# Apply tokenization to train and validation datasets
train_data = dataset['train'].map(tokenize_function, batched=True)
val_data = dataset['validation'].map(tokenize_function, batched=True)
test_data = dataset['test'].map(tokenize_function, batched=True)

# Print an example to verify
print(train_data[0]['text1'])  # It should show tokenized input

# # Decode the tokenized text back to human-readable text
# decoded_text = tokenizer.decode(train_data[0]['input_ids'], skip_special_tokens=True)
# print(f"Decoded Text: {decoded_text}")




The command processing function starts by substituting the main module name and path in the hosting process PEB, with the one of the default internet browser. The path of the main browser of the workstation is obtained by reading the registry value


In [12]:
from transformers import BertTokenizer

# Load the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Example sentence
text = "The command processing function starts by substituting the main module name and path in the hosting process PEB, with the one of the default internet browser. The path of the main browser of the workstation is obtained by reading the registry value."

# Tokenize the text
tokens = tokenizer.tokenize(text)
print(tokens)


['the', 'command', 'processing', 'function', 'starts', 'by', 'sub', '##stituting', 'the', 'main', 'module', 'name', 'and', 'path', 'in', 'the', 'hosting', 'process', 'pe', '##b', ',', 'with', 'the', 'one', 'of', 'the', 'default', 'internet', 'browser', '.', 'the', 'path', 'of', 'the', 'main', 'browser', 'of', 'the', 'works', '##tation', 'is', 'obtained', 'by', 'reading', 'the', 'registry', 'value', '.']


# 3. Tokenizer and Model

In [13]:
# Extract unique labels (MITRE techniques) from both train and validation datasets
labels = list(set(dataset['train']['labels']).union(set(dataset['validation']['labels'])))  # Extract unique labels
label_map = {label: i for i, label in enumerate(labels)}

# Function to encode labels into integers
def encode_labels(examples):
    # Safely map the labels, providing a default value if a label is not found in the label_map
    examples['labels'] = [label_map.get(label, -1) for label in examples['labels']]
    return examples

# Apply label encoding
train_data = train_data.map(encode_labels, batched=True)
val_data = val_data.map(encode_labels, batched=True)

# Print an example to verify
print(train_data[0])  # It should show tokenized input along with the encoded label


Map:   0%|          | 0/14936 [00:00<?, ? examples/s]

Map:   0%|          | 0/2630 [00:00<?, ? examples/s]

{'text1': 'The command processing function starts by substituting the main module name and path in the hosting process PEB, with the one of the default internet browser. The path of the main browser of the workstation is obtained by reading the registry value', 'labels': 1375, 'input_ids': [101, 1996, 3094, 6364, 3853, 4627, 2011, 4942, 21532, 1996, 2364, 11336, 2171, 1998, 4130, 1999, 1996, 9936, 2832, 21877, 2497, 1010, 2007, 1996, 2028, 1997, 1996, 12398, 4274, 16602, 1012, 1996, 4130, 1997, 1996, 2364, 16602, 1997, 1996, 2573, 12516, 2003, 4663, 2011, 3752, 1996, 15584, 3643, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [17]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

def compute_metrics(p):
    preds, labels = p
    # If using softmax, we need to use argmax to get the final class prediction
    preds = preds.argmax(axis=-1)
    
    # Calculate precision, recall, F1-score, and accuracy
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    accuracy = accuracy_score(labels, preds)
    
    # Return the metrics as a dictionary
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }


# 4.Training model with bert-base-uncased

In [21]:
from transformers import Trainer, TrainingArguments, BertForSequenceClassification

# Load the pre-trained BERT model for sequence classification
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(labels))

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',          
    num_train_epochs=3,              
    per_device_train_batch_size=8,   
    per_device_eval_batch_size=16,   
    evaluation_strategy="epoch",     
)

# Initialize the Trainer
trainer = Trainer(
    model=model,                         # The pre-trained model
    args=training_args,                  # The training arguments
    train_dataset=train_data,            # The training dataset
    eval_dataset=val_data,               # The validation dataset
    compute_metrics=compute_metrics      # Add the compute_metrics function
)

# Train the model
trainer.train()

# Optionally save the final model manually (this step is usually not required as the trainer saves it automatically)
trainer.save_model("./final_model")  # You can specify any directory you prefer

# Also save the tokenizer (if necessary)
tokenizer.save_pretrained("./final_model")

2025-04-09 07:16:09.938239: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744182969.958876 2084154 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744182969.967835 2084154 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-09 07:16:09.999871: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are new



Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,4.249989,0.319392,0.162067,0.319392,0.201676
2,5.094200,3.614886,0.414829,0.271528,0.414829,0.309434
3,3.575500,3.438241,0.437643,0.292122,0.437643,0.333476


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


('./final_model/tokenizer_config.json',
 './final_model/special_tokens_map.json',
 './final_model/vocab.txt',
 './final_model/added_tokens.json')

# 5. Evaluation

In [None]:
# Evaluate the model on the validation dataset
eval_results = trainer.evaluate(eval_dataset=val_data)

# Print the evaluation results
print("Evaluation results:", eval_results)


In [None]:
# Assuming test_data is defined and contains the test set
test_results = trainer.evaluate(eval_dataset=test_data)

# Print the test results
print("Test results:", test_results)


# 6. Inference

In [18]:
from transformers import BertForSequenceClassification, BertTokenizer
import torch

# Load the pre-trained BERT model and tokenizer
model = BertForSequenceClassification.from_pretrained('./final_model')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Function to get prediction for manual input text
def predict(input_text):
    # Tokenize the input text
    inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    
    # Get the model's output
    with torch.no_grad():  # Disable gradients for inference
        outputs = model(**inputs)
    
    # Get the logits (model output before applying softmax)
    logits = outputs.logits
    
    # Apply softmax to get probabilities (for multi-class classification)
    probabilities = torch.nn.functional.softmax(logits, dim=-1)
    
    # Get the predicted label (index of max probability)
    predicted_label = torch.argmax(probabilities, dim=-1).item()
    
    return predicted_label, probabilities

# Test with manual input
input_text = input("Enter text for classification: ")

predicted_label, probabilities = predict(input_text)

# Assuming `labels` contains the label names and was created previously (e.g., from `dataset['train']['label']`)
# You can map the predicted label index back to the label name
predicted_label_name = labels[predicted_label]

print(f"Predicted Label: {predicted_label_name}")
print(f"Prediction Probabilities: {probabilities}")




Enter text for classification:  An Iranian state-sponsored actor has been observed scanning and attempting to abuse the Log4Shell flaw in publicly-exposed Java applications to deploy a hitherto undocumented PowerShell-based modular backdoor dubbed "CharmPower" for follow-on post-exploitation. "The actor's attack setup was obviously rushed, as they used the basic open-source tool for the exploitation and based their operations on previous infrastructure, which made the attack easier to detect and attribute," researchers from Check Point said in a report published this week. The Israeli cybersecurity company linked the attack to a group known as APT35, which is also tracked using the codenames Charming Kitten, Phosphorus, and TA453, citing overlaps with toolsets previously identified as infrastructure used by the threat actor. Cybersecurity Log4Shell aka CVE-2021-44228 (CVSS score: 10.0) concerns a critical security vulnerability in the popular Log4j logging library that, if successfully

Predicted Label: ['T1518.001']
Prediction Probabilities: tensor([[0.0001, 0.0002, 0.0002,  ..., 0.0001, 0.0001, 0.0006]])


In [19]:
from transformers import BertForSequenceClassification, BertTokenizer
import torch
import nltk
nltk.download('punkt_tab')

# Load the pre-trained BERT model and tokenizer
model = BertForSequenceClassification.from_pretrained('./final_model')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# # Assuming you have a list of labels
# labels = ['Label_0', 'Label_1', 'Label_2']  # Update with your actual labels

# Function to get prediction for manual input text
def predict(input_text):
    # Tokenize the input text
    inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    
    # Get the model's output
    with torch.no_grad():  # Disable gradients for inference
        outputs = model(**inputs)
    
    # Get the logits (model output before applying softmax)
    logits = outputs.logits
    
    # Apply softmax to get probabilities (for multi-class classification)
    probabilities = torch.nn.functional.softmax(logits, dim=-1)
    
    # Get the predicted label (index of max probability)
    predicted_label = torch.argmax(probabilities, dim=-1).item()
    
    return predicted_label, probabilities

# Function to break text into sentences using nltk
def sentence_tokenizer(text):
    # Tokenize the text into sentences
    sentences = nltk.sent_tokenize(text)
    return sentences

# Test with article input
input_text = """
An Iranian state-sponsored actor has been observed scanning and attempting to abuse the Log4Shell flaw in publicly-exposed Java applications to deploy a hitherto undocumented PowerShell-based modular backdoor dubbed "CharmPower" for follow-on post-exploitation. "The actor's attack setup was obviously rushed, as they used the basic open-source tool for the exploitation and based their operations on previous infrastructure, which made the attack easier to detect and attribute," researchers from Check Point said in a report published this week. The Israeli cybersecurity company linked the attack to a group known as APT35, which is also tracked using the codenames Charming Kitten, Phosphorus, and TA453, citing overlaps with toolsets previously identified as infrastructure used by the threat actor. Cybersecurity Log4Shell aka CVE-2021-44228 (CVSS score: 10.0) concerns a critical security vulnerability in the popular Log4j logging library that, if successfully exploited, could lead to remote execution of arbitrary code on compromised systems.
"""

# Break the article into sentences
sentences = sentence_tokenizer(input_text)

# Process each sentence through the model
for sentence in sentences:
    predicted_label, probabilities = predict(sentence)
    
    # Map the predicted label index to the label name
    predicted_label_name = labels[predicted_label]
    
    print(f"Sentence: {sentence}")
    print(f"Predicted Label: {predicted_label_name}")
    print(f"Prediction Probabilities: {probabilities}\n")


[nltk_data] Downloading package punkt_tab to /home/jupyter-
[nltk_data]     st124945/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


Sentence: 
An Iranian state-sponsored actor has been observed scanning and attempting to abuse the Log4Shell flaw in publicly-exposed Java applications to deploy a hitherto undocumented PowerShell-based modular backdoor dubbed "CharmPower" for follow-on post-exploitation.
Predicted Label: ['T1580']
Prediction Probabilities: tensor([[0.0003, 0.0004, 0.0005,  ..., 0.0003, 0.0004, 0.0011]])

Sentence: "The actor's attack setup was obviously rushed, as they used the basic open-source tool for the exploitation and based their operations on previous infrastructure, which made the attack easier to detect and attribute," researchers from Check Point said in a report published this week.
Predicted Label: ['T1064', 'T1547.001']
Prediction Probabilities: tensor([[5.4603e-05, 8.6680e-05, 8.7713e-05,  ..., 7.2819e-05, 6.7793e-05,
         4.5937e-04]])

Sentence: The Israeli cybersecurity company linked the attack to a group known as APT35, which is also tracked using the codenames Charming Kitten,

In [20]:
import spacy
import torch
from transformers import BertForSequenceClassification, BertTokenizer

# Load the pre-trained BERT model and tokenizer
model = BertForSequenceClassification.from_pretrained('./final_model')  # Replace with your model path
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')  # Replace with your tokenizer if needed

# # Assuming you have a list of labels (e.g., ['Label_0', 'Label_1', 'Label_2'])
# labels = model.config.id2label # Update with your actual labels
# print(labels)
# Function to get prediction for a given sentence
def predict(input_text):
    inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    
    with torch.no_grad():  # Disable gradients for inference
        outputs = model(**inputs)
    
    logits = outputs.logits
    probabilities = torch.nn.functional.softmax(logits, dim=-1)
    
    predicted_label = torch.argmax(probabilities, dim=-1).item()
    
    return predicted_label, probabilities

# Function to break text into sentences using spaCy
def sentence_tokenizer_spacy(text):
    nlp = spacy.load("en_core_web_sm")
    doc = nlp(text)
    sentences = [sent.text.strip() for sent in doc.sents]  # Strip leading/trailing spaces
    return sentences

# Example news article
input_text = """
An Iranian state-sponsored actor has been observed scanning and attempting to abuse the Log4Shell flaw in publicly-exposed Java applications to deploy a hitherto undocumented PowerShell-based modular backdoor dubbed "CharmPower" for follow-on post-exploitation. "The actor's attack setup was obviously rushed, as they used the basic open-source tool for the exploitation and based their operations on previous infrastructure, which made the attack easier to detect and attribute," researchers from Check Point said in a report published this week. The Israeli cybersecurity company linked the attack to a group known as APT35, which is also tracked using the codenames Charming Kitten, Phosphorus, and TA453, citing overlaps with toolsets previously identified as infrastructure used by the threat actor. Cybersecurity Log4Shell aka CVE-2021-44228 (CVSS score: 10.0) concerns a critical security vulnerability in the popular Log4j logging library that, if successfully exploited, could lead to remote execution of arbitrary code on compromised systems.
"""

# Break the article into sentences using spaCy
sentences = sentence_tokenizer_spacy(input_text)

# Process each sentence through the model and get predictions
for sentence in sentences:
    predicted_label, probabilities = predict(sentence)
    
    # Map the predicted label index to the label name
    predicted_label_name = labels[predicted_label]
    
    print(f"Sentence: {sentence}")
    print(f"Predicted Label: {predicted_label_name}")
    print(f"Prediction Probabilities: {probabilities}\n")




Sentence: An Iranian state-sponsored actor has been observed scanning and attempting to abuse the Log4Shell flaw in publicly-exposed Java applications to deploy a hitherto undocumented PowerShell-based modular backdoor dubbed "CharmPower" for follow-on post-exploitation.
Predicted Label: ['T1580']
Prediction Probabilities: tensor([[0.0003, 0.0004, 0.0005,  ..., 0.0003, 0.0004, 0.0011]])

Sentence: "The actor's attack setup was obviously rushed, as they used the basic open-source tool for the exploitation and based their operations on previous infrastructure, which made the attack easier to detect and attribute," researchers from Check Point said in a report published this week.
Predicted Label: ['T1064', 'T1547.001']
Prediction Probabilities: tensor([[5.4603e-05, 8.6680e-05, 8.7713e-05,  ..., 7.2819e-05, 6.7793e-05,
         4.5937e-04]])

Sentence: The Israeli cybersecurity company linked the attack to a group known as APT35, which is also tracked using the codenames Charming Kitten, 