# FTE-HARM Complete Validation Pipeline v3

## With Dataset-Specific Hypotheses, Threshold Testing, and MITRE Corroboration

---

## Overview

This notebook implements the complete FTE-HARM validation with:

1. **Dataset-Specific Hypotheses** - Target AITv2 labels directly
2. **MITRE ATT&CK Metadata** - For forensic corroboration
3. **Threshold Testing** - Find optimal threshold prioritizing HIGH RECALL
4. **Label Mapping** - Automatic matching between hypotheses and ground truth
5. **Two P_Score Methods** - Option A (Binary) and Option B3 (Confidence-Weighted)

**Validation Goal:** HIGH RECALL first (catch all evidence), acceptable precision second.

## Cell 1: Imports and Paths

In [None]:
# =============================================================================
# FTE-HARM COMPLETE VALIDATION v3: IMPORTS AND PATHS
# =============================================================================

import os
import re
import json
import numpy as np
from datetime import datetime
from collections import defaultdict, Counter
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# -----------------------------------------------------------------------------
# PATH CONFIGURATION
# -----------------------------------------------------------------------------

DATASET_BASE_PATH = '/content/drive/My Drive/thesis/dataset'
OUTPUT_PATH = '/content/drive/My Drive/thesis/hypotheses_validation'
SUMMARY_PATH = '/content/drive/My Drive/thesis/hypotheses_validation/summary'
THRESHOLD_TEST_PATH = '/content/drive/My Drive/thesis/hypotheses_validation/threshold_test'
MITRE_PATH = '/content/drive/My Drive/thesis/hypotheses_validation/mitre_att&ck'

# Create directories
for path in [OUTPUT_PATH, SUMMARY_PATH, THRESHOLD_TEST_PATH, MITRE_PATH]:
    os.makedirs(path, exist_ok=True)

# Model paths
MODELS = {
    'distilbert': '/content/drive/My Drive/thesis/transformer/distilberta_base_uncased/results/checkpoint-5245',
    'distilroberta': '/content/drive/My Drive/thesis/transformer/distilroberta_base/results/checkpoint-5275',
    'roberta_large': '/content/drive/My Drive/thesis/transformer/roberta_large/results/checkpoint-2772',
    'xlm_roberta_base': '/content/drive/My Drive/thesis/transformer/xlm_roberta_base/results/checkpoint-12216',
    'xlm_roberta_large': '/content/drive/My Drive/thesis/transformer/xlm_roberta_large/results/checkpoint-12240',
}

SELECTED_MODEL = 'xlm_roberta_large'

print("Configuration loaded")
print(f"Threshold test output: {THRESHOLD_TEST_PATH}")

## Cell 2: Dataset-Specific Hypotheses (With MITRE Metadata)

In [None]:
# =============================================================================
# DATASET-SPECIFIC HYPOTHESES FOR AITv2
# With MITRE ATT&CK Metadata for Forensic Corroboration
# =============================================================================

DATASET_HYPOTHESES = {
    
    # -------------------------------------------------------------------------
    # PRIVILEGE ESCALATION
    # -------------------------------------------------------------------------
    
    'H1_attacker_change_user': {
        'name': 'H1_attacker_change_user',
        'description': 'Attacker switches user identity via su/sudo',
        'target_labels': ['attacker_change_user', 'escalate'],
        'mitre_technique': 'T1548.003',
        'mitre_tactic': 'Privilege Escalation',
        'weights': {
            'Username': 0.30, 'Process': 0.25, 'Action': 0.20,
            'DateTime': 0.15, 'Status': 0.10
        },
        'critical_entity': 'Username',
        'penalty_factor': 0.25,
        'threshold': 0.50,
        'log_sources': ['auth.log', 'audit.log', 'secure']
    },
    
    'H2_escalate': {
        'name': 'H2_escalate',
        'description': 'Generic privilege escalation',
        'target_labels': ['escalate', 'attacker_change_user'],
        'mitre_technique': 'T1068',
        'mitre_tactic': 'Privilege Escalation',
        'weights': {
            'Process': 0.30, 'Username': 0.25, 'Action': 0.20,
            'Status': 0.15, 'DateTime': 0.10
        },
        'critical_entity': 'Process',
        'penalty_factor': 0.25,
        'threshold': 0.50,
        'log_sources': ['auth.log', 'audit.log', 'syslog']
    },
    
    # -------------------------------------------------------------------------
    # DISCOVERY / SCANNING
    # -------------------------------------------------------------------------
    
    'H3_scan': {
        'name': 'H3_scan',
        'description': 'Network or host scanning activity',
        'target_labels': ['scan'],
        'mitre_technique': 'T1046',
        'mitre_tactic': 'Discovery',
        'weights': {
            'IPAddress': 0.30, 'Port': 0.25, 'Process': 0.20,
            'Action': 0.15, 'DateTime': 0.10
        },
        'critical_entity': 'IPAddress',
        'penalty_factor': 0.20,
        'threshold': 0.45,
        'log_sources': ['suricata', 'firewall', 'apache_access']
    },
    
    # -------------------------------------------------------------------------
    # CREDENTIAL ACCESS
    # -------------------------------------------------------------------------
    
    'H4_crack': {
        'name': 'H4_crack',
        'description': 'Password cracking or brute-force attempts',
        'target_labels': ['crack'],
        'mitre_technique': 'T1110.001',
        'mitre_tactic': 'Credential Access',
        'weights': {
            'Username': 0.30, 'Action': 0.25, 'Status': 0.20,
            'IPAddress': 0.15, 'DateTime': 0.10
        },
        'critical_entity': 'Username',
        'penalty_factor': 0.25,
        'threshold': 0.50,
        'log_sources': ['auth.log', 'secure', 'audit.log']
    },
    
    # -------------------------------------------------------------------------
    # PERSISTENCE
    # -------------------------------------------------------------------------
    
    'H5_webshell': {
        'name': 'H5_webshell',
        'description': 'Webshell upload or execution',
        'target_labels': ['webshell'],
        'mitre_technique': 'T1505.003',
        'mitre_tactic': 'Persistence',
        'weights': {
            'FilePath': 0.30, 'Process': 0.25, 'Action': 0.20,
            'IPAddress': 0.15, 'DateTime': 0.10
        },
        'critical_entity': 'FilePath',
        'penalty_factor': 0.25,
        'threshold': 0.50,
        'log_sources': ['apache_access', 'apache_error', 'nginx']
    },
    
    # -------------------------------------------------------------------------
    # EXFILTRATION
    # -------------------------------------------------------------------------
    
    'H6_exfiltrate': {
        'name': 'H6_exfiltrate',
        'description': 'Data exfiltration over network',
        'target_labels': ['exfiltrate', 'dns_exfil'],
        'mitre_technique': 'T1048',
        'mitre_tactic': 'Exfiltration',
        'weights': {
            'IPAddress': 0.30, 'DNSName': 0.25, 'Process': 0.20,
            'Action': 0.15, 'DateTime': 0.10
        },
        'critical_entity': 'IPAddress',
        'penalty_factor': 0.20,
        'threshold': 0.55,
        'log_sources': ['dns', 'firewall', 'suricata', 'netflow']
    },
    
    'H7_dns_exfil': {
        'name': 'H7_dns_exfil',
        'description': 'DNS tunneling for exfiltration',
        'target_labels': ['dns_exfil', 'exfiltrate'],
        'mitre_technique': 'T1071.004',
        'mitre_tactic': 'Exfiltration',
        'weights': {
            'DNSName': 0.35, 'IPAddress': 0.25, 'Action': 0.20,
            'Process': 0.10, 'DateTime': 0.10
        },
        'critical_entity': 'DNSName',
        'penalty_factor': 0.20,
        'threshold': 0.55,
        'log_sources': ['dns', 'named', 'bind']
    },
    
    # -------------------------------------------------------------------------
    # LATERAL MOVEMENT
    # -------------------------------------------------------------------------
    
    'H8_lateral': {
        'name': 'H8_lateral',
        'description': 'Lateral movement between systems',
        'target_labels': ['lateral'],
        'mitre_technique': 'T1021.004',
        'mitre_tactic': 'Lateral Movement',
        'weights': {
            'IPAddress': 0.30, 'Username': 0.25, 'Process': 0.20,
            'Action': 0.15, 'DateTime': 0.10
        },
        'critical_entity': 'IPAddress',
        'penalty_factor': 0.25,
        'threshold': 0.50,
        'log_sources': ['auth.log', 'secure', 'sshd']
    },
    
    # -------------------------------------------------------------------------
    # COMMAND & CONTROL
    # -------------------------------------------------------------------------
    
    'H9_rce': {
        'name': 'H9_rce',
        'description': 'Remote command execution',
        'target_labels': ['rce', 'c2'],
        'mitre_technique': 'T1059',
        'mitre_tactic': 'Execution',
        'weights': {
            'Process': 0.30, 'Action': 0.25, 'Username': 0.20,
            'IPAddress': 0.15, 'DateTime': 0.10
        },
        'critical_entity': 'Process',
        'penalty_factor': 0.25,
        'threshold': 0.50,
        'log_sources': ['audit.log', 'syslog', 'apache_access']
    },
    
    'H10_c2': {
        'name': 'H10_c2',
        'description': 'Command and control communication',
        'target_labels': ['c2', 'rce'],
        'mitre_technique': 'T1071.001',
        'mitre_tactic': 'Command And Control',
        'weights': {
            'IPAddress': 0.30, 'DNSName': 0.25, 'Process': 0.20,
            'Action': 0.15, 'DateTime': 0.10
        },
        'critical_entity': 'IPAddress',
        'penalty_factor': 0.20,
        'threshold': 0.55,
        'log_sources': ['firewall', 'suricata', 'dns', 'proxy']
    }
}

# Print summary
print(f"\nLoaded {len(DATASET_HYPOTHESES)} dataset-specific hypotheses")
print("\nHypotheses Summary:")
print("-"*80)
for name, hyp in DATASET_HYPOTHESES.items():
    targets = ', '.join(hyp['target_labels'])
    print(f"  {name}: targets=[{targets}] | MITRE={hyp['mitre_technique']}")

## Cell 3: Thresholds and Triage Configuration

In [None]:
# =============================================================================
# THRESHOLDS AND TRIAGE PRIORITIES
# =============================================================================

# Default thresholds (will be optimized via threshold testing)
THRESHOLDS = {
    'HIGH': 0.65,
    'MEDIUM': 0.50,
    'LOW': 0.35
}

# Malicious classification threshold (to be optimized)
# Start with LOW to maximize recall, then tune
MALICIOUS_THRESHOLD = 0.35  # Will be updated after threshold testing

TRIAGE_PRIORITIES = {
    'HIGH': 'Priority 1: Investigate immediately',
    'MEDIUM': 'Priority 2: Queue for investigation',
    'LOW': 'Priority 3: Review when possible',
    'INSUFFICIENT': 'Priority 4: Archive for later'
}

# Thresholds to test
THRESHOLDS_TO_TEST = [0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60, 0.65, 0.70]

print(f"Initial malicious threshold: {MALICIOUS_THRESHOLD}")
print(f"Will test {len(THRESHOLDS_TO_TEST)} threshold values")

## Cell 4: Data Loaders

In [None]:
# =============================================================================
# DATA LOADERS
# =============================================================================

def find_log_label_pair(folder_path):
    """Find log and label files with variable naming."""
    if not os.path.isdir(folder_path):
        return None, None
    
    files = os.listdir(folder_path)
    log_files = [f for f in files if f.endswith('.log')]
    
    if 'log.log' in log_files and 'label.log' in log_files:
        return os.path.join(folder_path, 'log.log'), os.path.join(folder_path, 'label.log')
    
    for f in log_files:
        if f.startswith('log_'):
            label = f'label_{f[4:]}'
            if label in log_files:
                return os.path.join(folder_path, f), os.path.join(folder_path, label)
    
    return None, None


def scan_all_datasets(base_path):
    """Scan for all valid dataset pairs."""
    datasets = []
    for root, dirs, files in os.walk(base_path):
        log_path, label_path = find_log_label_pair(root)
        if log_path and label_path:
            datasets.append({
                'name': os.path.relpath(root, base_path),
                'folder': root,
                'log_path': log_path,
                'label_path': label_path
            })
    return datasets


def load_ground_truth(label_path):
    """Load ground truth labels."""
    ground_truth = {}
    if not os.path.exists(label_path):
        return ground_truth
    
    with open(label_path, 'r', encoding='utf-8', errors='ignore') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                entry = json.loads(line)
                line_num = entry.get('line')
                if line_num is not None:
                    ground_truth[line_num] = {
                        'labels': entry.get('labels', []),
                        'rules': entry.get('rules', {})
                    }
            except json.JSONDecodeError:
                continue
    return ground_truth


def load_raw_logs(log_path):
    """Load raw logs with 1-indexed line tracking."""
    logs = []
    with open(log_path, 'r', encoding='utf-8', errors='ignore') as f:
        for line_number, log_text in enumerate(f, 1):
            log_text = log_text.strip()
            if log_text:
                logs.append((line_number, log_text))
    return logs


def load_dataset(dataset_info):
    """Load complete dataset."""
    logs = load_raw_logs(dataset_info['log_path'])
    ground_truth = load_ground_truth(dataset_info['label_path'])
    
    # Discover all labels
    all_labels = []
    for data in ground_truth.values():
        all_labels.extend(data['labels'])
    
    stats = {
        'name': dataset_info['name'],
        'total_lines': len(logs),
        'malicious_lines': len(ground_truth),
        'benign_lines': len(logs) - len(ground_truth),
        'unique_labels': list(set(all_labels)),
        'label_counts': dict(Counter(all_labels))
    }
    
    return logs, ground_truth, stats


# Scan datasets
print("\nScanning for datasets...")
all_datasets = scan_all_datasets(DATASET_BASE_PATH)
print(f"Found {len(all_datasets)} datasets")
for ds in all_datasets[:5]:
    print(f"  - {ds['name']}")

## Cell 5: Model and Entity Extraction

In [None]:
# =============================================================================
# MODEL AND ENTITY EXTRACTION
# =============================================================================

def get_model_pipeline(model_path):
    """Load NER model."""
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForTokenClassification.from_pretrained(model_path)
    return pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")


def extract_entities(log_text, nlp_pipeline):
    """Extract entities with Physical Token Quantization."""
    raw_entities = nlp_pipeline(log_text)
    physical_tokens = [(m.group(), m.start(), m.end()) 
                       for m in re.finditer(r'[^\[\]\s]+', log_text)]
    
    if not physical_tokens:
        return {}, []
    
    token_labels = []
    for token_text, start, end in physical_tokens:
        overlapping = [e for e in raw_entities if not (e['end'] <= start or e['start'] >= end)]
        if overlapping:
            priorities = {'IPAddress': 4, 'DNSName': 3, 'Port': 2, 'Object': 1}
            best = max(overlapping, key=lambda e: (priorities.get(e['entity_group'], 0), e['score']))
            token_labels.append({'text': token_text, 'label': best['entity_group'], 
                               'confidence': best['score']})
        else:
            token_labels.append({'text': token_text, 'label': 'O', 'confidence': 1.0})
    
    # Aggregate by entity type
    entity_types = defaultdict(list)
    for token in token_labels:
        if token['label'] != 'O':
            entity_types[token['label']].append({
                'value': token['text'],
                'confidence': token['confidence']
            })
    
    return dict(entity_types), token_labels


# Load model
print(f"\nLoading model: {SELECTED_MODEL}...")
nlp = get_model_pipeline(MODELS[SELECTED_MODEL])
print("Model loaded")

## Cell 6: P_Score Calculation (Option A & B3)

In [None]:
# =============================================================================
# P_SCORE CALCULATION
# =============================================================================

def calculate_pscore_option_a(entity_types, hypothesis):
    """P_Score using BINARY entity presence."""
    weights = hypothesis.get('weights', {})
    critical_entity = hypothesis.get('critical_entity', 'Process')
    penalty_factor = hypothesis.get('penalty_factor', 0.20)
    
    weighted_sum = 0.0
    for entity_type, weight in weights.items():
        if entity_type in entity_types and entity_types[entity_type]:
            weighted_sum += weight
    
    critical_present = critical_entity in entity_types and len(entity_types.get(critical_entity, [])) > 0
    p_score = weighted_sum if critical_present else weighted_sum * (1 - penalty_factor)
    
    return round(p_score, 4), critical_present


def calculate_pscore_option_b3(entity_types, hypothesis):
    """P_Score using CONFIDENCE-WEIGHTED entity presence."""
    weights = hypothesis.get('weights', {})
    critical_entity = hypothesis.get('critical_entity', 'Process')
    penalty_factor = hypothesis.get('penalty_factor', 0.20)
    
    weighted_sum = 0.0
    for entity_type, weight in weights.items():
        if entity_type in entity_types and entity_types[entity_type]:
            confidences = [e['confidence'] for e in entity_types[entity_type]]
            avg_confidence = sum(confidences) / len(confidences)
            weighted_sum += weight * avg_confidence
    
    critical_present = critical_entity in entity_types and len(entity_types.get(critical_entity, [])) > 0
    p_score = weighted_sum if critical_present else weighted_sum * (1 - penalty_factor)
    
    return round(p_score, 4), critical_present


def get_confidence_level(p_score):
    """Get confidence level from P_Score."""
    if p_score >= THRESHOLDS['HIGH']:
        return 'HIGH'
    elif p_score >= THRESHOLDS['MEDIUM']:
        return 'MEDIUM'
    elif p_score >= THRESHOLDS['LOW']:
        return 'LOW'
    return 'INSUFFICIENT'


print("P_Score calculation functions defined")

## Cell 7: Hypothesis Matching and Scoring

In [None]:
# =============================================================================
# HYPOTHESIS MATCHING AND MULTI-HYPOTHESIS SCORING
# =============================================================================

def check_hypothesis_label_match(hypothesis_name, ground_truth_labels):
    """Check if hypothesis targets match ground truth labels."""
    if hypothesis_name not in DATASET_HYPOTHESES:
        return False
    
    target_labels = DATASET_HYPOTHESES[hypothesis_name].get('target_labels', [])
    
    for gt_label in ground_truth_labels:
        gt_lower = gt_label.lower().strip()
        for target in target_labels:
            target_lower = target.lower().strip()
            if gt_lower == target_lower or gt_lower in target_lower or target_lower in gt_lower:
                return True
    return False


def score_all_hypotheses(entity_types, method='option_a'):
    """Score against ALL hypotheses, return best match."""
    score_func = calculate_pscore_option_a if method == 'option_a' else calculate_pscore_option_b3
    
    scores = {}
    for name, hyp in DATASET_HYPOTHESES.items():
        p_score, critical_present = score_func(entity_types, hyp)
        scores[name] = {
            'p_score': p_score,
            'critical_present': critical_present,
            'confidence_level': get_confidence_level(p_score)
        }
    
    # Rank by score
    ranking = sorted(scores.items(), key=lambda x: x[1]['p_score'], reverse=True)
    
    if ranking:
        top_name, top_data = ranking[0]
        return {
            'top_hypothesis': top_name,
            'p_score': top_data['p_score'],
            'confidence_level': top_data['confidence_level'],
            'critical_present': top_data['critical_present'],
            'all_scores': scores,
            'ranking': [(n, s['p_score']) for n, s in ranking[:5]]
        }
    
    return {
        'top_hypothesis': None,
        'p_score': 0.0,
        'confidence_level': 'INSUFFICIENT',
        'critical_present': False,
        'all_scores': {},
        'ranking': []
    }


def process_log_complete(line_number, log_text, nlp_pipeline, method='option_a', 
                         malicious_threshold=MALICIOUS_THRESHOLD):
    """Complete processing for a single log line."""
    entity_types, raw_tokens = extract_entities(log_text, nlp_pipeline)
    scoring = score_all_hypotheses(entity_types, method)
    
    return {
        'line_number': line_number,
        'log_text': log_text[:100] + '...' if len(log_text) > 100 else log_text,
        'entity_types': entity_types,
        'top_hypothesis': scoring['top_hypothesis'],
        'p_score': scoring['p_score'],
        'confidence_level': scoring['confidence_level'],
        'is_malicious': scoring['p_score'] >= malicious_threshold,
        'triage_priority': TRIAGE_PRIORITIES.get(scoring['confidence_level'], 'Unknown'),
        'method': method
    }


print("Hypothesis matching functions defined")

## Cell 8: Threshold Testing

In [None]:
# =============================================================================
# THRESHOLD TESTING
# =============================================================================

def evaluate_threshold(predictions, ground_truth, threshold):
    """Evaluate performance at a specific threshold."""
    tp = fp = tn = fn = 0
    
    for pred in predictions:
        line_num = pred['line_number']
        p_score = pred['p_score']
        predicted_malicious = p_score >= threshold
        actual_malicious = line_num in ground_truth
        
        if predicted_malicious and actual_malicious:
            tp += 1
        elif predicted_malicious and not actual_malicious:
            fp += 1
        elif not predicted_malicious and not actual_malicious:
            tn += 1
        else:
            fn += 1
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return {
        'threshold': threshold,
        'tp': tp, 'fp': fp, 'tn': tn, 'fn': fn,
        'precision': round(precision, 4),
        'recall': round(recall, 4),
        'f1_score': round(f1, 4),
        'fnr': round(fn / (fn + tp), 4) if (fn + tp) > 0 else 0.0
    }


def find_optimal_threshold(predictions, ground_truth, min_recall=0.90):
    """Find optimal threshold prioritizing HIGH RECALL."""
    results = []
    for threshold in THRESHOLDS_TO_TEST:
        results.append(evaluate_threshold(predictions, ground_truth, threshold))
    
    # Find thresholds meeting recall requirement
    meeting_recall = [r for r in results if r['recall'] >= min_recall]
    
    if meeting_recall:
        optimal = max(meeting_recall, key=lambda x: x['precision'])
        method = f"Highest precision with recall >= {min_recall}"
    else:
        optimal = max(results, key=lambda x: x['recall'])
        method = "Highest recall available"
    
    best_f1 = max(results, key=lambda x: x['f1_score'])
    
    return {
        'optimal_for_recall': optimal,
        'selection_method': method,
        'best_f1': best_f1,
        'all_results': results
    }


def print_threshold_results(analysis):
    """Print threshold analysis."""
    print("\n" + "="*90)
    print("THRESHOLD ANALYSIS (Goal: HIGH RECALL)")
    print("="*90)
    print(f"{'Threshold':<12} {'Precision':<12} {'Recall':<12} {'F1':<12} {'FN (Missed)':<12}")
    print("-"*90)
    
    for r in analysis['all_results']:
        marker = " <- OPTIMAL" if r['threshold'] == analysis['optimal_for_recall']['threshold'] else ""
        marker = " <- BEST F1" if r['threshold'] == analysis['best_f1']['threshold'] and not marker else marker
        print(f"{r['threshold']:<12} {r['precision']:<12} {r['recall']:<12} {r['f1_score']:<12} {r['fn']:<12}{marker}")
    
    print("\n" + "="*90)
    opt = analysis['optimal_for_recall']
    print(f"RECOMMENDED: Threshold={opt['threshold']} | Recall={opt['recall']} | Precision={opt['precision']}")
    print(f"Selection: {analysis['selection_method']}")
    print(f"False Negatives (MISSED EVIDENCE): {opt['fn']}")


print("Threshold testing functions defined")

## Cell 9: Validation with Hypothesis Matching

In [None]:
# =============================================================================
# VALIDATION WITH HYPOTHESIS MATCHING
# =============================================================================

def validate_binary(predictions, ground_truth, threshold):
    """Binary validation at specific threshold."""
    return evaluate_threshold(predictions, ground_truth, threshold)


def validate_two_stage(predictions, ground_truth, threshold):
    """Two-stage: detection + hypothesis matching."""
    binary = evaluate_threshold(predictions, ground_truth, threshold)
    
    hypothesis_correct = 0
    hypothesis_total = 0
    
    for pred in predictions:
        line_num = pred['line_number']
        p_score = pred['p_score']
        
        if p_score >= threshold and line_num in ground_truth:
            hypothesis_total += 1
            gt_labels = ground_truth[line_num]['labels']
            
            if check_hypothesis_label_match(pred['top_hypothesis'], gt_labels):
                hypothesis_correct += 1
    
    hypothesis_accuracy = hypothesis_correct / hypothesis_total if hypothesis_total > 0 else 0.0
    
    return {
        'stage_a': binary,
        'stage_b': {
            'correct': hypothesis_correct,
            'total': hypothesis_total,
            'accuracy': round(hypothesis_accuracy, 4)
        },
        'combined_score': round(binary['f1_score'] * hypothesis_accuracy, 4)
    }


print("Validation functions defined")

## Cell 10: Complete Validation Pipeline

In [None]:
# =============================================================================
# COMPLETE VALIDATION PIPELINE
# =============================================================================

def run_complete_validation(dataset_info, nlp_pipeline, methods=['option_a', 'option_b3']):
    """Run complete FTE-HARM validation with threshold testing."""
    
    print(f"\n{'='*80}")
    print(f"FTE-HARM COMPLETE VALIDATION")
    print(f"Dataset: {dataset_info['name']}")
    print(f"{'='*80}")
    
    # Load dataset
    logs, ground_truth, stats = load_dataset(dataset_info)
    print(f"\nLoaded: {stats['total_lines']} lines ({stats['malicious_lines']} malicious)")
    print(f"Labels found: {stats['unique_labels']}")
    
    results = {
        'dataset': dataset_info['name'],
        'stats': stats,
        'timestamp': datetime.now().isoformat()
    }
    
    for method in methods:
        print(f"\n--- Processing with {method.upper()} ---")
        
        # Process all logs
        predictions = []
        for i, (line_num, log_text) in enumerate(logs):
            pred = process_log_complete(line_num, log_text, nlp_pipeline, method)
            predictions.append(pred)
            if (i + 1) % 500 == 0:
                print(f"  Processed: {i+1}/{len(logs)}")
        
        print(f"Processed {len(predictions)} logs")
        
        # Threshold analysis
        print("\n  Running threshold analysis...")
        threshold_analysis = find_optimal_threshold(predictions, ground_truth)
        print_threshold_results(threshold_analysis)
        
        # Get optimal threshold
        optimal_threshold = threshold_analysis['optimal_for_recall']['threshold']
        
        # Run validations at optimal threshold
        binary = validate_binary(predictions, ground_truth, optimal_threshold)
        two_stage = validate_two_stage(predictions, ground_truth, optimal_threshold)
        
        results[method] = {
            'predictions': predictions,
            'threshold_analysis': threshold_analysis,
            'optimal_threshold': optimal_threshold,
            'binary_validation': binary,
            'two_stage_validation': two_stage
        }
        
        print(f"\n  Results at optimal threshold ({optimal_threshold}):")
        print(f"    Precision: {binary['precision']}")
        print(f"    Recall: {binary['recall']} <- PRIMARY METRIC")
        print(f"    F1: {binary['f1_score']}")
        print(f"    Hypothesis Accuracy: {two_stage['stage_b']['accuracy']}")
    
    return results


print("Complete validation pipeline defined")

## Cell 11: Save Results with MITRE Corroboration Table

In [None]:
# =============================================================================
# SAVE RESULTS WITH MITRE CORROBORATION TABLE
# =============================================================================

def save_complete_results(results, output_path=SUMMARY_PATH):
    """Save all results including MITRE corroboration table."""
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    dataset_name = results['dataset'].replace('/', '_')
    
    # Create dataset folder
    folder = os.path.join(output_path, dataset_name)
    os.makedirs(folder, exist_ok=True)
    
    for method in ['option_a', 'option_b3']:
        if method not in results:
            continue
        
        data = results[method]
        
        # Main results file
        result_path = os.path.join(folder, f'{method}_results_{timestamp}.txt')
        with open(result_path, 'w') as f:
            f.write("="*80 + "\n")
            f.write(f"FTE-HARM VALIDATION RESULTS\n")
            f.write("="*80 + "\n\n")
            f.write(f"Dataset: {results['dataset']}\n")
            f.write(f"Method: {method}\n")
            f.write(f"Timestamp: {results['timestamp']}\n")
            f.write(f"Optimal Threshold: {data['optimal_threshold']}\n\n")
            
            f.write("VALIDATION GOAL: HIGH RECALL (Forensic Triage Priority)\n\n")
            
            binary = data['binary_validation']
            f.write("BINARY VALIDATION:\n")
            f.write(f"  TP: {binary['tp']}, FP: {binary['fp']}, TN: {binary['tn']}, FN: {binary['fn']}\n")
            f.write(f"  Precision: {binary['precision']}\n")
            f.write(f"  Recall: {binary['recall']} <- PRIMARY METRIC\n")
            f.write(f"  F1: {binary['f1_score']}\n\n")
            
            two_stage = data['two_stage_validation']
            f.write("TWO-STAGE VALIDATION:\n")
            f.write(f"  Detection F1: {two_stage['stage_a']['f1_score']}\n")
            f.write(f"  Hypothesis Correct: {two_stage['stage_b']['correct']}/{two_stage['stage_b']['total']}\n")
            f.write(f"  Hypothesis Accuracy: {two_stage['stage_b']['accuracy']}\n")
    
    # MITRE Corroboration Table
    mitre_path = os.path.join(folder, f'mitre_corroboration_table_{timestamp}.txt')
    with open(mitre_path, 'w') as f:
        f.write("="*100 + "\n")
        f.write("MITRE ATT&CK CORROBORATION TABLE\n")
        f.write("Dataset-Specific Labels -> Standardised Threat Taxonomy\n")
        f.write("="*100 + "\n\n")
        
        f.write(f"{'AITv2 Label':<25} {'Hypothesis':<30} {'MITRE ID':<15} {'MITRE Tactic':<25}\n")
        f.write("-"*100 + "\n")
        
        # Group by target label
        label_to_hyp = defaultdict(list)
        for name, hyp in DATASET_HYPOTHESES.items():
            for label in hyp['target_labels']:
                label_to_hyp[label].append(hyp)
        
        for label in sorted(label_to_hyp.keys()):
            for hyp in label_to_hyp[label]:
                f.write(f"{label:<25} {hyp['name']:<30} {hyp['mitre_technique']:<15} {hyp['mitre_tactic']:<25}\n")
        
        f.write("\n" + "="*100 + "\n")
        f.write("This table demonstrates corroboration between dataset-specific attack labels\n")
        f.write("and the standardised MITRE ATT&CK threat taxonomy.\n")
    
    # Threshold test results
    threshold_path = os.path.join(THRESHOLD_TEST_PATH, f'threshold_test_{dataset_name}_{timestamp}.txt')
    with open(threshold_path, 'w') as f:
        f.write("THRESHOLD TEST RESULTS\n")
        f.write("="*80 + "\n")
        f.write(f"Dataset: {results['dataset']}\n\n")
        
        for method in ['option_a', 'option_b3']:
            if method not in results:
                continue
            analysis = results[method]['threshold_analysis']
            f.write(f"\n{method.upper()}:\n")
            f.write("-"*80 + "\n")
            f.write(f"{'Threshold':<12} {'Precision':<12} {'Recall':<12} {'F1':<12} {'FN':<12}\n")
            for r in analysis['all_results']:
                f.write(f"{r['threshold']:<12} {r['precision']:<12} {r['recall']:<12} {r['f1_score']:<12} {r['fn']:<12}\n")
            
            opt = analysis['optimal_for_recall']
            f.write(f"\nOptimal: {opt['threshold']} (Recall={opt['recall']}, FN={opt['fn']})\n")
    
    print(f"\nResults saved to: {folder}")
    print(f"Threshold test saved to: {threshold_path}")
    return folder


# Run on first dataset
if all_datasets:
    selected = all_datasets[0]
    print(f"\nSelected dataset: {selected['name']}")
    
    results = run_complete_validation(selected, nlp)
    save_complete_results(results)

## Cell 12: Run All Datasets

In [None]:
# =============================================================================
# RUN ALL DATASETS
# =============================================================================

def run_all_datasets(datasets, nlp_pipeline):
    """Process all datasets."""
    all_results = []
    
    for i, ds in enumerate(datasets, 1):
        print(f"\n[{i}/{len(datasets)}] {ds['name']}")
        try:
            results = run_complete_validation(ds, nlp_pipeline)
            save_complete_results(results)
            all_results.append({'dataset': ds['name'], 'status': 'success', 'results': results})
        except Exception as e:
            print(f"  Error: {e}")
            all_results.append({'dataset': ds['name'], 'status': 'error', 'error': str(e)})
    
    return all_results

# Uncomment to run all:
# all_results = run_all_datasets(all_datasets, nlp)

---

## Checklist

- [ ] Model loaded
- [ ] Datasets discovered
- [ ] Threshold testing completed
- [ ] Optimal threshold identified (prioritizing RECALL)
- [ ] Hypothesis matching validated
- [ ] MITRE corroboration table generated
- [ ] Results saved to summary folder