# Transformer NER Model Loader for Forensic Log Entity Extraction

This notebook demonstrates how to load fine-tuned transformer models and extract 22 entity types from forensic log data for the FTE-HARM framework.

## Purpose
- Load transformer models (DistilBERT, DistilRoBERTa, RoBERTa, XLM-RoBERTa) from Google Drive
- Extract structured entities from raw log text
- Bridge the gap between unstructured logs and FTE-HARM's hypothesis-aligned reasoning

## Entity Types
The models are trained to recognize 22 entity types using BIO tagging:
- **DateTime**: Timestamps (Jan 24 10:30:45)
- **IPAddress**: IP addresses (192.168.1.100)
- **DNSName**: Domain names (example.com)
- **Process**: Process names (sshd, dnsmasq)
- **Username**: User identifiers (admin, root)
- **Action**: Actions/verbs (login, failed, accept)
- And 16 more...

---
## Step 1: Mount Google Drive

The trained models are stored in Google Drive. We need to mount it first.

In [None]:
from google.colab import drive
drive.mount('/content/drive')
print('Google Drive mounted successfully')

---
## Step 2: Install Dependencies

In [None]:
%%capture
!pip install transformers torch

import torch
print(f'PyTorch version: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

---
## Step 3: Configuration

Define model paths and entity labels.

In [None]:
# Model checkpoint paths on Google Drive
TRANSFORMER_MODELS = {
    # DistilBERT (Distilled BERT - Fast, efficient)
    'distilbert': '/content/drive/My Drive/thesis/transformer/distilberta_base_uncased/results/checkpoint-5245',
    
    # DistilRoBERTa (RECOMMENDED - Best balance of speed and accuracy)
    'distilroberta': '/content/drive/My Drive/thesis/transformer/distilroberta_base/results/checkpoint-5275',
    
    # RoBERTa Large (High accuracy, slower)
    'roberta_large': '/content/drive/My Drive/thesis/transformer/roberta_large/results/checkpoint-2772',
    
    # XLM-RoBERTa Base (Multilingual capability)
    'xlm_roberta_base': '/content/drive/My Drive/thesis/transformer/xlm_roberta_base/results/checkpoint-12216',
    
    # XLM-RoBERTa Large (Best accuracy, slowest)
    'xlm_roberta_large': '/content/drive/My Drive/thesis/transformer/xlm_roberta_large/results/checkpoint-12240',
}

# 22 Entity labels based on BIO tagging scheme
ENTITY_LABELS = [
    'O',                        # 0  - Outside (not an entity)
    'B-Action',                 # 1  - Action/verb (login, failed, accept)
    'B-ApplicationSpecific',    # 2  - App-specific terms
    'B-AuthenticationType',     # 3  - Auth methods (password, publickey)
    'B-DNSName',                # 4  - Domain names (begin)
    'I-DNSName',                # 5  - Domain names (continuation)
    'B-DateTime',               # 6  - Timestamps (begin)
    'I-DateTime',               # 7  - Timestamps (continuation)
    'B-Error',                  # 8  - Error messages (begin)
    'I-Error',                  # 9  - Error messages (continuation)
    'B-IPAddress',              # 10 - IP addresses (begin only)
    'B-Object',                 # 11 - File/object names
    'B-Port',                   # 12 - Port numbers
    'B-Process',                # 13 - Process names (sshd, su, dnsmasq)
    'B-Protocol',               # 14 - Network protocols (TCP, UDP)
    'B-Service',                # 15 - Service names
    'B-SessionID',              # 16 - Session identifiers
    'B-Severity',               # 17 - Log severity (error, warn, info)
    'B-Status',                 # 18 - Status indicators (begin)
    'I-Status',                 # 19 - Status indicators (continuation)
    'B-System',                 # 20 - Hostnames/systems
    'B-Username',               # 21 - User identifiers
]

# Create bidirectional mappings
id2label = {i: label for i, label in enumerate(ENTITY_LABELS)}
label2id = {label: i for i, label in enumerate(ENTITY_LABELS)}

print(f'Entity labels defined: {len(ENTITY_LABELS)}')
print(f'Available models: {list(TRANSFORMER_MODELS.keys())}')

---
## Step 4: Load Model and Tokenizer

Select and load a transformer model. **DistilRoBERTa** is recommended for best balance of speed and accuracy.

In [None]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch

# Select model (change this to test different models)
SELECTED_MODEL = 'distilroberta'  # Options: distilbert, distilroberta, roberta_large, xlm_roberta_base, xlm_roberta_large
MODEL_PATH = TRANSFORMER_MODELS[SELECTED_MODEL]

print(f'Loading {SELECTED_MODEL}...')
print(f'Path: {MODEL_PATH}')

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForTokenClassification.from_pretrained(MODEL_PATH)

# Move to GPU if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()

print(f'\nModel loaded: {type(model).__name__}')
print(f'Number of labels: {model.config.num_labels}')
print(f'Architecture: {model.config.model_type}')
print(f'Device: {device}')

---
## Step 5: Define Hybrid Extraction Function

This function combines model predictions with regex post-processing to handle entities that fragment at token boundaries.

### Why Hybrid?
- **Model handles**: DateTime, DNSName, Error, Status (have I- tags)
- **Regex handles**: IPAddress, Process, Username, System (no I- tags, fragment)

In [None]:
import re

def extract_entities_bio(log_line):
    """
    HYBRID extraction: Model predictions + regex post-processing
    
    Args:
        log_line (str): Raw log entry text
    
    Returns:
        list of tuples: [(entity_type, value), ...]
    """
    
    # ========== STEP 1: TOKENIZATION WITH OFFSET MAPPING ==========
    inputs = tokenizer(
        log_line,
        return_tensors='pt',
        truncation=True,
        padding=True,
        return_offsets_mapping=True
    )
    
    offset_mapping = inputs.pop('offset_mapping')[0].tolist()
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # ========== STEP 2: MODEL PREDICTION ==========
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = torch.argmax(outputs.logits, dim=-1)
    
    pred_labels = [id2label[p.item()] for p in predictions[0]]
    
    # ========== STEP 3: FIX MULTIPLE B- TAGS (CRITICAL) ==========
    corrected_labels = []
    prev_entity_type = None
    
    for label in pred_labels:
        if label.startswith('B-'):
            entity_type = label[2:]
            if entity_type == prev_entity_type:
                corrected_labels.append(f'I-{entity_type}')
            else:
                corrected_labels.append(label)
                prev_entity_type = entity_type
        elif label.startswith('I-'):
            corrected_labels.append(label)
            prev_entity_type = label[2:]
        else:
            corrected_labels.append(label)
            prev_entity_type = None
    
    pred_labels = corrected_labels
    
    # ========== STEP 4: EXTRACT ENTITIES FROM MODEL ==========
    model_entities = []
    current_entity_type = None
    entity_spans = []
    
    for idx, (label, (start, end)) in enumerate(zip(pred_labels, offset_mapping)):
        if start == 0 and end == 0:  # Special token
            continue
        
        if label.startswith('B-'):
            if current_entity_type and entity_spans:
                entity_start = entity_spans[0][0]
                entity_end = entity_spans[-1][1]
                entity_value = log_line[entity_start:entity_end].strip()
                if entity_value:
                    model_entities.append((current_entity_type, entity_value))
            
            current_entity_type = label[2:]
            entity_spans = [(start, end)]
        
        elif label.startswith('I-') and current_entity_type:
            entity_type = label[2:]
            if entity_type == current_entity_type:
                entity_spans.append((start, end))
            else:
                if entity_spans:
                    entity_start = entity_spans[0][0]
                    entity_end = entity_spans[-1][1]
                    entity_value = log_line[entity_start:entity_end].strip()
                    if entity_value:
                        model_entities.append((current_entity_type, entity_value))
                current_entity_type = entity_type
                entity_spans = [(start, end)]
        
        elif label == 'O':
            if current_entity_type and entity_spans:
                entity_start = entity_spans[0][0]
                entity_end = entity_spans[-1][1]
                entity_value = log_line[entity_start:entity_end].strip()
                if entity_value:
                    model_entities.append((current_entity_type, entity_value))
            current_entity_type = None
            entity_spans = []
    
    if current_entity_type and entity_spans:
        entity_start = entity_spans[0][0]
        entity_end = entity_spans[-1][1]
        entity_value = log_line[entity_start:entity_end].strip()
        if entity_value:
            model_entities.append((current_entity_type, entity_value))
    
    # ========== STEP 5: HYBRID POST-PROCESSING ==========
    fragmented_types = {'IPAddress', 'Process', 'Username', 'System'}
    entities = [e for e in model_entities if e[0] not in fragmented_types]
    entities = [e for e in entities if e[0] != 'DNSName']  # DNS also fragments
    
    # ========== STEP 6: REGEX EXTRACTIONS ==========
    
    # IP Addresses
    ip_pattern = r'\b(?:\d{1,3}\.){3}\d{1,3}\b'
    for match in re.finditer(ip_pattern, log_line):
        entities.append(('IPAddress', match.group()))
    
    # DNS Names
    dns_pattern = r'\b([a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?(?:\.[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?)+)\b'
    for match in re.finditer(dns_pattern, log_line, re.IGNORECASE):
        domain = match.group()
        if not re.match(r'^\d+\.\d+', domain):
            entities.append(('DNSName', domain))
    
    # Process names with PIDs
    process_pattern = r'\b([a-zA-Z_][a-zA-Z0-9_-]*)\[(\d+)\]'
    for match in re.finditer(process_pattern, log_line):
        entities.append(('Process', match.group(1)))
        entities.append(('ProcessID', match.group(2)))
    
    # Usernames
    username_patterns = [
        r'\bfor\s+([a-z_][a-z0-9_-]*)\b',
        r'\buser\s+([a-z_][a-z0-9_-]*)\b',
        r'\bby\s+([a-z_][a-z0-9_-]*)\b',
    ]
    excluded = {'root', 'unknown', 'invalid', 'port', 'from', 'to', 'on', 'at'}
    
    for pattern in username_patterns:
        for match in re.finditer(pattern, log_line, re.IGNORECASE):
            username = match.group(1).lower()
            if username not in excluded and len(username) > 1:
                entities.append(('Username', match.group(1)))
    
    # System/Hostname
    parts = log_line.split()
    if len(parts) > 3:
        potential_host = parts[3].rstrip(':')
        if re.match(r'^[a-zA-Z][a-zA-Z0-9-]*$', potential_host) and len(potential_host) > 2:
            entities.append(('System', potential_host))
    
    # Deduplicate
    seen = set()
    unique = []
    for e in entities:
        if e not in seen:
            seen.add(e)
            unique.append(e)
    
    return unique

print('Hybrid extraction function defined')

---
## Step 6: Test Extraction

Test the extraction on sample forensic logs.

In [None]:
# Test 1: DNS Query Log
dns_log = "Jan 24 10:30:45 dnsmasq[1234]: query[A] example.com from 192.168.1.100"
entities = extract_entities_bio(dns_log)

print('=' * 60)
print('TEST 1: DNS Query Log')
print('=' * 60)
print(f'Input: {dns_log}')
print('\nExtracted Entities:')
for entity_type, value in entities:
    print(f'  [{entity_type}] {value}')

In [None]:
# Test 2: SSH Authentication Log
ssh_log = "Jan 24 04:37:40 intranet-server su[27950]: Successful su for jhall by www-data"
entities = extract_entities_bio(ssh_log)

print('=' * 60)
print('TEST 2: SSH Authentication Log')
print('=' * 60)
print(f'Input: {ssh_log}')
print('\nExtracted Entities:')
for entity_type, value in entities:
    print(f'  [{entity_type}] {value}')

In [None]:
# Test 3: Failed Login Attempt
failed_log = "Jan 24 10:15:32 server sshd[5678]: Failed password for invalid user admin from 10.0.0.50 port 22 ssh2"
entities = extract_entities_bio(failed_log)

print('=' * 60)
print('TEST 3: Failed Login Attempt')
print('=' * 60)
print(f'Input: {failed_log}')
print('\nExtracted Entities:')
for entity_type, value in entities:
    print(f'  [{entity_type}] {value}')

---
## Step 7: FTE-HARM Integration

Convert extracted entities to formats compatible with FTE-HARM.

In [None]:
def entities_to_dict(entities):
    """Convert entity list to dictionary format."""
    entity_dict = {}
    for entity_type, value in entities:
        if entity_type not in entity_dict:
            entity_dict[entity_type] = []
        entity_dict[entity_type].append(value)
    return entity_dict

def entities_to_tagged_string(entities):
    """Convert entities to tagged string format."""
    return ' '.join(f'[{t}: {v}]' for t, v in entities)

def format_for_fte_harm(log_line, entities):
    """Format extracted entities for FTE-HARM input."""
    return {
        'raw_log': log_line,
        'entities': entities_to_dict(entities),
        'entity_list': entities,
        'tagged_text': entities_to_tagged_string(entities),
    }

# Example
log = "Jan 24 10:30:45 server sshd[1234]: Failed login for admin from 192.168.1.100"
entities = extract_entities_bio(log)
fte_harm_input = format_for_fte_harm(log, entities)

print('FTE-HARM Input Format:')
print('-' * 40)
for key, value in fte_harm_input.items():
    print(f'{key}: {value}')

---
## Step 8: Batch Processing

Process multiple log entries at once.

In [None]:
# Sample batch of logs
sample_logs = [
    "Jan 24 10:30:45 dnsmasq[1234]: query[A] malware-c2.evil.com from 192.168.1.50",
    "Jan 24 10:31:02 server sshd[5678]: Failed password for root from 10.0.0.100 port 22",
    "Jan 24 10:31:15 firewall kernel: DENY TCP 192.168.1.50:45678 -> 8.8.8.8:53",
    "Jan 24 10:31:30 intranet su[9876]: Successful su for admin by www-data",
    "Jan 24 10:32:00 dnsmasq[1234]: reply malware-c2.evil.com is 198.51.100.1",
]

print('Batch Processing Results')
print('=' * 70)

for i, log in enumerate(sample_logs, 1):
    entities = extract_entities_bio(log)
    print(f'\nLog {i}: {log[:60]}...')
    print('Entities:')
    for entity_type, value in entities:
        print(f'  [{entity_type}] {value}')

---
## Validation Checklist

Verify that all requirements are met.

In [None]:
print('VALIDATION CHECKLIST')
print('=' * 50)

# Check model loaded
model_loaded = model is not None and tokenizer is not None
print(f'[{"PASS" if model_loaded else "FAIL"}] Model loads without errors')

# Check entity labels
labels_correct = len(ENTITY_LABELS) == 22
print(f'[{"PASS" if labels_correct else "FAIL"}] 22 entity labels defined correctly')

# Check mappings
mappings_ok = len(id2label) == 22 and len(label2id) == 22
print(f'[{"PASS" if mappings_ok else "FAIL"}] id2label and label2id mappings created')

# Test extraction
test_log = "Jan 24 10:30:45 server sshd[1234]: Failed login from 192.168.1.100"
test_entities = extract_entities_bio(test_log)
extraction_works = len(test_entities) > 0
print(f'[{"PASS" if extraction_works else "FAIL"}] Hybrid extraction function executes')

# Check IP not fragmented
ip_entities = [e for e in test_entities if e[0] == 'IPAddress']
ip_complete = any('192.168.1.100' in e[1] for e in ip_entities)
print(f'[{"PASS" if ip_complete else "FAIL"}] No fragmented IP addresses')

# Check process extracted
process_entities = [e for e in test_entities if e[0] == 'Process']
process_ok = any('sshd' in e[1] for e in process_entities)
print(f'[{"PASS" if process_ok else "FAIL"}] Complete process names with PIDs')

# Output format check
format_ok = all(isinstance(e, tuple) and len(e) == 2 for e in test_entities)
print(f'[{"PASS" if format_ok else "FAIL"}] Output format: [(entity_type, value), ...]')

print('\n' + '=' * 50)
all_pass = all([model_loaded, labels_correct, mappings_ok, extraction_works, ip_complete, process_ok, format_ok])
print(f'Overall: {"ALL CHECKS PASSED" if all_pass else "SOME CHECKS FAILED"}')

---
## Model Comparison (Optional)

Compare different models on the same log data.

In [None]:
# Uncomment to compare models (takes time to load each model)
"""
import time

test_log = "Jan 24 10:30:45 server sshd[1234]: Failed login for admin from 192.168.1.100"

for model_name in ['distilbert', 'distilroberta', 'roberta_large']:
    print(f'\n{"="*50}')
    print(f'Testing: {model_name}')
    print(f'{"="*50}')
    
    # Load model
    path = TRANSFORMER_MODELS[model_name]
    tokenizer = AutoTokenizer.from_pretrained(path)
    model = AutoModelForTokenClassification.from_pretrained(path)
    model.to(device)
    model.eval()
    
    # Time inference
    start = time.time()
    for _ in range(10):
        entities = extract_entities_bio(test_log)
    elapsed = (time.time() - start) / 10 * 1000
    
    print(f'Avg inference time: {elapsed:.2f}ms')
    print('Entities:')
    for entity_type, value in entities:
        print(f'  [{entity_type}] {value}')
"""
print('Model comparison code available - uncomment to run')

---
## Summary

### What We Achieved
- Loaded fine-tuned transformer models from Google Drive
- Implemented hybrid entity extraction (model + regex)
- Extracted 22 entity types from forensic log data
- Created FTE-HARM compatible output formats

### Key Findings
- Pure model extraction fragments entities without I- tags (IPAddress, Process, etc.)
- Hybrid approach solves fragmentation with regex post-processing
- DistilRoBERTa provides best balance of speed and accuracy

### Next Steps
1. Integrate with Dataset Loader to process log files
2. Feed extracted entities to FTE-HARM for hypothesis scoring
3. Validate on real forensic datasets