# Finetuning by adding Classification Head

**This use case is an extention of [Classification_cybersecurity_descriptions](https://github.com/RobustIntelligence/foundation-ai-cookbook/blob/main/2_examples/Classification_cybersecurity_descriptions.ipynb) shown in 2_examples**

For this demo, we use human-annotated datasets for cyber threat intelligence analysis from CTI-HAL to determine each excerpt of security blogs, reports etc. is associated with which MITRE ATT&CK ID.  **The dataset is NOT used for training of Foundation-Sec-8B model.**

To see the details of datasets, refer to
- Paper: https://arxiv.org/abs/2504.05866 <br>
- GitHub: https://github.com/dessertlab/CTI-HAL

We'll finetune Foundation-Sec-8B as well as original llama model to show how finetuning works, and how Foundation-Sec-8B outperforms the original model.

### Hardware
This finetuning has been conducted under Nvidia 8xA100 (80GB) GPUs. Though it's doable with 1 GPU, it'll be slower. If you don't have enough memories, consider enabling QLoRa. That'll save memories at the cost of small performance degration.

**Caution: The dataset used for finetuning is too small to perform effectively enough. If you are planning to implement finetuning with actual use cases, consider to use bigger datasets.**

# Setup

In [1]:
!export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

In [2]:
import random
import numpy as np
import torch

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

import warnings
warnings.simplefilter('ignore')

In [3]:
DEVICE = "cuda:0"

print(f"CUDA Available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0)}")

CUDA Available: True
GPU: NVIDIA A100-SXM4-80GB


# Model Download & Test

In [4]:
import os

HF_TOKEN = os.environ.get("HF_TOKEN")
WB_PROJECT_NAME = "finetuning_demo"

LLAMA_MODEL_ID = "meta-llama/Llama-3.1-8B"
FOUNDATION_SEC_8B_MODEL_ID = "fdtn-ai/Foundation-Sec-8B"

In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# tokenizer is the same for all processes

tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [6]:
def load_model(model_id):

    # Uncomment below if you want to enable QLoRa instead of normal LoRa
    # bnb_config = BitsAndBytesConfig(
    #     load_in_4bit = True,
    #     bnb_4bit_quant_type = "nf4",
    #     bnb_4bit_compute_dtype = "float16",
    #     bnb_4bit_use_double_quant = True
    # )

    model = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path = model_id,
        device_map = DEVICE,
        # quantization_config = bnb_config,
    ).to(DEVICE)
    model.config.use_cache = False
    model.config.pretraining_tp = 1
    model.generation_config.top_p = None
    model.generation_config.temperature = None
    model.generation_config.pad_token_id = tokenizer.eos_token_id    

    return model

In [7]:
splitter = "technique: "
MAX_LENGTH = 256
MAX_NEW_TOKEN = 5

def inference(prompt, model):
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        _output = model.generate(
            **inputs,
            max_new_tokens = MAX_NEW_TOKEN,
            do_sample = False,
            repetition_penalty = 1.2,
        )
    output = tokenizer.decode(_output[0], skip_special_tokens = True)
    response = output.split(splitter)[-1].strip()
    return response

Let's see how each model works with an example. 

Give a prompt to each model and see what the output looks like. <br>
The correct answer is T1047. Original llama failed to answer correctly, while Foundation-Sec-8B did successfully.

In [8]:
prompt = '''
context: This downloader is unique per system and contains a customized backdoor written in Assembler
technique: T1059

context: This malware was capable of stealing significant system and network information
technique: T1082

context: Email phishing credential theft
technique: T1566

context: they are served a ZIP archive containing a malicious LNK file.
technique: T1204

context: download and deploy Trickbot on the user's machine
technique: T1105

context: POSHSPY's use of WMI to both store and persist the backdoor code makes it nearly invisible to anyone not familiar with the intricacies of WMI.
technique: T'''

In [9]:
llama_model = load_model(LLAMA_MODEL_ID)
llama_output_test = inference(prompt, llama_model)

#To avoid OOM error load model one by one and remove models not currently being used
import gc

llama_model = None
gc.collect()

print(llama_output_test)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

T1547.001


In [10]:
foundation_sec_8b_model = load_model(FOUNDATION_SEC_8B_MODEL_ID)
foundation_sec_8b_output_test = inference(prompt, foundation_sec_8b_model)

foundation_sec_8b_model = None
gc.collect()

print(foundation_sec_8b_output_test)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

T1047


# Data Preparation

Let's download datasets and pre-process them for evaluation and finetuning.


In [11]:
prompt_template = '''context: {context}
technique: T'''

finetuning_prompt_template = '''context: {context}
reason: {description}
technique: T{label}'''

In [12]:
# Download data from https://github.com/dessertlab/CTI-HAL first
# Here it's assumed that CIT-HAL is downloaded at current directory
from pathlib import Path

PATH_TO_CTI_HAL = Path("CTI-HAL")

In [13]:
import json
import pandas as pd
import csv
from datasets import Dataset, load_from_disk
import re
from urllib.request import urlopen

def _collect_all_files(path: Path, extension: str):
    files_names = []
    for file_path in path.rglob(f"*.{extension}"):  # rglob searches recursively
        if file_path.is_file():
            files_names.append(file_path)
    return files_names


def get_attack_ids():
    """Get all attack IDs from the MITRE ATT&CK website."""
    URL = "https://attack.mitre.org/techniques/enterprise/"
    page = urlopen(URL)
    html = page.read().decode("utf-8")
    pattern = r'\bT\d{4}\b'
    matches = re.findall(pattern, html)
    matches.sort()
    attack_ids = set(matches)
    attack_ids = sorted(attack_ids)

    EXPECTED_NUM_ATTACK_IDS = 211
    assert len(attack_ids) == EXPECTED_NUM_ATTACK_IDS, f"Expected {EXPECTED_NUM_ATTACK_IDS} attack IDs (as of End of April 2025), but got {len(attack_ids)}"

    return attack_ids


def make_datasets(path_to_datasets: Path, csv_path: str):
    files_names = _collect_all_files(path_to_datasets, extension = "json")
    rows = []
    for json_file_name in files_names:
        with open(json_file_name, "r") as f:
            json_data = json.load(f)        
            for item in json_data:
                try:
                    context = item["context"]
                    description = item["metadata"]["description"]
                    label = item["technique"]
                    if context == "" or label == None:
                        continue
                    rows.append([context, description, label])
                except KeyError:
                    print(f"KeyError in file {json_file_name}: {item}")
                    continue

    header = ['context', 'description', 'label']
    with open(csv_path, 'wt') as f:
        csv_writer = csv.writer(f)
        csv_writer.writerow(header)
        for i, row in enumerate(rows):
            try:
                csv_writer.writerow(row)
            except:
                print(f"Error writing row {i}, skipping")
                continue

    df = pd.read_csv(csv_path)
    df = df.drop_duplicates().set_index("context")
    attack_ids = get_attack_ids()
    df = df[df["label"].isin(attack_ids)]    
    return df

In [14]:
SPLIT_SIZE = 0.2

CSV_NAME = "datasets.csv"
HF_DATASET_NAME = "hf_dataset"

In [15]:
df = make_datasets(PATH_TO_CTI_HAL / "data", csv_path = PATH_TO_CTI_HAL / CSV_NAME)
dataset = Dataset.from_pandas(df).train_test_split(test_size = SPLIT_SIZE, shuffle = True, seed = SEED)
dataset.save_to_disk(PATH_TO_CTI_HAL / HF_DATASET_NAME)

Error writing row 596, skipping
Error writing row 599, skipping


Saving the dataset (0/1 shards):   0%|          | 0/1247 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/312 [00:00<?, ? examples/s]

In [16]:
def load_dataset_and_preprocess():

    def _preprocess_data(examples):

        assert(len(examples["description"]) == len(examples["context"]) == len(examples["label"])), "Length of description, context and label must be the same"
        total_len = len(examples["description"])

        prompts = [
            finetuning_prompt_template.format(
                description = description,
                context = context,
                label = label[1:] # To remove first T as they are already a part of template
            )
            for description, context, label in zip(
                examples["description"],
                examples["context"],
                examples["label"]
            )
        ]

        return tokenizer(prompts, truncation = True, padding = "max_length", max_length = MAX_LENGTH)

    hf_datasets = load_from_disk(PATH_TO_CTI_HAL /  HF_DATASET_NAME)
    train_data = hf_datasets["train"]
    test_data = hf_datasets["test"]
    tokenized_train = train_data.map(_preprocess_data, batched=True, remove_columns=["context", "description", "label"])
    tokenized_test = test_data.map(_preprocess_data, batched=True, remove_columns=["context", "description", "label"])
    print(f"Train samples: {len(tokenized_train)}, Test samples: {len(tokenized_test)}")

    return tokenized_train, tokenized_test

In [17]:
tokenized_train, tokenized_test = load_dataset_and_preprocess()

Map:   0%|          | 0/1247 [00:00<?, ? examples/s]

Map:   0%|          | 0/312 [00:00<?, ? examples/s]

Train samples: 1247, Test samples: 312


# Evaluation (before finetuning)

Let's see how models perform before finetuing is conducted.

In [18]:
def get_prompts_and_labels():
    hf_dataset = load_from_disk(PATH_TO_CTI_HAL / HF_DATASET_NAME)
    df = hf_dataset["test"].to_pandas()
    prompts = []
    labels = []
    for row in df.iterrows():
        row = row[1]
        context = row['context']
        description = row['description']
        label = row['label']
        prompt = prompt_template.format(description=description, context=context)
        prompts.append(prompt)
        labels.append(label)
    return prompts, labels

In [19]:
import re

def _reg_check(label, pred, idx):
    pattern = r'\bT\d{4}\b'
    matches = re.findall(pattern, pred)
    if matches and matches[idx] == label:
        return True
    return False

def evaluate_pred(prompts, labels, model):
    preds = [inference(prompt, model) for prompt in prompts]
    num_exist = sum(1 for label, pred in zip(labels, preds) if _reg_check(str(label), pred, 0))
    print(f"{num_exist} out of total {len(labels)}")
    return round(num_exist/len(labels), 4)

def eval(model_id):
    model = load_model(model_id)
    prompts, labels = get_prompts_and_labels()
    result = evaluate_pred(prompts, labels, model)
    print(f"Accuracy: {result}")

In [20]:
eval(LLAMA_MODEL_ID)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

46 out of total 312
Accuracy: 0.1474


In [21]:
eval(FOUNDATION_SEC_8B_MODEL_ID)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

135 out of total 312
Accuracy: 0.4327


It shows that Foundation-Sec-8B already outperforms the original model.

# Finetuning as CausalML & Evaluation

Let's finetune the models using LoRa approach, maintaining the model as CausalML

In [None]:
from transformers import TrainingArguments
from peft import LoraConfig, get_peft_model, PeftConfig
from trl import SFTTrainer

OUTPUT_DIR = "./checkpoints"

def train(model_id):

    _output_dir = Path(OUTPUT_DIR) / str(f"{model_id}").replace("/", "_")

    model = load_model(model_id)

    training_args = TrainingArguments(
        output_dir = _output_dir,
        label_names = ["labels"],
        run_name = "finetuning_demo",
        num_train_epochs = 50,
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 8,
        eval_strategy = "no",
        logging_steps = 50,
        learning_rate = 4.e-5,
        weight_decay = 0.001,
        fp16 = False,
        bf16 = False,
        max_grad_norm = 0.3,
        max_steps = -1,
        group_by_length = True,
        lr_scheduler_type = "constant",
        seed = SEED,
        report_to = ["none"],
    )
    
    peft_parameters = LoraConfig(
        lora_alpha = 8,
        lora_dropout = 0.1,
        r = 8,
        bias = "none",
        task_type = "CAUSAL_LM",
    )
    peft_model = get_peft_model(model, peft_parameters)
    peft_model.print_trainable_parameters()

    trainer = SFTTrainer(
        model = peft_model,
        train_dataset = tokenized_train,
        eval_dataset = tokenized_test,
        peft_config = peft_parameters,
        args = training_args,
    )

    trainer.train()

In [23]:
from transformers.trainer_utils import get_last_checkpoint
from peft import PeftModel

def load_finetuned_model(original_model_id):
    _dir = Path(OUTPUT_DIR) / str(f"{original_model_id}").replace("/", "_")
    last_checkpoint = get_last_checkpoint(_dir)
    print("last_checkpoint:", last_checkpoint)
    peft_config = PeftConfig.from_pretrained(last_checkpoint)
    orginal_model = AutoModelForCausalLM.from_pretrained(
        original_model_id,
        torch_dtype = torch.float16,
        device_map = DEVICE,
    )
    peft_model = PeftModel.from_pretrained(orginal_model, last_checkpoint, is_trainable=True)
    model = peft_model.merge_and_unload()
    model.generation_config.top_p = None
    model.generation_config.temperature = None
    model.generation_config.pad_token_id = tokenizer.eos_token_id
    
    return model


def eval_finetuning(original_model_id):
    model = load_finetuned_model(original_model_id)
    prompts, labels = get_prompts_and_labels()
    result = evaluate_pred(prompts, labels, model)
    print(f"metrics: {result}")

In [24]:
train(LLAMA_MODEL_ID)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424


Truncating train dataset:   0%|          | 0/1247 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/312 [00:00<?, ? examples/s]

Step,Training Loss
50,0.738
100,0.0879
150,0.0781
200,0.0687
250,0.062
300,0.0626
350,0.0628
400,0.0638
450,0.0607
500,0.0577


In [25]:
eval_finetuning(LLAMA_MODEL_ID)

last_checkpoint: checkpoints/meta-llama_Llama-3.1-8B/checkpoint-950


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

150 out of total 312
metrics: 0.4808


In [26]:
train(FOUNDATION_SEC_8B_MODEL_ID)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

trainable params: 3,407,872 || all params: 8,034,717,696 || trainable%: 0.0424


Step,Training Loss
50,0.1183
100,0.0852
150,0.0709
200,0.0659
250,0.0596
300,0.0587
350,0.0577
400,0.057
450,0.0534
500,0.0507


In [27]:
eval_finetuning(FOUNDATION_SEC_8B_MODEL_ID)

last_checkpoint: checkpoints/fdtn-ai_Foundation-Sec-8B/checkpoint-950


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

168 out of total 312
metrics: 0.5385


Both of performances of the original model and Foundation-Sec-8B have improved, and finetuned Foundation-Sec-8B still outperforms the finetuned llama.

# Finetuning with Classification Head
Another approach of finetuning the models is to replace the lm head with a classification head.
Though this is a more popular approach for encoder models, it's also available for decoder models like Foundation-Sec-8B.

To load models use AutoModelForSequenceClassification instead. 
It's also necessary to map labels to indices and vice versa.

In [28]:
from transformers import AutoModelForSequenceClassification


def map_id_and_label():
    """Collect all attack IDs from the MITRE ATT&CK website and map them to labels."""

    attack_ids = get_attack_ids()
    id2label = {}
    label2id = {}
    for i, attack_id in enumerate(attack_ids):
        id2label[i] = attack_id
        label2id[attack_id] = i

    return id2label, label2id


def load_model_with_classification_head(model_id):
    """Load the model from the specified path."""
    id2label, label2id = map_id_and_label()

    model = AutoModelForSequenceClassification.from_pretrained(
        model_id, 
        num_labels = len(id2label), 
        id2label = id2label, 
        label2id = label2id, 
        device_map = "auto",
        torch_dtype = torch.bfloat16,
    )
    model.config.pad_token_id = model.config.eos_token_id
    model.config.use_cache = False
    model.config.pretraining_tp = 1
    return model

In [29]:
_, label2id = map_id_and_label()

def map_labels(example):
    """Map the labels to their corresponding IDs."""
    example["label"] = label2id[example["label"]]
    return example


def load_dataset_and_preprocess_for_classification_head():

    hf_datasets = load_from_disk(PATH_TO_CTI_HAL /  HF_DATASET_NAME)
    train_data = hf_datasets["train"]
    eval_data = hf_datasets["test"]

    train_data = train_data.map(map_labels)
    eval_data = eval_data.map(map_labels)

    def _preprocess_data(examples):
        return tokenizer(examples["context"], truncation = True, padding = "max_length", max_length = MAX_LENGTH)

    tokenized_train = train_data.map(_preprocess_data, batched=True, remove_columns=["context", "description"])
    tokenized_eval = eval_data.map(_preprocess_data, batched=True, remove_columns=["context", "description"])
    print(f"Train samples: {len(tokenized_train)}, Eval samples: {len(tokenized_eval)}")    

    return tokenized_train, tokenized_eval


tokenized_train_ch, tokenized_eval_ch = load_dataset_and_preprocess_for_classification_head()    

Map:   0%|          | 0/1247 [00:00<?, ? examples/s]

Map:   0%|          | 0/312 [00:00<?, ? examples/s]

Map:   0%|          | 0/1247 [00:00<?, ? examples/s]

Map:   0%|          | 0/312 [00:00<?, ? examples/s]

Train samples: 1247, Eval samples: 312


In [30]:
import evaluate
accuracy = evaluate.load("accuracy")
recall = evaluate.load('recall')
precision = evaluate.load("precision")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    _accuracy = accuracy.compute(predictions=predictions, references=labels)
    _recall = recall.compute(predictions=predictions, references=labels, average="macro")
    _precision = precision.compute(predictions=predictions, references=labels, average="macro")

    return {
        "accuracy": _accuracy["accuracy"],
        "recall": _recall["recall"],
        "precision": _precision["precision"],
    }

In [None]:
from transformers import Trainer

def train_with_classification_head(model_id):

    _output_dir = Path(OUTPUT_DIR) / (str(f"{model_id}").replace("/", "_") + "_classification_head")

    model = load_model_with_classification_head(model_id)    

    training_args = TrainingArguments(
        output_dir = _output_dir,
        learning_rate = 4e-5,
        per_device_train_batch_size = 16,
        per_device_eval_batch_size = 256,
        gradient_accumulation_steps = 8,      
        num_train_epochs = 3,
        weight_decay = 0.01,
        logging_steps = 3,
        eval_strategy = "steps",
        eval_steps = 3,
        lr_scheduler_type = "cosine",
        seed = 42,
        report_to = ["none"],
    )

    trainer = Trainer(
        model = model,
        args = training_args,
        train_dataset = tokenized_train_ch,
        eval_dataset = tokenized_eval_ch,
        compute_metrics = compute_metrics,
    )

    trainer.train()
    trainer.evaluate()

In [32]:
train_with_classification_head(LLAMA_MODEL_ID)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Llama-3.1-8B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Accuracy,Recall,Precision
3,9.5312,7.760417,0.051282,0.03587,0.00773
6,7.138,5.779647,0.051282,0.018855,0.013339
9,6.5599,4.685096,0.205128,0.061085,0.057106
12,3.5833,3.597756,0.323718,0.133727,0.14653
15,2.9062,3.223558,0.429487,0.170217,0.162027
18,2.7646,2.939503,0.423077,0.171264,0.149596
21,2.0485,2.860176,0.426282,0.17045,0.153106
24,1.7891,2.844551,0.423077,0.165133,0.151304
27,1.7477,2.844551,0.429487,0.170395,0.155997


In [33]:
train_with_classification_head(FOUNDATION_SEC_8B_MODEL_ID)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at fdtn-ai/Foundation-Sec-8B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Accuracy,Recall,Precision
3,9.8242,6.604968,0.099359,0.058172,0.05437
6,5.5039,3.835737,0.301282,0.107163,0.115338
9,3.3307,3.082532,0.416667,0.163168,0.157744
12,1.9006,2.888221,0.451923,0.199301,0.224316
15,1.5676,2.691907,0.477564,0.209952,0.236418
18,1.3639,2.67508,0.49359,0.213282,0.241888
21,1.091,2.687901,0.496795,0.213918,0.255955
24,0.6136,2.687901,0.5,0.214362,0.256059
27,0.6337,2.687901,0.496795,0.213918,0.255789
