In [1]:
import os
import json
import pandas as pd
import numpy as np
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification,
    Trainer, 
    TrainingArguments
)
from datasets import Dataset
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
import logging
import joblib

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set paths to your AnnoCTR folders
BASE_PATH = "anno-ctr-lrec-coling-2024/AnnoCTR"
TEXT_PATH = os.path.join(BASE_PATH, "text")
NER_PATH = os.path.join(BASE_PATH, "ner_json")
LINKING_PATH = os.path.join(BASE_PATH, "linking")

# Define the splits
SPLITS = ["train", "dev", "test"]

# Define threat classes (only keep phishing, APT, other, ransomware)
THREAT_CLASSES = ['phishing', 'APT', 'other', 'ransomware']

# Load and process data for threat classification
def load_data_for_threat_classification():
    datasets = {}
    
    for split in SPLITS:
        text_path = os.path.join(TEXT_PATH, split)
        ner_path = os.path.join(NER_PATH, split)
        
        # Check if directories exist
        if not os.path.exists(text_path):
            print(f"Warning: Text path {text_path} does not exist")
            continue
        
        print(f"Processing {split} split...")
        print(f"Text directory: {text_path}")
        
        texts = []
        threat_labels = []
        
        # Count files for progress reporting
        txt_files = [f for f in os.listdir(text_path) if f.endswith('.txt')]
        print(f"Found {len(txt_files)} text files in {split} split")
        
        for filename in txt_files:
            # Load text
            txt_file = os.path.join(text_path, filename)
            try:
                with open(txt_file, 'r', encoding='utf-8') as f:
                    raw_text = f.read()
            except Exception as e:
                print(f"Error reading {txt_file}: {e}")
                continue
            
            # Load NER entities to help determine threat type
            ner_file = os.path.join(ner_path, filename.replace('.txt', '.json'))
            entity_types = set()
            
            if os.path.exists(ner_file):
                try:
                    with open(ner_file, 'r', encoding='utf-8') as f:
                        ner_data_file = json.load(f)
                        
                        for entity in ner_data_file.get('entities', []):
                            entity_type = entity.get('label', '').lower()
                            entity_types.add(entity_type)
                except Exception as e:
                    print(f"Error processing NER file {ner_file}: {e}")
            
            # Determine threat type based on entity types and text content
            threat_type = 'other'  # Default
            
            # Simple heuristic: check for keywords in text and entity types
            text_lower = raw_text.lower()
            # Updated to only check for the classes we want
            if 'phishing' in text_lower:
                threat_type = 'phishing'
            elif 'apt' in entity_types or 'APT' in raw_text:
                threat_type = 'APT'
            elif 'ransomware' in text_lower:
                threat_type = 'ransomware'
            # Remove checks for spyware, botnet, and exploit classes
            
            texts.append(raw_text)
            threat_labels.append(threat_type)
        
        # Print stats
        print(f"Processed {len(texts)} texts for threat classification")
        
        # Verify we have data
        if not texts:
            print(f"Warning: No text data found for {split} split")
            continue
        
        # Create threat classification dataframe
        threat_df = pd.DataFrame({
            'text': texts,
            'label': threat_labels
        })
        
        datasets[split] = threat_df
    
    return datasets

# Fixed: Updated WeightedLossTrainer class to handle both cases with and without num_items_in_batch
class WeightedLossTrainer(Trainer):
    def __init__(self, class_weights=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights
        
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        
        if self.class_weights is not None:
            # Explicitly convert to float32 to match the expected datatype
            weight = torch.tensor(self.class_weights, device=logits.device, dtype=torch.float32)
            loss_fct = torch.nn.CrossEntropyLoss(weight=weight)
        else:
            loss_fct = torch.nn.CrossEntropyLoss()
            
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        
        return (loss, outputs) if return_outputs else loss

def train_threat_classifier(datasets, model_name='bert-base-uncased'):
    """
    Train a threat classification model using the provided datasets
    """
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Initialize label encoder
    threat_encoder = LabelEncoder()
    threat_encoder.fit(THREAT_CLASSES)
    
    # Save encoder
    os.makedirs("./threat_model_v2", exist_ok=True)
    joblib.dump(threat_encoder, "./threat_model_v2/threat_encoder.joblib")
    
    # Process datasets for threat classification
    processed_datasets = {}
    
    # Set the device to CPU explicitly if needed due to dtype issues
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    for split, threat_df in datasets.items():
        threat_labels = threat_encoder.transform(threat_df['label'])
        
        # Print distribution
        print(f"\n{split} threat class distribution:")
        print(pd.Series(threat_df['label']).value_counts())
        
        # Create dataset
        examples = []
        
        for i in range(len(threat_df)):
            text = threat_df['text'].iloc[i]
            label = int(threat_labels[i])
            
            # Tokenize
            # Tokenize with explicit float32 conversion
            tokenized = tokenizer(
                text,
                truncation=True,
                max_length=512,
                padding="max_length",
                return_tensors="pt"
            )
            
            # Ensure attention mask is properly formatted
            attention_mask = tokenized["attention_mask"][0].tolist()
            
            examples.append({
                "input_ids": tokenized["input_ids"][0].tolist(),
                "attention_mask": tokenized["attention_mask"][0].tolist(),
                "labels": label
            })
        
        processed_datasets[split] = Dataset.from_list(examples)
        print(f"Processed {len(examples)} examples for {split}")
    
    # Calculate class weights for balancing
    train_labels = datasets['train']['label'].values
    # Make sure all classes are represented in the weight calculation
    unique_classes = np.array(THREAT_CLASSES)
    train_label_counts = pd.Series(train_labels).value_counts().reindex(unique_classes, fill_value=0)
    
    # Manually calculate weights - inverse of frequency
    total_samples = len(train_labels)
    class_weights = {}
    for i, cls in enumerate(THREAT_CLASSES):
        count = train_label_counts[cls]
        if count > 0:
            weight = total_samples / (len(THREAT_CLASSES) * count)
        else:
            weight = 1.0  # Default weight for classes not in training set
        class_weights[i] = weight
    
    # Convert to list in the order of class indices
    weight_list = [class_weights[i] for i in range(len(THREAT_CLASSES))]
    
    print("Class weights to balance the dataset:")
    for i, weight in enumerate(weight_list):
        print(f"  {THREAT_CLASSES[i]}: {weight:.4f}")
    
    # Initialize model
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=len(THREAT_CLASSES)
    )
    
    # Define training arguments with backward compatibility
    try:
        # Option 1: Try with simpler arguments first
        training_args = TrainingArguments(
            output_dir="./threat_model_v2",
            num_train_epochs=3,
            per_device_train_batch_size=8,
            per_device_eval_batch_size=8,
            learning_rate=2e-5,
            weight_decay=0.01,
            logging_dir="./threat_model_v2/logs",
            logging_steps=100
        )
    except Exception as e:
        print(f"First training args approach failed: {e}")
        # Option 2: Try with explicit evaluation but no best model loading
        training_args = TrainingArguments(
            output_dir="./threat_model_v2",
            num_train_epochs=3,
            per_device_train_batch_size=8,
            per_device_eval_batch_size=8,
            learning_rate=2e-5,
            weight_decay=0.01,
            eval_steps=100,
            save_steps=100,
            logging_steps=100
        )
    
    # Define compute_metrics
    def compute_threat_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = np.argmax(predictions, axis=1)
        
        accuracy = (predictions == labels).mean()
        
        # Add class-wise metrics
        class_metrics = {}
        for i, class_name in enumerate(THREAT_CLASSES):
            # Calculate precision and recall for each class
            class_preds = (predictions == i)
            class_labels = (labels == i)
            
            true_pos = (class_preds & class_labels).sum()
            false_pos = (class_preds & ~class_labels).sum()
            false_neg = (~class_preds & class_labels).sum()
            
            precision = true_pos / (true_pos + false_pos) if (true_pos + false_pos) > 0 else 0
            recall = true_pos / (true_pos + false_neg) if (true_pos + false_neg) > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
            
            class_metrics[f"{class_name}_precision"] = precision
            class_metrics[f"{class_name}_recall"] = recall
            class_metrics[f"{class_name}_f1"] = f1
        
        # Calculate macro F1 (average of F1 scores across all classes)
        f1_scores = [class_metrics[f"{class_name}_f1"] for class_name in THREAT_CLASSES]
        macro_f1 = sum(f1_scores) / len(f1_scores)
        
        return {
            "accuracy": accuracy,
            "macro_f1": macro_f1,
            **class_metrics
        }
    
    # Initialize Weighted Trainer
    trainer = WeightedLossTrainer(
        class_weights=weight_list,
        model=model,
        args=training_args,
        train_dataset=processed_datasets["train"],
        eval_dataset=processed_datasets["dev"],
        compute_metrics=compute_threat_metrics
    )
    
    # Train the model
    print("Starting threat classifier training...")
    trainer.train()
    
    # Evaluate on test set
    if "test" in processed_datasets:
        test_results = trainer.evaluate(processed_datasets["test"])
        print(f"Test Results: {test_results}")
    
    # Save model and tokenizer
    model.save_pretrained("./threat_model_v2")
    tokenizer.save_pretrained("./threat_model_v2")
    
    return model, tokenizer, threat_encoder

def predict_threat(text, model=None, tokenizer=None, threat_encoder=None):
    """
    Make a threat classification prediction for a single text
    
    Args:
        text (str): The cyber threat report text to classify
        model: The trained threat classification model (will load if None)
        tokenizer: The tokenizer for the model (will load if None)
        threat_encoder: The LabelEncoder for mapping class indices to names (will load if None)
    
    Returns:
        dict: Prediction results with class probabilities
    """
    # Load model components if not provided
    if model is None or tokenizer is None or threat_encoder is None:
        model_path = "./threat_model_v2"
        
        if not os.path.exists(model_path):
            raise ValueError("Model directory not found. Please train the model first.")
        
        if model is None:
            model = AutoModelForSequenceClassification.from_pretrained(model_path)
        
        if tokenizer is None:
            tokenizer = AutoTokenizer.from_pretrained(model_path)
        
        if threat_encoder is None:
            encoder_path = os.path.join(model_path, "threat_encoder.joblib")
            if not os.path.exists(encoder_path):
                raise ValueError("Threat encoder not found. Please train the model first.")
            threat_encoder = joblib.load(encoder_path)
    
    # Tokenize input
    inputs = tokenizer(
        text,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    )
    
    # Get model predictions
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Process logits
    logits = outputs.logits
    probabilities = torch.nn.functional.softmax(logits, dim=1)[0]
    
    # Get prediction
    predicted_class_idx = torch.argmax(logits, dim=1).item()
    predicted_class = threat_encoder.inverse_transform([predicted_class_idx])[0]
    confidence = probabilities[predicted_class_idx].item()
    
    # Prepare detailed results
    class_probabilities = {
        threat_encoder.inverse_transform([i])[0]: prob.item()
        for i, prob in enumerate(probabilities)
    }
    
    # Sort classes by probability (descending)
    sorted_probs = sorted(
        class_probabilities.items(),
        key=lambda x: x[1],
        reverse=True
    )
    
    return {
        "predicted_class": predicted_class,
        "confidence": confidence,
        "class_probabilities": class_probabilities,
        "top_classes": sorted_probs[:3]  # Top 3 most likely classes
    }

def main():
    """
    Main function to run the threat classification pipeline
    """
    print("Loading data for cyber threat classification...")
    try:
        datasets = load_data_for_threat_classification()
        print("Data loaded successfully!")
    except Exception as e:
        print(f"Error loading data: {e}")
        import traceback
        traceback.print_exc()
        return
    
    print("Training threat classifier...")
    try:
        model, tokenizer, threat_encoder = train_threat_classifier(datasets)
        print("Threat classifier training complete!")
        
        # Example prediction
        example_text = """
        A new ransomware campaign has been detected targeting healthcare organizations.
        The malware encrypts critical patient data and demands payment in cryptocurrency.
        Initial infection vectors include phishing emails with malicious attachments and
        exploitation of vulnerabilities in outdated VPN software.
        """
        
        result = predict_threat(example_text, model, tokenizer, threat_encoder)
        
        print("\nExample Prediction Results:")
        print(f"Text: {example_text[:100]}...")
        print(f"Predicted Threat Class: {result['predicted_class']}")
        print(f"Confidence: {result['confidence']:.4f}")
        
        print("\nTop 3 Classes:")
        for cls, prob in result['top_classes']:
            print(f"  {cls}: {prob:.4f}")
            
        return model, tokenizer, threat_encoder
    
    except Exception as e:
        print(f"Error during training or prediction: {e}")
        import traceback
        traceback.print_exc()
        return

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm



Loading data for cyber threat classification...
Processing train split...
Text directory: anno-ctr-lrec-coling-2024/AnnoCTR\text\train
Found 70 text files in train split
Processed 70 texts for threat classification
Processing dev split...
Text directory: anno-ctr-lrec-coling-2024/AnnoCTR\text\dev
Found 16 text files in dev split
Processed 16 texts for threat classification
Processing test split...
Text directory: anno-ctr-lrec-coling-2024/AnnoCTR\text\test
Found 34 text files in test split
Processed 34 texts for threat classification
Data loaded successfully!
Training threat classifier...

train threat class distribution:
label
phishing      23
other         19
APT           16
ransomware    12
Name: count, dtype: int64
Processed 70 examples for train

dev threat class distribution:
label
other         4
ransomware    4
APT           4
phishing      4
Name: count, dtype: int64
Processed 16 examples for dev

test threat class distribution:
label
phishing      13
ransomware    11
APT   

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Starting threat classifier training...


Step,Training Loss


Test Results: {'eval_loss': 1.323173999786377, 'eval_accuracy': 0.3235294117647059, 'eval_macro_f1': 0.20730994152046783, 'eval_phishing_precision': 0, 'eval_phishing_recall': 0.0, 'eval_phishing_f1': 0, 'eval_APT_precision': 0.2, 'eval_APT_recall': 0.25, 'eval_APT_f1': 0.22222222222222224, 'eval_other_precision': 0.5, 'eval_other_recall': 0.07692307692307693, 'eval_other_f1': 0.13333333333333336, 'eval_ransomware_precision': 0.3333333333333333, 'eval_ransomware_recall': 0.8181818181818182, 'eval_ransomware_f1': 0.4736842105263157, 'eval_runtime': 13.0445, 'eval_samples_per_second': 2.606, 'eval_steps_per_second': 0.383, 'epoch': 3.0}
Threat classifier training complete!

Example Prediction Results:
Text: 
        A new ransomware campaign has been detected targeting healthcare organizations.
        The...
Predicted Threat Class: ransomware
Confidence: 0.3823

Top 3 Classes:
  ransomware: 0.3823
  APT: 0.2196
  other: 0.2124
