**FINETUNE Qwen3-Next-80B-A3B-Instruct ON REPORT DATA**

In [None]:
!pip install transformers datasets accelerate peft bitsandbytes --quiet


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.1/60.1 MB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# Run in Colab cell
!pip -q install -U "transformers>=4.33.0" "safetensors" "huggingface_hub" "evaluate"


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.1/40.1 kB[0m [31m459.3 kB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.6/11.6 MB[0m [31m45.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m564.3/564.3 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25h

# Fine-tuning (multi-task CTI Bench)
**Objective:** instruction-tune a qwen model on CTI Bench datasets (MCQ, RCM, VSP, TAA) using a single multi-task dataset with explicit task tags and balancing via oversampling (or sampler). We use LoRA (PEFT) + bitsandbytes 4-bit quantization to make the training affordable on a single multi-GB GPU.


**High level decisions & why:**
- **Single multi-task model with task tags** (`[MCQ]`, `[RCM]`, ...) so the model learns all tasks but can be directed by tags at inference time — simpler infra than 4 separate models.
- **LoRA (PEFT)** to fine-tune only a small number of adapter weights (fast, cheap, saves memory). :contentReference[oaicite:2]{index=2}
- **bitsandbytes 4-bit quantization (NF4)** to reduce memory usage enabling larger models on commodity GPUs. Use NF4 for training as recommended. :contentReference[oaicite:3]{index=3}
- **Balancing:** handle dataset-size imbalance either by (A) oversampling small sets to match the largest dataset or (B) weighted sampler. I provide both approaches.


# ---------------------------
# Step 1: Imports
# ---------------------------

In [None]:

from datasets import load_dataset, concatenate_datasets
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
from collections import Counter
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType



# ---------------------------
# Step 2: Dataset loading and Formatting (llm need a consistent input/output structure)
# ---------------------------

In [None]:
from datasets import load_dataset

SUBSETS = ["cti-mcq", "cti-rcm", "cti-vsp", "cti-taa", "cti-ate"]

for subset in SUBSETS:
    ds = load_dataset("AI4Sec/cti-bench", subset, split="test")  # most subsets only have 'test'
    print(f"\n=== Subset: {subset} ===")
    print("Columns:", ds.column_names)
    #print("First example:", ds[0])

README.md: 0.00B [00:00, ?B/s]

cti-mcq.tsv: 0.00B [00:00, ?B/s]

Generating test split:   0%|          | 0/2500 [00:00<?, ? examples/s]


=== Subset: cti-mcq ===
Columns: ['URL', 'Question', 'Option A', 'Option B', 'Option C', 'Option D', 'Prompt', 'GT']


cti-rcm.tsv: 0.00B [00:00, ?B/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]


=== Subset: cti-rcm ===
Columns: ['URL', 'Description', 'Prompt', 'GT']


cti-vsp.tsv: 0.00B [00:00, ?B/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]


=== Subset: cti-vsp ===
Columns: ['URL', 'Description', 'Prompt', 'GT']


cti-taa.tsv: 0.00B [00:00, ?B/s]

Generating test split:   0%|          | 0/50 [00:00<?, ? examples/s]


=== Subset: cti-taa ===
Columns: ['URL', 'Text', 'Prompt']


cti-ate.tsv: 0.00B [00:00, ?B/s]

Generating test split:   0%|          | 0/60 [00:00<?, ? examples/s]


=== Subset: cti-ate ===
Columns: ['URL', 'Platform', 'Description', 'Prompt', 'GT']


In [None]:
from datasets import load_dataset

SUBSETS = ["cti-mcq", "cti-rcm", "cti-vsp", "cti-taa", "cti-ate"]
ds_train_parts = {}
ds_val_parts = {}
def format_example_safe(batch, task):
    # Retourne les champs flatten
    input_texts = []
    target_texts = []
    sources = []

    for i in range(len(batch["Prompt"])):
        prompt = batch["Prompt"][i] if "Prompt" in batch else ""
        target = batch["GT"][i] if "GT" in batch else ""
        input_texts.append(prompt)
        target_texts.append(target)
        sources.append(batch.get("source", [task]*len(batch))[i])

    return {"input_text": input_texts, "target_text": target_texts, "source": sources}

for subset in SUBSETS:
    #  Load original dataset
    ds = load_dataset("AI4Sec/cti-bench", subset, split="test")  # note: only 'test' split exists
    #  Add source column if missing
    if "source" not in ds.column_names:
        ds = ds.add_column("source", [subset] * len(ds))
    #  Split train/validation
    split_ds = ds.train_test_split(test_size=0.2, seed=42)
    train_ds = split_ds["train"]
    val_ds = split_ds["test"]
    train_ds = train_ds.map(lambda batch: format_example_safe(batch, task=subset), batched=True)
    val_ds = val_ds.map(lambda batch: format_example_safe(batch, task=subset), batched=True)

    # Store in dictionaries
    ds_train_parts[subset] = train_ds
    ds_val_parts[subset] = val_ds


Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Map:   0%|          | 0/40 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/48 [00:00<?, ? examples/s]

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

In [None]:
print(ds_train_parts["cti-mcq"][1])
print(ds_train_parts["cti-taa"][1])
print(ds_train_parts["cti-vsp"][1])
print(ds_train_parts["cti-ate"][1])
print(ds_train_parts["cti-rcm"][1])

{'URL': 'https://attack.mitre.org/techniques/T1586/001/', 'Question': 'Which of the following methods is NOT typically used by adversaries to compromise social media accounts under the technique T1586.001?', 'Option A': 'Phishing for Information', 'Option B': 'Brute forcing credentials', 'Option C': 'Purchasing credentials from third-party sites', 'Option D': 'Exploiting zero-day vulnerabilities in social media platforms', 'Prompt': 'You are given a multiple-choice question (MCQ) from a Cyber Threat Intelligence (CTI) knowledge benchmark dataset. Your task is to choose the best option among the four provided. Return your answer as a single uppercase letter: A, B, C, or D.  **Question:** Which of the following methods is NOT typically used by adversaries to compromise social media accounts under the technique T1586.001?  **Options:** A) Phishing for Information B) Brute forcing credentials C) Purchasing credentials from third-party sites D) Exploiting zero-day vulnerabilities in social 

# ---------------------------
# Step 3: combine all subsets and apply weighted sampling
# ---------------------------

In [None]:
from datasets import concatenate_datasets
from torch.utils.data import DataLoader, WeightedRandomSampler
import torch

#  Combine all subsets into one train and one validation dataset
full_train_ds = concatenate_datasets(list(ds_train_parts.values()))
full_val_ds = concatenate_datasets(list(ds_val_parts.values()))

print("Combined train size:", len(full_train_ds))
print("Combined validation size:", len(full_val_ds))

Combined train size: 3688
Combined validation size: 922


In [None]:
full_val_ds[1]

{'URL': 'STIX_part1.txt',
 'Question': 'What is the primary focus of STIX Patterning as described in STIX 2.1?',
 'Option A': 'Automating threat actor communication',
 'Option B': 'Enhancing data storage and serialization',
 'Option C': 'Supporting STIX Indicators',
 'Option D': 'Facilitating secure transport of threat data',
 'Prompt': 'You are given a multiple-choice question (MCQ) from a Cyber Threat Intelligence (CTI) knowledge benchmark dataset. Your task is to choose the best option among the four provided. Return your answer as a single uppercase letter: A, B, C, or D.  **Question:** What is the primary focus of STIX Patterning as described in STIX 2.1?  **Options:** A) Automating threat actor communication B) Enhancing data storage and serialization C) Supporting STIX Indicators D) Facilitating secure transport of threat data  **Important:** The last line of your answer should contain only the single letter corresponding to the best option, with no additional text. ',
 'GT': 'C

In [None]:
full_train_ds[1]

{'URL': 'https://attack.mitre.org/techniques/T1586/001/',
 'Question': 'Which of the following methods is NOT typically used by adversaries to compromise social media accounts under the technique T1586.001?',
 'Option A': 'Phishing for Information',
 'Option B': 'Brute forcing credentials',
 'Option C': 'Purchasing credentials from third-party sites',
 'Option D': 'Exploiting zero-day vulnerabilities in social media platforms',
 'Prompt': 'You are given a multiple-choice question (MCQ) from a Cyber Threat Intelligence (CTI) knowledge benchmark dataset. Your task is to choose the best option among the four provided. Return your answer as a single uppercase letter: A, B, C, or D.  **Question:** Which of the following methods is NOT typically used by adversaries to compromise social media accounts under the technique T1586.001?  **Options:** A) Phishing for Information B) Brute forcing credentials C) Purchasing credentials from third-party sites D) Exploiting zero-day vulnerabilities in s

In [None]:
full_val_ds.column_names

['URL',
 'Question',
 'Option A',
 'Option B',
 'Option C',
 'Option D',
 'Prompt',
 'GT',
 'source',
 'input_text',
 'target_text',
 'Description',
 'Text',
 'Platform']

In [None]:
full_train_ds.column_names


['URL',
 'Question',
 'Option A',
 'Option B',
 'Option C',
 'Option D',
 'Prompt',
 'GT',
 'source',
 'input_text',
 'target_text',
 'Description',
 'Text',
 'Platform']

In [None]:
from datasets import Dataset

# For training
train_data = [
    {"input_text": ex["input_text"], "target_text": ex["target_text"], "source": ex["source"]}
    for ex in full_train_ds
]
clean_train_ds = Dataset.from_list(train_data)

# For validation
val_data = [
    {"input_text": ex["input_text"], "target_text": ex["target_text"], "source": ex["source"]}
    for ex in full_val_ds
]
clean_val_ds = Dataset.from_list(val_data)


In [None]:
clean_train_ds.column_names

['input_text', 'target_text', 'source']

In [None]:
clean_val_ds.column_names

['input_text', 'target_text', 'source']

In [None]:
clean_val_ds[2]

{'input_text': 'You are given a multiple-choice question (MCQ) from a Cyber Threat Intelligence (CTI) knowledge benchmark dataset. Your task is to choose the best option among the four provided. Return your answer as a single uppercase letter: A, B, C, or D.  **Question:** Which tool listed in the document can be used to collect information about domain users, including identification of domain admin accounts?  **Options:** A) dsquery B) AdFind C) BloodHound D) PowerShell  **Important:** The last line of your answer should contain only the single letter corresponding to the best option, with no additional text. ',
 'target_text': 'C',
 'source': 'cti-mcq'}

In [None]:
clean_train_ds[2]


{'input_text': "You are given a multiple-choice question (MCQ) from a Cyber Threat Intelligence (CTI) knowledge benchmark dataset. Your task is to choose the best option among the four provided. Return your answer as a single uppercase letter: A, B, C, or D.  **Question:** Which MITRE ATT&CK technique involves using the command 'rundll32.exe keymgr.dll KRShowKeyMgr' to access credential backups and restorations?  **Options:** A) T1078: Valid Accounts B) T1003: Credential Dumping C) T1555.004: Credentials from Password Stores: Windows Credential Manager D) T1081: Credentials in Files  **Important:** The last line of your answer should contain only the single letter corresponding to the best option, with no additional text. ",
 'target_text': 'C',
 'source': 'cti-mcq'}

In [None]:
#  Compute weights per-example based on source
# Count number of examples per source
from collections import Counter

source_counts = Counter(clean_train_ds["source"])
# Inverse count to give more weight to smaller subsets
inv_counts = {source: 1.0/count for source, count in source_counts.items()}

# Build weights list for each example
example_weights = [inv_counts[src] for src in clean_train_ds["source"]]


In [None]:
# Create WeightedRandomSampler
sampler = WeightedRandomSampler(weights=example_weights,
                                num_samples=len(example_weights),
                                replacement=True)

#  Prepare DataLoader
# If using HuggingFace tokenizer / data_collator
from transformers import AutoTokenizer, DataCollatorForSeq2Seq
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

In [None]:
def flatten_strings(batch):
    # Convert nested lists to strings if needed
    batch["input_text"] = [x[0] if isinstance(x, list) else str(x) for x in batch["input_text"]]
    batch["target_text"] = [x[0] if isinstance(x, list) else str(x) for x in batch["target_text"]]
    return batch

clean_train_ds = clean_train_ds.map(flatten_strings, batched=True)
clean_val_ds = clean_val_ds.map(flatten_strings, batched=True)

Map:   0%|          | 0/3688 [00:00<?, ? examples/s]

Map:   0%|          | 0/922 [00:00<?, ? examples/s]

In [None]:
clean_train_ds

Dataset({
    features: ['input_text', 'target_text', 'source'],
    num_rows: 3688
})

In [None]:
def tokenize_function(examples):
    model_inputs = tokenizer(
        examples["input_text"],
        padding="max_length",
        truncation=True,
        max_length=512
    )
    labels = tokenizer(
        examples["target_text"],
        padding="max_length",
        truncation=True,
        max_length=512
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_train_ds = clean_train_ds.map(tokenize_function, batched=True)
tokenized_val_ds = clean_val_ds.map(tokenize_function, batched=True)


Map:   0%|          | 0/3688 [00:00<?, ? examples/s]

Map:   0%|          | 0/922 [00:00<?, ? examples/s]

In [None]:
tokenized_train_ds.column_names

['input_text',
 'target_text',
 'source',
 'input_ids',
 'attention_mask',
 'labels']

In [None]:
tokenized_train_ds = tokenized_train_ds.remove_columns(
    [c for c in tokenized_train_ds.column_names if c not in ["input_ids","attention_mask","labels"]]
)
tokenized_val_ds = tokenized_val_ds.remove_columns(
    [c for c in tokenized_val_ds.column_names if c not in ["input_ids","attention_mask","labels"]]
)


In [None]:
tokenized_train_ds.column_names

['input_ids', 'attention_mask', 'labels']

In [None]:

data_collator = DataCollatorForSeq2Seq(tokenizer, padding=True, return_tensors="pt")


train_dataloader = DataLoader(
    tokenized_train_ds,
    batch_size=1,
    sampler=sampler,
    collate_fn=data_collator
)

val_dataloader = DataLoader(
    tokenized_val_ds,
    batch_size=1,
    shuffle=False,
    collate_fn=data_collator
)


print("Weighted sampling DataLoader ready.")


Weighted sampling DataLoader ready.


In [None]:
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype="bfloat16"
)
#model_name = "Qwen/Qwen3-4B-Instruct-2507"
model_name = "Qwen/Qwen1.5-0.5B"
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config,device_map="auto")


config.json:   0%|          | 0.00/661 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.24G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

In [None]:
import torch
from tqdm import tqdm
from transformers import get_scheduler
from torch.optim import AdamW

# ✅ Enable gradient checkpointing (saves memory)
model.gradient_checkpointing_enable()

# ✅ Make sure only LoRA params are trainable
trainable_params = [p for p in model.parameters() if p.requires_grad]

print(f"Trainable params: {sum(p.numel() for p in trainable_params)} "
      f"out of {sum(p.numel() for p in model.parameters())} total")

# ✅ Optimizer (LoRA params only)
optimizer = AdamW(trainable_params, lr=2e-4)

# ✅ LR scheduler
num_epochs = 3
num_update_steps_per_epoch = len(train_dataloader) // 1  # since batch_size=1
max_train_steps = num_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=100,
    num_training_steps=max_train_steps,
)

# ✅ Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# ✅ Gradient accumulation
grad_accum_steps = 16

for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0.0

    optimizer.zero_grad()
    for step, batch in enumerate(tqdm(train_dataloader, desc=f"Training Epoch {epoch+1}")):
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model(**batch)
        loss = outputs.loss / grad_accum_steps   # normalize for accumulation
        loss.backward()

        if (step + 1) % grad_accum_steps == 0:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        total_train_loss += loss.item() * grad_accum_steps  # rescale back

    avg_train_loss = total_train_loss / len(train_dataloader)
    print(f"Epoch {epoch+1} | Avg training loss: {avg_train_loss:.4f}")

    # ✅ Validation loop
    model.eval()
    total_val_loss = 0.0
    with torch.no_grad():
        for batch in val_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            total_val_loss += outputs.loss.item()
    avg_val_loss = total_val_loss / len(val_dataloader)
    print(f"Epoch {epoch+1} | Avg validation loss: {avg_val_loss:.4f}")

    # ✅ Save LoRA adapter only
    save_path = f"./qwen-0.5B-lora-epoch{epoch+1}"
    model.save_pretrained(save_path)
    tokenizer.save_pretrained(save_path)
    print(f"Checkpoint saved at {save_path}")


Trainable params: 155632640 out of 309847040 total


Training Epoch 1:   0%|          | 0/3688 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
Training Epoch 1:   1%|          | 20/3688 [11:53:16<2279:30:08, 2237.24s/it]