In [95]:
from datasets import load_dataset
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTConfig, SFTTrainer
import gc
import wandb

In [96]:
wandb.init(project="Qwen-fine-tuning", name="14b-ioc-extraction")

[34m[1mwandb[0m: Currently logged in as: [33mt-p-angevare[0m ([33mt-p-angevare-university-of-twente[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [97]:
gc.collect()
torch.cuda.empty_cache()

In [98]:
dataset = load_dataset("ai4privacy/pii-masking-300k")
dataset = dataset.filter(lambda x: x['language'] == 'English')
dataset = dataset.select_columns(["source_text", "privacy_mask", "id"])
dataset

DatasetDict({
    train: Dataset({
        features: ['source_text', 'privacy_mask', 'id'],
        num_rows: 29908
    })
    validation: Dataset({
        features: ['source_text', 'privacy_mask', 'id'],
        num_rows: 7946
    })
})

In [99]:
dataset_entity_mapping = {
    'USERNAME' : 'USERNAME',
    'EMAIL' : 'EMAIL',
    'LASTNAME1' : 'PERSON',
    'IP' : 'IP',
    'GIVENNAME1' : 'PERSON',
    'TEL' : 'PHONE',
    'CITY' : 'LOCATION',
    'POSTCODE' : 'LOCATION',
    'STREET': 'LOCATION',
    'STATE' : 'LOCATION',
    'BUILDING' : 'LOCATION',
    'COUNTRY' : 'LOCATION',
    'SECADDRESS' : 'LOCATION',
    'LASTNAME2' : 'PERSON',
    'GIVENNAME2' : 'PERSON',
    'GEOCOORD' : 'LOCATION',
    'LASTNAME3' : 'PERSON'
}

In [100]:
def clean_entities(privacy_mask):
    new_entities = []
    for ent in privacy_mask:
        if ent['label'] in dataset_entity_mapping.keys():
            new_entities.append({
                'type': dataset_entity_mapping.get(ent['label']),
                'text' : ent['value'],
                'start_pos' : ent['start'],
                'end_pos' : ent['end']
            })
    return new_entities

In [101]:
dataset = dataset.map(lambda x: {'privacy_mask': clean_entities(x['privacy_mask'])})

In [102]:
dataset['train'][0]

{'source_text': 'Subject: Group Messaging for Admissions Process\n\nGood morning, everyone,\n\nI hope this message finds you well. As we continue our admissions processes, I would like to update you on the latest developments and key information. Please find below the timeline for our upcoming meetings:\n\n- wynqvrh053 - Meeting at 10:20am\n- luka.burg - Meeting at 21\n- qahil.wittauer - Meeting at quarter past 13\n- gholamhossein.ruschke - Meeting at 9:47 PM\n- pdmjrsyoz1460 ',
 'privacy_mask': [{'end_pos': 297,
   'start_pos': 287,
   'text': 'wynqvrh053',
   'type': 'USERNAME'},
  {'end_pos': 330, 'start_pos': 321, 'text': 'luka.burg', 'type': 'USERNAME'},
  {'end_pos': 363,
   'start_pos': 349,
   'text': 'qahil.wittauer',
   'type': 'USERNAME'},
  {'end_pos': 416,
   'start_pos': 395,
   'text': 'gholamhossein.ruschke',
   'type': 'USERNAME'},
  {'end_pos': 453,
   'start_pos': 440,
   'text': 'pdmjrsyoz1460',
   'type': 'USERNAME'}],
 'id': '40767A'}

In [103]:
prompt = """
You are a cyber intelligence analyst with 20 years of experience in the the field.

Your task is to extract any entity from the input text. For each entity found you MUST indicate the type in UPPERCASE. ONLY extract entities if literal entity is present in input text.
The expected entity types are the following:

- EMAIL: email addresses format (user@domain.tld)
- IP: IP addresses (IPv4 x.x.x.x or IPv6)
- BTC: ONLY Bitcoin wallet addresses (26-35 alphanumeric, starting with 1, 3, or bc1) EXCLUDE the word bitcoin or values (for example 2.0 BTC)
- IBAN: iban bank account number
- PERSON: Human names (John Smith, John, Catalina) EXCLUDE initials (for example A.H.) 
- LOCATION: cities, countries, geographic locations, regions
- PHONE: phone numbers in any format
- WEB_RESOURCE: URLs and web addresses EXCLUDE filenames
 
**Output**:
The output MUST be in a JSON object with key 'entities' and the value a list of dictionaries including every entity found. For each entity you MUST indicate the type in UPPERCASE.

**OUTPUT EXAMPLE**:
{
  "entities": [
    {"entity": "target123@darkmail.org", "type": "EMAIL"},
    {"entity": "10.45.67.89", "type": "IP"},
    {"entity": "Thompson", "type": "PERSON"},
    {"entity": "Helsinki", "type": "LOCATION"},
    {"entity": "Tim", "type": "PERSON"}
  ]
}

Return empty array if no entities found in the input text.
PAY ATTENTION to sentences that begin with entity type PERSON, for example Anna.
PAY ATTENTION to when the sentences begin with possesive forms of entity type PERSON, for example Catalina's
PAY ATTENTION to when the sentences contain a FULL NAME, the FULL NAME MUST be extracted as ONE entity.
DO NOT include any entities from the example or the system prompt in your answer.

**Verification**:
1. profide your answer in valid JSON format.
2. verify that all extracted entities are present in the input text.
3. verify that no entities from the example or system prompt are included in your answer.
4. verify that extracted entities match the expected formats for their types.
5. provide your final revised andwer based on the verifications above.
"""

In [104]:
model_name = "Qwen/Qwen2.5-14B" 
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

In [105]:
import json

def convert_to_chatml(source_text, privacy_mask):
    # Convert to the JSON format expected by the prompt
    entities_json = {
        "entities": [
            {"entity": ent['text'], "type": ent['type']} 
            for ent in privacy_mask
        ]
    }
    
    return [
        {"role": "system", "content": prompt},
        {"role": "user", "content": source_text},
        {"role": "assistant", "content": json.dumps(entities_json, indent=2)}
    ]

In [106]:

dataset = dataset.map(lambda x: {"messages": convert_to_chatml(x['source_text'], x['privacy_mask'])})

Map:   0%|          | 0/29908 [00:00<?, ? examples/s]

Map:   0%|          | 0/7946 [00:00<?, ? examples/s]

In [107]:
# Reduced dataset for faster training (~2-3 hours instead of 12+)
train = dataset['train'].select(range(5000))  # 5k samples (was 30k)
val = dataset['validation'].select(range(500))

print(f"Training samples: {len(train)}")
print(f"Validation samples: {len(val)}")

Training samples: 5000
Validation samples: 500


In [108]:
# verify format with thinking tokens
example = train[0]
print("=== Source Text (truncated) ===")
print(example['source_text'][:200] + "...")
print("\n=== User Message (prompt + input) ===")
print(example['messages'][0]['content'][:300] + "...")
print("\n=== Assistant Response (with thinking tokens) ===")
print(example['messages'][1]['content'][:500] + "...")
print("\n=== Format: <think>reasoning</think> + JSON output ===")

=== Source Text (truncated) ===
Subject: Group Messaging for Admissions Process

Good morning, everyone,

I hope this message finds you well. As we continue our admissions processes, I would like to update you on the latest developm...

=== User Message (prompt + input) ===

You are a cyber intelligence analyst with 20 years of experience in the the field.

Your task is to extract any entity from the input text. For each entity found you MUST indicate the type in UPPERCASE. ONLY extract entities if literal entity is present in input text.
The expected entity types are ...

=== Assistant Response (with thinking tokens) ===
Subject: Group Messaging for Admissions Process

Good morning, everyone,

I hope this message finds you well. As we continue our admissions processes, I would like to update you on the latest developments and key information. Please find below the timeline for our upcoming meetings:

- wynqvrh053 - Meeting at 10:20am
- luka.burg - Meeting at 21
- qahil.wittauer - Meet

In [109]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

In [110]:
model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True,
        use_cache=False
    )

model.gradient_checkpointing_enable()

config.json:   0%|          | 0.00/664 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 8 files:   0%|          | 0/8 [00:00<?, ?it/s]

model-00006-of-00008.safetensors:   0%|          | 0.00/4.00G [00:00<?, ?B/s]

model-00002-of-00008.safetensors:   0%|          | 0.00/4.00G [00:00<?, ?B/s]

model-00008-of-00008.safetensors:   0%|          | 0.00/1.70G [00:00<?, ?B/s]

model-00004-of-00008.safetensors:   0%|          | 0.00/4.00G [00:00<?, ?B/s]

model-00005-of-00008.safetensors:   0%|          | 0.00/3.98G [00:00<?, ?B/s]

model-00003-of-00008.safetensors:   0%|          | 0.00/4.00G [00:00<?, ?B/s]

model-00007-of-00008.safetensors:   0%|          | 0.00/4.00G [00:00<?, ?B/s]

model-00001-of-00008.safetensors:   0%|          | 0.00/3.89G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

In [111]:
model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=8,                   
    lora_alpha=16,           
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

In [114]:
from transformers import EarlyStoppingCallback

training_args = SFTConfig(
    output_dir="./sft_qwen_14b_output",

    num_train_epochs=3,                 
    
    max_length=512,
    per_device_train_batch_size=1,       
    gradient_accumulation_steps=16,      
    

    learning_rate=1e-5,                  
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    weight_decay=0.01,
    max_grad_norm=0.5,                   
    
    logging_steps=10,
    save_steps=100,
    eval_strategy="steps",
    eval_steps=50,                      
    

    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    save_total_limit=3,
    

    packing=False,
    report_to="wandb",
    run_name="qwen-14b-pii",
    bf16=True,
    optim="adamw_8bit",
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=val,
    peft_config=lora_config,
    processing_class=tokenizer,
    callbacks=[
        EarlyStoppingCallback(
            early_stopping_patience=2,
            early_stopping_threshold=0.005
        )
    ]
)

print("Training configuration for DeepSeek-R1-Distill-Qwen-14B:")
print(f"  - Train samples: {len(train)}")
print(f"  - Epochs: {training_args.num_train_epochs}")
print(f"  - Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  - Learning rate: {training_args.learning_rate}")
print(f"  - Eval every {training_args.eval_steps} steps")
print(f"  - Max grad norm: {training_args.max_grad_norm}")
print(f"  - Early stopping patience: 2 evals")



Training configuration for DeepSeek-R1-Distill-Qwen-14B:
  - Train samples: 5000
  - Epochs: 3
  - Effective batch size: 16
  - Learning rate: 1e-05
  - Eval every 50 steps
  - Max grad norm: 0.5
  - Early stopping patience: 2 evals


In [115]:
trainer.train()

Step,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
50,2.1315,2.11563,1.958243,409600.0,0.577299
100,1.9718,1.944338,2.058485,819200.0,0.587022
150,1.5434,1.468676,1.553811,1228800.0,0.628149
200,0.5826,0.451372,0.602679,1638400.0,0.892368
250,0.0065,0.00492,0.034833,2048000.0,1.0
300,0.0037,0.003593,0.027035,2457600.0,1.0
350,0.0036,0.003484,0.026362,2863104.0,1.0


TrainOutput(global_step=350, training_loss=1.0126837648664202, metrics={'train_runtime': 47300.4894, 'train_samples_per_second': 0.317, 'train_steps_per_second': 0.02, 'total_flos': 2.405702904619991e+17, 'train_loss': 1.0126837648664202, 'epoch': 1.1184})

In [117]:
# Save the fine-tuned model
trainer.save_model("./sft_qwen_14b_output/final_model")
print("Model saved to ./sft_qwen_14b_output/final_model")

# Log final metrics
if wandb.run:
    wandb.finish()

Model saved to ./sft_qwen_14b_output/final_model


In [118]:
api = wandb.Api()
run = api.run("/t-p-angevare-university-of-twente/transformer-fine-tuning/runs/rjx3lp23")
history = run.history()
print(history)

    eval/loss  train/entropy  train/epoch  eval/entropy  \
0         NaN       1.284441        0.016           NaN   
1         NaN       1.269152        0.032           NaN   
2         NaN       1.277675        0.048           NaN   
3         NaN       1.298521        0.064           NaN   
4         NaN       1.296141        0.080           NaN   
..        ...            ...          ...           ...   
56        NaN       0.450221        0.768           NaN   
57        NaN       0.438981        0.784           NaN   
58        NaN       0.450000        0.800           NaN   
59   0.418354            NaN        0.800       0.43036   
60        NaN            NaN        0.800           NaN   

    eval/steps_per_second    _timestamp  eval/num_tokens  train/num_tokens  \
0                     NaN  1.767459e+09              NaN          129032.0   
1                     NaN  1.767460e+09              NaN          257738.0   
2                     NaN  1.767461e+09              NaN 

In [None]:
from peft import PeftModel
import json

gc.collect()
torch.cuda.empty_cache()


print("Loading base model for inference...")
inference_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)


print("Loading fine-tuned LoRA adapters...")
inference_model = PeftModel.from_pretrained(inference_model, "./sft_14b_output/final_model")
inference_model.eval()

print("Model loaded successfully!")

Loading base model for inference...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading fine-tuned LoRA adapters...
Model loaded successfully!


In [None]:
import re

def extract_entities(text, max_new_tokens=512):
    """Run entity extraction on input text using the fine-tuned R1-Distill model."""
    
    # Format as chat message - R1-Distill: no system prompt
    messages = [
        {"role": "user", "content": prompt + "\n\nInput text:\n" + text}
    ]
    
    # Tokenize
    input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(input_text, return_tensors="pt").to(inference_model.device)
    
    # Generate with recommended R1-Distill settings
    with torch.no_grad():
        outputs = inference_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.6,  # Recommended for R1-Distill (0.5-0.7)
            top_p=0.95,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    # Decode only the generated part
    generated = outputs[0][inputs['input_ids'].shape[1]:]
    response = tokenizer.decode(generated, skip_special_tokens=True)
    
    # Strip thinking tokens before parsing JSON
    response_clean = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL).strip()
    
    # Try to parse as JSON
    try:
        # Find JSON in response
        start = response_clean.find('{')
        end = response_clean.rfind('}') + 1
        if start != -1 and end > start:
            result = json.loads(response_clean[start:end])
            return result, response
    except json.JSONDecodeError:
        pass
    
    return None, response

In [None]:
# Test on various examples
test_cases = [
    # Test 1: Email and IP
    "Please contact John Smith at john.smith@company.com. The server IP is 192.168.1.100.",
    
    # Test 2: Location and person  
    "Maria Garcia traveled from Madrid to New York last week.",
    
    # Test 3: Phone number
    "Call me at +1-555-123-4567 or email support@example.org",
    
    # Test 4: Mixed entities (username)
    "The suspect known as darkh4cker_99 was traced to IP 10.0.0.1 in Berlin, Germany.",
]

print("=" * 80)
print("TESTING FINE-TUNED MODEL")
print("=" * 80)

for i, test_text in enumerate(test_cases):
    print(f"\n{'='*80}")
    print(f"TEST {i+1}")
    print(f"{'='*80}")
    print(f"INPUT: {test_text[:200]}{'...' if len(test_text) > 200 else ''}")
    print("-" * 40)
    
    result, raw_response = extract_entities(test_text)
    
    if result and 'entities' in result:
        print(f"EXTRACTED ENTITIES ({len(result['entities'])} found):")
        for ent in result['entities']:
            print(f"  - {ent.get('type', 'UNKNOWN')}: {ent.get('entity', 'N/A')}")
    else:
        print(f"RAW RESPONSE: {raw_response[:500]}")

print("\n" + "=" * 80)
print("TESTING COMPLETE")
print("=" * 80)

TESTING FINE-TUNED MODEL

TEST 1
INPUT: Please contact John Smith at john.smith@company.com. The server IP is 192.168.1.100.
----------------------------------------




EXTRACTED ENTITIES (2 found):
  - EMAIL: john.smith@company.com
  - IP: 192.168.1.100

TEST 2
INPUT: Maria Garcia traveled from Madrid to New York last week.
----------------------------------------
EXTRACTED ENTITIES (3 found):
  - PERSON: Maria Garcia
  - LOCATION: Madrid
  - LOCATION: New York

TEST 3
INPUT: Call me at +1-555-123-4567 or email support@example.org
----------------------------------------
EXTRACTED ENTITIES (2 found):
  - PHONE: +1-555-123-4567
  - EMAIL: support@example.org

TEST 4
INPUT: The suspect known as darkh4cker_99 was traced to IP 10.0.0.1 in Berlin, Germany.
----------------------------------------
EXTRACTED ENTITIES (3 found):
  - PERSON: darkh4cker_99
  - IP: 10.0.0.1
  - LOCATION: Berlin

TESTING COMPLETE


In [None]:
# Compare with ground truth from validation set
print("\n" + "=" * 80)
print("COMPARISON WITH GROUND TRUTH (Validation Sample)")
print("=" * 80)

sample = val[0]
print(f"\nINPUT TEXT:\n{sample['source_text'][:300]}...")

# Get model prediction
result, _ = extract_entities(sample['source_text'])


ground_truth = sample['privacy_mask']

print(f"\n{'GROUND TRUTH':-^40}")
for ent in ground_truth[:10]: 
    print(f"  - {ent['type']}: {ent['text']}")
if len(ground_truth) > 10:
    print(f"  ... and {len(ground_truth) - 10} more")

print(f"\n{'MODEL PREDICTION':-^40}")
if result and 'entities' in result:
    for ent in result['entities'][:10]:
        print(f"  - {ent.get('type', 'UNKNOWN')}: {ent.get('entity', 'N/A')}")
    if len(result['entities']) > 10:
        print(f"  ... and {len(result['entities']) - 10} more")
else:
    print("  No entities extracted or invalid JSON response")