# Language Model and MITRE ATT&CK


## Instructions

* Use "Fine-tuning a masked language model" as the template to create your own language model.
  * https://huggingface.co/learn/nlp-course/en/chapter7/3
* Selcet a built-in language model, and try to fine-tune it with an additional corpus.
* We would like to make the fine-tuned model learn 'cybersecurity' knowledge, so we choose to use some cybersecurity-related, professional documents from MITRE website.
  * https://attack.mitre.org/resources/attack-data-and-tools/
* In the MITRE data and tools page, please find two excel files which include the definitions of attack tactics and attack techniques.
  * enterprise-attack-v15.1-tactics.xlsx
  * enterprise-attack-v15.1-techniques.xlsx
* Parse the xlsx files, and extract 'name' and 'description' as your additional corpus.
* Try to fine-tune your model.
* Note that you do not have to push your model to huggingface, rather please keep it in your colab and use/test it directly.

In [1]:
!wget https://attack.mitre.org/docs/enterprise-attack-v15.1/enterprise-attack-v15.1-tactics.xlsx
!wget https://attack.mitre.org/docs/enterprise-attack-v15.1/enterprise-attack-v15.1-techniques.xlsx

--2024-06-03 11:27:56--  https://attack.mitre.org/docs/enterprise-attack-v15.1/enterprise-attack-v15.1-tactics.xlsx
Resolving attack.mitre.org (attack.mitre.org)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ...
Connecting to attack.mitre.org (attack.mitre.org)|185.199.108.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 10109 (9.9K) [application/vnd.openxmlformats-officedocument.spreadsheetml.sheet]
Saving to: ‘enterprise-attack-v15.1-tactics.xlsx’


2024-06-03 11:27:56 (84.0 MB/s) - ‘enterprise-attack-v15.1-tactics.xlsx’ saved [10109/10109]

--2024-06-03 11:27:56--  https://attack.mitre.org/docs/enterprise-attack-v15.1/enterprise-attack-v15.1-techniques.xlsx
Resolving attack.mitre.org (attack.mitre.org)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ...
Connecting to attack.mitre.org (attack.mitre.org)|185.199.108.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2615585 (2.5M) [application/vnd.openxmlformats-off

## Corpus

In [2]:
!pip install datasets
from datasets import Dataset, DatasetDict
import pandas as pd

Collecting datasets
  Downloading datasets-2.19.2-py3-none-any.whl (542 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.1/542.1 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
Collecting requests>=2.32.1 (from datasets)
  Downloading requests-2.32.3-py3-none-any.whl (64 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.9/64.9 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m16.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
[

In [3]:
tactics_df = pd.read_excel('enterprise-attack-v15.1-tactics.xlsx')
techniques_df = pd.read_excel('enterprise-attack-v15.1-techniques.xlsx')

In [4]:
tactics_df

Unnamed: 0,ID,STIX ID,name,description,url,created,last modified,domain,version
0,TA0009,x-mitre-tactic--d108ce10-2419-4cf9-a774-46161d...,Collection,The adversary is trying to gather data of inte...,https://attack.mitre.org/tactics/TA0009,17 October 2018,19 July 2019,enterprise-attack,1.0
1,TA0011,x-mitre-tactic--f72804c5-f15a-449e-a5da-2eecd1...,Command and Control,The adversary is trying to communicate with co...,https://attack.mitre.org/tactics/TA0011,17 October 2018,19 July 2019,enterprise-attack,1.0
2,TA0006,x-mitre-tactic--2558fd61-8c75-4730-94c4-11926d...,Credential Access,The adversary is trying to steal account names...,https://attack.mitre.org/tactics/TA0006,17 October 2018,19 July 2019,enterprise-attack,1.0
3,TA0005,x-mitre-tactic--78b23412-0651-46d7-a540-170a1c...,Defense Evasion,The adversary is trying to avoid being detecte...,https://attack.mitre.org/tactics/TA0005,17 October 2018,19 July 2019,enterprise-attack,1.0
4,TA0007,x-mitre-tactic--c17c5845-175e-4421-9713-829d05...,Discovery,The adversary is trying to figure out your env...,https://attack.mitre.org/tactics/TA0007,17 October 2018,19 July 2019,enterprise-attack,1.0
5,TA0002,x-mitre-tactic--4ca45d45-df4d-4613-8980-bac22d...,Execution,The adversary is trying to run malicious code....,https://attack.mitre.org/tactics/TA0002,17 October 2018,19 July 2019,enterprise-attack,1.0
6,TA0010,x-mitre-tactic--9a4e74ab-5008-408c-84bf-a10dfb...,Exfiltration,The adversary is trying to steal data.\n\nExfi...,https://attack.mitre.org/tactics/TA0010,17 October 2018,19 July 2019,enterprise-attack,1.0
7,TA0040,x-mitre-tactic--5569339b-94c2-49ee-afb3-222293...,Impact,"The adversary is trying to manipulate, interru...",https://attack.mitre.org/tactics/TA0040,14 March 2019,25 July 2019,enterprise-attack,1.0
8,TA0001,x-mitre-tactic--ffd5bcee-6e16-4dd2-8eca-7b3bee...,Initial Access,The adversary is trying to get into your netwo...,https://attack.mitre.org/tactics/TA0001,17 October 2018,19 July 2019,enterprise-attack,1.0
9,TA0008,x-mitre-tactic--7141578b-e50b-4dcc-bfa4-08a8dd...,Lateral Movement,The adversary is trying to move through your e...,https://attack.mitre.org/tactics/TA0008,17 October 2018,19 July 2019,enterprise-attack,1.0


In [5]:
techniques_df

Unnamed: 0,ID,STIX ID,name,description,url,created,last modified,domain,version,tactics,...,is sub-technique,sub-technique of,defenses bypassed,contributors,permissions required,supports remote,system requirements,impact type,effective permissions,relationship citations
0,T1548,attack-pattern--67720091-eee3-4d2d-ae16-826456...,Abuse Elevation Control Mechanism,Adversaries may circumvent mechanisms designed...,https://attack.mitre.org/techniques/T1548,30 January 2020,15 April 2024,enterprise-attack,1.3,"Defense Evasion, Privilege Escalation",...,False,,,,"Administrator, User",,,,,",(Citation: Github UACMe)"
1,T1548.002,attack-pattern--120d5519-3098-4e1c-9191-2aa612...,Abuse Elevation Control Mechanism: Bypass User...,Adversaries may bypass UAC mechanisms to eleva...,https://attack.mitre.org/techniques/T1548/002,30 January 2020,21 April 2023,enterprise-attack,2.1,"Defense Evasion, Privilege Escalation",...,True,T1548,Windows User Account Control,Casey Smith; Stefan Kanthak,"Administrator, User",,,,Administrator,"(Citation: Mandiant No Easy Breach),(Citation:..."
2,T1548.004,attack-pattern--b84903f0-c7d5-435d-a69e-de47cc...,Abuse Elevation Control Mechanism: Elevated Ex...,Adversaries may leverage the <code>Authorizati...,https://attack.mitre.org/techniques/T1548/004,30 January 2020,19 October 2022,enterprise-attack,1.0,"Defense Evasion, Privilege Escalation",...,True,T1548,,"Erika Noerenberg, @gutterchurl, Carbon Black; ...","Administrator, User",,,,root,"(Citation: Carbon Black Shlayer Feb 2019),"
3,T1548.001,attack-pattern--6831414d-bb70-42b7-8030-d4e06b...,Abuse Elevation Control Mechanism: Setuid and ...,An adversary may abuse configurations where an...,https://attack.mitre.org/techniques/T1548/001,30 January 2020,15 March 2023,enterprise-attack,1.1,"Defense Evasion, Privilege Escalation",...,True,T1548,,,User,,,,,"(Citation: OSX Keydnap malware),(Citation: ANS..."
4,T1548.003,attack-pattern--1365fe3b-0f50-455d-b4da-266ce3...,Abuse Elevation Control Mechanism: Sudo and Su...,Adversaries may perform sudo caching and/or us...,https://attack.mitre.org/techniques/T1548/003,30 January 2020,14 March 2022,enterprise-attack,1.0,"Defense Evasion, Privilege Escalation",...,True,T1548,,,User,,,,root,(Citation: Cobalt Strike Manual 4.3 November 2...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
632,T1102.002,attack-pattern--be055942-6e63-49d7-9fa1-9cb7d8...,Web Service: Bidirectional Communication,"Adversaries may use an existing, legitimate ex...",https://attack.mitre.org/techniques/T1102/002,14 March 2020,26 March 2020,enterprise-attack,1.0,Command and Control,...,True,T1102,,,User,,,,,(Citation: Trend Micro DRBControl February 202...
633,T1102.001,attack-pattern--f7827069-0bf2-4764-af4f-23fae0...,Web Service: Dead Drop Resolver,"Adversaries may use an existing, legitimate ex...",https://attack.mitre.org/techniques/T1102/001,14 March 2020,26 March 2020,enterprise-attack,1.0,Command and Control,...,True,T1102,,,User,,,,,(Citation: Securelist Brazilian Banking Malwar...
634,T1102.003,attack-pattern--9c99724c-a483-4d60-ad9d-7f004e...,Web Service: One-Way Communication,"Adversaries may use an existing, legitimate ex...",https://attack.mitre.org/techniques/T1102/003,14 March 2020,26 March 2020,enterprise-attack,1.0,Command and Control,...,True,T1102,,,User,,,,,"(Citation: Fortinet Metamorfo Feb 2020),(Citat..."
635,T1047,attack-pattern--01a5a209-b94c-450b-b7f9-946497...,Windows Management Instrumentation,Adversaries may abuse Windows Management Instr...,https://attack.mitre.org/techniques/T1047,31 May 2017,11 April 2024,enterprise-attack,1.5,Execution,...,False,,,"@ionstorm; Olaf Hartong, Falcon Force; Tristan...",,1.0,,,,(Citation: Crowdstrike TELCO BPO Campaign Dece...


## Now on your own

Write your codes here. There should be lots of codes.

In [6]:
from transformers import TFAutoModelForMaskedLM

model_checkpoint = "distilbert-base-uncased"
model = TFAutoModelForMaskedLM.from_pretrained(model_checkpoint)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

All PyTorch model weights were used when initializing TFDistilBertForMaskedLM.

All the weights of TFDistilBertForMaskedLM were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertForMaskedLM for predictions without further training.


In [7]:
model.summary()

Model: "tf_distil_bert_for_masked_lm"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 distilbert (TFDistilBertMa  multiple                  66362880  
 inLayer)                                                        
                                                                 
 vocab_transform (Dense)     multiple                  590592    
                                                                 
 vocab_layer_norm (LayerNor  multiple                  1536      
 malization)                                                     
                                                                 
 vocab_projector (TFDistilB  multiple                  23866170  
 ertLMHead)                                                      
                                                                 
Total params: 66985530 (255.53 MB)
Trainable params: 66985530 (255.53 MB)
Non-trainable params: 0 (0.00 

In [8]:
text = "This is a great [MASK]."

In [9]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [10]:
import numpy as np
import tensorflow as tf

inputs = tokenizer(text, return_tensors="np")
token_logits = model(**inputs).logits
# Find the location of [MASK] and extract its logits
mask_token_index = np.argwhere(inputs["input_ids"] == tokenizer.mask_token_id)[0, 1]
mask_token_logits = token_logits[0, mask_token_index, :]
# Pick the [MASK] candidates with the highest logits
# We negate the array before argsort to get the largest, not the smallest, logits
top_5_tokens = np.argsort(-mask_token_logits)[:5].tolist()

for token in top_5_tokens:
    print(f">>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}")

>>> This is a great deal.
>>> This is a great success.
>>> This is a great adventure.
>>> This is a great idea.
>>> This is a great feat.


In [11]:
# 合併 DataFrame
merged_df = pd.concat([tactics_df[['name', 'description']], techniques_df[['name', 'description']]], ignore_index=True)
merged_df = merged_df.rename(columns={'name': 'label', 'description': 'text'})
# 交換text label column
merged_df = merged_df[['text', 'label']]

# 將train 跟test切成各一半
total_rows = len(merged_df)
train_size = total_rows // 2
test_size = total_rows // 2

# 打亂數據
shuffled_df = merged_df.sample(frac=1, random_state=0).reset_index(drop=True)

train_dataset = Dataset.from_pandas(shuffled_df.iloc[:train_size])
test_dataset = Dataset.from_pandas(shuffled_df.iloc[train_size:train_size + test_size])

# DatasetDict
dataset_dict = DatasetDict({
    "train": train_dataset,
    "test": test_dataset,
})

dataset_dict

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 325
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 325
    })
})

In [12]:
sample = dataset_dict["train"].shuffle(seed=42).select(range(3))

for row in sample:
    print(f"\n'>>> Review: {row['text']}'")
    print(f"'>>> Label: {row['label']}'")


'>>> Review: Adversaries may manipulate products or product delivery mechanisms prior to receipt by a final consumer for the purpose of data or system compromise.

Supply chain compromise can take place at any stage of the supply chain including:

* Manipulation of development tools
* Manipulation of a development environment
* Manipulation of source code repositories (public or private)
* Manipulation of source code in open-source dependencies
* Manipulation of software update/distribution mechanisms
* Compromised/infected system images (multiple cases of removable media infected at the factory)(Citation: IBM Storwize)(Citation: Schneider Electric USB Malware) 
* Replacement of legitimate software with modified versions
* Sales of modified/counterfeit products to legitimate distributors
* Shipment interdiction

While supply chain compromise can impact any component of hardware or software, adversaries looking to gain execution have often focused on malicious additions to legitimate s

In [13]:
def tokenize_function(examples):
    result = tokenizer(examples["text"])
    if tokenizer.is_fast:
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    return result


# Use batched=True to activate fast multithreading!
tokenized_datasets = dataset_dict.map(
    tokenize_function, batched=True, remove_columns=["text", "label"]
)
tokenized_datasets

Map:   0%|          | 0/325 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (870 > 512). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/325 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids'],
        num_rows: 325
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids'],
        num_rows: 325
    })
})

In [14]:
chunk_size = 128

In [15]:
# Slicing produces a list of lists for each feature
tokenized_samples = tokenized_datasets["train"][:3]

for idx, sample in enumerate(tokenized_samples["input_ids"]):
    print(f"'>>> Review {idx} length: {len(sample)}'")

'>>> Review 0 length: 300'
'>>> Review 1 length: 238'
'>>> Review 2 length: 284'


In [16]:
concatenated_examples = {
    k: sum(tokenized_samples[k], []) for k in tokenized_samples.keys()
}
total_length = len(concatenated_examples["input_ids"])
print(f"'>>> Concatenated reviews length: {total_length}'")

'>>> Concatenated reviews length: 822'


In [17]:
chunks = {
    k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
    for k, t in concatenated_examples.items()
}

for chunk in chunks["input_ids"]:
    print(f"'>>> Chunk length: {len(chunk)}'")

'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 54'


In [18]:
def group_texts(examples):
    # Concatenate all texts
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    # Compute length of concatenated texts
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the last chunk if it's smaller than chunk_size
    total_length = (total_length // chunk_size) * chunk_size
    # Split by chunks of max_len
    result = {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }
    # Create a new labels column
    result["labels"] = result["input_ids"].copy()
    return result

In [19]:
lm_datasets = tokenized_datasets.map(group_texts, batched=True)
lm_datasets

Map:   0%|          | 0/325 [00:00<?, ? examples/s]

Map:   0%|          | 0/325 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 811
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 788
    })
})

In [20]:
tokenizer.decode(lm_datasets["train"][1]["input_ids"])

'opportunities for other forms of reconnaissance ( ex : [ phishing for information ] ( https : / / attack. mitre. org / techniques / t1598 ) or [ search open technical databases ] ( https : / / attack. mitre. org / techniques / t1596 ) ), establishing operational resources ( ex : [ establish accounts ] ( https : / / attack. mitre. org / techniques / t1585 ) or [ compromise accounts ] ( https : / / attack. mitre. org / techniques / t1586 ) ), and / or initial access ( ex : [ trusted relationship ] ('

In [21]:
tokenizer.decode(lm_datasets["train"][1]["labels"])

'opportunities for other forms of reconnaissance ( ex : [ phishing for information ] ( https : / / attack. mitre. org / techniques / t1598 ) or [ search open technical databases ] ( https : / / attack. mitre. org / techniques / t1596 ) ), establishing operational resources ( ex : [ establish accounts ] ( https : / / attack. mitre. org / techniques / t1585 ) or [ compromise accounts ] ( https : / / attack. mitre. org / techniques / t1586 ) ), and / or initial access ( ex : [ trusted relationship ] ('

In [22]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [23]:
samples = [lm_datasets["train"][i] for i in range(2)]
for sample in samples:
    _ = sample.pop("word_ids")

for chunk in data_collator(samples)["input_ids"]:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")


'>>> [CLS] adversaries may search websites owned by the victim for information that can be used [MASK] targeting. victim - owned websites may contain a variety of details, including names of [MASK] / divisions, physical locations, and data about [MASK] employees [MASK] as [MASK], roles, and contact info ( ex : [ email addresses [MASK] ( https : / / attack. mitre. org [MASK] techniques / t15 [MASK]9 / 002 ) ) [MASK] these sites may also have details highlighting [MASK] operations and [MASK]. [MASK] citation : [MASK]paritech leak ) adversaries may search victim - [MASK] websites to gather actionable information. [MASK] from these sources may reveal'

'>>> opportunities for [unused992] forms of reconnaissance ( ex : [ phishing for information ] ( https : / / attack. mitre. org / techniques / gail1598 [MASK] or [ search open [MASK] databases [MASK] ( [MASK] : / / attack. [MASK]re. org ゆ techniques [MASK] t15 [MASK]6 [MASK] ), establishing operational resources ( [MASK] : [ establish accou

In [24]:
import collections
import numpy as np

from transformers import default_data_collator

wwm_probability = 0.2


def whole_word_masking_data_collator(features):
    for feature in features:
        word_ids = feature.pop("word_ids")

        # Create a map between words and corresponding token indices
        mapping = collections.defaultdict(list)
        current_word_index = -1
        current_word = None
        for idx, word_id in enumerate(word_ids):
            if word_id is not None:
                if word_id != current_word:
                    current_word = word_id
                    current_word_index += 1
                mapping[current_word_index].append(idx)

        # Randomly mask words
        mask = np.random.binomial(1, wwm_probability, (len(mapping),))
        input_ids = feature["input_ids"]
        labels = feature["labels"]
        new_labels = [-100] * len(labels)
        for word_id in np.where(mask)[0]:
            word_id = word_id.item()
            for idx in mapping[word_id]:
                new_labels[idx] = labels[idx]
                input_ids[idx] = tokenizer.mask_token_id
        feature["labels"] = new_labels

    return default_data_collator(features)

In [25]:
samples = [lm_datasets["train"][i] for i in range(2)]
batch = whole_word_masking_data_collator(samples)

for chunk in batch["input_ids"]:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")


'>>> [CLS] [MASK] [MASK] [MASK] may [MASK] websites [MASK] [MASK] the victim [MASK] information [MASK] can be [MASK] during targeting. victim - owned [MASK] may [MASK] a variety [MASK] details [MASK] including names of departments / divisions, physical locations, and data about key [MASK] such as names [MASK] roles [MASK] and contact [MASK] ( ex : [ email [MASK] ] ( https [MASK] / / attack. mitre. [MASK] / techniques / t1589 [MASK] 002 ) ). these [MASK] may also have details highlighting business operations and relationships. [MASK] citation [MASK] comparitech leak ) adversaries may [MASK] victim - owned websites to gather actionable information. information from these sources may [MASK]'

'>>> opportunities for other forms [MASK] reconnaissance ( ex : [MASK] phishing for information ] ( https : / / attack. mitre [MASK] [MASK] / techniques [MASK] t1598 [MASK] or [ search open technical databases ] [MASK] [MASK] : / / attack [MASK] mitre. org / [MASK] [MASK] t1596 ) ), establishing ope

In [26]:
train_size = 600
test_size = int(0.1 * train_size)

downsampled_dataset = lm_datasets["train"].train_test_split(
    train_size=train_size, test_size=test_size, seed=42
)
downsampled_dataset

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 600
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 60
    })
})

In [27]:
# from huggingface_hub import notebook_login

# notebook_login()

In [28]:
tf_train_dataset = model.prepare_tf_dataset(
    downsampled_dataset["train"],
    collate_fn=data_collator,
    shuffle=True,
    batch_size=32,
)

tf_eval_dataset = model.prepare_tf_dataset(
    downsampled_dataset["test"],
    collate_fn=data_collator,
    shuffle=False,
    batch_size=32,
)

In [29]:
from transformers import create_optimizer
from transformers.keras_callbacks import PushToHubCallback
import tensorflow as tf

num_train_steps = len(tf_train_dataset)
optimizer, schedule = create_optimizer(
    init_lr=2e-5,
    num_warmup_steps=1_000,
    num_train_steps=num_train_steps,
    weight_decay_rate=0.01,
)
model.compile(optimizer=optimizer)

# Train in mixed-precision float16
# tf.keras.mixed_precision.set_global_policy("mixed_float16")

# model_name = model_checkpoint.split("/")[-1]
# callback = PushToHubCallback(
#     output_dir=f"{model_name}-finetuned-cyber", tokenizer=tokenizer
# )

## Perplexity

Show the perplexity of newly trained model.

In [30]:
import math

eval_loss = model.evaluate(tf_eval_dataset)
print(f"Perplexity: {math.exp(eval_loss):.2f}")

Cause: for/else statement not yet supported


Cause: for/else statement not yet supported
Perplexity: 21.65


In [31]:
from tensorflow.keras.callbacks import EarlyStopping

# early stopping callback
early_stopping_callback = EarlyStopping(
    monitor='val_loss',
    patience=10,
    restore_best_weights=True
)

# Train the model
model.fit(
    tf_train_dataset,
    validation_data=tf_eval_dataset,
    epochs=200,
    callbacks=[early_stopping_callback]
)

Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200
Epoch 22/200
Epoch 23/200
Epoch 24/200
Epoch 25/200
Epoch 26/200
Epoch 27/200
Epoch 28/200
Epoch 29/200
Epoch 30/200
Epoch 31/200
Epoch 32/200
Epoch 33/200
Epoch 34/200
Epoch 35/200
Epoch 36/200
Epoch 37/200
Epoch 38/200
Epoch 39/200
Epoch 40/200
Epoch 41/200
Epoch 42/200


<tf_keras.src.callbacks.History at 0x7d4a0f34c880>

In [32]:
eval_loss = model.evaluate(tf_eval_dataset)
print(f"Perplexity: {math.exp(eval_loss):.2f}")

Perplexity: 6.88


In [33]:
model.save_pretrained('./fine_tuned_model')
tokenizer.save_pretrained('./fine_tuned_model')

('./fine_tuned_model/tokenizer_config.json',
 './fine_tuned_model/special_tokens_map.json',
 './fine_tuned_model/vocab.txt',
 './fine_tuned_model/added_tokens.json',
 './fine_tuned_model/tokenizer.json')

## Downstream Task Test

* Now you should have two models, one is the original one downloaded from the HuggingFace, the other one is a fine-tuned one.

* Let's try a downstream task to see if the classification rate changes after your fine-tuned model learns some additional cybersecurity knowledge.

* In the example of 'Fine-tuning a masked language model', its 'Using our fine-tuned model' tests the now model with a "fill-mask" pipeline.

* In "Transformers, what can they do?" (https://huggingface.co/learn/nlp-course/en/chapter1/3), there are severl piplelines. Lets try 'Zero-shot classification'.

* Please prepare severl sentences (> 100) from the website (not from the downloaded xlsx files) as your testing examples.

* Feed these sentences into the original model and your fine-tuned model, and ask them which 'tactics' and 'techniques' this sentence belongs to?

* Show us the classification rate of 'tactics' and 'techniques' increase (or not) if fine-tuned model is used.

* Show us some examples that they really changes label of 'tactics' or 'techniques' when new model is used.

In [34]:
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification

In [35]:
import io
import pandas as pd
from google.colab import files

In [36]:
uploaded = files.upload()

Saving H_LM_attack_data.xlsx to H_LM_attack_data.xlsx


In [37]:
data_df = pd.read_excel(io.BytesIO(uploaded.get('H_LM_attack_data.xlsx')))

In [38]:
data_df

Unnamed: 0,name,description
0,Active Scanning,Monitor and analyze traffic patterns and packe...
1,Active Scanning,Consider correlation with process monitoring a...
2,Active Scanning,Monitor network data for uncommon data flows. ...
3,Gather Victim Host Information,Internet scanners may be used to look for patt...
4,Gather Victim Host Information,Much of this activity may have a very high occ...
...,...,...
98,Direct Volume Access,Monitor handle opens on volumes that are made ...
99,Direct Volume Access,Monitor for the creation of volume shadow copy...
100,Deobfuscate/Decode Files or Information,Monitor for changes made to files for unexpect...
101,Deobfuscate/Decode Files or Information,Monitor for newly executed processes that atte...


In [39]:
# get the candidate labels, descriptions, and the actual labels
candidate_labels = data_df['name'].unique().tolist()
descriptions = data_df['description'].tolist()
actual_labels = data_df['name'].tolist()

In [40]:
# Load the original model
original_model_name = 'distilbert-base-uncased'
original_tokenizer = AutoTokenizer.from_pretrained(original_model_name)
original_model = AutoModelForSequenceClassification.from_pretrained(original_model_name)
original_classifier = pipeline("zero-shot-classification", model=original_model, tokenizer=original_tokenizer)

# Load the fine-tuned model from the local directory
fine_tuned_model_path = './fine_tuned_model'  # Path to the directory where your fine-tuned model is saved
fine_tuned_tokenizer = AutoTokenizer.from_pretrained(fine_tuned_model_path)
fine_tuned_model = AutoModelForSequenceClassification.from_pretrained(fine_tuned_model_path, from_tf=True)
fine_tuned_classifier = pipeline("zero-shot-classification", model=fine_tuned_model, tokenizer=fine_tuned_tokenizer)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to -1. Define a descriptive label2id mapping in the model config to ensure correct outputs.
All TF 2.0 model weights were used when initializing DistilBertForSequenceClassification.

All the weights of DistilBertForSequenceClassification were initialized from the TF 2.0 model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use DistilBertForSequenceClassification for predictions without further training.
Failed to determine 'entailment' label id from the label2id mapping in the model config. Setti

In [41]:
def classify_descriptions(classifier, descriptions, candidate_labels):
    results = []
    i = 1
    for description in descriptions:
        classification = classifier(description, candidate_labels=candidate_labels)
        results.append(classification['labels'][0])
        i = i+1
    return results

In [42]:
# Calculate accuracy
def calculate_accuracy(predictions, actual_labels):
    correct = sum(p == a for p, a in zip(predictions, actual_labels))
    return correct / len(actual_labels)

In [43]:
# distilbert-base-uncased predictions
original_predictions = classify_descriptions(original_classifier, descriptions, candidate_labels)

# my model predictions
fine_tuned_predictions = classify_descriptions(fine_tuned_classifier, descriptions, candidate_labels)

original_accuracy = calculate_accuracy(original_predictions, actual_labels)
fine_tuned_accuracy = calculate_accuracy(fine_tuned_predictions, actual_labels)

In [44]:
# results
print(f"Original Model Accuracy: {original_accuracy * 100:.2f}%")
print(f"Fine-Tuned Model Accuracy: {fine_tuned_accuracy * 100:.2f}%")

Original Model Accuracy: 0.97%
Fine-Tuned Model Accuracy: 3.88%


In [46]:
# 列出有改變的label
changes = []
for description, orig, fine, actual in zip(descriptions, original_predictions, fine_tuned_predictions, actual_labels):
    if orig != fine:
        changes.append((description, orig, fine, actual))

print("\nExamples where classification changed:")
for change in changes:
    print(f"Description: {change[0]}\nOriginal: {change[1]}\nFine-tuned: {change[2]}\nActual label: {change[3]}\n")


Examples where classification changed:
Description: Monitor and analyze traffic patterns and packet inspection associated to protocol(s) that do not follow the expected protocol standards and traffic flows (e.g extraneous packets that do not belong to established flows, gratuitous or anomalous traffic patterns, anomalous syntax, or structure).
Original: Content Injection
Fine-tuned: Container Administration Command
Actual label: Active Scanning

Description: Consider correlation with process monitoring and command line to detect anomalous processes execution and command line arguments associated to traffic patterns (e.g. monitor anomalies in use of files that do not normally initiate connections for respective protocol(s)).
Original: Protocol Tunneling
Fine-tuned: Domain or Tenant Policy Modification
Actual label: Active Scanning

Description: Monitor network data for uncommon data flows. Processes utilizing the network that do not normally have network communication or have never bee